Files
Onyva-Postling/src/database/client.py
Ruben Fischer f14515e9cf Major updates: LinkedIn auto-posting, timezone fixes, and Docker improvements
Features:
- Add LinkedIn OAuth integration and auto-posting functionality
- Add scheduler service for automated post publishing
- Add metadata field to generated_posts for LinkedIn URLs
- Add privacy policy page for LinkedIn API compliance
- Add company management features and employee accounts
- Add license key system for company registrations

Fixes:
- Fix timezone issues (use UTC consistently across app)
- Fix datetime serialization errors in database operations
- Fix scheduling timezone conversion (local time to UTC)
- Fix import errors (get_database -> db)

Infrastructure:
- Update Docker setup to use port 8001 (avoid conflicts)
- Add SSL support with nginx-proxy and Let's Encrypt
- Add LinkedIn setup documentation
- Add migration scripts for schema updates

Services:
- Add linkedin_service.py for LinkedIn API integration
- Add scheduler_service.py for background job processing
- Add storage_service.py for Supabase Storage
- Add email_service.py improvements
- Add encryption utilities for token storage

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
2026-02-11 11:30:20 +01:00

1397 lines
56 KiB
Python

"""Supabase database client."""
import asyncio
from datetime import datetime, date, timezone
from typing import Optional, List, Dict, Any
from uuid import UUID
from supabase import create_client, Client
from loguru import logger
from src.config import settings
from src.database.models import (
LinkedInProfile, LinkedInPost, Topic,
ProfileAnalysis, ResearchResult, GeneratedPost, PostType,
User, Profile, Company, Invitation, ExamplePost, ReferenceProfile,
ApiUsageLog, LicenseKey, CompanyDailyQuota
)
class DatabaseClient:
"""Supabase database client wrapper."""
def __init__(self):
"""Initialize Supabase client."""
if settings.supabase_service_role_key:
self.client: Client = create_client(
settings.supabase_url,
settings.supabase_service_role_key
)
logger.info("Supabase client initialized with service role key (fast mode)")
else:
self.client: Client = create_client(
settings.supabase_url,
settings.supabase_key
)
logger.warning("Supabase client initialized with anon key (slow mode - RLS enabled)")
self.admin_client: Optional[Client] = self.client if settings.supabase_service_role_key else None
# ==================== LINKEDIN PROFILES ====================
async def save_linkedin_profile(self, profile: LinkedInProfile) -> LinkedInProfile:
"""Save or update LinkedIn profile."""
data = profile.model_dump(exclude={"id", "scraped_at"}, exclude_none=True)
if "user_id" in data:
data["user_id"] = str(data["user_id"])
existing = await asyncio.to_thread(
lambda: self.client.table("linkedin_profiles").select("*").eq(
"user_id", str(profile.user_id)
).execute()
)
if existing.data:
result = await asyncio.to_thread(
lambda: self.client.table("linkedin_profiles").update(data).eq(
"user_id", str(profile.user_id)
).execute()
)
else:
result = await asyncio.to_thread(
lambda: self.client.table("linkedin_profiles").insert(data).execute()
)
logger.info(f"Saved LinkedIn profile for user: {profile.user_id}")
return LinkedInProfile(**result.data[0])
async def get_linkedin_profile(self, user_id: UUID) -> Optional[LinkedInProfile]:
"""Get LinkedIn profile for user."""
result = await asyncio.to_thread(
lambda: self.client.table("linkedin_profiles").select("*").eq(
"user_id", str(user_id)
).execute()
)
if result.data:
return LinkedInProfile(**result.data[0])
return None
# ==================== LINKEDIN POSTS ====================
async def save_linkedin_posts(self, posts: List[LinkedInPost]) -> List[LinkedInPost]:
"""Save LinkedIn posts (bulk)."""
from datetime import datetime
seen = set()
unique_posts = []
for p in posts:
key = (str(p.user_id), p.post_url)
if key not in seen:
seen.add(key)
unique_posts.append(p)
if len(posts) != len(unique_posts):
logger.warning(f"Removed {len(posts) - len(unique_posts)} duplicate posts from batch")
data = []
for p in unique_posts:
post_dict = p.model_dump(exclude={"id", "scraped_at"}, exclude_none=True)
if "user_id" in post_dict:
post_dict["user_id"] = str(post_dict["user_id"])
if "post_date" in post_dict and isinstance(post_dict["post_date"], datetime):
post_dict["post_date"] = post_dict["post_date"].isoformat()
data.append(post_dict)
if not data:
logger.warning("No posts to save")
return []
result = await asyncio.to_thread(
lambda: self.client.table("linkedin_posts").upsert(
data,
on_conflict="user_id,post_url"
).execute()
)
logger.info(f"Saved {len(result.data)} LinkedIn posts")
return [LinkedInPost(**item) for item in result.data]
async def get_linkedin_posts(self, user_id: UUID) -> List[LinkedInPost]:
"""Get all LinkedIn posts for user."""
result = await asyncio.to_thread(
lambda: self.client.table("linkedin_posts").select("*").eq(
"user_id", str(user_id)
).order("post_date", desc=True).execute()
)
return [LinkedInPost(**item) for item in result.data]
async def copy_posts_to_user(self, source_user_id: UUID, target_user_id: UUID) -> List[LinkedInPost]:
"""Copy all LinkedIn posts from one user to another."""
source_posts = await self.get_linkedin_posts(source_user_id)
if not source_posts:
return []
new_posts = []
for p in source_posts:
new_post = LinkedInPost(
user_id=target_user_id,
post_url=p.post_url,
post_text=p.post_text,
post_date=p.post_date,
likes=p.likes,
comments=p.comments,
shares=p.shares,
raw_data=p.raw_data,
)
new_posts.append(new_post)
saved = await self.save_linkedin_posts(new_posts)
logger.info(f"Copied {len(saved)} posts from user {source_user_id} to {target_user_id}")
return saved
async def delete_linkedin_post(self, post_id: UUID) -> None:
"""Delete a LinkedIn post."""
await asyncio.to_thread(
lambda: self.client.table("linkedin_posts").delete().eq("id", str(post_id)).execute()
)
logger.info(f"Deleted LinkedIn post: {post_id}")
async def get_unclassified_posts(self, user_id: UUID) -> List[LinkedInPost]:
"""Get all LinkedIn posts without a post_type_id."""
result = await asyncio.to_thread(
lambda: self.client.table("linkedin_posts").select("*").eq(
"user_id", str(user_id)
).is_("post_type_id", "null").execute()
)
return [LinkedInPost(**item) for item in result.data]
async def get_posts_by_type(self, user_id: UUID, post_type_id: UUID) -> List[LinkedInPost]:
"""Get all LinkedIn posts for a specific post type."""
result = await asyncio.to_thread(
lambda: self.client.table("linkedin_posts").select("*").eq(
"user_id", str(user_id)
).eq("post_type_id", str(post_type_id)).order("post_date", desc=True).execute()
)
return [LinkedInPost(**item) for item in result.data]
async def update_post_classification(
self,
post_id: UUID,
post_type_id: UUID,
classification_method: str,
classification_confidence: float
) -> None:
"""Update a single post's classification."""
await asyncio.to_thread(
lambda: self.client.table("linkedin_posts").update({
"post_type_id": str(post_type_id),
"classification_method": classification_method,
"classification_confidence": classification_confidence
}).eq("id", str(post_id)).execute()
)
logger.debug(f"Updated classification for post {post_id}")
async def update_linkedin_post(self, post_id: UUID, updates: Dict[str, Any]) -> None:
"""Update a LinkedIn post with arbitrary fields."""
if "post_type_id" in updates and updates["post_type_id"]:
updates["post_type_id"] = str(updates["post_type_id"])
if "user_id" in updates and updates["user_id"]:
updates["user_id"] = str(updates["user_id"])
await asyncio.to_thread(
lambda: self.client.table("linkedin_posts").update(updates).eq("id", str(post_id)).execute()
)
logger.debug(f"Updated LinkedIn post {post_id}")
async def update_posts_classification_bulk(
self,
classifications: List[Dict[str, Any]]
) -> int:
"""Bulk update post classifications."""
count = 0
for classification in classifications:
try:
await asyncio.to_thread(
lambda c=classification: self.client.table("linkedin_posts").update({
"post_type_id": str(c["post_type_id"]),
"classification_method": c["classification_method"],
"classification_confidence": c["classification_confidence"]
}).eq("id", str(c["post_id"])).execute()
)
count += 1
except Exception as e:
logger.warning(f"Failed to update classification for post {classification['post_id']}: {e}")
logger.info(f"Bulk updated classifications for {count} posts")
return count
# ==================== POST TYPES ====================
async def create_post_type(self, post_type: PostType) -> PostType:
"""Create a new post type."""
data = post_type.model_dump(exclude={"id", "created_at", "updated_at"}, exclude_none=True)
if "user_id" in data:
data["user_id"] = str(data["user_id"])
result = await asyncio.to_thread(
lambda: self.client.table("post_types").insert(data).execute()
)
logger.info(f"Created post type: {result.data[0]['name']}")
return PostType(**result.data[0])
async def create_post_types_bulk(self, post_types: List[PostType]) -> List[PostType]:
"""Create multiple post types at once."""
if not post_types:
return []
data = []
for pt in post_types:
pt_dict = pt.model_dump(exclude={"id", "created_at", "updated_at"}, exclude_none=True)
if "user_id" in pt_dict:
pt_dict["user_id"] = str(pt_dict["user_id"])
data.append(pt_dict)
result = await asyncio.to_thread(
lambda: self.client.table("post_types").insert(data).execute()
)
logger.info(f"Created {len(result.data)} post types")
return [PostType(**item) for item in result.data]
async def get_post_types(self, user_id: UUID, active_only: bool = True) -> List[PostType]:
"""Get all post types for a user."""
def _query():
query = self.client.table("post_types").select("*").eq("user_id", str(user_id))
if active_only:
query = query.eq("is_active", True)
return query.order("name").execute()
result = await asyncio.to_thread(_query)
return [PostType(**item) for item in result.data]
# Alias for get_post_types
async def get_customer_post_types(self, user_id: UUID, active_only: bool = True) -> List[PostType]:
"""Alias for get_post_types."""
return await self.get_post_types(user_id, active_only)
async def get_post_type(self, post_type_id: UUID) -> Optional[PostType]:
"""Get a single post type by ID."""
result = await asyncio.to_thread(
lambda: self.client.table("post_types").select("*").eq(
"id", str(post_type_id)
).execute()
)
if result.data:
return PostType(**result.data[0])
return None
async def update_post_type(self, post_type_id: UUID, updates: Dict[str, Any]) -> PostType:
"""Update a post type."""
result = await asyncio.to_thread(
lambda: self.client.table("post_types").update(updates).eq(
"id", str(post_type_id)
).execute()
)
logger.info(f"Updated post type: {post_type_id}")
return PostType(**result.data[0])
async def update_post_type_analysis(
self,
post_type_id: UUID,
analysis: Dict[str, Any],
analyzed_post_count: int
) -> PostType:
"""Update the analysis for a post type."""
from datetime import datetime
result = await asyncio.to_thread(
lambda: self.client.table("post_types").update({
"analysis": analysis,
"analysis_generated_at": datetime.now(timezone.utc).isoformat(),
"analyzed_post_count": analyzed_post_count
}).eq("id", str(post_type_id)).execute()
)
logger.info(f"Updated analysis for post type: {post_type_id}")
return PostType(**result.data[0])
async def delete_post_type(self, post_type_id: UUID, soft: bool = True) -> None:
"""Delete a post type (soft delete by default)."""
if soft:
await asyncio.to_thread(
lambda: self.client.table("post_types").update({
"is_active": False
}).eq("id", str(post_type_id)).execute()
)
logger.info(f"Soft deleted post type: {post_type_id}")
else:
await asyncio.to_thread(
lambda: self.client.table("post_types").delete().eq(
"id", str(post_type_id)
).execute()
)
logger.info(f"Hard deleted post type: {post_type_id}")
# ==================== TOPICS ====================
async def save_topics(self, topics: List[Topic]) -> List[Topic]:
"""Save extracted topics."""
if not topics:
logger.warning("No topics to save")
return []
data = []
for t in topics:
topic_dict = t.model_dump(exclude={"id", "created_at"}, exclude_none=True)
if "user_id" in topic_dict:
topic_dict["user_id"] = str(topic_dict["user_id"])
if "extracted_from_post_id" in topic_dict and topic_dict["extracted_from_post_id"]:
topic_dict["extracted_from_post_id"] = str(topic_dict["extracted_from_post_id"])
if "target_post_type_id" in topic_dict and topic_dict["target_post_type_id"]:
topic_dict["target_post_type_id"] = str(topic_dict["target_post_type_id"])
data.append(topic_dict)
try:
result = await asyncio.to_thread(
lambda: self.client.table("topics").insert(data).execute()
)
logger.info(f"Saved {len(result.data)} topics to database")
return [Topic(**item) for item in result.data]
except Exception as e:
logger.error(f"Error saving topics: {e}", exc_info=True)
saved = []
for topic_data in data:
try:
result = await asyncio.to_thread(
lambda td=topic_data: self.client.table("topics").insert(td).execute()
)
saved.extend([Topic(**item) for item in result.data])
except Exception as single_error:
logger.warning(f"Skipped duplicate topic: {topic_data.get('title')}")
logger.info(f"Saved {len(saved)} topics individually")
return saved
async def get_topics(
self,
user_id: UUID,
unused_only: bool = False,
post_type_id: Optional[UUID] = None
) -> List[Topic]:
"""Get topics for user, optionally filtered by post type."""
def _query():
query = self.client.table("topics").select("*").eq("user_id", str(user_id))
if unused_only:
query = query.eq("is_used", False)
if post_type_id:
query = query.eq("target_post_type_id", str(post_type_id))
return query.order("created_at", desc=True).execute()
result = await asyncio.to_thread(_query)
return [Topic(**item) for item in result.data]
async def mark_topic_used(self, topic_id: UUID) -> None:
"""Mark topic as used."""
await asyncio.to_thread(
lambda: self.client.table("topics").update({
"is_used": True,
"used_at": "now()"
}).eq("id", str(topic_id)).execute()
)
logger.info(f"Marked topic {topic_id} as used")
# ==================== PROFILE ANALYSIS ====================
async def save_profile_analysis(self, analysis: ProfileAnalysis) -> ProfileAnalysis:
"""Save profile analysis."""
data = analysis.model_dump(exclude={"id", "created_at"}, exclude_none=True)
if "user_id" in data:
data["user_id"] = str(data["user_id"])
existing = await asyncio.to_thread(
lambda: self.client.table("profile_analyses").select("*").eq(
"user_id", str(analysis.user_id)
).execute()
)
if existing.data:
result = await asyncio.to_thread(
lambda: self.client.table("profile_analyses").update(data).eq(
"user_id", str(analysis.user_id)
).execute()
)
else:
result = await asyncio.to_thread(
lambda: self.client.table("profile_analyses").insert(data).execute()
)
logger.info(f"Saved profile analysis for user: {analysis.user_id}")
return ProfileAnalysis(**result.data[0])
async def get_profile_analysis(self, user_id: UUID) -> Optional[ProfileAnalysis]:
"""Get profile analysis for user."""
result = await asyncio.to_thread(
lambda: self.client.table("profile_analyses").select("*").eq(
"user_id", str(user_id)
).execute()
)
if result.data:
return ProfileAnalysis(**result.data[0])
return None
# ==================== RESEARCH RESULTS ====================
async def save_research_result(self, research: ResearchResult) -> ResearchResult:
"""Save research result."""
data = research.model_dump(exclude={"id", "created_at"}, exclude_none=True)
if "user_id" in data:
data["user_id"] = str(data["user_id"])
if "target_post_type_id" in data and data["target_post_type_id"]:
data["target_post_type_id"] = str(data["target_post_type_id"])
result = await asyncio.to_thread(
lambda: self.client.table("research_results").insert(data).execute()
)
logger.info(f"Saved research result for user: {research.user_id}")
return ResearchResult(**result.data[0])
async def get_latest_research(self, user_id: UUID) -> Optional[ResearchResult]:
"""Get latest research result for user."""
result = await asyncio.to_thread(
lambda: self.client.table("research_results").select("*").eq(
"user_id", str(user_id)
).order("created_at", desc=True).limit(1).execute()
)
if result.data:
return ResearchResult(**result.data[0])
return None
async def get_all_research(
self,
user_id: UUID,
post_type_id: Optional[UUID] = None
) -> List[ResearchResult]:
"""Get all research results for user, optionally filtered by post type."""
def _query():
query = self.client.table("research_results").select("*").eq(
"user_id", str(user_id)
)
if post_type_id:
query = query.eq("target_post_type_id", str(post_type_id))
return query.order("created_at", desc=True).execute()
result = await asyncio.to_thread(_query)
return [ResearchResult(**item) for item in result.data]
# ==================== GENERATED POSTS ====================
async def save_generated_post(self, post: GeneratedPost) -> GeneratedPost:
"""Save generated post."""
data = post.model_dump(exclude={"id", "created_at"}, exclude_none=True)
if "user_id" in data:
data["user_id"] = str(data["user_id"])
if "topic_id" in data and data["topic_id"]:
data["topic_id"] = str(data["topic_id"])
if "post_type_id" in data and data["post_type_id"]:
data["post_type_id"] = str(data["post_type_id"])
result = await asyncio.to_thread(
lambda: self.client.table("generated_posts").insert(data).execute()
)
logger.info(f"Saved generated post: {result.data[0]['id']}")
return GeneratedPost(**result.data[0])
async def update_generated_post(self, post_id: UUID, updates: Dict[str, Any]) -> GeneratedPost:
"""Update generated post."""
# Convert datetime fields to isoformat strings
datetime_fields = ['published_at', 'approved_at', 'rejected_at', 'scheduled_at']
for field in datetime_fields:
if field in updates and isinstance(updates[field], datetime):
updates[field] = updates[field].isoformat()
# Handle metadata dict - ensure all nested datetime values are serialized
if 'metadata' in updates and isinstance(updates['metadata'], dict):
serialized_metadata = {}
for key, value in updates['metadata'].items():
if isinstance(value, datetime):
serialized_metadata[key] = value.isoformat()
else:
serialized_metadata[key] = value
updates['metadata'] = serialized_metadata
result = await asyncio.to_thread(
lambda: self.client.table("generated_posts").update(updates).eq(
"id", str(post_id)
).execute()
)
logger.info(f"Updated generated post: {post_id}")
return GeneratedPost(**result.data[0])
async def get_generated_posts(self, user_id: UUID) -> List[GeneratedPost]:
"""Get all generated posts for user."""
result = await asyncio.to_thread(
lambda: self.client.table("generated_posts").select("*").eq(
"user_id", str(user_id)
).order("created_at", desc=True).execute()
)
return [GeneratedPost(**item) for item in result.data]
async def get_generated_post(self, post_id: UUID) -> Optional[GeneratedPost]:
"""Get a single generated post by ID."""
result = await asyncio.to_thread(
lambda: self.client.table("generated_posts").select("*").eq(
"id", str(post_id)
).execute()
)
if result.data:
return GeneratedPost(**result.data[0])
return None
async def get_scheduled_posts_due(self) -> List[GeneratedPost]:
"""Get all posts that are scheduled and due for publishing."""
from datetime import datetime
now = datetime.now(timezone.utc).isoformat()
result = await asyncio.to_thread(
lambda: self.client.table("generated_posts").select("*").eq(
"status", "scheduled"
).lte("scheduled_at", now).execute()
)
return [GeneratedPost(**item) for item in result.data]
async def get_scheduled_posts_for_company(self, company_id: UUID) -> List[Dict[str, Any]]:
"""Get all scheduled posts for a company with employee info."""
# Get all profiles (employees) belonging to this company
profiles_result = await asyncio.to_thread(
lambda: self.client.table("profiles").select("id, display_name").eq(
"company_id", str(company_id)
).execute()
)
if not profiles_result.data:
return []
user_ids = [p["id"] for p in profiles_result.data]
profile_map = {p["id"]: p for p in profiles_result.data}
posts_result = await asyncio.to_thread(
lambda: self.client.table("generated_posts").select(
"id, user_id, topic_title, post_content, status, scheduled_at, created_at"
).in_(
"user_id", user_ids
).in_(
"status", ["ready", "scheduled", "published"]
).order("scheduled_at", desc=False, nullsfirst=False).execute()
)
enriched_posts = []
for post in posts_result.data:
profile = profile_map.get(post["user_id"])
enriched_posts.append({
**post,
"employee_name": profile["display_name"] if profile else "Unknown",
"employee_user_id": post["user_id"]
})
return enriched_posts
async def schedule_post(self, post_id: UUID, scheduled_at: datetime, scheduled_by_user_id: UUID) -> GeneratedPost:
"""Schedule a post for publishing."""
from datetime import datetime as dt
updates = {
"status": "scheduled",
"scheduled_at": scheduled_at.isoformat() if isinstance(scheduled_at, dt) else scheduled_at,
"scheduled_by_user_id": str(scheduled_by_user_id)
}
return await self.update_generated_post(post_id, updates)
async def unschedule_post(self, post_id: UUID) -> GeneratedPost:
"""Remove scheduling from a post (back to approved status)."""
updates = {
"status": "approved",
"scheduled_at": None,
"scheduled_by_user_id": None
}
return await self.update_generated_post(post_id, updates)
# ==================== PROFILES / USERS ====================
async def create_profile(self, user_id: UUID, profile: Profile) -> Profile:
"""Create a profile for a user."""
data = profile.model_dump(exclude={"created_at", "updated_at"}, exclude_none=True)
data["id"] = str(user_id)
if "company_id" in data and data["company_id"]:
data["company_id"] = str(data["company_id"])
if "account_type" in data:
data["account_type"] = data["account_type"].value if hasattr(data["account_type"], "value") else data["account_type"]
if "onboarding_status" in data:
data["onboarding_status"] = data["onboarding_status"].value if hasattr(data["onboarding_status"], "value") else data["onboarding_status"]
result = await asyncio.to_thread(
lambda: self.client.table("profiles").upsert(data).execute()
)
logger.info(f"Created/updated profile: {result.data[0]['id']}")
return Profile(**result.data[0])
async def get_profile(self, user_id: UUID) -> Optional[Profile]:
"""Get profile by user ID."""
result = await asyncio.to_thread(
lambda: self.client.table("profiles").select("*").eq("id", str(user_id)).execute()
)
if result.data:
return Profile(**result.data[0])
return None
async def get_profiles_by_linkedin_url(self, linkedin_url: str) -> List[Profile]:
"""Get all profiles with this LinkedIn URL."""
result = await asyncio.to_thread(
lambda: self.client.table("profiles").select("*").eq("linkedin_url", linkedin_url).execute()
)
return [Profile(**item) for item in result.data]
async def update_profile(self, user_id: UUID, updates: Dict[str, Any]) -> Profile:
"""Update profile fields."""
if "company_id" in updates and updates["company_id"]:
updates["company_id"] = str(updates["company_id"])
for key in ["account_type", "onboarding_status"]:
if key in updates and hasattr(updates[key], "value"):
updates[key] = updates[key].value
result = await asyncio.to_thread(
lambda: self.client.table("profiles").update(updates).eq(
"id", str(user_id)
).execute()
)
logger.info(f"Updated profile: {user_id}")
return Profile(**result.data[0])
# ==================== LINKEDIN ACCOUNTS ====================
async def get_linkedin_account(self, user_id: UUID) -> Optional['LinkedInAccount']:
"""Get LinkedIn account for user."""
from src.database.models import LinkedInAccount
result = await asyncio.to_thread(
lambda: self.client.table("linkedin_accounts").select("*")
.eq("user_id", str(user_id)).eq("is_active", True).execute()
)
if result.data:
return LinkedInAccount(**result.data[0])
return None
async def get_linkedin_account_by_id(self, account_id: UUID) -> Optional['LinkedInAccount']:
"""Get LinkedIn account by ID."""
from src.database.models import LinkedInAccount
result = await asyncio.to_thread(
lambda: self.client.table("linkedin_accounts").select("*")
.eq("id", str(account_id)).execute()
)
if result.data:
return LinkedInAccount(**result.data[0])
return None
async def create_linkedin_account(self, account: 'LinkedInAccount') -> 'LinkedInAccount':
"""Create LinkedIn account connection."""
from src.database.models import LinkedInAccount
data = account.model_dump(exclude={'id', 'created_at', 'updated_at'})
data['user_id'] = str(data['user_id'])
data['token_expires_at'] = data['token_expires_at'].isoformat()
result = await asyncio.to_thread(
lambda: self.client.table("linkedin_accounts").insert(data).execute()
)
logger.info(f"Created LinkedIn account for user: {account.user_id}")
return LinkedInAccount(**result.data[0])
async def update_linkedin_account(self, account_id: UUID, updates: Dict) -> 'LinkedInAccount':
"""Update LinkedIn account."""
from src.database.models import LinkedInAccount
# Convert all datetime fields to isoformat strings
datetime_fields = ['token_expires_at', 'last_used_at', 'last_error_at']
for field in datetime_fields:
if field in updates and isinstance(updates[field], datetime):
updates[field] = updates[field].isoformat()
result = await asyncio.to_thread(
lambda: self.client.table("linkedin_accounts").update(updates)
.eq("id", str(account_id)).execute()
)
logger.info(f"Updated LinkedIn account: {account_id}")
return LinkedInAccount(**result.data[0])
async def delete_linkedin_account(self, account_id: UUID) -> None:
"""Delete LinkedIn account connection."""
await asyncio.to_thread(
lambda: self.client.table("linkedin_accounts").delete()
.eq("id", str(account_id)).execute()
)
logger.info(f"Deleted LinkedIn account: {account_id}")
# ==================== USERS ====================
async def get_user(self, user_id: UUID) -> Optional[User]:
"""Get user by ID (from users view)."""
result = await asyncio.to_thread(
lambda: self.client.table("users").select("*").eq("id", str(user_id)).execute()
)
if result.data:
return User(**result.data[0])
return None
async def get_user_by_email(self, email: str) -> Optional[User]:
"""Get user by email (from users view)."""
result = await asyncio.to_thread(
lambda: self.client.table("users").select("*").eq("email", email.lower()).execute()
)
if result.data:
return User(**result.data[0])
return None
async def get_user_by_linkedin_sub(self, linkedin_sub: str) -> Optional[User]:
"""Get user by LinkedIn sub (OAuth identifier)."""
result = await asyncio.to_thread(
lambda: self.client.table("users").select("*").eq("linkedin_sub", linkedin_sub).execute()
)
if result.data:
return User(**result.data[0])
return None
async def update_user(self, user_id: UUID, updates: Dict[str, Any]) -> User:
"""Update user/profile fields."""
profile_fields = ["account_type", "display_name", "onboarding_status",
"onboarding_data", "company_id",
"linkedin_url", "writing_style_notes", "metadata",
"profile_picture", "creator_email", "customer_email", "is_active"]
profile_updates = {k: v for k, v in updates.items() if k in profile_fields}
if profile_updates:
await self.update_profile(user_id, profile_updates)
return await self.get_user(user_id)
async def list_users(self, account_type: Optional[str] = None, company_id: Optional[UUID] = None) -> List[User]:
"""List users with optional filters (from users view)."""
def _query():
query = self.client.table("users").select("*")
if account_type:
query = query.eq("account_type", account_type)
if company_id:
query = query.eq("company_id", str(company_id))
return query.order("created_at", desc=True).execute()
result = await asyncio.to_thread(_query)
return [User(**item) for item in result.data]
async def delete_user_completely(self, user_id: UUID) -> bool:
"""Delete a user and all related data completely."""
try:
user_id_str = str(user_id)
# Delete all content data directly via user_id
for table in ["generated_posts", "linkedin_posts", "example_posts",
"reference_profiles", "profile_analyses", "linkedin_profiles",
"post_types", "research_results", "topics"]:
await asyncio.to_thread(
lambda t=table: self.client.table(t).delete().eq(
"user_id", user_id_str
).execute()
)
# Get user email for invitation deletion
user = await self.get_user(user_id)
if user and user.email:
await asyncio.to_thread(
lambda: self.client.table("invitations").delete().eq(
"email", user.email
).execute()
)
# Delete profile
await asyncio.to_thread(
lambda: self.client.table("profiles").delete().eq(
"id", user_id_str
).execute()
)
# Delete from auth.users
if self.admin_client:
await asyncio.to_thread(
lambda: self.admin_client.auth.admin.delete_user(user_id_str)
)
logger.info(f"Deleted auth user {user_id}")
else:
logger.warning(f"Cannot delete auth user {user_id} - no service role key configured")
logger.info(f"Completely deleted user {user_id} and all related data")
return True
except Exception as e:
logger.error(f"Error deleting user {user_id}: {e}")
raise
# ==================== COMPANIES ====================
async def create_company(self, company: Company) -> Company:
"""Create a new company."""
data = company.model_dump(exclude={"id", "created_at", "updated_at"}, exclude_none=True)
if "owner_user_id" in data:
data["owner_user_id"] = str(data["owner_user_id"])
if "license_key_id" in data:
data["license_key_id"] = str(data["license_key_id"])
result = await asyncio.to_thread(
lambda: self.client.table("companies").insert(data).execute()
)
logger.info(f"Created company: {result.data[0]['id']}")
return Company(**result.data[0])
async def get_company(self, company_id: UUID) -> Optional[Company]:
"""Get company by ID."""
result = await asyncio.to_thread(
lambda: self.client.table("companies").select("*").eq("id", str(company_id)).execute()
)
if result.data:
return Company(**result.data[0])
return None
async def get_company_by_owner(self, owner_user_id: UUID) -> Optional[Company]:
"""Get company by owner user ID."""
result = await asyncio.to_thread(
lambda: self.client.table("companies").select("*").eq("owner_user_id", str(owner_user_id)).execute()
)
if result.data:
return Company(**result.data[0])
return None
async def get_company_employees(self, company_id: UUID) -> List[User]:
"""Get all employees of a company."""
result = await asyncio.to_thread(
lambda: self.client.table("users").select("*").eq(
"company_id", str(company_id)
).eq("account_type", "employee").order("created_at", desc=True).execute()
)
return [User(**item) for item in result.data]
async def update_company(self, company_id: UUID, updates: Dict[str, Any]) -> Company:
"""Update company fields."""
result = await asyncio.to_thread(
lambda: self.client.table("companies").update(updates).eq(
"id", str(company_id)
).execute()
)
logger.info(f"Updated company: {company_id}")
return Company(**result.data[0])
async def list_companies(self) -> List[Company]:
"""List all companies."""
result = await asyncio.to_thread(
lambda: self.client.table("companies").select("*").order("created_at", desc=True).execute()
)
return [Company(**item) for item in result.data]
# ==================== INVITATIONS ====================
async def create_invitation(self, invitation: Invitation) -> Invitation:
"""Create a new invitation."""
from datetime import datetime
data = invitation.model_dump(exclude={"id", "created_at"}, exclude_none=True)
for key in ["company_id", "invited_by_user_id", "accepted_by_user_id"]:
if key in data and data[key]:
data[key] = str(data[key])
if "expires_at" in data and isinstance(data["expires_at"], datetime):
data["expires_at"] = data["expires_at"].isoformat()
if "accepted_at" in data and isinstance(data["accepted_at"], datetime):
data["accepted_at"] = data["accepted_at"].isoformat()
if "status" in data and hasattr(data["status"], "value"):
data["status"] = data["status"].value
result = await asyncio.to_thread(
lambda: self.client.table("invitations").insert(data).execute()
)
logger.info(f"Created invitation: {result.data[0]['id']}")
return Invitation(**result.data[0])
async def get_invitation_by_token(self, token: str) -> Optional[Invitation]:
"""Get invitation by token."""
result = await asyncio.to_thread(
lambda: self.client.table("invitations").select("*").eq("token", token).execute()
)
if result.data:
return Invitation(**result.data[0])
return None
async def get_invitation(self, invitation_id: UUID) -> Optional[Invitation]:
"""Get invitation by ID."""
result = await asyncio.to_thread(
lambda: self.client.table("invitations").select("*").eq("id", str(invitation_id)).execute()
)
if result.data:
return Invitation(**result.data[0])
return None
async def update_invitation(self, invitation_id: UUID, updates: Dict[str, Any]) -> Invitation:
"""Update invitation fields."""
from datetime import datetime
if "accepted_by_user_id" in updates and updates["accepted_by_user_id"]:
updates["accepted_by_user_id"] = str(updates["accepted_by_user_id"])
if "accepted_at" in updates and isinstance(updates["accepted_at"], datetime):
updates["accepted_at"] = updates["accepted_at"].isoformat()
if "status" in updates and hasattr(updates["status"], "value"):
updates["status"] = updates["status"].value
result = await asyncio.to_thread(
lambda: self.client.table("invitations").update(updates).eq(
"id", str(invitation_id)
).execute()
)
logger.info(f"Updated invitation: {invitation_id}")
return Invitation(**result.data[0])
async def get_pending_invitations(self, company_id: UUID) -> List[Invitation]:
"""Get all pending invitations for a company."""
result = await asyncio.to_thread(
lambda: self.client.table("invitations").select("*").eq(
"company_id", str(company_id)
).eq("status", "pending").order("created_at", desc=True).execute()
)
return [Invitation(**item) for item in result.data]
async def get_invitations_by_email(self, email: str) -> List[Invitation]:
"""Get all invitations for an email address."""
result = await asyncio.to_thread(
lambda: self.client.table("invitations").select("*").eq(
"email", email.lower()
).order("created_at", desc=True).execute()
)
return [Invitation(**item) for item in result.data]
# ==================== EXAMPLE POSTS ====================
async def save_example_posts(self, posts: List[ExamplePost]) -> List[ExamplePost]:
"""Save example posts (bulk)."""
if not posts:
return []
data = []
for p in posts:
post_dict = p.model_dump(exclude={"id", "created_at"}, exclude_none=True)
if "user_id" in post_dict:
post_dict["user_id"] = str(post_dict["user_id"])
if "post_type_id" in post_dict and post_dict["post_type_id"]:
post_dict["post_type_id"] = str(post_dict["post_type_id"])
data.append(post_dict)
result = await asyncio.to_thread(
lambda: self.client.table("example_posts").insert(data).execute()
)
logger.info(f"Saved {len(result.data)} example posts")
return [ExamplePost(**item) for item in result.data]
async def save_example_post(self, post: ExamplePost) -> ExamplePost:
"""Save a single example post."""
posts = await self.save_example_posts([post])
return posts[0] if posts else post
async def get_example_posts(self, user_id: UUID) -> List[ExamplePost]:
"""Get all example posts for a user."""
result = await asyncio.to_thread(
lambda: self.client.table("example_posts").select("*").eq(
"user_id", str(user_id)
).order("created_at", desc=True).execute()
)
return [ExamplePost(**item) for item in result.data]
async def delete_example_post(self, post_id: UUID) -> None:
"""Delete an example post."""
await asyncio.to_thread(
lambda: self.client.table("example_posts").delete().eq("id", str(post_id)).execute()
)
logger.info(f"Deleted example post: {post_id}")
# ==================== REFERENCE PROFILES ====================
async def create_reference_profile(self, profile: ReferenceProfile) -> ReferenceProfile:
"""Create a new reference profile."""
data = profile.model_dump(exclude={"id", "created_at"}, exclude_none=True)
if "user_id" in data:
data["user_id"] = str(data["user_id"])
result = await asyncio.to_thread(
lambda: self.client.table("reference_profiles").insert(data).execute()
)
logger.info(f"Created reference profile: {result.data[0]['id']}")
return ReferenceProfile(**result.data[0])
async def get_reference_profiles(self, user_id: UUID) -> List[ReferenceProfile]:
"""Get all reference profiles for a user."""
result = await asyncio.to_thread(
lambda: self.client.table("reference_profiles").select("*").eq(
"user_id", str(user_id)
).order("created_at", desc=True).execute()
)
return [ReferenceProfile(**item) for item in result.data]
async def update_reference_profile(self, profile_id: UUID, updates: Dict[str, Any]) -> ReferenceProfile:
"""Update reference profile fields."""
result = await asyncio.to_thread(
lambda: self.client.table("reference_profiles").update(updates).eq(
"id", str(profile_id)
).execute()
)
logger.info(f"Updated reference profile: {profile_id}")
return ReferenceProfile(**result.data[0])
async def delete_reference_profile(self, profile_id: UUID) -> None:
"""Delete a reference profile."""
await asyncio.to_thread(
lambda: self.client.table("reference_profiles").delete().eq("id", str(profile_id)).execute()
)
logger.info(f"Deleted reference profile: {profile_id}")
# ==================== API USAGE LOGS ====================
async def log_api_usage(
self,
provider: str,
model: str,
operation: str,
prompt_tokens: int,
completion_tokens: int,
total_tokens: int,
estimated_cost_usd: float,
user_id: Optional[UUID] = None,
company_id: Optional[UUID] = None
) -> None:
"""Log an API usage event."""
data = {
"provider": provider,
"model": model,
"operation": operation,
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": total_tokens,
"estimated_cost_usd": estimated_cost_usd
}
if user_id:
data["user_id"] = str(user_id)
if company_id:
data["company_id"] = str(company_id)
try:
await asyncio.to_thread(
lambda: self.client.table("api_usage_logs").insert(data).execute()
)
except Exception as e:
logger.warning(f"Failed to log API usage: {e}")
async def get_usage_stats(
self,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
user_id: Optional[UUID] = None,
company_id: Optional[UUID] = None
) -> List[Dict[str, Any]]:
"""Get aggregated usage stats."""
def _query():
query = self.client.table("api_usage_logs").select("*")
if user_id:
query = query.eq("user_id", str(user_id))
if company_id:
query = query.eq("company_id", str(company_id))
if start_date:
query = query.gte("created_at", start_date)
if end_date:
query = query.lte("created_at", end_date)
return query.order("created_at", desc=True).execute()
result = await asyncio.to_thread(_query)
return result.data
async def get_usage_by_day(
self,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
user_id: Optional[UUID] = None,
company_id: Optional[UUID] = None
) -> List[Dict[str, Any]]:
"""Get daily usage breakdown."""
logs = await self.get_usage_stats(start_date, end_date, user_id, company_id)
from collections import defaultdict
daily = defaultdict(lambda: {"total_tokens": 0, "estimated_cost_usd": 0.0, "count": 0})
for log in logs:
day = log["created_at"][:10]
daily[day]["total_tokens"] += log.get("total_tokens", 0)
daily[day]["estimated_cost_usd"] += float(log.get("estimated_cost_usd", 0))
daily[day]["count"] += 1
return [{"date": day, **data} for day, data in sorted(daily.items())]
async def get_post_stats(
self,
start_date: Optional[str] = None,
user_id: Optional[UUID] = None,
company_id: Optional[UUID] = None
) -> List[Dict[str, Any]]:
"""Get generated posts with select fields for statistics."""
# Resolve user_ids for filtering
target_user_ids = None
if user_id:
target_user_ids = [str(user_id)]
if company_id:
# Get all users belonging to this company
profiles_result = await asyncio.to_thread(
lambda: self.client.table("profiles").select("id").eq(
"company_id", str(company_id)
).execute()
)
company_user_ids = [p["id"] for p in profiles_result.data]
if target_user_ids is not None:
target_user_ids = [u for u in target_user_ids if u in company_user_ids]
else:
target_user_ids = company_user_ids
if not target_user_ids:
return []
def _query():
q = self.client.table("generated_posts").select(
"id, created_at, status, approved_at, published_at, user_id"
)
if target_user_ids is not None:
q = q.in_("user_id", target_user_ids)
if start_date:
q = q.gte("created_at", start_date)
return q.order("created_at", desc=True).execute()
result = await asyncio.to_thread(_query)
return result.data
# ==================== LICENSE KEYS ====================
async def list_license_keys(self) -> List[LicenseKey]:
"""Get all license keys (admin only)."""
result = await asyncio.to_thread(
lambda: self.client.table("license_keys")
.select("*")
.order("created_at", desc=True)
.execute()
)
return [LicenseKey(**row) for row in result.data]
async def get_license_key(self, key: str) -> Optional[LicenseKey]:
"""Get license key by key string."""
result = await asyncio.to_thread(
lambda: self.client.table("license_keys")
.select("*")
.eq("key", key)
.execute()
)
return LicenseKey(**result.data[0]) if result.data else None
async def create_license_key(
self,
key: str,
max_employees: int,
max_posts_per_day: int,
max_researches_per_day: int,
description: Optional[str] = None
) -> LicenseKey:
"""Create new license key."""
data = {
"key": key,
"max_employees": max_employees,
"max_posts_per_day": max_posts_per_day,
"max_researches_per_day": max_researches_per_day,
"description": description,
}
result = await asyncio.to_thread(
lambda: self.client.table("license_keys").insert(data).execute()
)
logger.info(f"Created license key: {key}")
return LicenseKey(**result.data[0])
async def mark_license_key_used(
self,
key: str,
company_id: UUID
) -> LicenseKey:
"""Mark license key as used and link to company."""
data = {
"used": True,
"company_id": str(company_id),
"used_at": datetime.now(timezone.utc).isoformat()
}
result = await asyncio.to_thread(
lambda: self.client.table("license_keys")
.update(data)
.eq("key", key)
.execute()
)
logger.info(f"Marked license key as used: {key}")
return LicenseKey(**result.data[0])
async def update_license_key(self, key_id: UUID, updates: Dict[str, Any]) -> LicenseKey:
"""Update license key limits (admin only)."""
result = await asyncio.to_thread(
lambda: self.client.table("license_keys")
.update(updates)
.eq("id", str(key_id))
.execute()
)
logger.info(f"Updated license key: {key_id}")
return LicenseKey(**result.data[0]) if result.data else None
async def delete_license_key(self, key_id: UUID) -> None:
"""Delete license key (admin only)."""
await asyncio.to_thread(
lambda: self.client.table("license_keys")
.delete()
.eq("id", str(key_id))
.execute()
)
logger.info(f"Deleted license key: {key_id}")
# ==================== COMPANY QUOTAS ====================
async def get_company_daily_quota(
self,
company_id: UUID,
date_: Optional[date] = None
) -> CompanyDailyQuota:
"""Get or create daily quota for company."""
if date_ is None:
date_ = date.today()
# Try to get existing
result = await asyncio.to_thread(
lambda: self.client.table("company_daily_quotas")
.select("*")
.eq("company_id", str(company_id))
.eq("date", date_.isoformat())
.execute()
)
if result.data:
return CompanyDailyQuota(**result.data[0])
# Create new quota for today
data = {
"company_id": str(company_id),
"date": date_.isoformat(),
"posts_created": 0,
"researches_created": 0
}
result = await asyncio.to_thread(
lambda: self.client.table("company_daily_quotas")
.insert(data)
.execute()
)
return CompanyDailyQuota(**result.data[0])
async def increment_company_posts_quota(self, company_id: UUID) -> None:
"""Increment daily posts count for company."""
try:
quota = await self.get_company_daily_quota(company_id)
new_count = quota.posts_created + 1
result = await asyncio.to_thread(
lambda: self.client.table("company_daily_quotas")
.update({"posts_created": new_count})
.eq("id", str(quota.id))
.execute()
)
logger.info(f"Incremented posts quota for company {company_id}: {quota.posts_created} -> {new_count}")
if not result.data:
logger.error(f"Failed to increment posts quota - no data returned")
except Exception as e:
logger.error(f"Error incrementing posts quota for company {company_id}: {e}")
raise
async def increment_company_researches_quota(self, company_id: UUID) -> None:
"""Increment daily researches count for company."""
try:
quota = await self.get_company_daily_quota(company_id)
new_count = quota.researches_created + 1
result = await asyncio.to_thread(
lambda: self.client.table("company_daily_quotas")
.update({"researches_created": new_count})
.eq("id", str(quota.id))
.execute()
)
logger.info(f"Incremented researches quota for company {company_id}: {quota.researches_created} -> {new_count}")
if not result.data:
logger.error(f"Failed to increment researches quota - no data returned")
except Exception as e:
logger.error(f"Error incrementing researches quota for company {company_id}: {e}")
raise
async def get_company_limits(self, company_id: UUID) -> Optional[LicenseKey]:
"""Get company limits from associated license key.
Returns None if no license key is associated (uses defaults).
"""
company = await self.get_company(company_id)
if not company or not company.license_key_id:
return None
result = await asyncio.to_thread(
lambda: self.client.table("license_keys")
.select("*")
.eq("id", str(company.license_key_id))
.execute()
)
return LicenseKey(**result.data[0]) if result.data else None
async def check_company_post_limit(self, company_id: UUID) -> tuple[bool, str]:
"""Check if company can create more posts today.
Returns (can_create: bool, error_message: str)
"""
license_key = await self.get_company_limits(company_id)
if not license_key:
# No license key, use defaults (unlimited)
return True, ""
quota = await self.get_company_daily_quota(company_id)
if quota.posts_created >= license_key.max_posts_per_day:
return False, f"Tageslimit erreicht ({license_key.max_posts_per_day} Posts/Tag). Versuche es morgen wieder."
return True, ""
async def check_company_research_limit(self, company_id: UUID) -> tuple[bool, str]:
"""Check if company can create more researches today.
Returns (can_create: bool, error_message: str)
"""
license_key = await self.get_company_limits(company_id)
if not license_key:
# No license key, use defaults (unlimited)
return True, ""
quota = await self.get_company_daily_quota(company_id)
if quota.researches_created >= license_key.max_researches_per_day:
return False, f"Tageslimit erreicht ({license_key.max_researches_per_day} Researches/Tag). Versuche es morgen wieder."
return True, ""
async def check_company_employee_limit(self, company_id: UUID) -> tuple[bool, str]:
"""Check if company can add more employees.
Returns (can_add: bool, error_message: str)
"""
license_key = await self.get_company_limits(company_id)
if not license_key:
# No license key, use defaults (unlimited)
return True, ""
employees = await self.list_users(account_type="employee", company_id=company_id)
if len(employees) >= license_key.max_employees:
return False, f"Maximale Mitarbeiteranzahl erreicht ({license_key.max_employees}). Bitte Lizenz upgraden."
return True, ""
# Global database client instance
db = DatabaseClient()