Files
Onyva-Postling/src/agents/base.py
2026-02-18 00:00:32 +01:00

197 lines
6.1 KiB
Python

"""Base agent class."""
import asyncio
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Tuple
from openai import OpenAI
import httpx
from loguru import logger
from src.config import settings, estimate_cost
from src.database.client import db
class BaseAgent(ABC):
"""Base class for all AI agents."""
def __init__(self, name: str):
"""
Initialize base agent.
Args:
name: Name of the agent
"""
self.name = name
self.openai_client = OpenAI(api_key=settings.openai_api_key)
self._usage_logs: List[Dict[str, Any]] = []
self._user_id: Optional[str] = None
self._company_id: Optional[str] = None
self._operation: str = "unknown"
logger.info(f"Initialized {name} agent")
def set_tracking_context(
self,
operation: str,
user_id: Optional[str] = None,
company_id: Optional[str] = None
):
"""Set context for usage tracking."""
self._operation = operation
self._user_id = user_id
self._company_id = company_id
self._usage_logs = []
@abstractmethod
async def process(self, *args, **kwargs) -> Any:
"""Process the agent's task."""
pass
async def _log_usage(self, provider: str, model: str, prompt_tokens: int, completion_tokens: int, total_tokens: int):
"""Log API usage to database."""
cost = estimate_cost(model, prompt_tokens, completion_tokens)
usage = {
"provider": provider,
"model": model,
"operation": self._operation,
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": total_tokens,
"estimated_cost_usd": cost
}
self._usage_logs.append(usage)
try:
from uuid import UUID
await db.log_api_usage(
provider=provider,
model=model,
operation=self._operation,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
estimated_cost_usd=cost,
user_id=UUID(self._user_id) if self._user_id else None,
company_id=UUID(self._company_id) if self._company_id else None
)
except Exception as e:
logger.warning(f"Failed to log usage to DB: {e}")
# Increment company token quota
if self._company_id:
try:
from uuid import UUID
await db.increment_company_tokens(UUID(self._company_id), total_tokens)
except Exception as e:
logger.warning(f"Failed to increment company tokens: {e}")
async def call_openai(
self,
system_prompt: str,
user_prompt: str,
model: str = "gpt-4o",
temperature: float = 0.7,
response_format: Optional[Dict[str, str]] = None
) -> str:
"""
Call OpenAI API.
Args:
system_prompt: System message
user_prompt: User message
model: Model to use
temperature: Temperature for sampling
response_format: Optional response format (e.g., {"type": "json_object"})
Returns:
Assistant's response
"""
logger.info(f"[{self.name}] Calling OpenAI ({model})")
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
]
kwargs = {
"model": model,
"messages": messages,
"temperature": temperature
}
if response_format:
kwargs["response_format"] = response_format
# Run synchronous OpenAI call in thread pool to avoid blocking event loop
response = await asyncio.to_thread(
self.openai_client.chat.completions.create,
**kwargs
)
result = response.choices[0].message.content
logger.debug(f"[{self.name}] Received response (length: {len(result)})")
# Track usage
if response.usage:
await self._log_usage(
provider="openai",
model=model,
prompt_tokens=response.usage.prompt_tokens,
completion_tokens=response.usage.completion_tokens,
total_tokens=response.usage.total_tokens
)
return result
async def call_perplexity(
self,
system_prompt: str,
user_prompt: str,
model: str = "sonar"
) -> str:
"""
Call Perplexity API for research.
Args:
system_prompt: System message
user_prompt: User message
model: Model to use
Returns:
Assistant's response
"""
logger.info(f"[{self.name}] Calling Perplexity ({model})")
url = "https://api.perplexity.ai/chat/completions"
headers = {
"Authorization": f"Bearer {settings.perplexity_api_key}",
"Content-Type": "application/json"
}
payload = {
"model": model,
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
]
}
async with httpx.AsyncClient() as client:
response = await client.post(url, json=payload, headers=headers, timeout=60.0)
response.raise_for_status()
result = response.json()
content = result["choices"][0]["message"]["content"]
logger.debug(f"[{self.name}] Received Perplexity response (length: {len(content)})")
# Track usage
usage = result.get("usage", {})
if usage:
await self._log_usage(
provider="perplexity",
model=model,
prompt_tokens=usage.get("prompt_tokens", 0),
completion_tokens=usage.get("completion_tokens", 0),
total_tokens=usage.get("prompt_tokens", 0) + usage.get("completion_tokens", 0)
)
return content