Files
Onyva-Postling/src/database/client.py
2026-04-02 10:39:07 +02:00

1830 lines
76 KiB
Python

"""Supabase database client."""
import asyncio
from datetime import datetime, date, timezone
from typing import Optional, List, Dict, Any
from uuid import UUID
from supabase import create_client, Client
from loguru import logger
from src.config import settings
from src.database.models import (
LinkedInProfile, LinkedInPost, Topic,
ProfileAnalysis, ResearchResult, GeneratedPost, PostType,
User, Profile, Company, Invitation, ExamplePost, ReferenceProfile,
ApiUsageLog, LicenseKey, CompanyDailyQuota, LicenseKeyOffer
)
from src.services.cache_service import (
cache,
PROFILE_TTL, USER_TTL, LINKEDIN_ACCOUNT_TTL, PROFILE_ANALYSIS_TTL,
COMPANY_TTL, LINKEDIN_POSTS_TTL, POST_TYPES_TTL, POST_TYPE_TTL,
GEN_POST_TTL, GEN_POSTS_TTL,
)
class DatabaseClient:
"""Supabase database client wrapper."""
def __init__(self):
"""Initialize Supabase client."""
if settings.supabase_service_role_key:
self.client: Client = create_client(
settings.supabase_url,
settings.supabase_service_role_key
)
logger.info("Supabase client initialized with service role key (fast mode)")
else:
self.client: Client = create_client(
settings.supabase_url,
settings.supabase_key
)
logger.warning("Supabase client initialized with anon key (slow mode - RLS enabled)")
self.admin_client: Optional[Client] = self.client if settings.supabase_service_role_key else None
# ==================== LINKEDIN PROFILES ====================
async def save_linkedin_profile(self, profile: LinkedInProfile) -> LinkedInProfile:
"""Save or update LinkedIn profile."""
data = profile.model_dump(exclude={"id", "scraped_at"}, exclude_none=True)
if "user_id" in data:
data["user_id"] = str(data["user_id"])
existing = await asyncio.to_thread(
lambda: self.client.table("linkedin_profiles").select("*").eq(
"user_id", str(profile.user_id)
).execute()
)
if existing.data:
result = await asyncio.to_thread(
lambda: self.client.table("linkedin_profiles").update(data).eq(
"user_id", str(profile.user_id)
).execute()
)
else:
result = await asyncio.to_thread(
lambda: self.client.table("linkedin_profiles").insert(data).execute()
)
logger.info(f"Saved LinkedIn profile for user: {profile.user_id}")
return LinkedInProfile(**result.data[0])
async def get_linkedin_profile(self, user_id: UUID) -> Optional[LinkedInProfile]:
"""Get LinkedIn profile for user."""
result = await asyncio.to_thread(
lambda: self.client.table("linkedin_profiles").select("*").eq(
"user_id", str(user_id)
).execute()
)
if result.data:
return LinkedInProfile(**result.data[0])
return None
# ==================== LINKEDIN POSTS ====================
async def save_linkedin_posts(self, posts: List[LinkedInPost]) -> List[LinkedInPost]:
"""Save LinkedIn posts (bulk)."""
seen = set()
unique_posts = []
for p in posts:
key = (str(p.user_id), p.post_url)
if key not in seen:
seen.add(key)
unique_posts.append(p)
if len(posts) != len(unique_posts):
logger.warning(f"Removed {len(posts) - len(unique_posts)} duplicate posts from batch")
data = []
for p in unique_posts:
# Use JSON mode so UUIDs/datetimes are serialized before the Supabase client
# builds its request payload.
post_dict = p.model_dump(mode="json", exclude={"id", "scraped_at"}, exclude_none=True)
data.append(post_dict)
if not data:
logger.warning("No posts to save")
return []
result = await asyncio.to_thread(
lambda: self.client.table("linkedin_posts").upsert(
data,
on_conflict="user_id,post_url"
).execute()
)
logger.info(f"Saved {len(result.data)} LinkedIn posts")
saved = [LinkedInPost(**item) for item in result.data]
# Invalidate cache for all affected users
affected_user_ids = {str(p.user_id) for p in saved}
for uid in affected_user_ids:
await cache.invalidate_linkedin_posts(uid)
return saved
async def get_linkedin_posts(self, user_id: UUID) -> List[LinkedInPost]:
"""Get all LinkedIn posts for user."""
key = cache.linkedin_posts_key(str(user_id))
if (hit := await cache.get(key)) is not None:
return [LinkedInPost(**item) for item in hit]
result = await asyncio.to_thread(
lambda: self.client.table("linkedin_posts").select("*").eq(
"user_id", str(user_id)
).order("post_date", desc=True).execute()
)
posts = [LinkedInPost(**item) for item in result.data]
await cache.set(key, [p.model_dump(mode="json") for p in posts], LINKEDIN_POSTS_TTL)
return posts
async def copy_posts_to_user(self, source_user_id: UUID, target_user_id: UUID) -> List[LinkedInPost]:
"""Copy all LinkedIn posts from one user to another."""
source_posts = await self.get_linkedin_posts(source_user_id)
if not source_posts:
return []
new_posts = []
for p in source_posts:
new_post = LinkedInPost(
user_id=target_user_id,
post_url=p.post_url,
post_text=p.post_text,
post_date=p.post_date,
likes=p.likes,
comments=p.comments,
shares=p.shares,
raw_data=p.raw_data,
)
new_posts.append(new_post)
saved = await self.save_linkedin_posts(new_posts)
logger.info(f"Copied {len(saved)} posts from user {source_user_id} to {target_user_id}")
return saved
async def delete_linkedin_post(self, post_id: UUID) -> None:
"""Delete a LinkedIn post."""
# Fetch user_id before deleting so we can invalidate the cache
lookup = await asyncio.to_thread(
lambda: self.client.table("linkedin_posts").select("user_id").eq("id", str(post_id)).execute()
)
user_id = lookup.data[0]["user_id"] if lookup.data else None
await asyncio.to_thread(
lambda: self.client.table("linkedin_posts").delete().eq("id", str(post_id)).execute()
)
if user_id:
await cache.invalidate_linkedin_posts(user_id)
logger.info(f"Deleted LinkedIn post: {post_id}")
async def get_linkedin_post(self, post_id: UUID) -> Optional[LinkedInPost]:
"""Get a single LinkedIn post by ID."""
result = await asyncio.to_thread(
lambda: self.client.table("linkedin_posts").select("*").eq("id", str(post_id)).limit(1).execute()
)
if not result.data:
return None
return LinkedInPost(**result.data[0])
async def get_unclassified_posts(self, user_id: UUID) -> List[LinkedInPost]:
"""Get all LinkedIn posts without a post_type_id."""
result = await asyncio.to_thread(
lambda: self.client.table("linkedin_posts").select("*").eq(
"user_id", str(user_id)
).is_("post_type_id", "null").execute()
)
return [LinkedInPost(**item) for item in result.data]
async def get_posts_by_type(self, user_id: UUID, post_type_id: UUID) -> List[LinkedInPost]:
"""Get all LinkedIn posts for a specific post type."""
result = await asyncio.to_thread(
lambda: self.client.table("linkedin_posts").select("*").eq(
"user_id", str(user_id)
).eq("post_type_id", str(post_type_id)).order("post_date", desc=True).execute()
)
return [LinkedInPost(**item) for item in result.data]
# ==================== LINKEDIN POST INSIGHTS ====================
async def upsert_post_insights_posts(self, posts: List['LinkedInPostInsightPost']) -> List['LinkedInPostInsightPost']:
"""Upsert LinkedIn post insights posts."""
from datetime import datetime
from src.database.models import LinkedInPostInsightPost
data = []
for p in posts:
post_dict = p.model_dump(exclude={"id", "created_at", "updated_at", "first_seen_at"}, exclude_none=True)
post_dict["user_id"] = str(post_dict["user_id"])
if post_dict.get("linkedin_account_id"):
post_dict["linkedin_account_id"] = str(post_dict["linkedin_account_id"])
if "post_date" in post_dict and isinstance(post_dict["post_date"], datetime):
post_dict["post_date"] = post_dict["post_date"].isoformat()
data.append(post_dict)
if not data:
return []
result = await asyncio.to_thread(
lambda: self.client.table("linkedin_post_insights_posts").upsert(
data,
on_conflict="user_id,post_urn"
).execute()
)
return [LinkedInPostInsightPost(**item) for item in result.data]
async def upsert_post_insights_daily(self, snapshots: List['LinkedInPostInsightDaily']) -> List['LinkedInPostInsightDaily']:
"""Upsert daily snapshots for post insights."""
from src.database.models import LinkedInPostInsightDaily
data = []
for s in snapshots:
snap_dict = s.model_dump(exclude={"id", "created_at", "updated_at"}, exclude_none=True)
snap_dict["user_id"] = str(snap_dict["user_id"])
snap_dict["post_id"] = str(snap_dict["post_id"])
if "snapshot_date" in snap_dict and hasattr(snap_dict["snapshot_date"], "isoformat"):
snap_dict["snapshot_date"] = snap_dict["snapshot_date"].isoformat()
data.append(snap_dict)
if not data:
return []
result = await asyncio.to_thread(
lambda: self.client.table("linkedin_post_insights_daily").upsert(
data,
on_conflict="post_id,snapshot_date"
).execute()
)
return [LinkedInPostInsightDaily(**item) for item in result.data]
async def get_post_insights_posts(self, user_id: UUID) -> List['LinkedInPostInsightPost']:
"""Get all LinkedIn post insights posts for user."""
from src.database.models import LinkedInPostInsightPost
result = await asyncio.to_thread(
lambda: self.client.table("linkedin_post_insights_posts").select("*").eq(
"user_id", str(user_id)
).order("post_date", desc=True).execute()
)
return [LinkedInPostInsightPost(**item) for item in result.data]
async def get_post_insights_daily(self, user_id: UUID, since_date: Optional[str] = None) -> List['LinkedInPostInsightDaily']:
"""Get daily post insights snapshots for user."""
from src.database.models import LinkedInPostInsightDaily
def _query():
q = self.client.table("linkedin_post_insights_daily").select("*").eq("user_id", str(user_id))
if since_date:
q = q.gte("snapshot_date", since_date)
return q.order("snapshot_date", desc=False).execute()
result = await asyncio.to_thread(_query)
return [LinkedInPostInsightDaily(**item) for item in result.data]
async def update_post_classification(
self,
post_id: UUID,
post_type_id: UUID,
classification_method: str,
classification_confidence: float
) -> None:
"""Update a single post's classification."""
await asyncio.to_thread(
lambda: self.client.table("linkedin_posts").update({
"post_type_id": str(post_type_id),
"classification_method": classification_method,
"classification_confidence": classification_confidence
}).eq("id", str(post_id)).execute()
)
logger.debug(f"Updated classification for post {post_id}")
async def update_linkedin_post(self, post_id: UUID, updates: Dict[str, Any]) -> None:
"""Update a LinkedIn post with arbitrary fields."""
if "post_type_id" in updates and updates["post_type_id"]:
updates["post_type_id"] = str(updates["post_type_id"])
if "user_id" in updates and updates["user_id"]:
updates["user_id"] = str(updates["user_id"])
result = await asyncio.to_thread(
lambda: self.client.table("linkedin_posts").update(updates).eq("id", str(post_id)).execute()
)
if result.data:
await cache.invalidate_linkedin_posts(result.data[0]["user_id"])
logger.debug(f"Updated LinkedIn post {post_id}")
async def update_posts_classification_bulk(
self,
classifications: List[Dict[str, Any]]
) -> int:
"""Bulk update post classifications."""
count = 0
for classification in classifications:
try:
await asyncio.to_thread(
lambda c=classification: self.client.table("linkedin_posts").update({
"post_type_id": str(c["post_type_id"]),
"classification_method": c["classification_method"],
"classification_confidence": c["classification_confidence"]
}).eq("id", str(c["post_id"])).execute()
)
count += 1
except Exception as e:
logger.warning(f"Failed to update classification for post {classification['post_id']}: {e}")
logger.info(f"Bulk updated classifications for {count} posts")
return count
# ==================== POST TYPES ====================
async def create_post_type(self, post_type: PostType) -> PostType:
"""Create a new post type."""
data = post_type.model_dump(exclude={"id", "created_at", "updated_at"}, exclude_none=True)
if "user_id" in data:
data["user_id"] = str(data["user_id"])
result = await asyncio.to_thread(
lambda: self.client.table("post_types").insert(data).execute()
)
logger.info(f"Created post type: {result.data[0]['name']}")
created = PostType(**result.data[0])
await cache.invalidate_post_types(str(created.user_id))
return created
async def create_post_types_bulk(self, post_types: List[PostType]) -> List[PostType]:
"""Create multiple post types at once."""
if not post_types:
return []
data = []
for pt in post_types:
pt_dict = pt.model_dump(exclude={"id", "created_at", "updated_at"}, exclude_none=True)
if "user_id" in pt_dict:
pt_dict["user_id"] = str(pt_dict["user_id"])
data.append(pt_dict)
result = await asyncio.to_thread(
lambda: self.client.table("post_types").insert(data).execute()
)
logger.info(f"Created {len(result.data)} post types")
created = [PostType(**item) for item in result.data]
affected_user_ids = {str(pt.user_id) for pt in created}
for uid in affected_user_ids:
await cache.invalidate_post_types(uid)
return created
async def get_post_types(self, user_id: UUID, active_only: bool = True) -> List[PostType]:
"""Get all post types for a user."""
key = cache.post_types_key(str(user_id), active_only)
if (hit := await cache.get(key)) is not None:
return [PostType(**item) for item in hit]
def _query():
query = self.client.table("post_types").select("*").eq("user_id", str(user_id))
if active_only:
query = query.eq("is_active", True)
return query.order("name").execute()
result = await asyncio.to_thread(_query)
post_types = [PostType(**item) for item in result.data]
await cache.set(key, [pt.model_dump(mode="json") for pt in post_types], POST_TYPES_TTL)
return post_types
# Alias for get_post_types
async def get_customer_post_types(self, user_id: UUID, active_only: bool = True) -> List[PostType]:
"""Alias for get_post_types."""
return await self.get_post_types(user_id, active_only)
async def get_post_type(self, post_type_id: UUID) -> Optional[PostType]:
"""Get a single post type by ID."""
key = cache.post_type_key(str(post_type_id))
if (hit := await cache.get(key)) is not None:
return PostType(**hit)
result = await asyncio.to_thread(
lambda: self.client.table("post_types").select("*").eq(
"id", str(post_type_id)
).execute()
)
if result.data:
pt = PostType(**result.data[0])
await cache.set(key, pt.model_dump(mode="json"), POST_TYPE_TTL)
return pt
return None
async def update_post_type(self, post_type_id: UUID, updates: Dict[str, Any]) -> PostType:
"""Update a post type."""
result = await asyncio.to_thread(
lambda: self.client.table("post_types").update(updates).eq(
"id", str(post_type_id)
).execute()
)
logger.info(f"Updated post type: {post_type_id}")
pt = PostType(**result.data[0])
await cache.invalidate_post_type(str(post_type_id), str(pt.user_id))
return pt
async def update_post_type_analysis(
self,
post_type_id: UUID,
analysis: Dict[str, Any],
analyzed_post_count: int
) -> PostType:
"""Update the analysis for a post type."""
from datetime import datetime
result = await asyncio.to_thread(
lambda: self.client.table("post_types").update({
"analysis": analysis,
"analysis_generated_at": datetime.now(timezone.utc).isoformat(),
"analyzed_post_count": analyzed_post_count
}).eq("id", str(post_type_id)).execute()
)
logger.info(f"Updated analysis for post type: {post_type_id}")
pt = PostType(**result.data[0])
await cache.invalidate_post_type(str(post_type_id), str(pt.user_id))
return pt
async def delete_post_type(self, post_type_id: UUID, soft: bool = True) -> None:
"""Delete a post type (soft delete by default)."""
if soft:
result = await asyncio.to_thread(
lambda: self.client.table("post_types").update({
"is_active": False
}).eq("id", str(post_type_id)).execute()
)
logger.info(f"Soft deleted post type: {post_type_id}")
if result.data:
await cache.invalidate_post_type(str(post_type_id), result.data[0]["user_id"])
else:
# Fetch user_id before hard delete for cache invalidation
lookup = await asyncio.to_thread(
lambda: self.client.table("post_types").select("user_id").eq(
"id", str(post_type_id)
).execute()
)
user_id = lookup.data[0]["user_id"] if lookup.data else None
await asyncio.to_thread(
lambda: self.client.table("post_types").delete().eq(
"id", str(post_type_id)
).execute()
)
if user_id:
await cache.invalidate_post_type(str(post_type_id), user_id)
logger.info(f"Hard deleted post type: {post_type_id}")
# ==================== TOPICS ====================
async def save_topics(self, topics: List[Topic]) -> List[Topic]:
"""Save extracted topics."""
if not topics:
logger.warning("No topics to save")
return []
data = []
for t in topics:
topic_dict = t.model_dump(exclude={"id", "created_at"}, exclude_none=True)
if "user_id" in topic_dict:
topic_dict["user_id"] = str(topic_dict["user_id"])
if "extracted_from_post_id" in topic_dict and topic_dict["extracted_from_post_id"]:
topic_dict["extracted_from_post_id"] = str(topic_dict["extracted_from_post_id"])
if "target_post_type_id" in topic_dict and topic_dict["target_post_type_id"]:
topic_dict["target_post_type_id"] = str(topic_dict["target_post_type_id"])
data.append(topic_dict)
try:
result = await asyncio.to_thread(
lambda: self.client.table("topics").insert(data).execute()
)
logger.info(f"Saved {len(result.data)} topics to database")
return [Topic(**item) for item in result.data]
except Exception as e:
logger.error(f"Error saving topics: {e}", exc_info=True)
saved = []
for topic_data in data:
try:
result = await asyncio.to_thread(
lambda td=topic_data: self.client.table("topics").insert(td).execute()
)
saved.extend([Topic(**item) for item in result.data])
except Exception as single_error:
logger.warning(f"Skipped duplicate topic: {topic_data.get('title')}")
logger.info(f"Saved {len(saved)} topics individually")
return saved
async def get_topics(
self,
user_id: UUID,
unused_only: bool = False,
post_type_id: Optional[UUID] = None
) -> List[Topic]:
"""Get topics for user, optionally filtered by post type."""
def _query():
query = self.client.table("topics").select("*").eq("user_id", str(user_id))
if unused_only:
query = query.eq("is_used", False)
if post_type_id:
query = query.eq("target_post_type_id", str(post_type_id))
return query.order("created_at", desc=True).execute()
result = await asyncio.to_thread(_query)
return [Topic(**item) for item in result.data]
async def mark_topic_used(self, topic_id: UUID) -> None:
"""Mark topic as used."""
await asyncio.to_thread(
lambda: self.client.table("topics").update({
"is_used": True,
"used_at": "now()"
}).eq("id", str(topic_id)).execute()
)
logger.info(f"Marked topic {topic_id} as used")
# ==================== PROFILE ANALYSIS ====================
async def save_profile_analysis(self, analysis: ProfileAnalysis) -> ProfileAnalysis:
"""Save profile analysis."""
data = analysis.model_dump(exclude={"id", "created_at"}, exclude_none=True)
if "user_id" in data:
data["user_id"] = str(data["user_id"])
existing = await asyncio.to_thread(
lambda: self.client.table("profile_analyses").select("*").eq(
"user_id", str(analysis.user_id)
).execute()
)
if existing.data:
result = await asyncio.to_thread(
lambda: self.client.table("profile_analyses").update(data).eq(
"user_id", str(analysis.user_id)
).execute()
)
else:
result = await asyncio.to_thread(
lambda: self.client.table("profile_analyses").insert(data).execute()
)
logger.info(f"Saved profile analysis for user: {analysis.user_id}")
saved = ProfileAnalysis(**result.data[0])
await cache.delete(cache.profile_analysis_key(str(analysis.user_id)))
return saved
async def get_profile_analysis(self, user_id: UUID) -> Optional[ProfileAnalysis]:
"""Get profile analysis for user."""
key = cache.profile_analysis_key(str(user_id))
if (hit := await cache.get(key)) is not None:
return ProfileAnalysis(**hit)
result = await asyncio.to_thread(
lambda: self.client.table("profile_analyses").select("*").eq(
"user_id", str(user_id)
).execute()
)
if result.data:
pa = ProfileAnalysis(**result.data[0])
await cache.set(key, pa.model_dump(mode="json"), PROFILE_ANALYSIS_TTL)
return pa
return None
# ==================== RESEARCH RESULTS ====================
async def save_research_result(self, research: ResearchResult) -> ResearchResult:
"""Save research result."""
data = research.model_dump(exclude={"id", "created_at"}, exclude_none=True)
if "user_id" in data:
data["user_id"] = str(data["user_id"])
if "target_post_type_id" in data and data["target_post_type_id"]:
data["target_post_type_id"] = str(data["target_post_type_id"])
result = await asyncio.to_thread(
lambda: self.client.table("research_results").insert(data).execute()
)
logger.info(f"Saved research result for user: {research.user_id}")
return ResearchResult(**result.data[0])
async def get_latest_research(self, user_id: UUID) -> Optional[ResearchResult]:
"""Get latest research result for user."""
result = await asyncio.to_thread(
lambda: self.client.table("research_results").select("*").eq(
"user_id", str(user_id)
).order("created_at", desc=True).limit(1).execute()
)
if result.data:
return ResearchResult(**result.data[0])
return None
async def get_all_research(
self,
user_id: UUID,
post_type_id: Optional[UUID] = None
) -> List[ResearchResult]:
"""Get all research results for user, optionally filtered by post type."""
def _query():
query = self.client.table("research_results").select("*").eq(
"user_id", str(user_id)
)
if post_type_id:
query = query.eq("target_post_type_id", str(post_type_id))
return query.order("created_at", desc=True).execute()
result = await asyncio.to_thread(_query)
return [ResearchResult(**item) for item in result.data]
# ==================== GENERATED POSTS ====================
async def save_generated_post(self, post: GeneratedPost) -> GeneratedPost:
"""Save generated post."""
data = post.model_dump(exclude={"id", "created_at"}, exclude_none=True)
if "user_id" in data:
data["user_id"] = str(data["user_id"])
if "topic_id" in data and data["topic_id"]:
data["topic_id"] = str(data["topic_id"])
if "post_type_id" in data and data["post_type_id"]:
data["post_type_id"] = str(data["post_type_id"])
result = await asyncio.to_thread(
lambda: self.client.table("generated_posts").insert(data).execute()
)
logger.info(f"Saved generated post: {result.data[0]['id']}")
saved = GeneratedPost(**result.data[0])
await cache.invalidate_gen_posts(str(saved.user_id))
return saved
async def update_generated_post(self, post_id: UUID, updates: Dict[str, Any]) -> GeneratedPost:
"""Update generated post."""
# Convert datetime fields to isoformat strings
datetime_fields = ['published_at', 'approved_at', 'rejected_at', 'scheduled_at']
for field in datetime_fields:
if field in updates and isinstance(updates[field], datetime):
updates[field] = updates[field].isoformat()
# Handle metadata dict - ensure all nested datetime values are serialized
if 'metadata' in updates and isinstance(updates['metadata'], dict):
serialized_metadata = {}
for k, value in updates['metadata'].items():
serialized_metadata[k] = value.isoformat() if isinstance(value, datetime) else value
updates['metadata'] = serialized_metadata
result = await asyncio.to_thread(
lambda: self.client.table("generated_posts").update(updates).eq(
"id", str(post_id)
).execute()
)
logger.info(f"Updated generated post: {post_id}")
updated = GeneratedPost(**result.data[0])
await cache.invalidate_gen_post(str(post_id), str(updated.user_id))
return updated
async def get_generated_posts(self, user_id: UUID) -> List[GeneratedPost]:
"""Get all generated posts for user."""
key = cache.gen_posts_key(str(user_id))
if (hit := await cache.get(key)) is not None:
return [GeneratedPost(**item) for item in hit]
result = await asyncio.to_thread(
lambda: self.client.table("generated_posts").select("*").eq(
"user_id", str(user_id)
).order("created_at", desc=True).execute()
)
posts = [GeneratedPost(**item) for item in result.data]
await cache.set(key, [p.model_dump(mode="json") for p in posts], GEN_POSTS_TTL)
return posts
async def get_generated_post(self, post_id: UUID) -> Optional[GeneratedPost]:
"""Get a single generated post by ID."""
key = cache.gen_post_key(str(post_id))
if (hit := await cache.get(key)) is not None:
return GeneratedPost(**hit)
result = await asyncio.to_thread(
lambda: self.client.table("generated_posts").select("*").eq(
"id", str(post_id)
).execute()
)
if result.data:
post = GeneratedPost(**result.data[0])
await cache.set(key, post.model_dump(mode="json"), GEN_POST_TTL)
return post
return None
async def get_scheduled_posts_due(self) -> List[GeneratedPost]:
"""Get all posts that are scheduled and due for publishing."""
from datetime import datetime
now = datetime.now(timezone.utc).isoformat()
result = await asyncio.to_thread(
lambda: self.client.table("generated_posts").select("*").eq(
"status", "scheduled"
).lte("scheduled_at", now).execute()
)
return [GeneratedPost(**item) for item in result.data]
async def get_scheduled_posts_for_company(self, company_id: UUID) -> List[Dict[str, Any]]:
"""Get all scheduled posts for a company with employee info."""
# Get all profiles (employees) belonging to this company
profiles_result = await asyncio.to_thread(
lambda: self.client.table("profiles").select("id, display_name").eq(
"company_id", str(company_id)
).execute()
)
if not profiles_result.data:
return []
user_ids = [p["id"] for p in profiles_result.data]
profile_map = {p["id"]: p for p in profiles_result.data}
posts_result = await asyncio.to_thread(
lambda: self.client.table("generated_posts").select(
"id, user_id, topic_title, post_content, status, scheduled_at, created_at"
).in_(
"user_id", user_ids
).in_(
"status", ["ready", "scheduled", "published"]
).order("scheduled_at", desc=False, nullsfirst=False).execute()
)
enriched_posts = []
for post in posts_result.data:
profile = profile_map.get(post["user_id"])
enriched_posts.append({
**post,
"employee_name": profile["display_name"] if profile else "Unknown",
"employee_user_id": post["user_id"]
})
return enriched_posts
async def schedule_post(self, post_id: UUID, scheduled_at: datetime, scheduled_by_user_id: UUID) -> GeneratedPost:
"""Schedule a post for publishing."""
from datetime import datetime as dt
updates = {
"status": "scheduled",
"scheduled_at": scheduled_at.isoformat() if isinstance(scheduled_at, dt) else scheduled_at,
"scheduled_by_user_id": str(scheduled_by_user_id)
}
return await self.update_generated_post(post_id, updates)
async def unschedule_post(self, post_id: UUID) -> GeneratedPost:
"""Remove scheduling from a post (back to approved status)."""
updates = {
"status": "approved",
"scheduled_at": None,
"scheduled_by_user_id": None
}
return await self.update_generated_post(post_id, updates)
# ==================== PROFILES / USERS ====================
async def create_profile(self, user_id: UUID, profile: Profile) -> Profile:
"""Create a profile for a user."""
data = profile.model_dump(exclude={"created_at", "updated_at"}, exclude_none=True)
data["id"] = str(user_id)
if "company_id" in data and data["company_id"]:
data["company_id"] = str(data["company_id"])
if "account_type" in data:
data["account_type"] = data["account_type"].value if hasattr(data["account_type"], "value") else data["account_type"]
if "onboarding_status" in data:
data["onboarding_status"] = data["onboarding_status"].value if hasattr(data["onboarding_status"], "value") else data["onboarding_status"]
result = await asyncio.to_thread(
lambda: self.client.table("profiles").upsert(data).execute()
)
logger.info(f"Created/updated profile: {result.data[0]['id']}")
return Profile(**result.data[0])
async def get_profile(self, user_id: UUID) -> Optional[Profile]:
"""Get profile by user ID."""
key = cache.profile_key(str(user_id))
if (hit := await cache.get(key)) is not None:
return Profile(**hit)
result = await asyncio.to_thread(
lambda: self.client.table("profiles").select("*").eq("id", str(user_id)).execute()
)
if result.data:
profile = Profile(**result.data[0])
await cache.set(key, profile.model_dump(mode="json"), PROFILE_TTL)
return profile
return None
async def get_profiles_by_linkedin_url(self, linkedin_url: str) -> List[Profile]:
"""Get all profiles with this LinkedIn URL."""
result = await asyncio.to_thread(
lambda: self.client.table("profiles").select("*").eq("linkedin_url", linkedin_url).execute()
)
return [Profile(**item) for item in result.data]
async def update_profile(self, user_id: UUID, updates: Dict[str, Any]) -> Profile:
"""Update profile fields."""
if "company_id" in updates and updates["company_id"]:
updates["company_id"] = str(updates["company_id"])
for k in ["account_type", "onboarding_status"]:
if k in updates and hasattr(updates[k], "value"):
updates[k] = updates[k].value
result = await asyncio.to_thread(
lambda: self.client.table("profiles").update(updates).eq(
"id", str(user_id)
).execute()
)
logger.info(f"Updated profile: {user_id}")
profile = Profile(**result.data[0])
await cache.invalidate_profile(str(user_id))
return profile
# ==================== LINKEDIN ACCOUNTS ====================
async def get_linkedin_account(self, user_id: UUID) -> Optional['LinkedInAccount']:
"""Get LinkedIn account for user."""
from src.database.models import LinkedInAccount
key = cache.linkedin_account_key(str(user_id))
if (hit := await cache.get(key)) is not None:
return LinkedInAccount(**hit)
result = await asyncio.to_thread(
lambda: self.client.table("linkedin_accounts").select("*")
.eq("user_id", str(user_id)).eq("is_active", True).execute()
)
if result.data:
account = LinkedInAccount(**result.data[0])
await cache.set(key, account.model_dump(mode="json"), LINKEDIN_ACCOUNT_TTL)
return account
return None
async def get_linkedin_account_by_id(self, account_id: UUID) -> Optional['LinkedInAccount']:
"""Get LinkedIn account by ID."""
from src.database.models import LinkedInAccount
result = await asyncio.to_thread(
lambda: self.client.table("linkedin_accounts").select("*")
.eq("id", str(account_id)).execute()
)
if result.data:
return LinkedInAccount(**result.data[0])
return None
async def list_linkedin_accounts(self, active_only: bool = True) -> list['LinkedInAccount']:
"""List LinkedIn accounts (optionally only active ones)."""
from src.database.models import LinkedInAccount
def _query():
q = self.client.table("linkedin_accounts").select("*")
if active_only:
q = q.eq("is_active", True)
return q.execute()
result = await asyncio.to_thread(_query)
return [LinkedInAccount(**item) for item in result.data]
async def create_linkedin_account(self, account: 'LinkedInAccount') -> 'LinkedInAccount':
"""Create LinkedIn account connection."""
from src.database.models import LinkedInAccount
data = account.model_dump(exclude={'id', 'created_at', 'updated_at'})
data['user_id'] = str(data['user_id'])
data['token_expires_at'] = data['token_expires_at'].isoformat()
result = await asyncio.to_thread(
lambda: self.client.table("linkedin_accounts").insert(data).execute()
)
logger.info(f"Created LinkedIn account for user: {account.user_id}")
created = LinkedInAccount(**result.data[0])
await cache.invalidate_linkedin_account(str(created.user_id))
return created
async def update_linkedin_account(self, account_id: UUID, updates: Dict) -> 'LinkedInAccount':
"""Update LinkedIn account."""
from src.database.models import LinkedInAccount
# Convert all datetime fields to isoformat strings
datetime_fields = ['token_expires_at', 'last_used_at', 'last_error_at']
for field in datetime_fields:
if field in updates and isinstance(updates[field], datetime):
updates[field] = updates[field].isoformat()
result = await asyncio.to_thread(
lambda: self.client.table("linkedin_accounts").update(updates)
.eq("id", str(account_id)).execute()
)
logger.info(f"Updated LinkedIn account: {account_id}")
updated = LinkedInAccount(**result.data[0])
await cache.invalidate_linkedin_account(str(updated.user_id))
return updated
async def delete_linkedin_account(self, account_id: UUID) -> None:
"""Delete LinkedIn account connection."""
# Fetch user_id before delete for cache invalidation
lookup = await asyncio.to_thread(
lambda: self.client.table("linkedin_accounts").select("user_id")
.eq("id", str(account_id)).execute()
)
user_id = lookup.data[0]["user_id"] if lookup.data else None
await asyncio.to_thread(
lambda: self.client.table("linkedin_accounts").delete()
.eq("id", str(account_id)).execute()
)
if user_id:
await cache.invalidate_linkedin_account(user_id)
logger.info(f"Deleted LinkedIn account: {account_id}")
# ==================== TELEGRAM ACCOUNTS ====================
async def get_telegram_account(self, user_id: UUID) -> Optional['TelegramAccount']:
"""Get Telegram account for user."""
from src.database.models import TelegramAccount
result = await asyncio.to_thread(
lambda: self.client.table("telegram_accounts").select("*")
.eq("user_id", str(user_id)).eq("is_active", True).execute()
)
if result.data:
return TelegramAccount(**result.data[0])
return None
async def get_telegram_account_by_chat_id(self, chat_id: str) -> Optional['TelegramAccount']:
"""Get Telegram account by chat_id."""
from src.database.models import TelegramAccount
result = await asyncio.to_thread(
lambda: self.client.table("telegram_accounts").select("*")
.eq("telegram_chat_id", chat_id).eq("is_active", True).execute()
)
if result.data:
return TelegramAccount(**result.data[0])
return None
async def save_telegram_account(self, account: 'TelegramAccount') -> 'TelegramAccount':
"""Create or update a Telegram account connection."""
from src.database.models import TelegramAccount
data = account.model_dump(exclude={'id', 'created_at', 'updated_at'}, exclude_none=True)
data['user_id'] = str(data['user_id'])
existing = await asyncio.to_thread(
lambda: self.client.table("telegram_accounts").select("id")
.eq("user_id", str(account.user_id)).execute()
)
if existing.data:
result = await asyncio.to_thread(
lambda: self.client.table("telegram_accounts").update(data)
.eq("user_id", str(account.user_id)).execute()
)
else:
result = await asyncio.to_thread(
lambda: self.client.table("telegram_accounts").insert(data).execute()
)
logger.info(f"Saved Telegram account for user: {account.user_id}")
return TelegramAccount(**result.data[0])
async def delete_telegram_account(self, user_id: UUID) -> bool:
"""Delete Telegram account connection for user."""
await asyncio.to_thread(
lambda: self.client.table("telegram_accounts").delete()
.eq("user_id", str(user_id)).execute()
)
logger.info(f"Deleted Telegram account for user: {user_id}")
return True
# ==================== TEAMS ACCOUNTS ====================
async def get_teams_account(self, user_id: UUID) -> Optional['TeamsAccount']:
"""Get Teams account for user."""
from src.database.models import TeamsAccount
result = await asyncio.to_thread(
lambda: self.client.table("teams_accounts").select("*")
.eq("user_id", str(user_id)).eq("is_active", True).execute()
)
if result.data:
return TeamsAccount(**result.data[0])
return None
async def save_teams_account(self, account: 'TeamsAccount') -> 'TeamsAccount':
"""Create or update a Teams account connection."""
from src.database.models import TeamsAccount
data = account.model_dump(exclude={'id', 'created_at', 'updated_at'}, exclude_none=True)
data['user_id'] = str(data['user_id'])
existing = await asyncio.to_thread(
lambda: self.client.table("teams_accounts").select("id")
.eq("user_id", str(account.user_id)).execute()
)
if existing.data:
result = await asyncio.to_thread(
lambda: self.client.table("teams_accounts").update(data)
.eq("user_id", str(account.user_id)).execute()
)
else:
result = await asyncio.to_thread(
lambda: self.client.table("teams_accounts").insert(data).execute()
)
logger.info(f"Saved Teams account for user: {account.user_id}")
return TeamsAccount(**result.data[0])
async def delete_teams_account(self, user_id: UUID) -> bool:
"""Delete Teams account connection for user."""
await asyncio.to_thread(
lambda: self.client.table("teams_accounts").delete()
.eq("user_id", str(user_id)).execute()
)
logger.info(f"Deleted Teams account for user: {user_id}")
return True
# ==================== USERS ====================
async def get_user(self, user_id: UUID) -> Optional[User]:
"""Get user by ID (from users view)."""
key = cache.user_key(str(user_id))
if (hit := await cache.get(key)) is not None:
return User(**hit)
result = await asyncio.to_thread(
lambda: self.client.table("users").select("*").eq("id", str(user_id)).execute()
)
if result.data:
user = User(**result.data[0])
await cache.set(key, user.model_dump(mode="json"), USER_TTL)
return user
return None
async def get_user_by_email(self, email: str) -> Optional[User]:
"""Get user by email (from users view)."""
result = await asyncio.to_thread(
lambda: self.client.table("users").select("*").eq("email", email.lower()).execute()
)
if result.data:
return User(**result.data[0])
return None
async def get_user_by_linkedin_sub(self, linkedin_sub: str) -> Optional[User]:
"""Get user by LinkedIn sub (OAuth identifier)."""
result = await asyncio.to_thread(
lambda: self.client.table("users").select("*").eq("linkedin_sub", linkedin_sub).execute()
)
if result.data:
return User(**result.data[0])
return None
async def update_user(self, user_id: UUID, updates: Dict[str, Any]) -> User:
"""Update user/profile fields."""
profile_fields = ["account_type", "display_name", "onboarding_status",
"onboarding_data", "company_id",
"linkedin_url", "writing_style_notes", "metadata",
"profile_picture", "creator_email", "customer_email", "is_active"]
profile_updates = {k: v for k, v in updates.items() if k in profile_fields}
if profile_updates:
await self.update_profile(user_id, profile_updates)
# update_profile already calls cache.invalidate_profile which also kills user_key
# Invalidate user view separately (in case it wasn't covered above)
await cache.delete(cache.user_key(str(user_id)))
return await self.get_user(user_id)
async def list_users(self, account_type: Optional[str] = None, company_id: Optional[UUID] = None) -> List[User]:
"""List users with optional filters (from users view)."""
def _query():
query = self.client.table("users").select("*")
if account_type:
query = query.eq("account_type", account_type)
if company_id:
query = query.eq("company_id", str(company_id))
return query.order("created_at", desc=True).execute()
result = await asyncio.to_thread(_query)
return [User(**item) for item in result.data]
async def delete_user_completely(self, user_id: UUID) -> bool:
"""Delete a user and all related data completely."""
try:
user_id_str = str(user_id)
# Delete all content data directly via user_id
for table in ["generated_posts", "linkedin_posts", "example_posts",
"reference_profiles", "profile_analyses", "linkedin_profiles",
"post_types", "research_results", "topics"]:
await asyncio.to_thread(
lambda t=table: self.client.table(t).delete().eq(
"user_id", user_id_str
).execute()
)
# Get user email for invitation deletion
user = await self.get_user(user_id)
if user and user.email:
await asyncio.to_thread(
lambda: self.client.table("invitations").delete().eq(
"email", user.email
).execute()
)
# Delete profile
await asyncio.to_thread(
lambda: self.client.table("profiles").delete().eq(
"id", user_id_str
).execute()
)
# Delete from auth.users
if self.admin_client:
await asyncio.to_thread(
lambda: self.admin_client.auth.admin.delete_user(user_id_str)
)
logger.info(f"Deleted auth user {user_id}")
else:
logger.warning(f"Cannot delete auth user {user_id} - no service role key configured")
logger.info(f"Completely deleted user {user_id} and all related data")
return True
except Exception as e:
logger.error(f"Error deleting user {user_id}: {e}")
raise
# ==================== COMPANIES ====================
async def create_company(self, company: Company) -> Company:
"""Create a new company."""
data = company.model_dump(exclude={"id", "created_at", "updated_at"}, exclude_none=True)
if "owner_user_id" in data:
data["owner_user_id"] = str(data["owner_user_id"])
if "license_key_id" in data:
data["license_key_id"] = str(data["license_key_id"])
result = await asyncio.to_thread(
lambda: self.client.table("companies").insert(data).execute()
)
logger.info(f"Created company: {result.data[0]['id']}")
return Company(**result.data[0])
async def get_company(self, company_id: UUID) -> Optional[Company]:
"""Get company by ID."""
key = cache.company_key(str(company_id))
if (hit := await cache.get(key)) is not None:
return Company(**hit)
result = await asyncio.to_thread(
lambda: self.client.table("companies").select("*").eq("id", str(company_id)).execute()
)
if result.data:
company = Company(**result.data[0])
await cache.set(key, company.model_dump(mode="json"), COMPANY_TTL)
return company
return None
async def get_company_by_owner(self, owner_user_id: UUID) -> Optional[Company]:
"""Get company by owner user ID."""
result = await asyncio.to_thread(
lambda: self.client.table("companies").select("*").eq("owner_user_id", str(owner_user_id)).execute()
)
if result.data:
return Company(**result.data[0])
return None
async def get_company_employees(self, company_id: UUID) -> List[User]:
"""Get all employees of a company."""
result = await asyncio.to_thread(
lambda: self.client.table("users").select("*").eq(
"company_id", str(company_id)
).eq("account_type", "employee").order("created_at", desc=True).execute()
)
return [User(**item) for item in result.data]
async def update_company(self, company_id: UUID, updates: Dict[str, Any]) -> Company:
"""Update company fields."""
result = await asyncio.to_thread(
lambda: self.client.table("companies").update(updates).eq(
"id", str(company_id)
).execute()
)
logger.info(f"Updated company: {company_id}")
company = Company(**result.data[0])
await cache.invalidate_company(str(company_id))
return company
async def list_companies(self) -> List[Company]:
"""List all companies."""
result = await asyncio.to_thread(
lambda: self.client.table("companies").select("*").order("created_at", desc=True).execute()
)
return [Company(**item) for item in result.data]
# ==================== INVITATIONS ====================
async def create_invitation(self, invitation: Invitation) -> Invitation:
"""Create a new invitation."""
from datetime import datetime
data = invitation.model_dump(exclude={"id", "created_at"}, exclude_none=True)
for key in ["company_id", "invited_by_user_id", "accepted_by_user_id"]:
if key in data and data[key]:
data[key] = str(data[key])
if "expires_at" in data and isinstance(data["expires_at"], datetime):
data["expires_at"] = data["expires_at"].isoformat()
if "accepted_at" in data and isinstance(data["accepted_at"], datetime):
data["accepted_at"] = data["accepted_at"].isoformat()
if "status" in data and hasattr(data["status"], "value"):
data["status"] = data["status"].value
result = await asyncio.to_thread(
lambda: self.client.table("invitations").insert(data).execute()
)
logger.info(f"Created invitation: {result.data[0]['id']}")
return Invitation(**result.data[0])
async def get_invitation_by_token(self, token: str) -> Optional[Invitation]:
"""Get invitation by token."""
result = await asyncio.to_thread(
lambda: self.client.table("invitations").select("*").eq("token", token).execute()
)
if result.data:
return Invitation(**result.data[0])
return None
async def get_invitation(self, invitation_id: UUID) -> Optional[Invitation]:
"""Get invitation by ID."""
result = await asyncio.to_thread(
lambda: self.client.table("invitations").select("*").eq("id", str(invitation_id)).execute()
)
if result.data:
return Invitation(**result.data[0])
return None
async def update_invitation(self, invitation_id: UUID, updates: Dict[str, Any]) -> Invitation:
"""Update invitation fields."""
from datetime import datetime
if "accepted_by_user_id" in updates and updates["accepted_by_user_id"]:
updates["accepted_by_user_id"] = str(updates["accepted_by_user_id"])
if "accepted_at" in updates and isinstance(updates["accepted_at"], datetime):
updates["accepted_at"] = updates["accepted_at"].isoformat()
if "status" in updates and hasattr(updates["status"], "value"):
updates["status"] = updates["status"].value
result = await asyncio.to_thread(
lambda: self.client.table("invitations").update(updates).eq(
"id", str(invitation_id)
).execute()
)
logger.info(f"Updated invitation: {invitation_id}")
return Invitation(**result.data[0])
async def get_pending_invitations(self, company_id: UUID) -> List[Invitation]:
"""Get all pending invitations for a company."""
result = await asyncio.to_thread(
lambda: self.client.table("invitations").select("*").eq(
"company_id", str(company_id)
).eq("status", "pending").order("created_at", desc=True).execute()
)
return [Invitation(**item) for item in result.data]
async def get_invitations_by_email(self, email: str) -> List[Invitation]:
"""Get all invitations for an email address."""
result = await asyncio.to_thread(
lambda: self.client.table("invitations").select("*").eq(
"email", email.lower()
).order("created_at", desc=True).execute()
)
return [Invitation(**item) for item in result.data]
# ==================== EXAMPLE POSTS ====================
async def save_example_posts(self, posts: List[ExamplePost]) -> List[ExamplePost]:
"""Save example posts (bulk)."""
if not posts:
return []
data = []
for p in posts:
post_dict = p.model_dump(exclude={"id", "created_at"}, exclude_none=True)
if "user_id" in post_dict:
post_dict["user_id"] = str(post_dict["user_id"])
if "post_type_id" in post_dict and post_dict["post_type_id"]:
post_dict["post_type_id"] = str(post_dict["post_type_id"])
data.append(post_dict)
result = await asyncio.to_thread(
lambda: self.client.table("example_posts").insert(data).execute()
)
logger.info(f"Saved {len(result.data)} example posts")
return [ExamplePost(**item) for item in result.data]
async def save_example_post(self, post: ExamplePost) -> ExamplePost:
"""Save a single example post."""
posts = await self.save_example_posts([post])
return posts[0] if posts else post
async def get_example_posts(self, user_id: UUID) -> List[ExamplePost]:
"""Get all example posts for a user."""
result = await asyncio.to_thread(
lambda: self.client.table("example_posts").select("*").eq(
"user_id", str(user_id)
).order("created_at", desc=True).execute()
)
return [ExamplePost(**item) for item in result.data]
async def delete_example_post(self, post_id: UUID) -> None:
"""Delete an example post."""
await asyncio.to_thread(
lambda: self.client.table("example_posts").delete().eq("id", str(post_id)).execute()
)
logger.info(f"Deleted example post: {post_id}")
# ==================== REFERENCE PROFILES ====================
async def create_reference_profile(self, profile: ReferenceProfile) -> ReferenceProfile:
"""Create a new reference profile."""
data = profile.model_dump(exclude={"id", "created_at"}, exclude_none=True)
if "user_id" in data:
data["user_id"] = str(data["user_id"])
result = await asyncio.to_thread(
lambda: self.client.table("reference_profiles").insert(data).execute()
)
logger.info(f"Created reference profile: {result.data[0]['id']}")
return ReferenceProfile(**result.data[0])
async def get_reference_profiles(self, user_id: UUID) -> List[ReferenceProfile]:
"""Get all reference profiles for a user."""
result = await asyncio.to_thread(
lambda: self.client.table("reference_profiles").select("*").eq(
"user_id", str(user_id)
).order("created_at", desc=True).execute()
)
return [ReferenceProfile(**item) for item in result.data]
async def update_reference_profile(self, profile_id: UUID, updates: Dict[str, Any]) -> ReferenceProfile:
"""Update reference profile fields."""
result = await asyncio.to_thread(
lambda: self.client.table("reference_profiles").update(updates).eq(
"id", str(profile_id)
).execute()
)
logger.info(f"Updated reference profile: {profile_id}")
return ReferenceProfile(**result.data[0])
async def delete_reference_profile(self, profile_id: UUID) -> None:
"""Delete a reference profile."""
await asyncio.to_thread(
lambda: self.client.table("reference_profiles").delete().eq("id", str(profile_id)).execute()
)
logger.info(f"Deleted reference profile: {profile_id}")
# ==================== API USAGE LOGS ====================
async def log_api_usage(
self,
provider: str,
model: str,
operation: str,
prompt_tokens: int,
completion_tokens: int,
total_tokens: int,
estimated_cost_usd: float,
user_id: Optional[UUID] = None,
company_id: Optional[UUID] = None
) -> None:
"""Log an API usage event."""
data = {
"provider": provider,
"model": model,
"operation": operation,
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": total_tokens,
"estimated_cost_usd": estimated_cost_usd
}
if user_id:
data["user_id"] = str(user_id)
if company_id:
data["company_id"] = str(company_id)
try:
await asyncio.to_thread(
lambda: self.client.table("api_usage_logs").insert(data).execute()
)
except Exception as e:
logger.warning(f"Failed to log API usage: {e}")
async def get_usage_stats(
self,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
user_id: Optional[UUID] = None,
company_id: Optional[UUID] = None
) -> List[Dict[str, Any]]:
"""Get aggregated usage stats."""
def _query():
query = self.client.table("api_usage_logs").select("*")
if user_id:
query = query.eq("user_id", str(user_id))
if company_id:
query = query.eq("company_id", str(company_id))
if start_date:
query = query.gte("created_at", start_date)
if end_date:
query = query.lte("created_at", end_date)
return query.order("created_at", desc=True).execute()
result = await asyncio.to_thread(_query)
return result.data
async def get_usage_by_day(
self,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
user_id: Optional[UUID] = None,
company_id: Optional[UUID] = None
) -> List[Dict[str, Any]]:
"""Get daily usage breakdown."""
logs = await self.get_usage_stats(start_date, end_date, user_id, company_id)
from collections import defaultdict
daily = defaultdict(lambda: {"total_tokens": 0, "estimated_cost_usd": 0.0, "count": 0})
for log in logs:
day = log["created_at"][:10]
daily[day]["total_tokens"] += log.get("total_tokens", 0)
daily[day]["estimated_cost_usd"] += float(log.get("estimated_cost_usd", 0))
daily[day]["count"] += 1
return [{"date": day, **data} for day, data in sorted(daily.items())]
async def get_post_stats(
self,
start_date: Optional[str] = None,
user_id: Optional[UUID] = None,
company_id: Optional[UUID] = None
) -> List[Dict[str, Any]]:
"""Get generated posts with select fields for statistics."""
# Resolve user_ids for filtering
target_user_ids = None
if user_id:
target_user_ids = [str(user_id)]
if company_id:
# Get all users belonging to this company
profiles_result = await asyncio.to_thread(
lambda: self.client.table("profiles").select("id").eq(
"company_id", str(company_id)
).execute()
)
company_user_ids = [p["id"] for p in profiles_result.data]
if target_user_ids is not None:
target_user_ids = [u for u in target_user_ids if u in company_user_ids]
else:
target_user_ids = company_user_ids
if not target_user_ids:
return []
def _query():
q = self.client.table("generated_posts").select(
"id, created_at, status, approved_at, published_at, user_id"
)
if target_user_ids is not None:
q = q.in_("user_id", target_user_ids)
if start_date:
q = q.gte("created_at", start_date)
return q.order("created_at", desc=True).execute()
result = await asyncio.to_thread(_query)
return result.data
# ==================== LICENSE KEYS ====================
async def list_license_keys(self) -> List[LicenseKey]:
"""Get all license keys (admin only)."""
result = await asyncio.to_thread(
lambda: self.client.table("license_keys")
.select("*")
.order("created_at", desc=True)
.execute()
)
return [LicenseKey(**row) for row in result.data]
async def get_license_key(self, key: str) -> Optional[LicenseKey]:
"""Get license key by key string."""
result = await asyncio.to_thread(
lambda: self.client.table("license_keys")
.select("*")
.eq("key", key)
.execute()
)
return LicenseKey(**result.data[0]) if result.data else None
async def create_license_key(
self,
key: str,
max_employees: int,
daily_token_limit: Optional[int] = None,
description: Optional[str] = None
) -> LicenseKey:
"""Create new license key."""
data = {
"key": key,
"max_employees": max_employees,
"description": description,
}
if daily_token_limit is not None:
data["daily_token_limit"] = daily_token_limit
result = await asyncio.to_thread(
lambda: self.client.table("license_keys").insert(data).execute()
)
logger.info(f"Created license key: {key}")
return LicenseKey(**result.data[0])
async def mark_license_key_used(
self,
key: str,
company_id: UUID
) -> LicenseKey:
"""Mark license key as used and link to company."""
data = {
"used": True,
"company_id": str(company_id),
"used_at": datetime.now(timezone.utc).isoformat()
}
result = await asyncio.to_thread(
lambda: self.client.table("license_keys")
.update(data)
.eq("key", key)
.execute()
)
logger.info(f"Marked license key as used: {key}")
return LicenseKey(**result.data[0])
async def get_license_key_by_id(self, key_id: UUID) -> Optional[LicenseKey]:
"""Get license key by UUID."""
result = await asyncio.to_thread(
lambda: self.client.table("license_keys")
.select("*")
.eq("id", str(key_id))
.execute()
)
return LicenseKey(**result.data[0]) if result.data else None
async def update_license_key(self, key_id: UUID, updates: Dict[str, Any]) -> LicenseKey:
"""Update license key limits (admin only)."""
result = await asyncio.to_thread(
lambda: self.client.table("license_keys")
.update(updates)
.eq("id", str(key_id))
.execute()
)
logger.info(f"Updated license key: {key_id}")
return LicenseKey(**result.data[0]) if result.data else None
async def delete_license_key(self, key_id: UUID) -> None:
"""Delete license key (admin only)."""
await asyncio.to_thread(
lambda: self.client.table("license_keys")
.delete()
.eq("id", str(key_id))
.execute()
)
logger.info(f"Deleted license key: {key_id}")
# ==================== LICENSE KEY OFFERS ====================
async def create_license_key_offer(
self,
license_key_id: UUID,
moco_offer_id: int,
moco_offer_identifier: Optional[str],
moco_offer_url: Optional[str],
offer_title: Optional[str],
company_name: Optional[str],
price: Optional[float],
payment_frequency: Optional[str],
) -> "LicenseKeyOffer":
"""Save a MOCO offer linked to a license key."""
data = {
"license_key_id": str(license_key_id),
"moco_offer_id": moco_offer_id,
"moco_offer_identifier": moco_offer_identifier,
"moco_offer_url": moco_offer_url,
"offer_title": offer_title,
"company_name": company_name,
"price": price,
"payment_frequency": payment_frequency,
"status": "draft",
}
result = await asyncio.to_thread(
lambda: self.client.table("license_key_offers").insert(data).execute()
)
return LicenseKeyOffer(**result.data[0])
async def list_license_key_offers(self, license_key_id: UUID) -> list["LicenseKeyOffer"]:
"""List all MOCO offers for a license key."""
result = await asyncio.to_thread(
lambda: self.client.table("license_key_offers")
.select("*")
.eq("license_key_id", str(license_key_id))
.order("created_at", desc=True)
.execute()
)
return [LicenseKeyOffer(**row) for row in result.data]
async def update_license_key_offer_status(
self, offer_id: UUID, status: str
) -> "LicenseKeyOffer":
"""Update the status of a stored offer."""
result = await asyncio.to_thread(
lambda: self.client.table("license_key_offers")
.update({"status": status})
.eq("id", str(offer_id))
.execute()
)
return LicenseKeyOffer(**result.data[0])
# ==================== COMPANY QUOTAS ====================
async def get_company_daily_quota(
self,
company_id: UUID,
date_: Optional[date] = None
) -> CompanyDailyQuota:
"""Get or create daily quota for company."""
if date_ is None:
date_ = date.today()
# Try to get existing
result = await asyncio.to_thread(
lambda: self.client.table("company_daily_quotas")
.select("*")
.eq("company_id", str(company_id))
.eq("date", date_.isoformat())
.execute()
)
if result.data:
return CompanyDailyQuota(**result.data[0])
# Create new quota for today
data = {
"company_id": str(company_id),
"date": date_.isoformat(),
"tokens_used": 0
}
result = await asyncio.to_thread(
lambda: self.client.table("company_daily_quotas")
.insert(data)
.execute()
)
return CompanyDailyQuota(**result.data[0])
async def increment_company_tokens(self, company_id: UUID, tokens: int) -> None:
"""Increment daily token usage for company."""
try:
quota = await self.get_company_daily_quota(company_id)
new_count = quota.tokens_used + tokens
await asyncio.to_thread(
lambda: self.client.table("company_daily_quotas")
.update({"tokens_used": new_count})
.eq("id", str(quota.id))
.execute()
)
logger.info(f"Incremented token quota for company {company_id}: {quota.tokens_used} -> {new_count}")
except Exception as e:
logger.warning(f"Error incrementing token quota for company {company_id}: {e}")
async def get_company_limits(self, company_id: UUID) -> Optional[LicenseKey]:
"""Get company limits from associated license key.
Returns None if no license key is associated (uses defaults).
"""
company = await self.get_company(company_id)
if not company or not company.license_key_id:
return None
result = await asyncio.to_thread(
lambda: self.client.table("license_keys")
.select("*")
.eq("id", str(company.license_key_id))
.execute()
)
return LicenseKey(**result.data[0]) if result.data else None
async def check_company_token_limit(self, company_id: UUID) -> tuple[bool, str, int, int]:
"""Check if company has token budget remaining today.
Returns (can_proceed: bool, error_message: str, tokens_used: int, daily_limit: int)
"""
license_key = await self.get_company_limits(company_id)
if not license_key or license_key.daily_token_limit is None:
return True, "", 0, 0 # no limit = unlimited
quota = await self.get_company_daily_quota(company_id)
if quota.tokens_used >= license_key.daily_token_limit:
return False, f"Tageslimit erreicht ({license_key.daily_token_limit:,} Tokens/Tag). Morgen wieder verfügbar.", quota.tokens_used, license_key.daily_token_limit
return True, "", quota.tokens_used, license_key.daily_token_limit
async def check_company_employee_limit(self, company_id: UUID) -> tuple[bool, str]:
"""Check if company can add more employees.
Returns (can_add: bool, error_message: str)
"""
license_key = await self.get_company_limits(company_id)
if not license_key:
# No license key, use defaults (unlimited)
return True, ""
employees = await self.list_users(account_type="employee", company_id=company_id)
if len(employees) >= license_key.max_employees:
return False, f"Maximale Mitarbeiteranzahl erreicht ({license_key.max_employees}). Bitte Lizenz upgraden."
return True, ""
# ==================== EMAIL ACTION TOKENS ====================
async def create_email_token(self, token: str, post_id: UUID, action: str, expires_hours: int = 72) -> None:
"""Store an email action token in the database."""
from datetime import timedelta
expires_at = datetime.now(timezone.utc) + timedelta(hours=expires_hours)
data = {
"token": token,
"post_id": str(post_id),
"action": action,
"expires_at": expires_at.isoformat(),
"used": False,
}
await asyncio.to_thread(
lambda: self.client.table("email_action_tokens").insert(data).execute()
)
logger.debug(f"Created email token for post {post_id} action={action}")
async def get_email_token(self, token: str) -> Optional[Dict[str, Any]]:
"""Retrieve email token data; returns None if not found."""
result = await asyncio.to_thread(
lambda: self.client.table("email_action_tokens").select("*").eq("token", token).execute()
)
if not result.data:
return None
return result.data[0]
async def mark_email_token_used(self, token: str) -> None:
"""Mark an email token as used."""
await asyncio.to_thread(
lambda: self.client.table("email_action_tokens").update({"used": True}).eq("token", token).execute()
)
async def mark_all_post_tokens_used(self, post_id: UUID) -> None:
"""Mark all email action tokens for a post as used (invalidate approve + reject together)."""
await asyncio.to_thread(
lambda: self.client.table("email_action_tokens").update({"used": True}).eq("post_id", str(post_id)).execute()
)
async def cleanup_expired_email_tokens(self) -> None:
"""Delete expired email tokens from the database."""
now = datetime.now(timezone.utc).isoformat()
result = await asyncio.to_thread(
lambda: self.client.table("email_action_tokens").delete().lt("expires_at", now).execute()
)
count = len(result.data) if result.data else 0
if count:
logger.info(f"Cleaned up {count} expired email tokens")
# ==================== LINKEDIN TOKEN REFRESH ====================
async def get_expiring_linkedin_accounts(self, within_days: int = 7) -> list:
"""Return active LinkedIn accounts whose tokens expire within within_days and have a refresh_token."""
from src.database.models import LinkedInAccount
from datetime import timedelta
cutoff = (datetime.now(timezone.utc) + timedelta(days=within_days)).isoformat()
result = await asyncio.to_thread(
lambda: self.client.table("linkedin_accounts")
.select("*")
.eq("is_active", True)
.lt("token_expires_at", cutoff)
.not_.is_("refresh_token", "null")
.execute()
)
return [LinkedInAccount(**row) for row in result.data]
# ==================== EMPLOYEE COMPANY PERMISSIONS ====================
async def get_employee_permissions(self, user_id: UUID, company_id: UUID):
"""Get permissions for an employee in a company. Returns None if no row exists (treat as all-true defaults)."""
from src.database.models import EmployeeCompanyPermissions
result = await asyncio.to_thread(
lambda: self.client.table("employee_company_permissions").select("*").eq(
"user_id", str(user_id)
).eq("company_id", str(company_id)).execute()
)
if result.data:
return EmployeeCompanyPermissions(**result.data[0])
return None
async def upsert_employee_permissions(self, user_id: UUID, company_id: UUID, updates: Dict[str, Any]) -> None:
"""Insert or update employee-company permissions."""
data = {
"user_id": str(user_id),
"company_id": str(company_id),
**updates
}
await asyncio.to_thread(
lambda: self.client.table("employee_company_permissions").upsert(
data, on_conflict="user_id,company_id"
).execute()
)
logger.info(f"Upserted permissions for user {user_id} in company {company_id}")
async def get_scheduled_posts_for_user(self, user_id: UUID) -> List[GeneratedPost]:
"""Get scheduled/approved/published posts for an employee (for their calendar)."""
result = await asyncio.to_thread(
lambda: self.client.table("generated_posts").select("*").eq(
"user_id", str(user_id)
).in_(
"status", ["approved", "ready", "scheduled", "published"]
).order("scheduled_at", desc=False, nullsfirst=False).execute()
)
return [GeneratedPost(**item) for item in result.data]
# Global database client instance
db = DatabaseClient()