242 lines
8.8 KiB
Python
242 lines
8.8 KiB
Python
from abc import ABC, abstractmethod
|
|
import os
|
|
import logging
|
|
import random
|
|
import requests
|
|
from typing import Dict, Any
|
|
from openai import OpenAI
|
|
from dotenv import load_dotenv
|
|
|
|
load_dotenv()
|
|
|
|
|
|
class BaseProvider(ABC):
|
|
"""Abstract base class for all AI providers"""
|
|
|
|
@abstractmethod
|
|
def generate_response(self, system_role: str, prompt: str) -> Dict[str, Any]:
|
|
"""
|
|
Generate a response from the AI provider
|
|
|
|
Returns:
|
|
Dict containing:
|
|
- response: str - The generated text
|
|
- usage: Dict with prompt_tokens, completion_tokens, total_tokens
|
|
"""
|
|
pass
|
|
|
|
|
|
class OpenAIProvider(BaseProvider):
|
|
"""OpenAI API provider"""
|
|
|
|
def __init__(self):
|
|
api_key = os.getenv("OPENAI_API_KEY")
|
|
if not api_key:
|
|
raise RuntimeError("OPENAI_API_KEY is missing in .env")
|
|
|
|
self.client = OpenAI(api_key=api_key)
|
|
self.model = os.getenv("OPENAI_MODEL", "gpt-3.5-turbo")
|
|
|
|
def generate_response(self, system_role: str, prompt: str) -> Dict[str, Any]:
|
|
response = self.client.chat.completions.create(
|
|
model=self.model,
|
|
messages=[
|
|
{"role": "system", "content": system_role},
|
|
{"role": "user", "content": prompt}
|
|
],
|
|
timeout=10
|
|
)
|
|
|
|
return {
|
|
"response": response.choices[0].message.content,
|
|
"usage": {
|
|
"prompt_tokens": response.usage.prompt_tokens,
|
|
"completion_tokens": response.usage.completion_tokens,
|
|
"total_tokens": response.usage.total_tokens
|
|
}
|
|
}
|
|
|
|
|
|
class OllamaProvider(BaseProvider):
|
|
"""Ollama API provider for local models"""
|
|
|
|
def __init__(self):
|
|
self.base_url = os.getenv("OLLAMA_URL", "http://localhost:11434")
|
|
self.model = os.getenv("OLLAMA_MODEL", "llama2")
|
|
self.temperature = float(os.getenv("OLLAMA_TEMPERATURE", "0.7"))
|
|
self.timeout = int(os.getenv("OLLAMA_TIMEOUT", "60"))
|
|
|
|
def generate_response(self, system_role: str, prompt: str) -> Dict[str, Any]:
|
|
try:
|
|
# Ollama chat completions API (similar to OpenAI)
|
|
response = requests.post(
|
|
f"{self.base_url}/api/chat",
|
|
json={
|
|
"model": self.model,
|
|
"messages": [
|
|
{"role": "system", "content": system_role},
|
|
{"role": "user", "content": prompt}
|
|
],
|
|
"stream": False,
|
|
"options": {
|
|
"temperature": self.temperature
|
|
}
|
|
},
|
|
timeout=self.timeout
|
|
)
|
|
response.raise_for_status()
|
|
|
|
result = response.json()
|
|
message = result.get("message", {})
|
|
generated_text = message.get("content", "")
|
|
|
|
# Extract token usage if available, otherwise estimate
|
|
eval_count = result.get("eval_count", 0)
|
|
prompt_eval_count = result.get("prompt_eval_count", 0)
|
|
|
|
if eval_count == 0 or prompt_eval_count == 0:
|
|
# Estimate if not provided
|
|
prompt_tokens = len(system_role.split()) + len(prompt.split())
|
|
completion_tokens = len(generated_text.split())
|
|
else:
|
|
prompt_tokens = prompt_eval_count
|
|
completion_tokens = eval_count
|
|
|
|
return {
|
|
"response": generated_text,
|
|
"usage": {
|
|
"prompt_tokens": prompt_tokens,
|
|
"completion_tokens": completion_tokens,
|
|
"total_tokens": prompt_tokens + completion_tokens
|
|
}
|
|
}
|
|
except requests.exceptions.ConnectionError:
|
|
logging.error(f"Could not connect to Ollama at {self.base_url}. Is Ollama running?")
|
|
raise RuntimeError(f"Ollama connection failed. Ensure Ollama is running at {self.base_url}")
|
|
except Exception as e:
|
|
logging.error(f"Ollama error: {e}")
|
|
raise
|
|
|
|
|
|
class LocalModelProvider(BaseProvider):
|
|
"""Generic local model provider (e.g., llama.cpp, text-generation-webui, etc.)"""
|
|
|
|
def __init__(self):
|
|
self.base_url = os.getenv("LOCAL_MODEL_URL", "http://localhost:5000")
|
|
self.api_path = os.getenv("LOCAL_MODEL_API_PATH", "/v1/completions")
|
|
self.model = os.getenv("LOCAL_MODEL_NAME", "local-model")
|
|
|
|
def generate_response(self, system_role: str, prompt: str) -> Dict[str, Any]:
|
|
full_prompt = f"{system_role}\n\nUser: {prompt}\nAssistant:"
|
|
|
|
try:
|
|
response = requests.post(
|
|
f"{self.base_url}{self.api_path}",
|
|
json={
|
|
"prompt": full_prompt,
|
|
"max_tokens": 500,
|
|
"temperature": 0.7
|
|
},
|
|
timeout=30
|
|
)
|
|
response.raise_for_status()
|
|
|
|
result = response.json()
|
|
generated_text = result.get("choices", [{}])[0].get("text", "")
|
|
|
|
# Estimate token usage
|
|
prompt_tokens = len(full_prompt.split()) * 1.3
|
|
completion_tokens = len(generated_text.split()) * 1.3
|
|
|
|
return {
|
|
"response": generated_text.strip(),
|
|
"usage": {
|
|
"prompt_tokens": int(prompt_tokens),
|
|
"completion_tokens": int(completion_tokens),
|
|
"total_tokens": int(prompt_tokens + completion_tokens)
|
|
}
|
|
}
|
|
except Exception as e:
|
|
logging.error(f"Local model error: {e}")
|
|
raise
|
|
|
|
|
|
class LoremIpsumProvider(BaseProvider):
|
|
"""Lorem Ipsum generator for testing"""
|
|
|
|
def __init__(self):
|
|
self.lorem_words = [
|
|
"lorem", "ipsum", "dolor", "sit", "amet", "consectetur", "adipiscing", "elit",
|
|
"sed", "do", "eiusmod", "tempor", "incididunt", "ut", "labore", "et", "dolore",
|
|
"magna", "aliqua", "enim", "ad", "minim", "veniam", "quis", "nostrud",
|
|
"exercitation", "ullamco", "laboris", "nisi", "aliquip", "ex", "ea", "commodo",
|
|
"consequat", "duis", "aute", "irure", "in", "reprehenderit", "voluptate",
|
|
"velit", "esse", "cillum", "fugiat", "nulla", "pariatur", "excepteur", "sint",
|
|
"occaecat", "cupidatat", "non", "proident", "sunt", "culpa", "qui", "officia",
|
|
"deserunt", "mollit", "anim", "id", "est", "laborum"
|
|
]
|
|
|
|
def generate_response(self, system_role: str, prompt: str) -> Dict[str, Any]:
|
|
# Generate random lorem ipsum text
|
|
word_count = random.randint(50, 200)
|
|
words = []
|
|
|
|
for i in range(word_count):
|
|
word = random.choice(self.lorem_words)
|
|
# Capitalize first word of sentence
|
|
if i == 0 or (i > 0 and words[-1].endswith('.')):
|
|
word = word.capitalize()
|
|
words.append(word)
|
|
|
|
# Add punctuation
|
|
if random.random() > 0.85:
|
|
words[-1] += random.choice(['.', ',', ';'])
|
|
|
|
# Ensure last word has period
|
|
if not words[-1].endswith('.'):
|
|
words[-1] += '.'
|
|
|
|
response_text = ' '.join(words)
|
|
|
|
# Calculate token usage
|
|
prompt_tokens = len(system_role.split()) + len(prompt.split())
|
|
completion_tokens = len(response_text.split())
|
|
|
|
return {
|
|
"response": response_text,
|
|
"usage": {
|
|
"prompt_tokens": prompt_tokens,
|
|
"completion_tokens": completion_tokens,
|
|
"total_tokens": prompt_tokens + completion_tokens
|
|
}
|
|
}
|
|
|
|
|
|
class ProviderFactory:
|
|
"""Factory for creating AI providers"""
|
|
|
|
_providers = {
|
|
"openai": OpenAIProvider,
|
|
"ollama": OllamaProvider,
|
|
"local": LocalModelProvider,
|
|
"lorem": LoremIpsumProvider
|
|
}
|
|
|
|
@classmethod
|
|
def get_provider(cls, provider_name: str = None) -> BaseProvider:
|
|
"""Get a provider instance by name"""
|
|
if provider_name is None:
|
|
provider_name = os.getenv("AI_PROVIDER", "lorem")
|
|
|
|
provider_name = provider_name.lower()
|
|
if provider_name not in cls._providers:
|
|
raise ValueError(f"Unknown provider: {provider_name}. Available: {list(cls._providers.keys())}")
|
|
|
|
return cls._providers[provider_name]()
|
|
|
|
@classmethod
|
|
def register_provider(cls, name: str, provider_class: type):
|
|
"""Register a new provider type"""
|
|
if not issubclass(provider_class, BaseProvider):
|
|
raise ValueError("Provider must inherit from BaseProvider")
|
|
cls._providers[name] = provider_class |