maid/proxy/providers.py

242 lines
8.8 KiB
Python

from abc import ABC, abstractmethod
import os
import logging
import random
import requests
from typing import Dict, Any
from openai import OpenAI
from dotenv import load_dotenv
load_dotenv()
class BaseProvider(ABC):
"""Abstract base class for all AI providers"""
@abstractmethod
def generate_response(self, system_role: str, prompt: str) -> Dict[str, Any]:
"""
Generate a response from the AI provider
Returns:
Dict containing:
- response: str - The generated text
- usage: Dict with prompt_tokens, completion_tokens, total_tokens
"""
pass
class OpenAIProvider(BaseProvider):
"""OpenAI API provider"""
def __init__(self):
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
raise RuntimeError("OPENAI_API_KEY is missing in .env")
self.client = OpenAI(api_key=api_key)
self.model = os.getenv("OPENAI_MODEL", "gpt-3.5-turbo")
def generate_response(self, system_role: str, prompt: str) -> Dict[str, Any]:
response = self.client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": system_role},
{"role": "user", "content": prompt}
],
timeout=10
)
return {
"response": response.choices[0].message.content,
"usage": {
"prompt_tokens": response.usage.prompt_tokens,
"completion_tokens": response.usage.completion_tokens,
"total_tokens": response.usage.total_tokens
}
}
class OllamaProvider(BaseProvider):
"""Ollama API provider for local models"""
def __init__(self):
self.base_url = os.getenv("OLLAMA_URL", "http://localhost:11434")
self.model = os.getenv("OLLAMA_MODEL", "llama2")
self.temperature = float(os.getenv("OLLAMA_TEMPERATURE", "0.7"))
self.timeout = int(os.getenv("OLLAMA_TIMEOUT", "60"))
def generate_response(self, system_role: str, prompt: str) -> Dict[str, Any]:
try:
# Ollama chat completions API (similar to OpenAI)
response = requests.post(
f"{self.base_url}/api/chat",
json={
"model": self.model,
"messages": [
{"role": "system", "content": system_role},
{"role": "user", "content": prompt}
],
"stream": False,
"options": {
"temperature": self.temperature
}
},
timeout=self.timeout
)
response.raise_for_status()
result = response.json()
message = result.get("message", {})
generated_text = message.get("content", "")
# Extract token usage if available, otherwise estimate
eval_count = result.get("eval_count", 0)
prompt_eval_count = result.get("prompt_eval_count", 0)
if eval_count == 0 or prompt_eval_count == 0:
# Estimate if not provided
prompt_tokens = len(system_role.split()) + len(prompt.split())
completion_tokens = len(generated_text.split())
else:
prompt_tokens = prompt_eval_count
completion_tokens = eval_count
return {
"response": generated_text,
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens
}
}
except requests.exceptions.ConnectionError:
logging.error(f"Could not connect to Ollama at {self.base_url}. Is Ollama running?")
raise RuntimeError(f"Ollama connection failed. Ensure Ollama is running at {self.base_url}")
except Exception as e:
logging.error(f"Ollama error: {e}")
raise
class LocalModelProvider(BaseProvider):
"""Generic local model provider (e.g., llama.cpp, text-generation-webui, etc.)"""
def __init__(self):
self.base_url = os.getenv("LOCAL_MODEL_URL", "http://localhost:5000")
self.api_path = os.getenv("LOCAL_MODEL_API_PATH", "/v1/completions")
self.model = os.getenv("LOCAL_MODEL_NAME", "local-model")
def generate_response(self, system_role: str, prompt: str) -> Dict[str, Any]:
full_prompt = f"{system_role}\n\nUser: {prompt}\nAssistant:"
try:
response = requests.post(
f"{self.base_url}{self.api_path}",
json={
"prompt": full_prompt,
"max_tokens": 500,
"temperature": 0.7
},
timeout=30
)
response.raise_for_status()
result = response.json()
generated_text = result.get("choices", [{}])[0].get("text", "")
# Estimate token usage
prompt_tokens = len(full_prompt.split()) * 1.3
completion_tokens = len(generated_text.split()) * 1.3
return {
"response": generated_text.strip(),
"usage": {
"prompt_tokens": int(prompt_tokens),
"completion_tokens": int(completion_tokens),
"total_tokens": int(prompt_tokens + completion_tokens)
}
}
except Exception as e:
logging.error(f"Local model error: {e}")
raise
class LoremIpsumProvider(BaseProvider):
"""Lorem Ipsum generator for testing"""
def __init__(self):
self.lorem_words = [
"lorem", "ipsum", "dolor", "sit", "amet", "consectetur", "adipiscing", "elit",
"sed", "do", "eiusmod", "tempor", "incididunt", "ut", "labore", "et", "dolore",
"magna", "aliqua", "enim", "ad", "minim", "veniam", "quis", "nostrud",
"exercitation", "ullamco", "laboris", "nisi", "aliquip", "ex", "ea", "commodo",
"consequat", "duis", "aute", "irure", "in", "reprehenderit", "voluptate",
"velit", "esse", "cillum", "fugiat", "nulla", "pariatur", "excepteur", "sint",
"occaecat", "cupidatat", "non", "proident", "sunt", "culpa", "qui", "officia",
"deserunt", "mollit", "anim", "id", "est", "laborum"
]
def generate_response(self, system_role: str, prompt: str) -> Dict[str, Any]:
# Generate random lorem ipsum text
word_count = random.randint(50, 200)
words = []
for i in range(word_count):
word = random.choice(self.lorem_words)
# Capitalize first word of sentence
if i == 0 or (i > 0 and words[-1].endswith('.')):
word = word.capitalize()
words.append(word)
# Add punctuation
if random.random() > 0.85:
words[-1] += random.choice(['.', ',', ';'])
# Ensure last word has period
if not words[-1].endswith('.'):
words[-1] += '.'
response_text = ' '.join(words)
# Calculate token usage
prompt_tokens = len(system_role.split()) + len(prompt.split())
completion_tokens = len(response_text.split())
return {
"response": response_text,
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens
}
}
class ProviderFactory:
"""Factory for creating AI providers"""
_providers = {
"openai": OpenAIProvider,
"ollama": OllamaProvider,
"local": LocalModelProvider,
"lorem": LoremIpsumProvider
}
@classmethod
def get_provider(cls, provider_name: str = None) -> BaseProvider:
"""Get a provider instance by name"""
if provider_name is None:
provider_name = os.getenv("AI_PROVIDER", "lorem")
provider_name = provider_name.lower()
if provider_name not in cls._providers:
raise ValueError(f"Unknown provider: {provider_name}. Available: {list(cls._providers.keys())}")
return cls._providers[provider_name]()
@classmethod
def register_provider(cls, name: str, provider_class: type):
"""Register a new provider type"""
if not issubclass(provider_class, BaseProvider):
raise ValueError("Provider must inherit from BaseProvider")
cls._providers[name] = provider_class