"""Base agent class.""" import asyncio from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional, Tuple from openai import OpenAI import httpx from loguru import logger from src.config import settings, estimate_cost from src.database.client import db class BaseAgent(ABC): """Base class for all AI agents.""" def __init__(self, name: str): """ Initialize base agent. Args: name: Name of the agent """ self.name = name self.openai_client = OpenAI(api_key=settings.openai_api_key) self._usage_logs: List[Dict[str, Any]] = [] self._user_id: Optional[str] = None self._company_id: Optional[str] = None self._operation: str = "unknown" logger.info(f"Initialized {name} agent") def set_tracking_context( self, operation: str, user_id: Optional[str] = None, company_id: Optional[str] = None ): """Set context for usage tracking.""" self._operation = operation self._user_id = user_id self._company_id = company_id self._usage_logs = [] @abstractmethod async def process(self, *args, **kwargs) -> Any: """Process the agent's task.""" pass async def _log_usage(self, provider: str, model: str, prompt_tokens: int, completion_tokens: int, total_tokens: int): """Log API usage to database.""" cost = estimate_cost(model, prompt_tokens, completion_tokens) usage = { "provider": provider, "model": model, "operation": self._operation, "prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens, "total_tokens": total_tokens, "estimated_cost_usd": cost } self._usage_logs.append(usage) try: from uuid import UUID await db.log_api_usage( provider=provider, model=model, operation=self._operation, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=total_tokens, estimated_cost_usd=cost, user_id=UUID(self._user_id) if self._user_id else None, company_id=UUID(self._company_id) if self._company_id else None ) except Exception as e: logger.warning(f"Failed to log usage to DB: {e}") # Increment company token quota if self._company_id: try: from uuid import UUID await db.increment_company_tokens(UUID(self._company_id), total_tokens) except Exception as e: logger.warning(f"Failed to increment company tokens: {e}") async def call_openai( self, system_prompt: str, user_prompt: str, model: str = "gpt-4o", temperature: float = 0.7, response_format: Optional[Dict[str, str]] = None ) -> str: """ Call OpenAI API. Args: system_prompt: System message user_prompt: User message model: Model to use temperature: Temperature for sampling response_format: Optional response format (e.g., {"type": "json_object"}) Returns: Assistant's response """ logger.info(f"[{self.name}] Calling OpenAI ({model})") messages = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt} ] kwargs = { "model": model, "messages": messages, "temperature": temperature } if response_format: kwargs["response_format"] = response_format # Run synchronous OpenAI call in thread pool to avoid blocking event loop response = await asyncio.to_thread( self.openai_client.chat.completions.create, **kwargs ) result = response.choices[0].message.content logger.debug(f"[{self.name}] Received response (length: {len(result)})") # Track usage if response.usage: await self._log_usage( provider="openai", model=model, prompt_tokens=response.usage.prompt_tokens, completion_tokens=response.usage.completion_tokens, total_tokens=response.usage.total_tokens ) return result async def call_perplexity( self, system_prompt: str, user_prompt: str, model: str = "sonar" ) -> str: """ Call Perplexity API for research. Args: system_prompt: System message user_prompt: User message model: Model to use Returns: Assistant's response """ logger.info(f"[{self.name}] Calling Perplexity ({model})") url = "https://api.perplexity.ai/chat/completions" headers = { "Authorization": f"Bearer {settings.perplexity_api_key}", "Content-Type": "application/json" } payload = { "model": model, "messages": [ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt} ] } async with httpx.AsyncClient() as client: response = await client.post(url, json=payload, headers=headers, timeout=60.0) response.raise_for_status() result = response.json() content = result["choices"][0]["message"]["content"] logger.debug(f"[{self.name}] Received Perplexity response (length: {len(content)})") # Track usage usage = result.get("usage", {}) if usage: await self._log_usage( provider="perplexity", model=model, prompt_tokens=usage.get("prompt_tokens", 0), completion_tokens=usage.get("completion_tokens", 0), total_tokens=usage.get("prompt_tokens", 0) + usage.get("completion_tokens", 0) ) return content