Features: - Add LinkedIn OAuth integration and auto-posting functionality - Add scheduler service for automated post publishing - Add metadata field to generated_posts for LinkedIn URLs - Add privacy policy page for LinkedIn API compliance - Add company management features and employee accounts - Add license key system for company registrations Fixes: - Fix timezone issues (use UTC consistently across app) - Fix datetime serialization errors in database operations - Fix scheduling timezone conversion (local time to UTC) - Fix import errors (get_database -> db) Infrastructure: - Update Docker setup to use port 8001 (avoid conflicts) - Add SSL support with nginx-proxy and Let's Encrypt - Add LinkedIn setup documentation - Add migration scripts for schema updates Services: - Add linkedin_service.py for LinkedIn API integration - Add scheduler_service.py for background job processing - Add storage_service.py for Supabase Storage - Add email_service.py improvements - Add encryption utilities for token storage Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
1397 lines
56 KiB
Python
1397 lines
56 KiB
Python
"""Supabase database client."""
|
|
import asyncio
|
|
from datetime import datetime, date, timezone
|
|
from typing import Optional, List, Dict, Any
|
|
from uuid import UUID
|
|
from supabase import create_client, Client
|
|
from loguru import logger
|
|
|
|
from src.config import settings
|
|
from src.database.models import (
|
|
LinkedInProfile, LinkedInPost, Topic,
|
|
ProfileAnalysis, ResearchResult, GeneratedPost, PostType,
|
|
User, Profile, Company, Invitation, ExamplePost, ReferenceProfile,
|
|
ApiUsageLog, LicenseKey, CompanyDailyQuota
|
|
)
|
|
|
|
|
|
class DatabaseClient:
|
|
"""Supabase database client wrapper."""
|
|
|
|
def __init__(self):
|
|
"""Initialize Supabase client."""
|
|
if settings.supabase_service_role_key:
|
|
self.client: Client = create_client(
|
|
settings.supabase_url,
|
|
settings.supabase_service_role_key
|
|
)
|
|
logger.info("Supabase client initialized with service role key (fast mode)")
|
|
else:
|
|
self.client: Client = create_client(
|
|
settings.supabase_url,
|
|
settings.supabase_key
|
|
)
|
|
logger.warning("Supabase client initialized with anon key (slow mode - RLS enabled)")
|
|
|
|
self.admin_client: Optional[Client] = self.client if settings.supabase_service_role_key else None
|
|
|
|
# ==================== LINKEDIN PROFILES ====================
|
|
|
|
async def save_linkedin_profile(self, profile: LinkedInProfile) -> LinkedInProfile:
|
|
"""Save or update LinkedIn profile."""
|
|
data = profile.model_dump(exclude={"id", "scraped_at"}, exclude_none=True)
|
|
if "user_id" in data:
|
|
data["user_id"] = str(data["user_id"])
|
|
|
|
existing = await asyncio.to_thread(
|
|
lambda: self.client.table("linkedin_profiles").select("*").eq(
|
|
"user_id", str(profile.user_id)
|
|
).execute()
|
|
)
|
|
|
|
if existing.data:
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("linkedin_profiles").update(data).eq(
|
|
"user_id", str(profile.user_id)
|
|
).execute()
|
|
)
|
|
else:
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("linkedin_profiles").insert(data).execute()
|
|
)
|
|
|
|
logger.info(f"Saved LinkedIn profile for user: {profile.user_id}")
|
|
return LinkedInProfile(**result.data[0])
|
|
|
|
async def get_linkedin_profile(self, user_id: UUID) -> Optional[LinkedInProfile]:
|
|
"""Get LinkedIn profile for user."""
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("linkedin_profiles").select("*").eq(
|
|
"user_id", str(user_id)
|
|
).execute()
|
|
)
|
|
if result.data:
|
|
return LinkedInProfile(**result.data[0])
|
|
return None
|
|
|
|
# ==================== LINKEDIN POSTS ====================
|
|
|
|
async def save_linkedin_posts(self, posts: List[LinkedInPost]) -> List[LinkedInPost]:
|
|
"""Save LinkedIn posts (bulk)."""
|
|
from datetime import datetime
|
|
|
|
seen = set()
|
|
unique_posts = []
|
|
for p in posts:
|
|
key = (str(p.user_id), p.post_url)
|
|
if key not in seen:
|
|
seen.add(key)
|
|
unique_posts.append(p)
|
|
|
|
if len(posts) != len(unique_posts):
|
|
logger.warning(f"Removed {len(posts) - len(unique_posts)} duplicate posts from batch")
|
|
|
|
data = []
|
|
for p in unique_posts:
|
|
post_dict = p.model_dump(exclude={"id", "scraped_at"}, exclude_none=True)
|
|
if "user_id" in post_dict:
|
|
post_dict["user_id"] = str(post_dict["user_id"])
|
|
if "post_date" in post_dict and isinstance(post_dict["post_date"], datetime):
|
|
post_dict["post_date"] = post_dict["post_date"].isoformat()
|
|
data.append(post_dict)
|
|
|
|
if not data:
|
|
logger.warning("No posts to save")
|
|
return []
|
|
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("linkedin_posts").upsert(
|
|
data,
|
|
on_conflict="user_id,post_url"
|
|
).execute()
|
|
)
|
|
logger.info(f"Saved {len(result.data)} LinkedIn posts")
|
|
return [LinkedInPost(**item) for item in result.data]
|
|
|
|
async def get_linkedin_posts(self, user_id: UUID) -> List[LinkedInPost]:
|
|
"""Get all LinkedIn posts for user."""
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("linkedin_posts").select("*").eq(
|
|
"user_id", str(user_id)
|
|
).order("post_date", desc=True).execute()
|
|
)
|
|
return [LinkedInPost(**item) for item in result.data]
|
|
|
|
async def copy_posts_to_user(self, source_user_id: UUID, target_user_id: UUID) -> List[LinkedInPost]:
|
|
"""Copy all LinkedIn posts from one user to another."""
|
|
source_posts = await self.get_linkedin_posts(source_user_id)
|
|
if not source_posts:
|
|
return []
|
|
|
|
new_posts = []
|
|
for p in source_posts:
|
|
new_post = LinkedInPost(
|
|
user_id=target_user_id,
|
|
post_url=p.post_url,
|
|
post_text=p.post_text,
|
|
post_date=p.post_date,
|
|
likes=p.likes,
|
|
comments=p.comments,
|
|
shares=p.shares,
|
|
raw_data=p.raw_data,
|
|
)
|
|
new_posts.append(new_post)
|
|
|
|
saved = await self.save_linkedin_posts(new_posts)
|
|
logger.info(f"Copied {len(saved)} posts from user {source_user_id} to {target_user_id}")
|
|
return saved
|
|
|
|
async def delete_linkedin_post(self, post_id: UUID) -> None:
|
|
"""Delete a LinkedIn post."""
|
|
await asyncio.to_thread(
|
|
lambda: self.client.table("linkedin_posts").delete().eq("id", str(post_id)).execute()
|
|
)
|
|
logger.info(f"Deleted LinkedIn post: {post_id}")
|
|
|
|
async def get_unclassified_posts(self, user_id: UUID) -> List[LinkedInPost]:
|
|
"""Get all LinkedIn posts without a post_type_id."""
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("linkedin_posts").select("*").eq(
|
|
"user_id", str(user_id)
|
|
).is_("post_type_id", "null").execute()
|
|
)
|
|
return [LinkedInPost(**item) for item in result.data]
|
|
|
|
async def get_posts_by_type(self, user_id: UUID, post_type_id: UUID) -> List[LinkedInPost]:
|
|
"""Get all LinkedIn posts for a specific post type."""
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("linkedin_posts").select("*").eq(
|
|
"user_id", str(user_id)
|
|
).eq("post_type_id", str(post_type_id)).order("post_date", desc=True).execute()
|
|
)
|
|
return [LinkedInPost(**item) for item in result.data]
|
|
|
|
async def update_post_classification(
|
|
self,
|
|
post_id: UUID,
|
|
post_type_id: UUID,
|
|
classification_method: str,
|
|
classification_confidence: float
|
|
) -> None:
|
|
"""Update a single post's classification."""
|
|
await asyncio.to_thread(
|
|
lambda: self.client.table("linkedin_posts").update({
|
|
"post_type_id": str(post_type_id),
|
|
"classification_method": classification_method,
|
|
"classification_confidence": classification_confidence
|
|
}).eq("id", str(post_id)).execute()
|
|
)
|
|
logger.debug(f"Updated classification for post {post_id}")
|
|
|
|
async def update_linkedin_post(self, post_id: UUID, updates: Dict[str, Any]) -> None:
|
|
"""Update a LinkedIn post with arbitrary fields."""
|
|
if "post_type_id" in updates and updates["post_type_id"]:
|
|
updates["post_type_id"] = str(updates["post_type_id"])
|
|
if "user_id" in updates and updates["user_id"]:
|
|
updates["user_id"] = str(updates["user_id"])
|
|
|
|
await asyncio.to_thread(
|
|
lambda: self.client.table("linkedin_posts").update(updates).eq("id", str(post_id)).execute()
|
|
)
|
|
logger.debug(f"Updated LinkedIn post {post_id}")
|
|
|
|
async def update_posts_classification_bulk(
|
|
self,
|
|
classifications: List[Dict[str, Any]]
|
|
) -> int:
|
|
"""Bulk update post classifications."""
|
|
count = 0
|
|
for classification in classifications:
|
|
try:
|
|
await asyncio.to_thread(
|
|
lambda c=classification: self.client.table("linkedin_posts").update({
|
|
"post_type_id": str(c["post_type_id"]),
|
|
"classification_method": c["classification_method"],
|
|
"classification_confidence": c["classification_confidence"]
|
|
}).eq("id", str(c["post_id"])).execute()
|
|
)
|
|
count += 1
|
|
except Exception as e:
|
|
logger.warning(f"Failed to update classification for post {classification['post_id']}: {e}")
|
|
logger.info(f"Bulk updated classifications for {count} posts")
|
|
return count
|
|
|
|
# ==================== POST TYPES ====================
|
|
|
|
async def create_post_type(self, post_type: PostType) -> PostType:
|
|
"""Create a new post type."""
|
|
data = post_type.model_dump(exclude={"id", "created_at", "updated_at"}, exclude_none=True)
|
|
if "user_id" in data:
|
|
data["user_id"] = str(data["user_id"])
|
|
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("post_types").insert(data).execute()
|
|
)
|
|
logger.info(f"Created post type: {result.data[0]['name']}")
|
|
return PostType(**result.data[0])
|
|
|
|
async def create_post_types_bulk(self, post_types: List[PostType]) -> List[PostType]:
|
|
"""Create multiple post types at once."""
|
|
if not post_types:
|
|
return []
|
|
|
|
data = []
|
|
for pt in post_types:
|
|
pt_dict = pt.model_dump(exclude={"id", "created_at", "updated_at"}, exclude_none=True)
|
|
if "user_id" in pt_dict:
|
|
pt_dict["user_id"] = str(pt_dict["user_id"])
|
|
data.append(pt_dict)
|
|
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("post_types").insert(data).execute()
|
|
)
|
|
logger.info(f"Created {len(result.data)} post types")
|
|
return [PostType(**item) for item in result.data]
|
|
|
|
async def get_post_types(self, user_id: UUID, active_only: bool = True) -> List[PostType]:
|
|
"""Get all post types for a user."""
|
|
def _query():
|
|
query = self.client.table("post_types").select("*").eq("user_id", str(user_id))
|
|
if active_only:
|
|
query = query.eq("is_active", True)
|
|
return query.order("name").execute()
|
|
|
|
result = await asyncio.to_thread(_query)
|
|
return [PostType(**item) for item in result.data]
|
|
|
|
# Alias for get_post_types
|
|
async def get_customer_post_types(self, user_id: UUID, active_only: bool = True) -> List[PostType]:
|
|
"""Alias for get_post_types."""
|
|
return await self.get_post_types(user_id, active_only)
|
|
|
|
async def get_post_type(self, post_type_id: UUID) -> Optional[PostType]:
|
|
"""Get a single post type by ID."""
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("post_types").select("*").eq(
|
|
"id", str(post_type_id)
|
|
).execute()
|
|
)
|
|
if result.data:
|
|
return PostType(**result.data[0])
|
|
return None
|
|
|
|
async def update_post_type(self, post_type_id: UUID, updates: Dict[str, Any]) -> PostType:
|
|
"""Update a post type."""
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("post_types").update(updates).eq(
|
|
"id", str(post_type_id)
|
|
).execute()
|
|
)
|
|
logger.info(f"Updated post type: {post_type_id}")
|
|
return PostType(**result.data[0])
|
|
|
|
async def update_post_type_analysis(
|
|
self,
|
|
post_type_id: UUID,
|
|
analysis: Dict[str, Any],
|
|
analyzed_post_count: int
|
|
) -> PostType:
|
|
"""Update the analysis for a post type."""
|
|
from datetime import datetime
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("post_types").update({
|
|
"analysis": analysis,
|
|
"analysis_generated_at": datetime.now(timezone.utc).isoformat(),
|
|
"analyzed_post_count": analyzed_post_count
|
|
}).eq("id", str(post_type_id)).execute()
|
|
)
|
|
logger.info(f"Updated analysis for post type: {post_type_id}")
|
|
return PostType(**result.data[0])
|
|
|
|
async def delete_post_type(self, post_type_id: UUID, soft: bool = True) -> None:
|
|
"""Delete a post type (soft delete by default)."""
|
|
if soft:
|
|
await asyncio.to_thread(
|
|
lambda: self.client.table("post_types").update({
|
|
"is_active": False
|
|
}).eq("id", str(post_type_id)).execute()
|
|
)
|
|
logger.info(f"Soft deleted post type: {post_type_id}")
|
|
else:
|
|
await asyncio.to_thread(
|
|
lambda: self.client.table("post_types").delete().eq(
|
|
"id", str(post_type_id)
|
|
).execute()
|
|
)
|
|
logger.info(f"Hard deleted post type: {post_type_id}")
|
|
|
|
# ==================== TOPICS ====================
|
|
|
|
async def save_topics(self, topics: List[Topic]) -> List[Topic]:
|
|
"""Save extracted topics."""
|
|
if not topics:
|
|
logger.warning("No topics to save")
|
|
return []
|
|
|
|
data = []
|
|
for t in topics:
|
|
topic_dict = t.model_dump(exclude={"id", "created_at"}, exclude_none=True)
|
|
if "user_id" in topic_dict:
|
|
topic_dict["user_id"] = str(topic_dict["user_id"])
|
|
if "extracted_from_post_id" in topic_dict and topic_dict["extracted_from_post_id"]:
|
|
topic_dict["extracted_from_post_id"] = str(topic_dict["extracted_from_post_id"])
|
|
if "target_post_type_id" in topic_dict and topic_dict["target_post_type_id"]:
|
|
topic_dict["target_post_type_id"] = str(topic_dict["target_post_type_id"])
|
|
data.append(topic_dict)
|
|
|
|
try:
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("topics").insert(data).execute()
|
|
)
|
|
logger.info(f"Saved {len(result.data)} topics to database")
|
|
return [Topic(**item) for item in result.data]
|
|
except Exception as e:
|
|
logger.error(f"Error saving topics: {e}", exc_info=True)
|
|
saved = []
|
|
for topic_data in data:
|
|
try:
|
|
result = await asyncio.to_thread(
|
|
lambda td=topic_data: self.client.table("topics").insert(td).execute()
|
|
)
|
|
saved.extend([Topic(**item) for item in result.data])
|
|
except Exception as single_error:
|
|
logger.warning(f"Skipped duplicate topic: {topic_data.get('title')}")
|
|
logger.info(f"Saved {len(saved)} topics individually")
|
|
return saved
|
|
|
|
async def get_topics(
|
|
self,
|
|
user_id: UUID,
|
|
unused_only: bool = False,
|
|
post_type_id: Optional[UUID] = None
|
|
) -> List[Topic]:
|
|
"""Get topics for user, optionally filtered by post type."""
|
|
def _query():
|
|
query = self.client.table("topics").select("*").eq("user_id", str(user_id))
|
|
if unused_only:
|
|
query = query.eq("is_used", False)
|
|
if post_type_id:
|
|
query = query.eq("target_post_type_id", str(post_type_id))
|
|
return query.order("created_at", desc=True).execute()
|
|
|
|
result = await asyncio.to_thread(_query)
|
|
return [Topic(**item) for item in result.data]
|
|
|
|
async def mark_topic_used(self, topic_id: UUID) -> None:
|
|
"""Mark topic as used."""
|
|
await asyncio.to_thread(
|
|
lambda: self.client.table("topics").update({
|
|
"is_used": True,
|
|
"used_at": "now()"
|
|
}).eq("id", str(topic_id)).execute()
|
|
)
|
|
logger.info(f"Marked topic {topic_id} as used")
|
|
|
|
# ==================== PROFILE ANALYSIS ====================
|
|
|
|
async def save_profile_analysis(self, analysis: ProfileAnalysis) -> ProfileAnalysis:
|
|
"""Save profile analysis."""
|
|
data = analysis.model_dump(exclude={"id", "created_at"}, exclude_none=True)
|
|
if "user_id" in data:
|
|
data["user_id"] = str(data["user_id"])
|
|
|
|
existing = await asyncio.to_thread(
|
|
lambda: self.client.table("profile_analyses").select("*").eq(
|
|
"user_id", str(analysis.user_id)
|
|
).execute()
|
|
)
|
|
|
|
if existing.data:
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("profile_analyses").update(data).eq(
|
|
"user_id", str(analysis.user_id)
|
|
).execute()
|
|
)
|
|
else:
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("profile_analyses").insert(data).execute()
|
|
)
|
|
|
|
logger.info(f"Saved profile analysis for user: {analysis.user_id}")
|
|
return ProfileAnalysis(**result.data[0])
|
|
|
|
async def get_profile_analysis(self, user_id: UUID) -> Optional[ProfileAnalysis]:
|
|
"""Get profile analysis for user."""
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("profile_analyses").select("*").eq(
|
|
"user_id", str(user_id)
|
|
).execute()
|
|
)
|
|
if result.data:
|
|
return ProfileAnalysis(**result.data[0])
|
|
return None
|
|
|
|
# ==================== RESEARCH RESULTS ====================
|
|
|
|
async def save_research_result(self, research: ResearchResult) -> ResearchResult:
|
|
"""Save research result."""
|
|
data = research.model_dump(exclude={"id", "created_at"}, exclude_none=True)
|
|
if "user_id" in data:
|
|
data["user_id"] = str(data["user_id"])
|
|
if "target_post_type_id" in data and data["target_post_type_id"]:
|
|
data["target_post_type_id"] = str(data["target_post_type_id"])
|
|
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("research_results").insert(data).execute()
|
|
)
|
|
logger.info(f"Saved research result for user: {research.user_id}")
|
|
return ResearchResult(**result.data[0])
|
|
|
|
async def get_latest_research(self, user_id: UUID) -> Optional[ResearchResult]:
|
|
"""Get latest research result for user."""
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("research_results").select("*").eq(
|
|
"user_id", str(user_id)
|
|
).order("created_at", desc=True).limit(1).execute()
|
|
)
|
|
if result.data:
|
|
return ResearchResult(**result.data[0])
|
|
return None
|
|
|
|
async def get_all_research(
|
|
self,
|
|
user_id: UUID,
|
|
post_type_id: Optional[UUID] = None
|
|
) -> List[ResearchResult]:
|
|
"""Get all research results for user, optionally filtered by post type."""
|
|
def _query():
|
|
query = self.client.table("research_results").select("*").eq(
|
|
"user_id", str(user_id)
|
|
)
|
|
if post_type_id:
|
|
query = query.eq("target_post_type_id", str(post_type_id))
|
|
return query.order("created_at", desc=True).execute()
|
|
|
|
result = await asyncio.to_thread(_query)
|
|
return [ResearchResult(**item) for item in result.data]
|
|
|
|
# ==================== GENERATED POSTS ====================
|
|
|
|
async def save_generated_post(self, post: GeneratedPost) -> GeneratedPost:
|
|
"""Save generated post."""
|
|
data = post.model_dump(exclude={"id", "created_at"}, exclude_none=True)
|
|
if "user_id" in data:
|
|
data["user_id"] = str(data["user_id"])
|
|
if "topic_id" in data and data["topic_id"]:
|
|
data["topic_id"] = str(data["topic_id"])
|
|
if "post_type_id" in data and data["post_type_id"]:
|
|
data["post_type_id"] = str(data["post_type_id"])
|
|
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("generated_posts").insert(data).execute()
|
|
)
|
|
logger.info(f"Saved generated post: {result.data[0]['id']}")
|
|
return GeneratedPost(**result.data[0])
|
|
|
|
async def update_generated_post(self, post_id: UUID, updates: Dict[str, Any]) -> GeneratedPost:
|
|
"""Update generated post."""
|
|
# Convert datetime fields to isoformat strings
|
|
datetime_fields = ['published_at', 'approved_at', 'rejected_at', 'scheduled_at']
|
|
for field in datetime_fields:
|
|
if field in updates and isinstance(updates[field], datetime):
|
|
updates[field] = updates[field].isoformat()
|
|
|
|
# Handle metadata dict - ensure all nested datetime values are serialized
|
|
if 'metadata' in updates and isinstance(updates['metadata'], dict):
|
|
serialized_metadata = {}
|
|
for key, value in updates['metadata'].items():
|
|
if isinstance(value, datetime):
|
|
serialized_metadata[key] = value.isoformat()
|
|
else:
|
|
serialized_metadata[key] = value
|
|
updates['metadata'] = serialized_metadata
|
|
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("generated_posts").update(updates).eq(
|
|
"id", str(post_id)
|
|
).execute()
|
|
)
|
|
logger.info(f"Updated generated post: {post_id}")
|
|
return GeneratedPost(**result.data[0])
|
|
|
|
async def get_generated_posts(self, user_id: UUID) -> List[GeneratedPost]:
|
|
"""Get all generated posts for user."""
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("generated_posts").select("*").eq(
|
|
"user_id", str(user_id)
|
|
).order("created_at", desc=True).execute()
|
|
)
|
|
return [GeneratedPost(**item) for item in result.data]
|
|
|
|
async def get_generated_post(self, post_id: UUID) -> Optional[GeneratedPost]:
|
|
"""Get a single generated post by ID."""
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("generated_posts").select("*").eq(
|
|
"id", str(post_id)
|
|
).execute()
|
|
)
|
|
if result.data:
|
|
return GeneratedPost(**result.data[0])
|
|
return None
|
|
|
|
async def get_scheduled_posts_due(self) -> List[GeneratedPost]:
|
|
"""Get all posts that are scheduled and due for publishing."""
|
|
from datetime import datetime
|
|
now = datetime.now(timezone.utc).isoformat()
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("generated_posts").select("*").eq(
|
|
"status", "scheduled"
|
|
).lte("scheduled_at", now).execute()
|
|
)
|
|
return [GeneratedPost(**item) for item in result.data]
|
|
|
|
async def get_scheduled_posts_for_company(self, company_id: UUID) -> List[Dict[str, Any]]:
|
|
"""Get all scheduled posts for a company with employee info."""
|
|
# Get all profiles (employees) belonging to this company
|
|
profiles_result = await asyncio.to_thread(
|
|
lambda: self.client.table("profiles").select("id, display_name").eq(
|
|
"company_id", str(company_id)
|
|
).execute()
|
|
)
|
|
|
|
if not profiles_result.data:
|
|
return []
|
|
|
|
user_ids = [p["id"] for p in profiles_result.data]
|
|
profile_map = {p["id"]: p for p in profiles_result.data}
|
|
|
|
posts_result = await asyncio.to_thread(
|
|
lambda: self.client.table("generated_posts").select(
|
|
"id, user_id, topic_title, post_content, status, scheduled_at, created_at"
|
|
).in_(
|
|
"user_id", user_ids
|
|
).in_(
|
|
"status", ["ready", "scheduled", "published"]
|
|
).order("scheduled_at", desc=False, nullsfirst=False).execute()
|
|
)
|
|
|
|
enriched_posts = []
|
|
for post in posts_result.data:
|
|
profile = profile_map.get(post["user_id"])
|
|
enriched_posts.append({
|
|
**post,
|
|
"employee_name": profile["display_name"] if profile else "Unknown",
|
|
"employee_user_id": post["user_id"]
|
|
})
|
|
|
|
return enriched_posts
|
|
|
|
async def schedule_post(self, post_id: UUID, scheduled_at: datetime, scheduled_by_user_id: UUID) -> GeneratedPost:
|
|
"""Schedule a post for publishing."""
|
|
from datetime import datetime as dt
|
|
updates = {
|
|
"status": "scheduled",
|
|
"scheduled_at": scheduled_at.isoformat() if isinstance(scheduled_at, dt) else scheduled_at,
|
|
"scheduled_by_user_id": str(scheduled_by_user_id)
|
|
}
|
|
return await self.update_generated_post(post_id, updates)
|
|
|
|
async def unschedule_post(self, post_id: UUID) -> GeneratedPost:
|
|
"""Remove scheduling from a post (back to approved status)."""
|
|
updates = {
|
|
"status": "approved",
|
|
"scheduled_at": None,
|
|
"scheduled_by_user_id": None
|
|
}
|
|
return await self.update_generated_post(post_id, updates)
|
|
|
|
|
|
# ==================== PROFILES / USERS ====================
|
|
|
|
async def create_profile(self, user_id: UUID, profile: Profile) -> Profile:
|
|
"""Create a profile for a user."""
|
|
data = profile.model_dump(exclude={"created_at", "updated_at"}, exclude_none=True)
|
|
data["id"] = str(user_id)
|
|
if "company_id" in data and data["company_id"]:
|
|
data["company_id"] = str(data["company_id"])
|
|
if "account_type" in data:
|
|
data["account_type"] = data["account_type"].value if hasattr(data["account_type"], "value") else data["account_type"]
|
|
if "onboarding_status" in data:
|
|
data["onboarding_status"] = data["onboarding_status"].value if hasattr(data["onboarding_status"], "value") else data["onboarding_status"]
|
|
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("profiles").upsert(data).execute()
|
|
)
|
|
logger.info(f"Created/updated profile: {result.data[0]['id']}")
|
|
return Profile(**result.data[0])
|
|
|
|
async def get_profile(self, user_id: UUID) -> Optional[Profile]:
|
|
"""Get profile by user ID."""
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("profiles").select("*").eq("id", str(user_id)).execute()
|
|
)
|
|
if result.data:
|
|
return Profile(**result.data[0])
|
|
return None
|
|
|
|
async def get_profiles_by_linkedin_url(self, linkedin_url: str) -> List[Profile]:
|
|
"""Get all profiles with this LinkedIn URL."""
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("profiles").select("*").eq("linkedin_url", linkedin_url).execute()
|
|
)
|
|
return [Profile(**item) for item in result.data]
|
|
|
|
async def update_profile(self, user_id: UUID, updates: Dict[str, Any]) -> Profile:
|
|
"""Update profile fields."""
|
|
if "company_id" in updates and updates["company_id"]:
|
|
updates["company_id"] = str(updates["company_id"])
|
|
for key in ["account_type", "onboarding_status"]:
|
|
if key in updates and hasattr(updates[key], "value"):
|
|
updates[key] = updates[key].value
|
|
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("profiles").update(updates).eq(
|
|
"id", str(user_id)
|
|
).execute()
|
|
)
|
|
logger.info(f"Updated profile: {user_id}")
|
|
return Profile(**result.data[0])
|
|
|
|
# ==================== LINKEDIN ACCOUNTS ====================
|
|
|
|
async def get_linkedin_account(self, user_id: UUID) -> Optional['LinkedInAccount']:
|
|
"""Get LinkedIn account for user."""
|
|
from src.database.models import LinkedInAccount
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("linkedin_accounts").select("*")
|
|
.eq("user_id", str(user_id)).eq("is_active", True).execute()
|
|
)
|
|
if result.data:
|
|
return LinkedInAccount(**result.data[0])
|
|
return None
|
|
|
|
async def get_linkedin_account_by_id(self, account_id: UUID) -> Optional['LinkedInAccount']:
|
|
"""Get LinkedIn account by ID."""
|
|
from src.database.models import LinkedInAccount
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("linkedin_accounts").select("*")
|
|
.eq("id", str(account_id)).execute()
|
|
)
|
|
if result.data:
|
|
return LinkedInAccount(**result.data[0])
|
|
return None
|
|
|
|
async def create_linkedin_account(self, account: 'LinkedInAccount') -> 'LinkedInAccount':
|
|
"""Create LinkedIn account connection."""
|
|
from src.database.models import LinkedInAccount
|
|
data = account.model_dump(exclude={'id', 'created_at', 'updated_at'})
|
|
data['user_id'] = str(data['user_id'])
|
|
data['token_expires_at'] = data['token_expires_at'].isoformat()
|
|
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("linkedin_accounts").insert(data).execute()
|
|
)
|
|
logger.info(f"Created LinkedIn account for user: {account.user_id}")
|
|
return LinkedInAccount(**result.data[0])
|
|
|
|
async def update_linkedin_account(self, account_id: UUID, updates: Dict) -> 'LinkedInAccount':
|
|
"""Update LinkedIn account."""
|
|
from src.database.models import LinkedInAccount
|
|
# Convert all datetime fields to isoformat strings
|
|
datetime_fields = ['token_expires_at', 'last_used_at', 'last_error_at']
|
|
for field in datetime_fields:
|
|
if field in updates and isinstance(updates[field], datetime):
|
|
updates[field] = updates[field].isoformat()
|
|
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("linkedin_accounts").update(updates)
|
|
.eq("id", str(account_id)).execute()
|
|
)
|
|
logger.info(f"Updated LinkedIn account: {account_id}")
|
|
return LinkedInAccount(**result.data[0])
|
|
|
|
async def delete_linkedin_account(self, account_id: UUID) -> None:
|
|
"""Delete LinkedIn account connection."""
|
|
await asyncio.to_thread(
|
|
lambda: self.client.table("linkedin_accounts").delete()
|
|
.eq("id", str(account_id)).execute()
|
|
)
|
|
logger.info(f"Deleted LinkedIn account: {account_id}")
|
|
|
|
# ==================== USERS ====================
|
|
|
|
async def get_user(self, user_id: UUID) -> Optional[User]:
|
|
"""Get user by ID (from users view)."""
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("users").select("*").eq("id", str(user_id)).execute()
|
|
)
|
|
if result.data:
|
|
return User(**result.data[0])
|
|
return None
|
|
|
|
async def get_user_by_email(self, email: str) -> Optional[User]:
|
|
"""Get user by email (from users view)."""
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("users").select("*").eq("email", email.lower()).execute()
|
|
)
|
|
if result.data:
|
|
return User(**result.data[0])
|
|
return None
|
|
|
|
async def get_user_by_linkedin_sub(self, linkedin_sub: str) -> Optional[User]:
|
|
"""Get user by LinkedIn sub (OAuth identifier)."""
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("users").select("*").eq("linkedin_sub", linkedin_sub).execute()
|
|
)
|
|
if result.data:
|
|
return User(**result.data[0])
|
|
return None
|
|
|
|
async def update_user(self, user_id: UUID, updates: Dict[str, Any]) -> User:
|
|
"""Update user/profile fields."""
|
|
profile_fields = ["account_type", "display_name", "onboarding_status",
|
|
"onboarding_data", "company_id",
|
|
"linkedin_url", "writing_style_notes", "metadata",
|
|
"profile_picture", "creator_email", "customer_email", "is_active"]
|
|
profile_updates = {k: v for k, v in updates.items() if k in profile_fields}
|
|
|
|
if profile_updates:
|
|
await self.update_profile(user_id, profile_updates)
|
|
|
|
return await self.get_user(user_id)
|
|
|
|
async def list_users(self, account_type: Optional[str] = None, company_id: Optional[UUID] = None) -> List[User]:
|
|
"""List users with optional filters (from users view)."""
|
|
def _query():
|
|
query = self.client.table("users").select("*")
|
|
if account_type:
|
|
query = query.eq("account_type", account_type)
|
|
if company_id:
|
|
query = query.eq("company_id", str(company_id))
|
|
return query.order("created_at", desc=True).execute()
|
|
|
|
result = await asyncio.to_thread(_query)
|
|
return [User(**item) for item in result.data]
|
|
|
|
async def delete_user_completely(self, user_id: UUID) -> bool:
|
|
"""Delete a user and all related data completely."""
|
|
try:
|
|
user_id_str = str(user_id)
|
|
|
|
# Delete all content data directly via user_id
|
|
for table in ["generated_posts", "linkedin_posts", "example_posts",
|
|
"reference_profiles", "profile_analyses", "linkedin_profiles",
|
|
"post_types", "research_results", "topics"]:
|
|
await asyncio.to_thread(
|
|
lambda t=table: self.client.table(t).delete().eq(
|
|
"user_id", user_id_str
|
|
).execute()
|
|
)
|
|
|
|
# Get user email for invitation deletion
|
|
user = await self.get_user(user_id)
|
|
if user and user.email:
|
|
await asyncio.to_thread(
|
|
lambda: self.client.table("invitations").delete().eq(
|
|
"email", user.email
|
|
).execute()
|
|
)
|
|
|
|
# Delete profile
|
|
await asyncio.to_thread(
|
|
lambda: self.client.table("profiles").delete().eq(
|
|
"id", user_id_str
|
|
).execute()
|
|
)
|
|
|
|
# Delete from auth.users
|
|
if self.admin_client:
|
|
await asyncio.to_thread(
|
|
lambda: self.admin_client.auth.admin.delete_user(user_id_str)
|
|
)
|
|
logger.info(f"Deleted auth user {user_id}")
|
|
else:
|
|
logger.warning(f"Cannot delete auth user {user_id} - no service role key configured")
|
|
|
|
logger.info(f"Completely deleted user {user_id} and all related data")
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error deleting user {user_id}: {e}")
|
|
raise
|
|
|
|
# ==================== COMPANIES ====================
|
|
|
|
async def create_company(self, company: Company) -> Company:
|
|
"""Create a new company."""
|
|
data = company.model_dump(exclude={"id", "created_at", "updated_at"}, exclude_none=True)
|
|
if "owner_user_id" in data:
|
|
data["owner_user_id"] = str(data["owner_user_id"])
|
|
if "license_key_id" in data:
|
|
data["license_key_id"] = str(data["license_key_id"])
|
|
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("companies").insert(data).execute()
|
|
)
|
|
logger.info(f"Created company: {result.data[0]['id']}")
|
|
return Company(**result.data[0])
|
|
|
|
async def get_company(self, company_id: UUID) -> Optional[Company]:
|
|
"""Get company by ID."""
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("companies").select("*").eq("id", str(company_id)).execute()
|
|
)
|
|
if result.data:
|
|
return Company(**result.data[0])
|
|
return None
|
|
|
|
async def get_company_by_owner(self, owner_user_id: UUID) -> Optional[Company]:
|
|
"""Get company by owner user ID."""
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("companies").select("*").eq("owner_user_id", str(owner_user_id)).execute()
|
|
)
|
|
if result.data:
|
|
return Company(**result.data[0])
|
|
return None
|
|
|
|
async def get_company_employees(self, company_id: UUID) -> List[User]:
|
|
"""Get all employees of a company."""
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("users").select("*").eq(
|
|
"company_id", str(company_id)
|
|
).eq("account_type", "employee").order("created_at", desc=True).execute()
|
|
)
|
|
return [User(**item) for item in result.data]
|
|
|
|
async def update_company(self, company_id: UUID, updates: Dict[str, Any]) -> Company:
|
|
"""Update company fields."""
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("companies").update(updates).eq(
|
|
"id", str(company_id)
|
|
).execute()
|
|
)
|
|
logger.info(f"Updated company: {company_id}")
|
|
return Company(**result.data[0])
|
|
|
|
async def list_companies(self) -> List[Company]:
|
|
"""List all companies."""
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("companies").select("*").order("created_at", desc=True).execute()
|
|
)
|
|
return [Company(**item) for item in result.data]
|
|
|
|
# ==================== INVITATIONS ====================
|
|
|
|
async def create_invitation(self, invitation: Invitation) -> Invitation:
|
|
"""Create a new invitation."""
|
|
from datetime import datetime
|
|
data = invitation.model_dump(exclude={"id", "created_at"}, exclude_none=True)
|
|
for key in ["company_id", "invited_by_user_id", "accepted_by_user_id"]:
|
|
if key in data and data[key]:
|
|
data[key] = str(data[key])
|
|
if "expires_at" in data and isinstance(data["expires_at"], datetime):
|
|
data["expires_at"] = data["expires_at"].isoformat()
|
|
if "accepted_at" in data and isinstance(data["accepted_at"], datetime):
|
|
data["accepted_at"] = data["accepted_at"].isoformat()
|
|
if "status" in data and hasattr(data["status"], "value"):
|
|
data["status"] = data["status"].value
|
|
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("invitations").insert(data).execute()
|
|
)
|
|
logger.info(f"Created invitation: {result.data[0]['id']}")
|
|
return Invitation(**result.data[0])
|
|
|
|
async def get_invitation_by_token(self, token: str) -> Optional[Invitation]:
|
|
"""Get invitation by token."""
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("invitations").select("*").eq("token", token).execute()
|
|
)
|
|
if result.data:
|
|
return Invitation(**result.data[0])
|
|
return None
|
|
|
|
async def get_invitation(self, invitation_id: UUID) -> Optional[Invitation]:
|
|
"""Get invitation by ID."""
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("invitations").select("*").eq("id", str(invitation_id)).execute()
|
|
)
|
|
if result.data:
|
|
return Invitation(**result.data[0])
|
|
return None
|
|
|
|
async def update_invitation(self, invitation_id: UUID, updates: Dict[str, Any]) -> Invitation:
|
|
"""Update invitation fields."""
|
|
from datetime import datetime
|
|
if "accepted_by_user_id" in updates and updates["accepted_by_user_id"]:
|
|
updates["accepted_by_user_id"] = str(updates["accepted_by_user_id"])
|
|
if "accepted_at" in updates and isinstance(updates["accepted_at"], datetime):
|
|
updates["accepted_at"] = updates["accepted_at"].isoformat()
|
|
if "status" in updates and hasattr(updates["status"], "value"):
|
|
updates["status"] = updates["status"].value
|
|
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("invitations").update(updates).eq(
|
|
"id", str(invitation_id)
|
|
).execute()
|
|
)
|
|
logger.info(f"Updated invitation: {invitation_id}")
|
|
return Invitation(**result.data[0])
|
|
|
|
async def get_pending_invitations(self, company_id: UUID) -> List[Invitation]:
|
|
"""Get all pending invitations for a company."""
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("invitations").select("*").eq(
|
|
"company_id", str(company_id)
|
|
).eq("status", "pending").order("created_at", desc=True).execute()
|
|
)
|
|
return [Invitation(**item) for item in result.data]
|
|
|
|
async def get_invitations_by_email(self, email: str) -> List[Invitation]:
|
|
"""Get all invitations for an email address."""
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("invitations").select("*").eq(
|
|
"email", email.lower()
|
|
).order("created_at", desc=True).execute()
|
|
)
|
|
return [Invitation(**item) for item in result.data]
|
|
|
|
# ==================== EXAMPLE POSTS ====================
|
|
|
|
async def save_example_posts(self, posts: List[ExamplePost]) -> List[ExamplePost]:
|
|
"""Save example posts (bulk)."""
|
|
if not posts:
|
|
return []
|
|
|
|
data = []
|
|
for p in posts:
|
|
post_dict = p.model_dump(exclude={"id", "created_at"}, exclude_none=True)
|
|
if "user_id" in post_dict:
|
|
post_dict["user_id"] = str(post_dict["user_id"])
|
|
if "post_type_id" in post_dict and post_dict["post_type_id"]:
|
|
post_dict["post_type_id"] = str(post_dict["post_type_id"])
|
|
data.append(post_dict)
|
|
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("example_posts").insert(data).execute()
|
|
)
|
|
logger.info(f"Saved {len(result.data)} example posts")
|
|
return [ExamplePost(**item) for item in result.data]
|
|
|
|
async def save_example_post(self, post: ExamplePost) -> ExamplePost:
|
|
"""Save a single example post."""
|
|
posts = await self.save_example_posts([post])
|
|
return posts[0] if posts else post
|
|
|
|
async def get_example_posts(self, user_id: UUID) -> List[ExamplePost]:
|
|
"""Get all example posts for a user."""
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("example_posts").select("*").eq(
|
|
"user_id", str(user_id)
|
|
).order("created_at", desc=True).execute()
|
|
)
|
|
return [ExamplePost(**item) for item in result.data]
|
|
|
|
async def delete_example_post(self, post_id: UUID) -> None:
|
|
"""Delete an example post."""
|
|
await asyncio.to_thread(
|
|
lambda: self.client.table("example_posts").delete().eq("id", str(post_id)).execute()
|
|
)
|
|
logger.info(f"Deleted example post: {post_id}")
|
|
|
|
# ==================== REFERENCE PROFILES ====================
|
|
|
|
async def create_reference_profile(self, profile: ReferenceProfile) -> ReferenceProfile:
|
|
"""Create a new reference profile."""
|
|
data = profile.model_dump(exclude={"id", "created_at"}, exclude_none=True)
|
|
if "user_id" in data:
|
|
data["user_id"] = str(data["user_id"])
|
|
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("reference_profiles").insert(data).execute()
|
|
)
|
|
logger.info(f"Created reference profile: {result.data[0]['id']}")
|
|
return ReferenceProfile(**result.data[0])
|
|
|
|
async def get_reference_profiles(self, user_id: UUID) -> List[ReferenceProfile]:
|
|
"""Get all reference profiles for a user."""
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("reference_profiles").select("*").eq(
|
|
"user_id", str(user_id)
|
|
).order("created_at", desc=True).execute()
|
|
)
|
|
return [ReferenceProfile(**item) for item in result.data]
|
|
|
|
async def update_reference_profile(self, profile_id: UUID, updates: Dict[str, Any]) -> ReferenceProfile:
|
|
"""Update reference profile fields."""
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("reference_profiles").update(updates).eq(
|
|
"id", str(profile_id)
|
|
).execute()
|
|
)
|
|
logger.info(f"Updated reference profile: {profile_id}")
|
|
return ReferenceProfile(**result.data[0])
|
|
|
|
async def delete_reference_profile(self, profile_id: UUID) -> None:
|
|
"""Delete a reference profile."""
|
|
await asyncio.to_thread(
|
|
lambda: self.client.table("reference_profiles").delete().eq("id", str(profile_id)).execute()
|
|
)
|
|
logger.info(f"Deleted reference profile: {profile_id}")
|
|
|
|
# ==================== API USAGE LOGS ====================
|
|
|
|
async def log_api_usage(
|
|
self,
|
|
provider: str,
|
|
model: str,
|
|
operation: str,
|
|
prompt_tokens: int,
|
|
completion_tokens: int,
|
|
total_tokens: int,
|
|
estimated_cost_usd: float,
|
|
user_id: Optional[UUID] = None,
|
|
company_id: Optional[UUID] = None
|
|
) -> None:
|
|
"""Log an API usage event."""
|
|
data = {
|
|
"provider": provider,
|
|
"model": model,
|
|
"operation": operation,
|
|
"prompt_tokens": prompt_tokens,
|
|
"completion_tokens": completion_tokens,
|
|
"total_tokens": total_tokens,
|
|
"estimated_cost_usd": estimated_cost_usd
|
|
}
|
|
if user_id:
|
|
data["user_id"] = str(user_id)
|
|
if company_id:
|
|
data["company_id"] = str(company_id)
|
|
|
|
try:
|
|
await asyncio.to_thread(
|
|
lambda: self.client.table("api_usage_logs").insert(data).execute()
|
|
)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to log API usage: {e}")
|
|
|
|
async def get_usage_stats(
|
|
self,
|
|
start_date: Optional[str] = None,
|
|
end_date: Optional[str] = None,
|
|
user_id: Optional[UUID] = None,
|
|
company_id: Optional[UUID] = None
|
|
) -> List[Dict[str, Any]]:
|
|
"""Get aggregated usage stats."""
|
|
def _query():
|
|
query = self.client.table("api_usage_logs").select("*")
|
|
if user_id:
|
|
query = query.eq("user_id", str(user_id))
|
|
if company_id:
|
|
query = query.eq("company_id", str(company_id))
|
|
if start_date:
|
|
query = query.gte("created_at", start_date)
|
|
if end_date:
|
|
query = query.lte("created_at", end_date)
|
|
return query.order("created_at", desc=True).execute()
|
|
|
|
result = await asyncio.to_thread(_query)
|
|
return result.data
|
|
|
|
async def get_usage_by_day(
|
|
self,
|
|
start_date: Optional[str] = None,
|
|
end_date: Optional[str] = None,
|
|
user_id: Optional[UUID] = None,
|
|
company_id: Optional[UUID] = None
|
|
) -> List[Dict[str, Any]]:
|
|
"""Get daily usage breakdown."""
|
|
logs = await self.get_usage_stats(start_date, end_date, user_id, company_id)
|
|
|
|
from collections import defaultdict
|
|
daily = defaultdict(lambda: {"total_tokens": 0, "estimated_cost_usd": 0.0, "count": 0})
|
|
|
|
for log in logs:
|
|
day = log["created_at"][:10]
|
|
daily[day]["total_tokens"] += log.get("total_tokens", 0)
|
|
daily[day]["estimated_cost_usd"] += float(log.get("estimated_cost_usd", 0))
|
|
daily[day]["count"] += 1
|
|
|
|
return [{"date": day, **data} for day, data in sorted(daily.items())]
|
|
|
|
async def get_post_stats(
|
|
self,
|
|
start_date: Optional[str] = None,
|
|
user_id: Optional[UUID] = None,
|
|
company_id: Optional[UUID] = None
|
|
) -> List[Dict[str, Any]]:
|
|
"""Get generated posts with select fields for statistics."""
|
|
# Resolve user_ids for filtering
|
|
target_user_ids = None
|
|
if user_id:
|
|
target_user_ids = [str(user_id)]
|
|
if company_id:
|
|
# Get all users belonging to this company
|
|
profiles_result = await asyncio.to_thread(
|
|
lambda: self.client.table("profiles").select("id").eq(
|
|
"company_id", str(company_id)
|
|
).execute()
|
|
)
|
|
company_user_ids = [p["id"] for p in profiles_result.data]
|
|
if target_user_ids is not None:
|
|
target_user_ids = [u for u in target_user_ids if u in company_user_ids]
|
|
else:
|
|
target_user_ids = company_user_ids
|
|
if not target_user_ids:
|
|
return []
|
|
|
|
def _query():
|
|
q = self.client.table("generated_posts").select(
|
|
"id, created_at, status, approved_at, published_at, user_id"
|
|
)
|
|
if target_user_ids is not None:
|
|
q = q.in_("user_id", target_user_ids)
|
|
if start_date:
|
|
q = q.gte("created_at", start_date)
|
|
return q.order("created_at", desc=True).execute()
|
|
|
|
result = await asyncio.to_thread(_query)
|
|
return result.data
|
|
|
|
# ==================== LICENSE KEYS ====================
|
|
|
|
async def list_license_keys(self) -> List[LicenseKey]:
|
|
"""Get all license keys (admin only)."""
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("license_keys")
|
|
.select("*")
|
|
.order("created_at", desc=True)
|
|
.execute()
|
|
)
|
|
return [LicenseKey(**row) for row in result.data]
|
|
|
|
async def get_license_key(self, key: str) -> Optional[LicenseKey]:
|
|
"""Get license key by key string."""
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("license_keys")
|
|
.select("*")
|
|
.eq("key", key)
|
|
.execute()
|
|
)
|
|
return LicenseKey(**result.data[0]) if result.data else None
|
|
|
|
async def create_license_key(
|
|
self,
|
|
key: str,
|
|
max_employees: int,
|
|
max_posts_per_day: int,
|
|
max_researches_per_day: int,
|
|
description: Optional[str] = None
|
|
) -> LicenseKey:
|
|
"""Create new license key."""
|
|
data = {
|
|
"key": key,
|
|
"max_employees": max_employees,
|
|
"max_posts_per_day": max_posts_per_day,
|
|
"max_researches_per_day": max_researches_per_day,
|
|
"description": description,
|
|
}
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("license_keys").insert(data).execute()
|
|
)
|
|
logger.info(f"Created license key: {key}")
|
|
return LicenseKey(**result.data[0])
|
|
|
|
async def mark_license_key_used(
|
|
self,
|
|
key: str,
|
|
company_id: UUID
|
|
) -> LicenseKey:
|
|
"""Mark license key as used and link to company."""
|
|
data = {
|
|
"used": True,
|
|
"company_id": str(company_id),
|
|
"used_at": datetime.now(timezone.utc).isoformat()
|
|
}
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("license_keys")
|
|
.update(data)
|
|
.eq("key", key)
|
|
.execute()
|
|
)
|
|
logger.info(f"Marked license key as used: {key}")
|
|
return LicenseKey(**result.data[0])
|
|
|
|
async def update_license_key(self, key_id: UUID, updates: Dict[str, Any]) -> LicenseKey:
|
|
"""Update license key limits (admin only)."""
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("license_keys")
|
|
.update(updates)
|
|
.eq("id", str(key_id))
|
|
.execute()
|
|
)
|
|
logger.info(f"Updated license key: {key_id}")
|
|
return LicenseKey(**result.data[0]) if result.data else None
|
|
|
|
async def delete_license_key(self, key_id: UUID) -> None:
|
|
"""Delete license key (admin only)."""
|
|
await asyncio.to_thread(
|
|
lambda: self.client.table("license_keys")
|
|
.delete()
|
|
.eq("id", str(key_id))
|
|
.execute()
|
|
)
|
|
logger.info(f"Deleted license key: {key_id}")
|
|
|
|
# ==================== COMPANY QUOTAS ====================
|
|
|
|
async def get_company_daily_quota(
|
|
self,
|
|
company_id: UUID,
|
|
date_: Optional[date] = None
|
|
) -> CompanyDailyQuota:
|
|
"""Get or create daily quota for company."""
|
|
if date_ is None:
|
|
date_ = date.today()
|
|
|
|
# Try to get existing
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("company_daily_quotas")
|
|
.select("*")
|
|
.eq("company_id", str(company_id))
|
|
.eq("date", date_.isoformat())
|
|
.execute()
|
|
)
|
|
|
|
if result.data:
|
|
return CompanyDailyQuota(**result.data[0])
|
|
|
|
# Create new quota for today
|
|
data = {
|
|
"company_id": str(company_id),
|
|
"date": date_.isoformat(),
|
|
"posts_created": 0,
|
|
"researches_created": 0
|
|
}
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("company_daily_quotas")
|
|
.insert(data)
|
|
.execute()
|
|
)
|
|
return CompanyDailyQuota(**result.data[0])
|
|
|
|
async def increment_company_posts_quota(self, company_id: UUID) -> None:
|
|
"""Increment daily posts count for company."""
|
|
try:
|
|
quota = await self.get_company_daily_quota(company_id)
|
|
new_count = quota.posts_created + 1
|
|
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("company_daily_quotas")
|
|
.update({"posts_created": new_count})
|
|
.eq("id", str(quota.id))
|
|
.execute()
|
|
)
|
|
|
|
logger.info(f"Incremented posts quota for company {company_id}: {quota.posts_created} -> {new_count}")
|
|
|
|
if not result.data:
|
|
logger.error(f"Failed to increment posts quota - no data returned")
|
|
except Exception as e:
|
|
logger.error(f"Error incrementing posts quota for company {company_id}: {e}")
|
|
raise
|
|
|
|
async def increment_company_researches_quota(self, company_id: UUID) -> None:
|
|
"""Increment daily researches count for company."""
|
|
try:
|
|
quota = await self.get_company_daily_quota(company_id)
|
|
new_count = quota.researches_created + 1
|
|
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("company_daily_quotas")
|
|
.update({"researches_created": new_count})
|
|
.eq("id", str(quota.id))
|
|
.execute()
|
|
)
|
|
|
|
logger.info(f"Incremented researches quota for company {company_id}: {quota.researches_created} -> {new_count}")
|
|
|
|
if not result.data:
|
|
logger.error(f"Failed to increment researches quota - no data returned")
|
|
except Exception as e:
|
|
logger.error(f"Error incrementing researches quota for company {company_id}: {e}")
|
|
raise
|
|
|
|
async def get_company_limits(self, company_id: UUID) -> Optional[LicenseKey]:
|
|
"""Get company limits from associated license key.
|
|
|
|
Returns None if no license key is associated (uses defaults).
|
|
"""
|
|
company = await self.get_company(company_id)
|
|
if not company or not company.license_key_id:
|
|
return None
|
|
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("license_keys")
|
|
.select("*")
|
|
.eq("id", str(company.license_key_id))
|
|
.execute()
|
|
)
|
|
|
|
return LicenseKey(**result.data[0]) if result.data else None
|
|
|
|
async def check_company_post_limit(self, company_id: UUID) -> tuple[bool, str]:
|
|
"""Check if company can create more posts today.
|
|
|
|
Returns (can_create: bool, error_message: str)
|
|
"""
|
|
license_key = await self.get_company_limits(company_id)
|
|
if not license_key:
|
|
# No license key, use defaults (unlimited)
|
|
return True, ""
|
|
|
|
quota = await self.get_company_daily_quota(company_id)
|
|
|
|
if quota.posts_created >= license_key.max_posts_per_day:
|
|
return False, f"Tageslimit erreicht ({license_key.max_posts_per_day} Posts/Tag). Versuche es morgen wieder."
|
|
|
|
return True, ""
|
|
|
|
async def check_company_research_limit(self, company_id: UUID) -> tuple[bool, str]:
|
|
"""Check if company can create more researches today.
|
|
|
|
Returns (can_create: bool, error_message: str)
|
|
"""
|
|
license_key = await self.get_company_limits(company_id)
|
|
if not license_key:
|
|
# No license key, use defaults (unlimited)
|
|
return True, ""
|
|
|
|
quota = await self.get_company_daily_quota(company_id)
|
|
|
|
if quota.researches_created >= license_key.max_researches_per_day:
|
|
return False, f"Tageslimit erreicht ({license_key.max_researches_per_day} Researches/Tag). Versuche es morgen wieder."
|
|
|
|
return True, ""
|
|
|
|
async def check_company_employee_limit(self, company_id: UUID) -> tuple[bool, str]:
|
|
"""Check if company can add more employees.
|
|
|
|
Returns (can_add: bool, error_message: str)
|
|
"""
|
|
license_key = await self.get_company_limits(company_id)
|
|
if not license_key:
|
|
# No license key, use defaults (unlimited)
|
|
return True, ""
|
|
|
|
employees = await self.list_users(account_type="employee", company_id=company_id)
|
|
|
|
if len(employees) >= license_key.max_employees:
|
|
return False, f"Maximale Mitarbeiteranzahl erreicht ({license_key.max_employees}). Bitte Lizenz upgraden."
|
|
|
|
return True, ""
|
|
|
|
|
|
# Global database client instance
|
|
db = DatabaseClient()
|