"""Supabase database client.""" import asyncio from datetime import datetime, date, timezone from typing import Optional, List, Dict, Any from uuid import UUID from supabase import create_client, Client from loguru import logger from src.config import settings from src.database.models import ( LinkedInProfile, LinkedInPost, Topic, ProfileAnalysis, ResearchResult, GeneratedPost, PostType, User, Profile, Company, Invitation, ExamplePost, ReferenceProfile, ApiUsageLog, LicenseKey, CompanyDailyQuota ) class DatabaseClient: """Supabase database client wrapper.""" def __init__(self): """Initialize Supabase client.""" if settings.supabase_service_role_key: self.client: Client = create_client( settings.supabase_url, settings.supabase_service_role_key ) logger.info("Supabase client initialized with service role key (fast mode)") else: self.client: Client = create_client( settings.supabase_url, settings.supabase_key ) logger.warning("Supabase client initialized with anon key (slow mode - RLS enabled)") self.admin_client: Optional[Client] = self.client if settings.supabase_service_role_key else None # ==================== LINKEDIN PROFILES ==================== async def save_linkedin_profile(self, profile: LinkedInProfile) -> LinkedInProfile: """Save or update LinkedIn profile.""" data = profile.model_dump(exclude={"id", "scraped_at"}, exclude_none=True) if "user_id" in data: data["user_id"] = str(data["user_id"]) existing = await asyncio.to_thread( lambda: self.client.table("linkedin_profiles").select("*").eq( "user_id", str(profile.user_id) ).execute() ) if existing.data: result = await asyncio.to_thread( lambda: self.client.table("linkedin_profiles").update(data).eq( "user_id", str(profile.user_id) ).execute() ) else: result = await asyncio.to_thread( lambda: self.client.table("linkedin_profiles").insert(data).execute() ) logger.info(f"Saved LinkedIn profile for user: {profile.user_id}") return LinkedInProfile(**result.data[0]) async def get_linkedin_profile(self, user_id: UUID) -> Optional[LinkedInProfile]: """Get LinkedIn profile for user.""" result = await asyncio.to_thread( lambda: self.client.table("linkedin_profiles").select("*").eq( "user_id", str(user_id) ).execute() ) if result.data: return LinkedInProfile(**result.data[0]) return None # ==================== LINKEDIN POSTS ==================== async def save_linkedin_posts(self, posts: List[LinkedInPost]) -> List[LinkedInPost]: """Save LinkedIn posts (bulk).""" from datetime import datetime seen = set() unique_posts = [] for p in posts: key = (str(p.user_id), p.post_url) if key not in seen: seen.add(key) unique_posts.append(p) if len(posts) != len(unique_posts): logger.warning(f"Removed {len(posts) - len(unique_posts)} duplicate posts from batch") data = [] for p in unique_posts: post_dict = p.model_dump(exclude={"id", "scraped_at"}, exclude_none=True) if "user_id" in post_dict: post_dict["user_id"] = str(post_dict["user_id"]) if "post_date" in post_dict and isinstance(post_dict["post_date"], datetime): post_dict["post_date"] = post_dict["post_date"].isoformat() data.append(post_dict) if not data: logger.warning("No posts to save") return [] result = await asyncio.to_thread( lambda: self.client.table("linkedin_posts").upsert( data, on_conflict="user_id,post_url" ).execute() ) logger.info(f"Saved {len(result.data)} LinkedIn posts") return [LinkedInPost(**item) for item in result.data] async def get_linkedin_posts(self, user_id: UUID) -> List[LinkedInPost]: """Get all LinkedIn posts for user.""" result = await asyncio.to_thread( lambda: self.client.table("linkedin_posts").select("*").eq( "user_id", str(user_id) ).order("post_date", desc=True).execute() ) return [LinkedInPost(**item) for item in result.data] async def copy_posts_to_user(self, source_user_id: UUID, target_user_id: UUID) -> List[LinkedInPost]: """Copy all LinkedIn posts from one user to another.""" source_posts = await self.get_linkedin_posts(source_user_id) if not source_posts: return [] new_posts = [] for p in source_posts: new_post = LinkedInPost( user_id=target_user_id, post_url=p.post_url, post_text=p.post_text, post_date=p.post_date, likes=p.likes, comments=p.comments, shares=p.shares, raw_data=p.raw_data, ) new_posts.append(new_post) saved = await self.save_linkedin_posts(new_posts) logger.info(f"Copied {len(saved)} posts from user {source_user_id} to {target_user_id}") return saved async def delete_linkedin_post(self, post_id: UUID) -> None: """Delete a LinkedIn post.""" await asyncio.to_thread( lambda: self.client.table("linkedin_posts").delete().eq("id", str(post_id)).execute() ) logger.info(f"Deleted LinkedIn post: {post_id}") async def get_unclassified_posts(self, user_id: UUID) -> List[LinkedInPost]: """Get all LinkedIn posts without a post_type_id.""" result = await asyncio.to_thread( lambda: self.client.table("linkedin_posts").select("*").eq( "user_id", str(user_id) ).is_("post_type_id", "null").execute() ) return [LinkedInPost(**item) for item in result.data] async def get_posts_by_type(self, user_id: UUID, post_type_id: UUID) -> List[LinkedInPost]: """Get all LinkedIn posts for a specific post type.""" result = await asyncio.to_thread( lambda: self.client.table("linkedin_posts").select("*").eq( "user_id", str(user_id) ).eq("post_type_id", str(post_type_id)).order("post_date", desc=True).execute() ) return [LinkedInPost(**item) for item in result.data] async def update_post_classification( self, post_id: UUID, post_type_id: UUID, classification_method: str, classification_confidence: float ) -> None: """Update a single post's classification.""" await asyncio.to_thread( lambda: self.client.table("linkedin_posts").update({ "post_type_id": str(post_type_id), "classification_method": classification_method, "classification_confidence": classification_confidence }).eq("id", str(post_id)).execute() ) logger.debug(f"Updated classification for post {post_id}") async def update_linkedin_post(self, post_id: UUID, updates: Dict[str, Any]) -> None: """Update a LinkedIn post with arbitrary fields.""" if "post_type_id" in updates and updates["post_type_id"]: updates["post_type_id"] = str(updates["post_type_id"]) if "user_id" in updates and updates["user_id"]: updates["user_id"] = str(updates["user_id"]) await asyncio.to_thread( lambda: self.client.table("linkedin_posts").update(updates).eq("id", str(post_id)).execute() ) logger.debug(f"Updated LinkedIn post {post_id}") async def update_posts_classification_bulk( self, classifications: List[Dict[str, Any]] ) -> int: """Bulk update post classifications.""" count = 0 for classification in classifications: try: await asyncio.to_thread( lambda c=classification: self.client.table("linkedin_posts").update({ "post_type_id": str(c["post_type_id"]), "classification_method": c["classification_method"], "classification_confidence": c["classification_confidence"] }).eq("id", str(c["post_id"])).execute() ) count += 1 except Exception as e: logger.warning(f"Failed to update classification for post {classification['post_id']}: {e}") logger.info(f"Bulk updated classifications for {count} posts") return count # ==================== POST TYPES ==================== async def create_post_type(self, post_type: PostType) -> PostType: """Create a new post type.""" data = post_type.model_dump(exclude={"id", "created_at", "updated_at"}, exclude_none=True) if "user_id" in data: data["user_id"] = str(data["user_id"]) result = await asyncio.to_thread( lambda: self.client.table("post_types").insert(data).execute() ) logger.info(f"Created post type: {result.data[0]['name']}") return PostType(**result.data[0]) async def create_post_types_bulk(self, post_types: List[PostType]) -> List[PostType]: """Create multiple post types at once.""" if not post_types: return [] data = [] for pt in post_types: pt_dict = pt.model_dump(exclude={"id", "created_at", "updated_at"}, exclude_none=True) if "user_id" in pt_dict: pt_dict["user_id"] = str(pt_dict["user_id"]) data.append(pt_dict) result = await asyncio.to_thread( lambda: self.client.table("post_types").insert(data).execute() ) logger.info(f"Created {len(result.data)} post types") return [PostType(**item) for item in result.data] async def get_post_types(self, user_id: UUID, active_only: bool = True) -> List[PostType]: """Get all post types for a user.""" def _query(): query = self.client.table("post_types").select("*").eq("user_id", str(user_id)) if active_only: query = query.eq("is_active", True) return query.order("name").execute() result = await asyncio.to_thread(_query) return [PostType(**item) for item in result.data] # Alias for get_post_types async def get_customer_post_types(self, user_id: UUID, active_only: bool = True) -> List[PostType]: """Alias for get_post_types.""" return await self.get_post_types(user_id, active_only) async def get_post_type(self, post_type_id: UUID) -> Optional[PostType]: """Get a single post type by ID.""" result = await asyncio.to_thread( lambda: self.client.table("post_types").select("*").eq( "id", str(post_type_id) ).execute() ) if result.data: return PostType(**result.data[0]) return None async def update_post_type(self, post_type_id: UUID, updates: Dict[str, Any]) -> PostType: """Update a post type.""" result = await asyncio.to_thread( lambda: self.client.table("post_types").update(updates).eq( "id", str(post_type_id) ).execute() ) logger.info(f"Updated post type: {post_type_id}") return PostType(**result.data[0]) async def update_post_type_analysis( self, post_type_id: UUID, analysis: Dict[str, Any], analyzed_post_count: int ) -> PostType: """Update the analysis for a post type.""" from datetime import datetime result = await asyncio.to_thread( lambda: self.client.table("post_types").update({ "analysis": analysis, "analysis_generated_at": datetime.now(timezone.utc).isoformat(), "analyzed_post_count": analyzed_post_count }).eq("id", str(post_type_id)).execute() ) logger.info(f"Updated analysis for post type: {post_type_id}") return PostType(**result.data[0]) async def delete_post_type(self, post_type_id: UUID, soft: bool = True) -> None: """Delete a post type (soft delete by default).""" if soft: await asyncio.to_thread( lambda: self.client.table("post_types").update({ "is_active": False }).eq("id", str(post_type_id)).execute() ) logger.info(f"Soft deleted post type: {post_type_id}") else: await asyncio.to_thread( lambda: self.client.table("post_types").delete().eq( "id", str(post_type_id) ).execute() ) logger.info(f"Hard deleted post type: {post_type_id}") # ==================== TOPICS ==================== async def save_topics(self, topics: List[Topic]) -> List[Topic]: """Save extracted topics.""" if not topics: logger.warning("No topics to save") return [] data = [] for t in topics: topic_dict = t.model_dump(exclude={"id", "created_at"}, exclude_none=True) if "user_id" in topic_dict: topic_dict["user_id"] = str(topic_dict["user_id"]) if "extracted_from_post_id" in topic_dict and topic_dict["extracted_from_post_id"]: topic_dict["extracted_from_post_id"] = str(topic_dict["extracted_from_post_id"]) if "target_post_type_id" in topic_dict and topic_dict["target_post_type_id"]: topic_dict["target_post_type_id"] = str(topic_dict["target_post_type_id"]) data.append(topic_dict) try: result = await asyncio.to_thread( lambda: self.client.table("topics").insert(data).execute() ) logger.info(f"Saved {len(result.data)} topics to database") return [Topic(**item) for item in result.data] except Exception as e: logger.error(f"Error saving topics: {e}", exc_info=True) saved = [] for topic_data in data: try: result = await asyncio.to_thread( lambda td=topic_data: self.client.table("topics").insert(td).execute() ) saved.extend([Topic(**item) for item in result.data]) except Exception as single_error: logger.warning(f"Skipped duplicate topic: {topic_data.get('title')}") logger.info(f"Saved {len(saved)} topics individually") return saved async def get_topics( self, user_id: UUID, unused_only: bool = False, post_type_id: Optional[UUID] = None ) -> List[Topic]: """Get topics for user, optionally filtered by post type.""" def _query(): query = self.client.table("topics").select("*").eq("user_id", str(user_id)) if unused_only: query = query.eq("is_used", False) if post_type_id: query = query.eq("target_post_type_id", str(post_type_id)) return query.order("created_at", desc=True).execute() result = await asyncio.to_thread(_query) return [Topic(**item) for item in result.data] async def mark_topic_used(self, topic_id: UUID) -> None: """Mark topic as used.""" await asyncio.to_thread( lambda: self.client.table("topics").update({ "is_used": True, "used_at": "now()" }).eq("id", str(topic_id)).execute() ) logger.info(f"Marked topic {topic_id} as used") # ==================== PROFILE ANALYSIS ==================== async def save_profile_analysis(self, analysis: ProfileAnalysis) -> ProfileAnalysis: """Save profile analysis.""" data = analysis.model_dump(exclude={"id", "created_at"}, exclude_none=True) if "user_id" in data: data["user_id"] = str(data["user_id"]) existing = await asyncio.to_thread( lambda: self.client.table("profile_analyses").select("*").eq( "user_id", str(analysis.user_id) ).execute() ) if existing.data: result = await asyncio.to_thread( lambda: self.client.table("profile_analyses").update(data).eq( "user_id", str(analysis.user_id) ).execute() ) else: result = await asyncio.to_thread( lambda: self.client.table("profile_analyses").insert(data).execute() ) logger.info(f"Saved profile analysis for user: {analysis.user_id}") return ProfileAnalysis(**result.data[0]) async def get_profile_analysis(self, user_id: UUID) -> Optional[ProfileAnalysis]: """Get profile analysis for user.""" result = await asyncio.to_thread( lambda: self.client.table("profile_analyses").select("*").eq( "user_id", str(user_id) ).execute() ) if result.data: return ProfileAnalysis(**result.data[0]) return None # ==================== RESEARCH RESULTS ==================== async def save_research_result(self, research: ResearchResult) -> ResearchResult: """Save research result.""" data = research.model_dump(exclude={"id", "created_at"}, exclude_none=True) if "user_id" in data: data["user_id"] = str(data["user_id"]) if "target_post_type_id" in data and data["target_post_type_id"]: data["target_post_type_id"] = str(data["target_post_type_id"]) result = await asyncio.to_thread( lambda: self.client.table("research_results").insert(data).execute() ) logger.info(f"Saved research result for user: {research.user_id}") return ResearchResult(**result.data[0]) async def get_latest_research(self, user_id: UUID) -> Optional[ResearchResult]: """Get latest research result for user.""" result = await asyncio.to_thread( lambda: self.client.table("research_results").select("*").eq( "user_id", str(user_id) ).order("created_at", desc=True).limit(1).execute() ) if result.data: return ResearchResult(**result.data[0]) return None async def get_all_research( self, user_id: UUID, post_type_id: Optional[UUID] = None ) -> List[ResearchResult]: """Get all research results for user, optionally filtered by post type.""" def _query(): query = self.client.table("research_results").select("*").eq( "user_id", str(user_id) ) if post_type_id: query = query.eq("target_post_type_id", str(post_type_id)) return query.order("created_at", desc=True).execute() result = await asyncio.to_thread(_query) return [ResearchResult(**item) for item in result.data] # ==================== GENERATED POSTS ==================== async def save_generated_post(self, post: GeneratedPost) -> GeneratedPost: """Save generated post.""" data = post.model_dump(exclude={"id", "created_at"}, exclude_none=True) if "user_id" in data: data["user_id"] = str(data["user_id"]) if "topic_id" in data and data["topic_id"]: data["topic_id"] = str(data["topic_id"]) if "post_type_id" in data and data["post_type_id"]: data["post_type_id"] = str(data["post_type_id"]) result = await asyncio.to_thread( lambda: self.client.table("generated_posts").insert(data).execute() ) logger.info(f"Saved generated post: {result.data[0]['id']}") return GeneratedPost(**result.data[0]) async def update_generated_post(self, post_id: UUID, updates: Dict[str, Any]) -> GeneratedPost: """Update generated post.""" # Convert datetime fields to isoformat strings datetime_fields = ['published_at', 'approved_at', 'rejected_at', 'scheduled_at'] for field in datetime_fields: if field in updates and isinstance(updates[field], datetime): updates[field] = updates[field].isoformat() # Handle metadata dict - ensure all nested datetime values are serialized if 'metadata' in updates and isinstance(updates['metadata'], dict): serialized_metadata = {} for key, value in updates['metadata'].items(): if isinstance(value, datetime): serialized_metadata[key] = value.isoformat() else: serialized_metadata[key] = value updates['metadata'] = serialized_metadata result = await asyncio.to_thread( lambda: self.client.table("generated_posts").update(updates).eq( "id", str(post_id) ).execute() ) logger.info(f"Updated generated post: {post_id}") return GeneratedPost(**result.data[0]) async def get_generated_posts(self, user_id: UUID) -> List[GeneratedPost]: """Get all generated posts for user.""" result = await asyncio.to_thread( lambda: self.client.table("generated_posts").select("*").eq( "user_id", str(user_id) ).order("created_at", desc=True).execute() ) return [GeneratedPost(**item) for item in result.data] async def get_generated_post(self, post_id: UUID) -> Optional[GeneratedPost]: """Get a single generated post by ID.""" result = await asyncio.to_thread( lambda: self.client.table("generated_posts").select("*").eq( "id", str(post_id) ).execute() ) if result.data: return GeneratedPost(**result.data[0]) return None async def get_scheduled_posts_due(self) -> List[GeneratedPost]: """Get all posts that are scheduled and due for publishing.""" from datetime import datetime now = datetime.now(timezone.utc).isoformat() result = await asyncio.to_thread( lambda: self.client.table("generated_posts").select("*").eq( "status", "scheduled" ).lte("scheduled_at", now).execute() ) return [GeneratedPost(**item) for item in result.data] async def get_scheduled_posts_for_company(self, company_id: UUID) -> List[Dict[str, Any]]: """Get all scheduled posts for a company with employee info.""" # Get all profiles (employees) belonging to this company profiles_result = await asyncio.to_thread( lambda: self.client.table("profiles").select("id, display_name").eq( "company_id", str(company_id) ).execute() ) if not profiles_result.data: return [] user_ids = [p["id"] for p in profiles_result.data] profile_map = {p["id"]: p for p in profiles_result.data} posts_result = await asyncio.to_thread( lambda: self.client.table("generated_posts").select( "id, user_id, topic_title, post_content, status, scheduled_at, created_at" ).in_( "user_id", user_ids ).in_( "status", ["ready", "scheduled", "published"] ).order("scheduled_at", desc=False, nullsfirst=False).execute() ) enriched_posts = [] for post in posts_result.data: profile = profile_map.get(post["user_id"]) enriched_posts.append({ **post, "employee_name": profile["display_name"] if profile else "Unknown", "employee_user_id": post["user_id"] }) return enriched_posts async def schedule_post(self, post_id: UUID, scheduled_at: datetime, scheduled_by_user_id: UUID) -> GeneratedPost: """Schedule a post for publishing.""" from datetime import datetime as dt updates = { "status": "scheduled", "scheduled_at": scheduled_at.isoformat() if isinstance(scheduled_at, dt) else scheduled_at, "scheduled_by_user_id": str(scheduled_by_user_id) } return await self.update_generated_post(post_id, updates) async def unschedule_post(self, post_id: UUID) -> GeneratedPost: """Remove scheduling from a post (back to approved status).""" updates = { "status": "approved", "scheduled_at": None, "scheduled_by_user_id": None } return await self.update_generated_post(post_id, updates) # ==================== PROFILES / USERS ==================== async def create_profile(self, user_id: UUID, profile: Profile) -> Profile: """Create a profile for a user.""" data = profile.model_dump(exclude={"created_at", "updated_at"}, exclude_none=True) data["id"] = str(user_id) if "company_id" in data and data["company_id"]: data["company_id"] = str(data["company_id"]) if "account_type" in data: data["account_type"] = data["account_type"].value if hasattr(data["account_type"], "value") else data["account_type"] if "onboarding_status" in data: data["onboarding_status"] = data["onboarding_status"].value if hasattr(data["onboarding_status"], "value") else data["onboarding_status"] result = await asyncio.to_thread( lambda: self.client.table("profiles").upsert(data).execute() ) logger.info(f"Created/updated profile: {result.data[0]['id']}") return Profile(**result.data[0]) async def get_profile(self, user_id: UUID) -> Optional[Profile]: """Get profile by user ID.""" result = await asyncio.to_thread( lambda: self.client.table("profiles").select("*").eq("id", str(user_id)).execute() ) if result.data: return Profile(**result.data[0]) return None async def get_profiles_by_linkedin_url(self, linkedin_url: str) -> List[Profile]: """Get all profiles with this LinkedIn URL.""" result = await asyncio.to_thread( lambda: self.client.table("profiles").select("*").eq("linkedin_url", linkedin_url).execute() ) return [Profile(**item) for item in result.data] async def update_profile(self, user_id: UUID, updates: Dict[str, Any]) -> Profile: """Update profile fields.""" if "company_id" in updates and updates["company_id"]: updates["company_id"] = str(updates["company_id"]) for key in ["account_type", "onboarding_status"]: if key in updates and hasattr(updates[key], "value"): updates[key] = updates[key].value result = await asyncio.to_thread( lambda: self.client.table("profiles").update(updates).eq( "id", str(user_id) ).execute() ) logger.info(f"Updated profile: {user_id}") return Profile(**result.data[0]) # ==================== LINKEDIN ACCOUNTS ==================== async def get_linkedin_account(self, user_id: UUID) -> Optional['LinkedInAccount']: """Get LinkedIn account for user.""" from src.database.models import LinkedInAccount result = await asyncio.to_thread( lambda: self.client.table("linkedin_accounts").select("*") .eq("user_id", str(user_id)).eq("is_active", True).execute() ) if result.data: return LinkedInAccount(**result.data[0]) return None async def get_linkedin_account_by_id(self, account_id: UUID) -> Optional['LinkedInAccount']: """Get LinkedIn account by ID.""" from src.database.models import LinkedInAccount result = await asyncio.to_thread( lambda: self.client.table("linkedin_accounts").select("*") .eq("id", str(account_id)).execute() ) if result.data: return LinkedInAccount(**result.data[0]) return None async def create_linkedin_account(self, account: 'LinkedInAccount') -> 'LinkedInAccount': """Create LinkedIn account connection.""" from src.database.models import LinkedInAccount data = account.model_dump(exclude={'id', 'created_at', 'updated_at'}) data['user_id'] = str(data['user_id']) data['token_expires_at'] = data['token_expires_at'].isoformat() result = await asyncio.to_thread( lambda: self.client.table("linkedin_accounts").insert(data).execute() ) logger.info(f"Created LinkedIn account for user: {account.user_id}") return LinkedInAccount(**result.data[0]) async def update_linkedin_account(self, account_id: UUID, updates: Dict) -> 'LinkedInAccount': """Update LinkedIn account.""" from src.database.models import LinkedInAccount # Convert all datetime fields to isoformat strings datetime_fields = ['token_expires_at', 'last_used_at', 'last_error_at'] for field in datetime_fields: if field in updates and isinstance(updates[field], datetime): updates[field] = updates[field].isoformat() result = await asyncio.to_thread( lambda: self.client.table("linkedin_accounts").update(updates) .eq("id", str(account_id)).execute() ) logger.info(f"Updated LinkedIn account: {account_id}") return LinkedInAccount(**result.data[0]) async def delete_linkedin_account(self, account_id: UUID) -> None: """Delete LinkedIn account connection.""" await asyncio.to_thread( lambda: self.client.table("linkedin_accounts").delete() .eq("id", str(account_id)).execute() ) logger.info(f"Deleted LinkedIn account: {account_id}") # ==================== USERS ==================== async def get_user(self, user_id: UUID) -> Optional[User]: """Get user by ID (from users view).""" result = await asyncio.to_thread( lambda: self.client.table("users").select("*").eq("id", str(user_id)).execute() ) if result.data: return User(**result.data[0]) return None async def get_user_by_email(self, email: str) -> Optional[User]: """Get user by email (from users view).""" result = await asyncio.to_thread( lambda: self.client.table("users").select("*").eq("email", email.lower()).execute() ) if result.data: return User(**result.data[0]) return None async def get_user_by_linkedin_sub(self, linkedin_sub: str) -> Optional[User]: """Get user by LinkedIn sub (OAuth identifier).""" result = await asyncio.to_thread( lambda: self.client.table("users").select("*").eq("linkedin_sub", linkedin_sub).execute() ) if result.data: return User(**result.data[0]) return None async def update_user(self, user_id: UUID, updates: Dict[str, Any]) -> User: """Update user/profile fields.""" profile_fields = ["account_type", "display_name", "onboarding_status", "onboarding_data", "company_id", "linkedin_url", "writing_style_notes", "metadata", "profile_picture", "creator_email", "customer_email", "is_active"] profile_updates = {k: v for k, v in updates.items() if k in profile_fields} if profile_updates: await self.update_profile(user_id, profile_updates) return await self.get_user(user_id) async def list_users(self, account_type: Optional[str] = None, company_id: Optional[UUID] = None) -> List[User]: """List users with optional filters (from users view).""" def _query(): query = self.client.table("users").select("*") if account_type: query = query.eq("account_type", account_type) if company_id: query = query.eq("company_id", str(company_id)) return query.order("created_at", desc=True).execute() result = await asyncio.to_thread(_query) return [User(**item) for item in result.data] async def delete_user_completely(self, user_id: UUID) -> bool: """Delete a user and all related data completely.""" try: user_id_str = str(user_id) # Delete all content data directly via user_id for table in ["generated_posts", "linkedin_posts", "example_posts", "reference_profiles", "profile_analyses", "linkedin_profiles", "post_types", "research_results", "topics"]: await asyncio.to_thread( lambda t=table: self.client.table(t).delete().eq( "user_id", user_id_str ).execute() ) # Get user email for invitation deletion user = await self.get_user(user_id) if user and user.email: await asyncio.to_thread( lambda: self.client.table("invitations").delete().eq( "email", user.email ).execute() ) # Delete profile await asyncio.to_thread( lambda: self.client.table("profiles").delete().eq( "id", user_id_str ).execute() ) # Delete from auth.users if self.admin_client: await asyncio.to_thread( lambda: self.admin_client.auth.admin.delete_user(user_id_str) ) logger.info(f"Deleted auth user {user_id}") else: logger.warning(f"Cannot delete auth user {user_id} - no service role key configured") logger.info(f"Completely deleted user {user_id} and all related data") return True except Exception as e: logger.error(f"Error deleting user {user_id}: {e}") raise # ==================== COMPANIES ==================== async def create_company(self, company: Company) -> Company: """Create a new company.""" data = company.model_dump(exclude={"id", "created_at", "updated_at"}, exclude_none=True) if "owner_user_id" in data: data["owner_user_id"] = str(data["owner_user_id"]) if "license_key_id" in data: data["license_key_id"] = str(data["license_key_id"]) result = await asyncio.to_thread( lambda: self.client.table("companies").insert(data).execute() ) logger.info(f"Created company: {result.data[0]['id']}") return Company(**result.data[0]) async def get_company(self, company_id: UUID) -> Optional[Company]: """Get company by ID.""" result = await asyncio.to_thread( lambda: self.client.table("companies").select("*").eq("id", str(company_id)).execute() ) if result.data: return Company(**result.data[0]) return None async def get_company_by_owner(self, owner_user_id: UUID) -> Optional[Company]: """Get company by owner user ID.""" result = await asyncio.to_thread( lambda: self.client.table("companies").select("*").eq("owner_user_id", str(owner_user_id)).execute() ) if result.data: return Company(**result.data[0]) return None async def get_company_employees(self, company_id: UUID) -> List[User]: """Get all employees of a company.""" result = await asyncio.to_thread( lambda: self.client.table("users").select("*").eq( "company_id", str(company_id) ).eq("account_type", "employee").order("created_at", desc=True).execute() ) return [User(**item) for item in result.data] async def update_company(self, company_id: UUID, updates: Dict[str, Any]) -> Company: """Update company fields.""" result = await asyncio.to_thread( lambda: self.client.table("companies").update(updates).eq( "id", str(company_id) ).execute() ) logger.info(f"Updated company: {company_id}") return Company(**result.data[0]) async def list_companies(self) -> List[Company]: """List all companies.""" result = await asyncio.to_thread( lambda: self.client.table("companies").select("*").order("created_at", desc=True).execute() ) return [Company(**item) for item in result.data] # ==================== INVITATIONS ==================== async def create_invitation(self, invitation: Invitation) -> Invitation: """Create a new invitation.""" from datetime import datetime data = invitation.model_dump(exclude={"id", "created_at"}, exclude_none=True) for key in ["company_id", "invited_by_user_id", "accepted_by_user_id"]: if key in data and data[key]: data[key] = str(data[key]) if "expires_at" in data and isinstance(data["expires_at"], datetime): data["expires_at"] = data["expires_at"].isoformat() if "accepted_at" in data and isinstance(data["accepted_at"], datetime): data["accepted_at"] = data["accepted_at"].isoformat() if "status" in data and hasattr(data["status"], "value"): data["status"] = data["status"].value result = await asyncio.to_thread( lambda: self.client.table("invitations").insert(data).execute() ) logger.info(f"Created invitation: {result.data[0]['id']}") return Invitation(**result.data[0]) async def get_invitation_by_token(self, token: str) -> Optional[Invitation]: """Get invitation by token.""" result = await asyncio.to_thread( lambda: self.client.table("invitations").select("*").eq("token", token).execute() ) if result.data: return Invitation(**result.data[0]) return None async def get_invitation(self, invitation_id: UUID) -> Optional[Invitation]: """Get invitation by ID.""" result = await asyncio.to_thread( lambda: self.client.table("invitations").select("*").eq("id", str(invitation_id)).execute() ) if result.data: return Invitation(**result.data[0]) return None async def update_invitation(self, invitation_id: UUID, updates: Dict[str, Any]) -> Invitation: """Update invitation fields.""" from datetime import datetime if "accepted_by_user_id" in updates and updates["accepted_by_user_id"]: updates["accepted_by_user_id"] = str(updates["accepted_by_user_id"]) if "accepted_at" in updates and isinstance(updates["accepted_at"], datetime): updates["accepted_at"] = updates["accepted_at"].isoformat() if "status" in updates and hasattr(updates["status"], "value"): updates["status"] = updates["status"].value result = await asyncio.to_thread( lambda: self.client.table("invitations").update(updates).eq( "id", str(invitation_id) ).execute() ) logger.info(f"Updated invitation: {invitation_id}") return Invitation(**result.data[0]) async def get_pending_invitations(self, company_id: UUID) -> List[Invitation]: """Get all pending invitations for a company.""" result = await asyncio.to_thread( lambda: self.client.table("invitations").select("*").eq( "company_id", str(company_id) ).eq("status", "pending").order("created_at", desc=True).execute() ) return [Invitation(**item) for item in result.data] async def get_invitations_by_email(self, email: str) -> List[Invitation]: """Get all invitations for an email address.""" result = await asyncio.to_thread( lambda: self.client.table("invitations").select("*").eq( "email", email.lower() ).order("created_at", desc=True).execute() ) return [Invitation(**item) for item in result.data] # ==================== EXAMPLE POSTS ==================== async def save_example_posts(self, posts: List[ExamplePost]) -> List[ExamplePost]: """Save example posts (bulk).""" if not posts: return [] data = [] for p in posts: post_dict = p.model_dump(exclude={"id", "created_at"}, exclude_none=True) if "user_id" in post_dict: post_dict["user_id"] = str(post_dict["user_id"]) if "post_type_id" in post_dict and post_dict["post_type_id"]: post_dict["post_type_id"] = str(post_dict["post_type_id"]) data.append(post_dict) result = await asyncio.to_thread( lambda: self.client.table("example_posts").insert(data).execute() ) logger.info(f"Saved {len(result.data)} example posts") return [ExamplePost(**item) for item in result.data] async def save_example_post(self, post: ExamplePost) -> ExamplePost: """Save a single example post.""" posts = await self.save_example_posts([post]) return posts[0] if posts else post async def get_example_posts(self, user_id: UUID) -> List[ExamplePost]: """Get all example posts for a user.""" result = await asyncio.to_thread( lambda: self.client.table("example_posts").select("*").eq( "user_id", str(user_id) ).order("created_at", desc=True).execute() ) return [ExamplePost(**item) for item in result.data] async def delete_example_post(self, post_id: UUID) -> None: """Delete an example post.""" await asyncio.to_thread( lambda: self.client.table("example_posts").delete().eq("id", str(post_id)).execute() ) logger.info(f"Deleted example post: {post_id}") # ==================== REFERENCE PROFILES ==================== async def create_reference_profile(self, profile: ReferenceProfile) -> ReferenceProfile: """Create a new reference profile.""" data = profile.model_dump(exclude={"id", "created_at"}, exclude_none=True) if "user_id" in data: data["user_id"] = str(data["user_id"]) result = await asyncio.to_thread( lambda: self.client.table("reference_profiles").insert(data).execute() ) logger.info(f"Created reference profile: {result.data[0]['id']}") return ReferenceProfile(**result.data[0]) async def get_reference_profiles(self, user_id: UUID) -> List[ReferenceProfile]: """Get all reference profiles for a user.""" result = await asyncio.to_thread( lambda: self.client.table("reference_profiles").select("*").eq( "user_id", str(user_id) ).order("created_at", desc=True).execute() ) return [ReferenceProfile(**item) for item in result.data] async def update_reference_profile(self, profile_id: UUID, updates: Dict[str, Any]) -> ReferenceProfile: """Update reference profile fields.""" result = await asyncio.to_thread( lambda: self.client.table("reference_profiles").update(updates).eq( "id", str(profile_id) ).execute() ) logger.info(f"Updated reference profile: {profile_id}") return ReferenceProfile(**result.data[0]) async def delete_reference_profile(self, profile_id: UUID) -> None: """Delete a reference profile.""" await asyncio.to_thread( lambda: self.client.table("reference_profiles").delete().eq("id", str(profile_id)).execute() ) logger.info(f"Deleted reference profile: {profile_id}") # ==================== API USAGE LOGS ==================== async def log_api_usage( self, provider: str, model: str, operation: str, prompt_tokens: int, completion_tokens: int, total_tokens: int, estimated_cost_usd: float, user_id: Optional[UUID] = None, company_id: Optional[UUID] = None ) -> None: """Log an API usage event.""" data = { "provider": provider, "model": model, "operation": operation, "prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens, "total_tokens": total_tokens, "estimated_cost_usd": estimated_cost_usd } if user_id: data["user_id"] = str(user_id) if company_id: data["company_id"] = str(company_id) try: await asyncio.to_thread( lambda: self.client.table("api_usage_logs").insert(data).execute() ) except Exception as e: logger.warning(f"Failed to log API usage: {e}") async def get_usage_stats( self, start_date: Optional[str] = None, end_date: Optional[str] = None, user_id: Optional[UUID] = None, company_id: Optional[UUID] = None ) -> List[Dict[str, Any]]: """Get aggregated usage stats.""" def _query(): query = self.client.table("api_usage_logs").select("*") if user_id: query = query.eq("user_id", str(user_id)) if company_id: query = query.eq("company_id", str(company_id)) if start_date: query = query.gte("created_at", start_date) if end_date: query = query.lte("created_at", end_date) return query.order("created_at", desc=True).execute() result = await asyncio.to_thread(_query) return result.data async def get_usage_by_day( self, start_date: Optional[str] = None, end_date: Optional[str] = None, user_id: Optional[UUID] = None, company_id: Optional[UUID] = None ) -> List[Dict[str, Any]]: """Get daily usage breakdown.""" logs = await self.get_usage_stats(start_date, end_date, user_id, company_id) from collections import defaultdict daily = defaultdict(lambda: {"total_tokens": 0, "estimated_cost_usd": 0.0, "count": 0}) for log in logs: day = log["created_at"][:10] daily[day]["total_tokens"] += log.get("total_tokens", 0) daily[day]["estimated_cost_usd"] += float(log.get("estimated_cost_usd", 0)) daily[day]["count"] += 1 return [{"date": day, **data} for day, data in sorted(daily.items())] async def get_post_stats( self, start_date: Optional[str] = None, user_id: Optional[UUID] = None, company_id: Optional[UUID] = None ) -> List[Dict[str, Any]]: """Get generated posts with select fields for statistics.""" # Resolve user_ids for filtering target_user_ids = None if user_id: target_user_ids = [str(user_id)] if company_id: # Get all users belonging to this company profiles_result = await asyncio.to_thread( lambda: self.client.table("profiles").select("id").eq( "company_id", str(company_id) ).execute() ) company_user_ids = [p["id"] for p in profiles_result.data] if target_user_ids is not None: target_user_ids = [u for u in target_user_ids if u in company_user_ids] else: target_user_ids = company_user_ids if not target_user_ids: return [] def _query(): q = self.client.table("generated_posts").select( "id, created_at, status, approved_at, published_at, user_id" ) if target_user_ids is not None: q = q.in_("user_id", target_user_ids) if start_date: q = q.gte("created_at", start_date) return q.order("created_at", desc=True).execute() result = await asyncio.to_thread(_query) return result.data # ==================== LICENSE KEYS ==================== async def list_license_keys(self) -> List[LicenseKey]: """Get all license keys (admin only).""" result = await asyncio.to_thread( lambda: self.client.table("license_keys") .select("*") .order("created_at", desc=True) .execute() ) return [LicenseKey(**row) for row in result.data] async def get_license_key(self, key: str) -> Optional[LicenseKey]: """Get license key by key string.""" result = await asyncio.to_thread( lambda: self.client.table("license_keys") .select("*") .eq("key", key) .execute() ) return LicenseKey(**result.data[0]) if result.data else None async def create_license_key( self, key: str, max_employees: int, max_posts_per_day: int, max_researches_per_day: int, description: Optional[str] = None ) -> LicenseKey: """Create new license key.""" data = { "key": key, "max_employees": max_employees, "max_posts_per_day": max_posts_per_day, "max_researches_per_day": max_researches_per_day, "description": description, } result = await asyncio.to_thread( lambda: self.client.table("license_keys").insert(data).execute() ) logger.info(f"Created license key: {key}") return LicenseKey(**result.data[0]) async def mark_license_key_used( self, key: str, company_id: UUID ) -> LicenseKey: """Mark license key as used and link to company.""" data = { "used": True, "company_id": str(company_id), "used_at": datetime.now(timezone.utc).isoformat() } result = await asyncio.to_thread( lambda: self.client.table("license_keys") .update(data) .eq("key", key) .execute() ) logger.info(f"Marked license key as used: {key}") return LicenseKey(**result.data[0]) async def update_license_key(self, key_id: UUID, updates: Dict[str, Any]) -> LicenseKey: """Update license key limits (admin only).""" result = await asyncio.to_thread( lambda: self.client.table("license_keys") .update(updates) .eq("id", str(key_id)) .execute() ) logger.info(f"Updated license key: {key_id}") return LicenseKey(**result.data[0]) if result.data else None async def delete_license_key(self, key_id: UUID) -> None: """Delete license key (admin only).""" await asyncio.to_thread( lambda: self.client.table("license_keys") .delete() .eq("id", str(key_id)) .execute() ) logger.info(f"Deleted license key: {key_id}") # ==================== COMPANY QUOTAS ==================== async def get_company_daily_quota( self, company_id: UUID, date_: Optional[date] = None ) -> CompanyDailyQuota: """Get or create daily quota for company.""" if date_ is None: date_ = date.today() # Try to get existing result = await asyncio.to_thread( lambda: self.client.table("company_daily_quotas") .select("*") .eq("company_id", str(company_id)) .eq("date", date_.isoformat()) .execute() ) if result.data: return CompanyDailyQuota(**result.data[0]) # Create new quota for today data = { "company_id": str(company_id), "date": date_.isoformat(), "posts_created": 0, "researches_created": 0 } result = await asyncio.to_thread( lambda: self.client.table("company_daily_quotas") .insert(data) .execute() ) return CompanyDailyQuota(**result.data[0]) async def increment_company_posts_quota(self, company_id: UUID) -> None: """Increment daily posts count for company.""" try: quota = await self.get_company_daily_quota(company_id) new_count = quota.posts_created + 1 result = await asyncio.to_thread( lambda: self.client.table("company_daily_quotas") .update({"posts_created": new_count}) .eq("id", str(quota.id)) .execute() ) logger.info(f"Incremented posts quota for company {company_id}: {quota.posts_created} -> {new_count}") if not result.data: logger.error(f"Failed to increment posts quota - no data returned") except Exception as e: logger.error(f"Error incrementing posts quota for company {company_id}: {e}") raise async def increment_company_researches_quota(self, company_id: UUID) -> None: """Increment daily researches count for company.""" try: quota = await self.get_company_daily_quota(company_id) new_count = quota.researches_created + 1 result = await asyncio.to_thread( lambda: self.client.table("company_daily_quotas") .update({"researches_created": new_count}) .eq("id", str(quota.id)) .execute() ) logger.info(f"Incremented researches quota for company {company_id}: {quota.researches_created} -> {new_count}") if not result.data: logger.error(f"Failed to increment researches quota - no data returned") except Exception as e: logger.error(f"Error incrementing researches quota for company {company_id}: {e}") raise async def get_company_limits(self, company_id: UUID) -> Optional[LicenseKey]: """Get company limits from associated license key. Returns None if no license key is associated (uses defaults). """ company = await self.get_company(company_id) if not company or not company.license_key_id: return None result = await asyncio.to_thread( lambda: self.client.table("license_keys") .select("*") .eq("id", str(company.license_key_id)) .execute() ) return LicenseKey(**result.data[0]) if result.data else None async def check_company_post_limit(self, company_id: UUID) -> tuple[bool, str]: """Check if company can create more posts today. Returns (can_create: bool, error_message: str) """ license_key = await self.get_company_limits(company_id) if not license_key: # No license key, use defaults (unlimited) return True, "" quota = await self.get_company_daily_quota(company_id) if quota.posts_created >= license_key.max_posts_per_day: return False, f"Tageslimit erreicht ({license_key.max_posts_per_day} Posts/Tag). Versuche es morgen wieder." return True, "" async def check_company_research_limit(self, company_id: UUID) -> tuple[bool, str]: """Check if company can create more researches today. Returns (can_create: bool, error_message: str) """ license_key = await self.get_company_limits(company_id) if not license_key: # No license key, use defaults (unlimited) return True, "" quota = await self.get_company_daily_quota(company_id) if quota.researches_created >= license_key.max_researches_per_day: return False, f"Tageslimit erreicht ({license_key.max_researches_per_day} Researches/Tag). Versuche es morgen wieder." return True, "" async def check_company_employee_limit(self, company_id: UUID) -> tuple[bool, str]: """Check if company can add more employees. Returns (can_add: bool, error_message: str) """ license_key = await self.get_company_limits(company_id) if not license_key: # No license key, use defaults (unlimited) return True, "" employees = await self.list_users(account_type="employee", company_id=company_id) if len(employees) >= license_key.max_employees: return False, f"Maximale Mitarbeiteranzahl erreicht ({license_key.max_employees}). Bitte Lizenz upgraden." return True, "" # Global database client instance db = DatabaseClient()