from abc import ABC, abstractmethod import os import logging import random import requests from typing import Dict, Any from openai import OpenAI from dotenv import load_dotenv load_dotenv() class BaseProvider(ABC): """Abstract base class for all AI providers""" @abstractmethod def generate_response(self, system_role: str, prompt: str) -> Dict[str, Any]: """ Generate a response from the AI provider Returns: Dict containing: - response: str - The generated text - usage: Dict with prompt_tokens, completion_tokens, total_tokens """ pass class OpenAIProvider(BaseProvider): """OpenAI API provider""" def __init__(self): api_key = os.getenv("OPENAI_API_KEY") if not api_key: raise RuntimeError("OPENAI_API_KEY is missing in .env") self.client = OpenAI(api_key=api_key) self.model = os.getenv("OPENAI_MODEL", "gpt-3.5-turbo") def generate_response(self, system_role: str, prompt: str) -> Dict[str, Any]: response = self.client.chat.completions.create( model=self.model, messages=[ {"role": "system", "content": system_role}, {"role": "user", "content": prompt} ], timeout=10 ) return { "response": response.choices[0].message.content, "usage": { "prompt_tokens": response.usage.prompt_tokens, "completion_tokens": response.usage.completion_tokens, "total_tokens": response.usage.total_tokens } } class OllamaProvider(BaseProvider): """Ollama API provider for local models""" def __init__(self): self.base_url = os.getenv("OLLAMA_URL", "http://localhost:11434") self.model = os.getenv("OLLAMA_MODEL", "llama2") self.temperature = float(os.getenv("OLLAMA_TEMPERATURE", "0.7")) self.timeout = int(os.getenv("OLLAMA_TIMEOUT", "60")) def generate_response(self, system_role: str, prompt: str) -> Dict[str, Any]: try: # Ollama chat completions API (similar to OpenAI) response = requests.post( f"{self.base_url}/api/chat", json={ "model": self.model, "messages": [ {"role": "system", "content": system_role}, {"role": "user", "content": prompt} ], "stream": False, "options": { "temperature": self.temperature } }, timeout=self.timeout ) response.raise_for_status() result = response.json() message = result.get("message", {}) generated_text = message.get("content", "") # Extract token usage if available, otherwise estimate eval_count = result.get("eval_count", 0) prompt_eval_count = result.get("prompt_eval_count", 0) if eval_count == 0 or prompt_eval_count == 0: # Estimate if not provided prompt_tokens = len(system_role.split()) + len(prompt.split()) completion_tokens = len(generated_text.split()) else: prompt_tokens = prompt_eval_count completion_tokens = eval_count return { "response": generated_text, "usage": { "prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens, "total_tokens": prompt_tokens + completion_tokens } } except requests.exceptions.ConnectionError: logging.error(f"Could not connect to Ollama at {self.base_url}. Is Ollama running?") raise RuntimeError(f"Ollama connection failed. Ensure Ollama is running at {self.base_url}") except Exception as e: logging.error(f"Ollama error: {e}") raise class LocalModelProvider(BaseProvider): """Generic local model provider (e.g., llama.cpp, text-generation-webui, etc.)""" def __init__(self): self.base_url = os.getenv("LOCAL_MODEL_URL", "http://localhost:5000") self.api_path = os.getenv("LOCAL_MODEL_API_PATH", "/v1/completions") self.model = os.getenv("LOCAL_MODEL_NAME", "local-model") def generate_response(self, system_role: str, prompt: str) -> Dict[str, Any]: full_prompt = f"{system_role}\n\nUser: {prompt}\nAssistant:" try: response = requests.post( f"{self.base_url}{self.api_path}", json={ "prompt": full_prompt, "max_tokens": 500, "temperature": 0.7 }, timeout=30 ) response.raise_for_status() result = response.json() generated_text = result.get("choices", [{}])[0].get("text", "") # Estimate token usage prompt_tokens = len(full_prompt.split()) * 1.3 completion_tokens = len(generated_text.split()) * 1.3 return { "response": generated_text.strip(), "usage": { "prompt_tokens": int(prompt_tokens), "completion_tokens": int(completion_tokens), "total_tokens": int(prompt_tokens + completion_tokens) } } except Exception as e: logging.error(f"Local model error: {e}") raise class LoremIpsumProvider(BaseProvider): """Lorem Ipsum generator for testing""" def __init__(self): self.lorem_words = [ "lorem", "ipsum", "dolor", "sit", "amet", "consectetur", "adipiscing", "elit", "sed", "do", "eiusmod", "tempor", "incididunt", "ut", "labore", "et", "dolore", "magna", "aliqua", "enim", "ad", "minim", "veniam", "quis", "nostrud", "exercitation", "ullamco", "laboris", "nisi", "aliquip", "ex", "ea", "commodo", "consequat", "duis", "aute", "irure", "in", "reprehenderit", "voluptate", "velit", "esse", "cillum", "fugiat", "nulla", "pariatur", "excepteur", "sint", "occaecat", "cupidatat", "non", "proident", "sunt", "culpa", "qui", "officia", "deserunt", "mollit", "anim", "id", "est", "laborum" ] def generate_response(self, system_role: str, prompt: str) -> Dict[str, Any]: # Generate random lorem ipsum text word_count = random.randint(50, 200) words = [] for i in range(word_count): word = random.choice(self.lorem_words) # Capitalize first word of sentence if i == 0 or (i > 0 and words[-1].endswith('.')): word = word.capitalize() words.append(word) # Add punctuation if random.random() > 0.85: words[-1] += random.choice(['.', ',', ';']) # Ensure last word has period if not words[-1].endswith('.'): words[-1] += '.' response_text = ' '.join(words) # Calculate token usage prompt_tokens = len(system_role.split()) + len(prompt.split()) completion_tokens = len(response_text.split()) return { "response": response_text, "usage": { "prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens, "total_tokens": prompt_tokens + completion_tokens } } class ProviderFactory: """Factory for creating AI providers""" _providers = { "openai": OpenAIProvider, "ollama": OllamaProvider, "local": LocalModelProvider, "lorem": LoremIpsumProvider } @classmethod def get_provider(cls, provider_name: str = None) -> BaseProvider: """Get a provider instance by name""" if provider_name is None: provider_name = os.getenv("AI_PROVIDER", "lorem") provider_name = provider_name.lower() if provider_name not in cls._providers: raise ValueError(f"Unknown provider: {provider_name}. Available: {list(cls._providers.keys())}") return cls._providers[provider_name]() @classmethod def register_provider(cls, name: str, provider_class: type): """Register a new provider type""" if not issubclass(provider_class, BaseProvider): raise ValueError("Provider must inherit from BaseProvider") cls._providers[name] = provider_class