1830 lines
76 KiB
Python
1830 lines
76 KiB
Python
"""Supabase database client."""
|
|
import asyncio
|
|
from datetime import datetime, date, timezone
|
|
from typing import Optional, List, Dict, Any
|
|
from uuid import UUID
|
|
from supabase import create_client, Client
|
|
from loguru import logger
|
|
|
|
from src.config import settings
|
|
from src.database.models import (
|
|
LinkedInProfile, LinkedInPost, Topic,
|
|
ProfileAnalysis, ResearchResult, GeneratedPost, PostType,
|
|
User, Profile, Company, Invitation, ExamplePost, ReferenceProfile,
|
|
ApiUsageLog, LicenseKey, CompanyDailyQuota, LicenseKeyOffer
|
|
)
|
|
from src.services.cache_service import (
|
|
cache,
|
|
PROFILE_TTL, USER_TTL, LINKEDIN_ACCOUNT_TTL, PROFILE_ANALYSIS_TTL,
|
|
COMPANY_TTL, LINKEDIN_POSTS_TTL, POST_TYPES_TTL, POST_TYPE_TTL,
|
|
GEN_POST_TTL, GEN_POSTS_TTL,
|
|
)
|
|
|
|
|
|
class DatabaseClient:
|
|
"""Supabase database client wrapper."""
|
|
|
|
def __init__(self):
|
|
"""Initialize Supabase client."""
|
|
if settings.supabase_service_role_key:
|
|
self.client: Client = create_client(
|
|
settings.supabase_url,
|
|
settings.supabase_service_role_key
|
|
)
|
|
logger.info("Supabase client initialized with service role key (fast mode)")
|
|
else:
|
|
self.client: Client = create_client(
|
|
settings.supabase_url,
|
|
settings.supabase_key
|
|
)
|
|
logger.warning("Supabase client initialized with anon key (slow mode - RLS enabled)")
|
|
|
|
self.admin_client: Optional[Client] = self.client if settings.supabase_service_role_key else None
|
|
|
|
# ==================== LINKEDIN PROFILES ====================
|
|
|
|
async def save_linkedin_profile(self, profile: LinkedInProfile) -> LinkedInProfile:
|
|
"""Save or update LinkedIn profile."""
|
|
data = profile.model_dump(exclude={"id", "scraped_at"}, exclude_none=True)
|
|
if "user_id" in data:
|
|
data["user_id"] = str(data["user_id"])
|
|
|
|
existing = await asyncio.to_thread(
|
|
lambda: self.client.table("linkedin_profiles").select("*").eq(
|
|
"user_id", str(profile.user_id)
|
|
).execute()
|
|
)
|
|
|
|
if existing.data:
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("linkedin_profiles").update(data).eq(
|
|
"user_id", str(profile.user_id)
|
|
).execute()
|
|
)
|
|
else:
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("linkedin_profiles").insert(data).execute()
|
|
)
|
|
|
|
logger.info(f"Saved LinkedIn profile for user: {profile.user_id}")
|
|
return LinkedInProfile(**result.data[0])
|
|
|
|
async def get_linkedin_profile(self, user_id: UUID) -> Optional[LinkedInProfile]:
|
|
"""Get LinkedIn profile for user."""
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("linkedin_profiles").select("*").eq(
|
|
"user_id", str(user_id)
|
|
).execute()
|
|
)
|
|
if result.data:
|
|
return LinkedInProfile(**result.data[0])
|
|
return None
|
|
|
|
# ==================== LINKEDIN POSTS ====================
|
|
|
|
async def save_linkedin_posts(self, posts: List[LinkedInPost]) -> List[LinkedInPost]:
|
|
"""Save LinkedIn posts (bulk)."""
|
|
seen = set()
|
|
unique_posts = []
|
|
for p in posts:
|
|
key = (str(p.user_id), p.post_url)
|
|
if key not in seen:
|
|
seen.add(key)
|
|
unique_posts.append(p)
|
|
|
|
if len(posts) != len(unique_posts):
|
|
logger.warning(f"Removed {len(posts) - len(unique_posts)} duplicate posts from batch")
|
|
|
|
data = []
|
|
for p in unique_posts:
|
|
# Use JSON mode so UUIDs/datetimes are serialized before the Supabase client
|
|
# builds its request payload.
|
|
post_dict = p.model_dump(mode="json", exclude={"id", "scraped_at"}, exclude_none=True)
|
|
data.append(post_dict)
|
|
|
|
if not data:
|
|
logger.warning("No posts to save")
|
|
return []
|
|
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("linkedin_posts").upsert(
|
|
data,
|
|
on_conflict="user_id,post_url"
|
|
).execute()
|
|
)
|
|
logger.info(f"Saved {len(result.data)} LinkedIn posts")
|
|
saved = [LinkedInPost(**item) for item in result.data]
|
|
# Invalidate cache for all affected users
|
|
affected_user_ids = {str(p.user_id) for p in saved}
|
|
for uid in affected_user_ids:
|
|
await cache.invalidate_linkedin_posts(uid)
|
|
return saved
|
|
|
|
async def get_linkedin_posts(self, user_id: UUID) -> List[LinkedInPost]:
|
|
"""Get all LinkedIn posts for user."""
|
|
key = cache.linkedin_posts_key(str(user_id))
|
|
if (hit := await cache.get(key)) is not None:
|
|
return [LinkedInPost(**item) for item in hit]
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("linkedin_posts").select("*").eq(
|
|
"user_id", str(user_id)
|
|
).order("post_date", desc=True).execute()
|
|
)
|
|
posts = [LinkedInPost(**item) for item in result.data]
|
|
await cache.set(key, [p.model_dump(mode="json") for p in posts], LINKEDIN_POSTS_TTL)
|
|
return posts
|
|
|
|
async def copy_posts_to_user(self, source_user_id: UUID, target_user_id: UUID) -> List[LinkedInPost]:
|
|
"""Copy all LinkedIn posts from one user to another."""
|
|
source_posts = await self.get_linkedin_posts(source_user_id)
|
|
if not source_posts:
|
|
return []
|
|
|
|
new_posts = []
|
|
for p in source_posts:
|
|
new_post = LinkedInPost(
|
|
user_id=target_user_id,
|
|
post_url=p.post_url,
|
|
post_text=p.post_text,
|
|
post_date=p.post_date,
|
|
likes=p.likes,
|
|
comments=p.comments,
|
|
shares=p.shares,
|
|
raw_data=p.raw_data,
|
|
)
|
|
new_posts.append(new_post)
|
|
|
|
saved = await self.save_linkedin_posts(new_posts)
|
|
logger.info(f"Copied {len(saved)} posts from user {source_user_id} to {target_user_id}")
|
|
return saved
|
|
|
|
async def delete_linkedin_post(self, post_id: UUID) -> None:
|
|
"""Delete a LinkedIn post."""
|
|
# Fetch user_id before deleting so we can invalidate the cache
|
|
lookup = await asyncio.to_thread(
|
|
lambda: self.client.table("linkedin_posts").select("user_id").eq("id", str(post_id)).execute()
|
|
)
|
|
user_id = lookup.data[0]["user_id"] if lookup.data else None
|
|
await asyncio.to_thread(
|
|
lambda: self.client.table("linkedin_posts").delete().eq("id", str(post_id)).execute()
|
|
)
|
|
if user_id:
|
|
await cache.invalidate_linkedin_posts(user_id)
|
|
logger.info(f"Deleted LinkedIn post: {post_id}")
|
|
|
|
async def get_linkedin_post(self, post_id: UUID) -> Optional[LinkedInPost]:
|
|
"""Get a single LinkedIn post by ID."""
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("linkedin_posts").select("*").eq("id", str(post_id)).limit(1).execute()
|
|
)
|
|
if not result.data:
|
|
return None
|
|
return LinkedInPost(**result.data[0])
|
|
|
|
async def get_unclassified_posts(self, user_id: UUID) -> List[LinkedInPost]:
|
|
"""Get all LinkedIn posts without a post_type_id."""
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("linkedin_posts").select("*").eq(
|
|
"user_id", str(user_id)
|
|
).is_("post_type_id", "null").execute()
|
|
)
|
|
return [LinkedInPost(**item) for item in result.data]
|
|
|
|
async def get_posts_by_type(self, user_id: UUID, post_type_id: UUID) -> List[LinkedInPost]:
|
|
"""Get all LinkedIn posts for a specific post type."""
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("linkedin_posts").select("*").eq(
|
|
"user_id", str(user_id)
|
|
).eq("post_type_id", str(post_type_id)).order("post_date", desc=True).execute()
|
|
)
|
|
return [LinkedInPost(**item) for item in result.data]
|
|
|
|
# ==================== LINKEDIN POST INSIGHTS ====================
|
|
|
|
async def upsert_post_insights_posts(self, posts: List['LinkedInPostInsightPost']) -> List['LinkedInPostInsightPost']:
|
|
"""Upsert LinkedIn post insights posts."""
|
|
from datetime import datetime
|
|
from src.database.models import LinkedInPostInsightPost
|
|
|
|
data = []
|
|
for p in posts:
|
|
post_dict = p.model_dump(exclude={"id", "created_at", "updated_at", "first_seen_at"}, exclude_none=True)
|
|
post_dict["user_id"] = str(post_dict["user_id"])
|
|
if post_dict.get("linkedin_account_id"):
|
|
post_dict["linkedin_account_id"] = str(post_dict["linkedin_account_id"])
|
|
if "post_date" in post_dict and isinstance(post_dict["post_date"], datetime):
|
|
post_dict["post_date"] = post_dict["post_date"].isoformat()
|
|
data.append(post_dict)
|
|
|
|
if not data:
|
|
return []
|
|
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("linkedin_post_insights_posts").upsert(
|
|
data,
|
|
on_conflict="user_id,post_urn"
|
|
).execute()
|
|
)
|
|
return [LinkedInPostInsightPost(**item) for item in result.data]
|
|
|
|
async def upsert_post_insights_daily(self, snapshots: List['LinkedInPostInsightDaily']) -> List['LinkedInPostInsightDaily']:
|
|
"""Upsert daily snapshots for post insights."""
|
|
from src.database.models import LinkedInPostInsightDaily
|
|
|
|
data = []
|
|
for s in snapshots:
|
|
snap_dict = s.model_dump(exclude={"id", "created_at", "updated_at"}, exclude_none=True)
|
|
snap_dict["user_id"] = str(snap_dict["user_id"])
|
|
snap_dict["post_id"] = str(snap_dict["post_id"])
|
|
if "snapshot_date" in snap_dict and hasattr(snap_dict["snapshot_date"], "isoformat"):
|
|
snap_dict["snapshot_date"] = snap_dict["snapshot_date"].isoformat()
|
|
data.append(snap_dict)
|
|
|
|
if not data:
|
|
return []
|
|
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("linkedin_post_insights_daily").upsert(
|
|
data,
|
|
on_conflict="post_id,snapshot_date"
|
|
).execute()
|
|
)
|
|
return [LinkedInPostInsightDaily(**item) for item in result.data]
|
|
|
|
async def get_post_insights_posts(self, user_id: UUID) -> List['LinkedInPostInsightPost']:
|
|
"""Get all LinkedIn post insights posts for user."""
|
|
from src.database.models import LinkedInPostInsightPost
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("linkedin_post_insights_posts").select("*").eq(
|
|
"user_id", str(user_id)
|
|
).order("post_date", desc=True).execute()
|
|
)
|
|
return [LinkedInPostInsightPost(**item) for item in result.data]
|
|
|
|
async def get_post_insights_daily(self, user_id: UUID, since_date: Optional[str] = None) -> List['LinkedInPostInsightDaily']:
|
|
"""Get daily post insights snapshots for user."""
|
|
from src.database.models import LinkedInPostInsightDaily
|
|
|
|
def _query():
|
|
q = self.client.table("linkedin_post_insights_daily").select("*").eq("user_id", str(user_id))
|
|
if since_date:
|
|
q = q.gte("snapshot_date", since_date)
|
|
return q.order("snapshot_date", desc=False).execute()
|
|
|
|
result = await asyncio.to_thread(_query)
|
|
return [LinkedInPostInsightDaily(**item) for item in result.data]
|
|
|
|
async def update_post_classification(
|
|
self,
|
|
post_id: UUID,
|
|
post_type_id: UUID,
|
|
classification_method: str,
|
|
classification_confidence: float
|
|
) -> None:
|
|
"""Update a single post's classification."""
|
|
await asyncio.to_thread(
|
|
lambda: self.client.table("linkedin_posts").update({
|
|
"post_type_id": str(post_type_id),
|
|
"classification_method": classification_method,
|
|
"classification_confidence": classification_confidence
|
|
}).eq("id", str(post_id)).execute()
|
|
)
|
|
logger.debug(f"Updated classification for post {post_id}")
|
|
|
|
async def update_linkedin_post(self, post_id: UUID, updates: Dict[str, Any]) -> None:
|
|
"""Update a LinkedIn post with arbitrary fields."""
|
|
if "post_type_id" in updates and updates["post_type_id"]:
|
|
updates["post_type_id"] = str(updates["post_type_id"])
|
|
if "user_id" in updates and updates["user_id"]:
|
|
updates["user_id"] = str(updates["user_id"])
|
|
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("linkedin_posts").update(updates).eq("id", str(post_id)).execute()
|
|
)
|
|
if result.data:
|
|
await cache.invalidate_linkedin_posts(result.data[0]["user_id"])
|
|
logger.debug(f"Updated LinkedIn post {post_id}")
|
|
|
|
async def update_posts_classification_bulk(
|
|
self,
|
|
classifications: List[Dict[str, Any]]
|
|
) -> int:
|
|
"""Bulk update post classifications."""
|
|
count = 0
|
|
for classification in classifications:
|
|
try:
|
|
await asyncio.to_thread(
|
|
lambda c=classification: self.client.table("linkedin_posts").update({
|
|
"post_type_id": str(c["post_type_id"]),
|
|
"classification_method": c["classification_method"],
|
|
"classification_confidence": c["classification_confidence"]
|
|
}).eq("id", str(c["post_id"])).execute()
|
|
)
|
|
count += 1
|
|
except Exception as e:
|
|
logger.warning(f"Failed to update classification for post {classification['post_id']}: {e}")
|
|
logger.info(f"Bulk updated classifications for {count} posts")
|
|
return count
|
|
|
|
# ==================== POST TYPES ====================
|
|
|
|
async def create_post_type(self, post_type: PostType) -> PostType:
|
|
"""Create a new post type."""
|
|
data = post_type.model_dump(exclude={"id", "created_at", "updated_at"}, exclude_none=True)
|
|
if "user_id" in data:
|
|
data["user_id"] = str(data["user_id"])
|
|
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("post_types").insert(data).execute()
|
|
)
|
|
logger.info(f"Created post type: {result.data[0]['name']}")
|
|
created = PostType(**result.data[0])
|
|
await cache.invalidate_post_types(str(created.user_id))
|
|
return created
|
|
|
|
async def create_post_types_bulk(self, post_types: List[PostType]) -> List[PostType]:
|
|
"""Create multiple post types at once."""
|
|
if not post_types:
|
|
return []
|
|
|
|
data = []
|
|
for pt in post_types:
|
|
pt_dict = pt.model_dump(exclude={"id", "created_at", "updated_at"}, exclude_none=True)
|
|
if "user_id" in pt_dict:
|
|
pt_dict["user_id"] = str(pt_dict["user_id"])
|
|
data.append(pt_dict)
|
|
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("post_types").insert(data).execute()
|
|
)
|
|
logger.info(f"Created {len(result.data)} post types")
|
|
created = [PostType(**item) for item in result.data]
|
|
affected_user_ids = {str(pt.user_id) for pt in created}
|
|
for uid in affected_user_ids:
|
|
await cache.invalidate_post_types(uid)
|
|
return created
|
|
|
|
async def get_post_types(self, user_id: UUID, active_only: bool = True) -> List[PostType]:
|
|
"""Get all post types for a user."""
|
|
key = cache.post_types_key(str(user_id), active_only)
|
|
if (hit := await cache.get(key)) is not None:
|
|
return [PostType(**item) for item in hit]
|
|
|
|
def _query():
|
|
query = self.client.table("post_types").select("*").eq("user_id", str(user_id))
|
|
if active_only:
|
|
query = query.eq("is_active", True)
|
|
return query.order("name").execute()
|
|
|
|
result = await asyncio.to_thread(_query)
|
|
post_types = [PostType(**item) for item in result.data]
|
|
await cache.set(key, [pt.model_dump(mode="json") for pt in post_types], POST_TYPES_TTL)
|
|
return post_types
|
|
|
|
# Alias for get_post_types
|
|
async def get_customer_post_types(self, user_id: UUID, active_only: bool = True) -> List[PostType]:
|
|
"""Alias for get_post_types."""
|
|
return await self.get_post_types(user_id, active_only)
|
|
|
|
async def get_post_type(self, post_type_id: UUID) -> Optional[PostType]:
|
|
"""Get a single post type by ID."""
|
|
key = cache.post_type_key(str(post_type_id))
|
|
if (hit := await cache.get(key)) is not None:
|
|
return PostType(**hit)
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("post_types").select("*").eq(
|
|
"id", str(post_type_id)
|
|
).execute()
|
|
)
|
|
if result.data:
|
|
pt = PostType(**result.data[0])
|
|
await cache.set(key, pt.model_dump(mode="json"), POST_TYPE_TTL)
|
|
return pt
|
|
return None
|
|
|
|
async def update_post_type(self, post_type_id: UUID, updates: Dict[str, Any]) -> PostType:
|
|
"""Update a post type."""
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("post_types").update(updates).eq(
|
|
"id", str(post_type_id)
|
|
).execute()
|
|
)
|
|
logger.info(f"Updated post type: {post_type_id}")
|
|
pt = PostType(**result.data[0])
|
|
await cache.invalidate_post_type(str(post_type_id), str(pt.user_id))
|
|
return pt
|
|
|
|
async def update_post_type_analysis(
|
|
self,
|
|
post_type_id: UUID,
|
|
analysis: Dict[str, Any],
|
|
analyzed_post_count: int
|
|
) -> PostType:
|
|
"""Update the analysis for a post type."""
|
|
from datetime import datetime
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("post_types").update({
|
|
"analysis": analysis,
|
|
"analysis_generated_at": datetime.now(timezone.utc).isoformat(),
|
|
"analyzed_post_count": analyzed_post_count
|
|
}).eq("id", str(post_type_id)).execute()
|
|
)
|
|
logger.info(f"Updated analysis for post type: {post_type_id}")
|
|
pt = PostType(**result.data[0])
|
|
await cache.invalidate_post_type(str(post_type_id), str(pt.user_id))
|
|
return pt
|
|
|
|
async def delete_post_type(self, post_type_id: UUID, soft: bool = True) -> None:
|
|
"""Delete a post type (soft delete by default)."""
|
|
if soft:
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("post_types").update({
|
|
"is_active": False
|
|
}).eq("id", str(post_type_id)).execute()
|
|
)
|
|
logger.info(f"Soft deleted post type: {post_type_id}")
|
|
if result.data:
|
|
await cache.invalidate_post_type(str(post_type_id), result.data[0]["user_id"])
|
|
else:
|
|
# Fetch user_id before hard delete for cache invalidation
|
|
lookup = await asyncio.to_thread(
|
|
lambda: self.client.table("post_types").select("user_id").eq(
|
|
"id", str(post_type_id)
|
|
).execute()
|
|
)
|
|
user_id = lookup.data[0]["user_id"] if lookup.data else None
|
|
await asyncio.to_thread(
|
|
lambda: self.client.table("post_types").delete().eq(
|
|
"id", str(post_type_id)
|
|
).execute()
|
|
)
|
|
if user_id:
|
|
await cache.invalidate_post_type(str(post_type_id), user_id)
|
|
logger.info(f"Hard deleted post type: {post_type_id}")
|
|
|
|
# ==================== TOPICS ====================
|
|
|
|
async def save_topics(self, topics: List[Topic]) -> List[Topic]:
|
|
"""Save extracted topics."""
|
|
if not topics:
|
|
logger.warning("No topics to save")
|
|
return []
|
|
|
|
data = []
|
|
for t in topics:
|
|
topic_dict = t.model_dump(exclude={"id", "created_at"}, exclude_none=True)
|
|
if "user_id" in topic_dict:
|
|
topic_dict["user_id"] = str(topic_dict["user_id"])
|
|
if "extracted_from_post_id" in topic_dict and topic_dict["extracted_from_post_id"]:
|
|
topic_dict["extracted_from_post_id"] = str(topic_dict["extracted_from_post_id"])
|
|
if "target_post_type_id" in topic_dict and topic_dict["target_post_type_id"]:
|
|
topic_dict["target_post_type_id"] = str(topic_dict["target_post_type_id"])
|
|
data.append(topic_dict)
|
|
|
|
try:
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("topics").insert(data).execute()
|
|
)
|
|
logger.info(f"Saved {len(result.data)} topics to database")
|
|
return [Topic(**item) for item in result.data]
|
|
except Exception as e:
|
|
logger.error(f"Error saving topics: {e}", exc_info=True)
|
|
saved = []
|
|
for topic_data in data:
|
|
try:
|
|
result = await asyncio.to_thread(
|
|
lambda td=topic_data: self.client.table("topics").insert(td).execute()
|
|
)
|
|
saved.extend([Topic(**item) for item in result.data])
|
|
except Exception as single_error:
|
|
logger.warning(f"Skipped duplicate topic: {topic_data.get('title')}")
|
|
logger.info(f"Saved {len(saved)} topics individually")
|
|
return saved
|
|
|
|
async def get_topics(
|
|
self,
|
|
user_id: UUID,
|
|
unused_only: bool = False,
|
|
post_type_id: Optional[UUID] = None
|
|
) -> List[Topic]:
|
|
"""Get topics for user, optionally filtered by post type."""
|
|
def _query():
|
|
query = self.client.table("topics").select("*").eq("user_id", str(user_id))
|
|
if unused_only:
|
|
query = query.eq("is_used", False)
|
|
if post_type_id:
|
|
query = query.eq("target_post_type_id", str(post_type_id))
|
|
return query.order("created_at", desc=True).execute()
|
|
|
|
result = await asyncio.to_thread(_query)
|
|
return [Topic(**item) for item in result.data]
|
|
|
|
async def mark_topic_used(self, topic_id: UUID) -> None:
|
|
"""Mark topic as used."""
|
|
await asyncio.to_thread(
|
|
lambda: self.client.table("topics").update({
|
|
"is_used": True,
|
|
"used_at": "now()"
|
|
}).eq("id", str(topic_id)).execute()
|
|
)
|
|
logger.info(f"Marked topic {topic_id} as used")
|
|
|
|
# ==================== PROFILE ANALYSIS ====================
|
|
|
|
async def save_profile_analysis(self, analysis: ProfileAnalysis) -> ProfileAnalysis:
|
|
"""Save profile analysis."""
|
|
data = analysis.model_dump(exclude={"id", "created_at"}, exclude_none=True)
|
|
if "user_id" in data:
|
|
data["user_id"] = str(data["user_id"])
|
|
|
|
existing = await asyncio.to_thread(
|
|
lambda: self.client.table("profile_analyses").select("*").eq(
|
|
"user_id", str(analysis.user_id)
|
|
).execute()
|
|
)
|
|
|
|
if existing.data:
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("profile_analyses").update(data).eq(
|
|
"user_id", str(analysis.user_id)
|
|
).execute()
|
|
)
|
|
else:
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("profile_analyses").insert(data).execute()
|
|
)
|
|
|
|
logger.info(f"Saved profile analysis for user: {analysis.user_id}")
|
|
saved = ProfileAnalysis(**result.data[0])
|
|
await cache.delete(cache.profile_analysis_key(str(analysis.user_id)))
|
|
return saved
|
|
|
|
async def get_profile_analysis(self, user_id: UUID) -> Optional[ProfileAnalysis]:
|
|
"""Get profile analysis for user."""
|
|
key = cache.profile_analysis_key(str(user_id))
|
|
if (hit := await cache.get(key)) is not None:
|
|
return ProfileAnalysis(**hit)
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("profile_analyses").select("*").eq(
|
|
"user_id", str(user_id)
|
|
).execute()
|
|
)
|
|
if result.data:
|
|
pa = ProfileAnalysis(**result.data[0])
|
|
await cache.set(key, pa.model_dump(mode="json"), PROFILE_ANALYSIS_TTL)
|
|
return pa
|
|
return None
|
|
|
|
# ==================== RESEARCH RESULTS ====================
|
|
|
|
async def save_research_result(self, research: ResearchResult) -> ResearchResult:
|
|
"""Save research result."""
|
|
data = research.model_dump(exclude={"id", "created_at"}, exclude_none=True)
|
|
if "user_id" in data:
|
|
data["user_id"] = str(data["user_id"])
|
|
if "target_post_type_id" in data and data["target_post_type_id"]:
|
|
data["target_post_type_id"] = str(data["target_post_type_id"])
|
|
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("research_results").insert(data).execute()
|
|
)
|
|
logger.info(f"Saved research result for user: {research.user_id}")
|
|
return ResearchResult(**result.data[0])
|
|
|
|
async def get_latest_research(self, user_id: UUID) -> Optional[ResearchResult]:
|
|
"""Get latest research result for user."""
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("research_results").select("*").eq(
|
|
"user_id", str(user_id)
|
|
).order("created_at", desc=True).limit(1).execute()
|
|
)
|
|
if result.data:
|
|
return ResearchResult(**result.data[0])
|
|
return None
|
|
|
|
async def get_all_research(
|
|
self,
|
|
user_id: UUID,
|
|
post_type_id: Optional[UUID] = None
|
|
) -> List[ResearchResult]:
|
|
"""Get all research results for user, optionally filtered by post type."""
|
|
def _query():
|
|
query = self.client.table("research_results").select("*").eq(
|
|
"user_id", str(user_id)
|
|
)
|
|
if post_type_id:
|
|
query = query.eq("target_post_type_id", str(post_type_id))
|
|
return query.order("created_at", desc=True).execute()
|
|
|
|
result = await asyncio.to_thread(_query)
|
|
return [ResearchResult(**item) for item in result.data]
|
|
|
|
# ==================== GENERATED POSTS ====================
|
|
|
|
async def save_generated_post(self, post: GeneratedPost) -> GeneratedPost:
|
|
"""Save generated post."""
|
|
data = post.model_dump(exclude={"id", "created_at"}, exclude_none=True)
|
|
if "user_id" in data:
|
|
data["user_id"] = str(data["user_id"])
|
|
if "topic_id" in data and data["topic_id"]:
|
|
data["topic_id"] = str(data["topic_id"])
|
|
if "post_type_id" in data and data["post_type_id"]:
|
|
data["post_type_id"] = str(data["post_type_id"])
|
|
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("generated_posts").insert(data).execute()
|
|
)
|
|
logger.info(f"Saved generated post: {result.data[0]['id']}")
|
|
saved = GeneratedPost(**result.data[0])
|
|
await cache.invalidate_gen_posts(str(saved.user_id))
|
|
return saved
|
|
|
|
async def update_generated_post(self, post_id: UUID, updates: Dict[str, Any]) -> GeneratedPost:
|
|
"""Update generated post."""
|
|
# Convert datetime fields to isoformat strings
|
|
datetime_fields = ['published_at', 'approved_at', 'rejected_at', 'scheduled_at']
|
|
for field in datetime_fields:
|
|
if field in updates and isinstance(updates[field], datetime):
|
|
updates[field] = updates[field].isoformat()
|
|
|
|
# Handle metadata dict - ensure all nested datetime values are serialized
|
|
if 'metadata' in updates and isinstance(updates['metadata'], dict):
|
|
serialized_metadata = {}
|
|
for k, value in updates['metadata'].items():
|
|
serialized_metadata[k] = value.isoformat() if isinstance(value, datetime) else value
|
|
updates['metadata'] = serialized_metadata
|
|
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("generated_posts").update(updates).eq(
|
|
"id", str(post_id)
|
|
).execute()
|
|
)
|
|
logger.info(f"Updated generated post: {post_id}")
|
|
updated = GeneratedPost(**result.data[0])
|
|
await cache.invalidate_gen_post(str(post_id), str(updated.user_id))
|
|
return updated
|
|
|
|
async def get_generated_posts(self, user_id: UUID) -> List[GeneratedPost]:
|
|
"""Get all generated posts for user."""
|
|
key = cache.gen_posts_key(str(user_id))
|
|
if (hit := await cache.get(key)) is not None:
|
|
return [GeneratedPost(**item) for item in hit]
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("generated_posts").select("*").eq(
|
|
"user_id", str(user_id)
|
|
).order("created_at", desc=True).execute()
|
|
)
|
|
posts = [GeneratedPost(**item) for item in result.data]
|
|
await cache.set(key, [p.model_dump(mode="json") for p in posts], GEN_POSTS_TTL)
|
|
return posts
|
|
|
|
async def get_generated_post(self, post_id: UUID) -> Optional[GeneratedPost]:
|
|
"""Get a single generated post by ID."""
|
|
key = cache.gen_post_key(str(post_id))
|
|
if (hit := await cache.get(key)) is not None:
|
|
return GeneratedPost(**hit)
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("generated_posts").select("*").eq(
|
|
"id", str(post_id)
|
|
).execute()
|
|
)
|
|
if result.data:
|
|
post = GeneratedPost(**result.data[0])
|
|
await cache.set(key, post.model_dump(mode="json"), GEN_POST_TTL)
|
|
return post
|
|
return None
|
|
|
|
async def get_scheduled_posts_due(self) -> List[GeneratedPost]:
|
|
"""Get all posts that are scheduled and due for publishing."""
|
|
from datetime import datetime
|
|
now = datetime.now(timezone.utc).isoformat()
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("generated_posts").select("*").eq(
|
|
"status", "scheduled"
|
|
).lte("scheduled_at", now).execute()
|
|
)
|
|
return [GeneratedPost(**item) for item in result.data]
|
|
|
|
async def get_scheduled_posts_for_company(self, company_id: UUID) -> List[Dict[str, Any]]:
|
|
"""Get all scheduled posts for a company with employee info."""
|
|
# Get all profiles (employees) belonging to this company
|
|
profiles_result = await asyncio.to_thread(
|
|
lambda: self.client.table("profiles").select("id, display_name").eq(
|
|
"company_id", str(company_id)
|
|
).execute()
|
|
)
|
|
|
|
if not profiles_result.data:
|
|
return []
|
|
|
|
user_ids = [p["id"] for p in profiles_result.data]
|
|
profile_map = {p["id"]: p for p in profiles_result.data}
|
|
|
|
posts_result = await asyncio.to_thread(
|
|
lambda: self.client.table("generated_posts").select(
|
|
"id, user_id, topic_title, post_content, status, scheduled_at, created_at"
|
|
).in_(
|
|
"user_id", user_ids
|
|
).in_(
|
|
"status", ["ready", "scheduled", "published"]
|
|
).order("scheduled_at", desc=False, nullsfirst=False).execute()
|
|
)
|
|
|
|
enriched_posts = []
|
|
for post in posts_result.data:
|
|
profile = profile_map.get(post["user_id"])
|
|
enriched_posts.append({
|
|
**post,
|
|
"employee_name": profile["display_name"] if profile else "Unknown",
|
|
"employee_user_id": post["user_id"]
|
|
})
|
|
|
|
return enriched_posts
|
|
|
|
async def schedule_post(self, post_id: UUID, scheduled_at: datetime, scheduled_by_user_id: UUID) -> GeneratedPost:
|
|
"""Schedule a post for publishing."""
|
|
from datetime import datetime as dt
|
|
updates = {
|
|
"status": "scheduled",
|
|
"scheduled_at": scheduled_at.isoformat() if isinstance(scheduled_at, dt) else scheduled_at,
|
|
"scheduled_by_user_id": str(scheduled_by_user_id)
|
|
}
|
|
return await self.update_generated_post(post_id, updates)
|
|
|
|
async def unschedule_post(self, post_id: UUID) -> GeneratedPost:
|
|
"""Remove scheduling from a post (back to approved status)."""
|
|
updates = {
|
|
"status": "approved",
|
|
"scheduled_at": None,
|
|
"scheduled_by_user_id": None
|
|
}
|
|
return await self.update_generated_post(post_id, updates)
|
|
|
|
|
|
# ==================== PROFILES / USERS ====================
|
|
|
|
async def create_profile(self, user_id: UUID, profile: Profile) -> Profile:
|
|
"""Create a profile for a user."""
|
|
data = profile.model_dump(exclude={"created_at", "updated_at"}, exclude_none=True)
|
|
data["id"] = str(user_id)
|
|
if "company_id" in data and data["company_id"]:
|
|
data["company_id"] = str(data["company_id"])
|
|
if "account_type" in data:
|
|
data["account_type"] = data["account_type"].value if hasattr(data["account_type"], "value") else data["account_type"]
|
|
if "onboarding_status" in data:
|
|
data["onboarding_status"] = data["onboarding_status"].value if hasattr(data["onboarding_status"], "value") else data["onboarding_status"]
|
|
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("profiles").upsert(data).execute()
|
|
)
|
|
logger.info(f"Created/updated profile: {result.data[0]['id']}")
|
|
return Profile(**result.data[0])
|
|
|
|
async def get_profile(self, user_id: UUID) -> Optional[Profile]:
|
|
"""Get profile by user ID."""
|
|
key = cache.profile_key(str(user_id))
|
|
if (hit := await cache.get(key)) is not None:
|
|
return Profile(**hit)
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("profiles").select("*").eq("id", str(user_id)).execute()
|
|
)
|
|
if result.data:
|
|
profile = Profile(**result.data[0])
|
|
await cache.set(key, profile.model_dump(mode="json"), PROFILE_TTL)
|
|
return profile
|
|
return None
|
|
|
|
async def get_profiles_by_linkedin_url(self, linkedin_url: str) -> List[Profile]:
|
|
"""Get all profiles with this LinkedIn URL."""
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("profiles").select("*").eq("linkedin_url", linkedin_url).execute()
|
|
)
|
|
return [Profile(**item) for item in result.data]
|
|
|
|
async def update_profile(self, user_id: UUID, updates: Dict[str, Any]) -> Profile:
|
|
"""Update profile fields."""
|
|
if "company_id" in updates and updates["company_id"]:
|
|
updates["company_id"] = str(updates["company_id"])
|
|
for k in ["account_type", "onboarding_status"]:
|
|
if k in updates and hasattr(updates[k], "value"):
|
|
updates[k] = updates[k].value
|
|
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("profiles").update(updates).eq(
|
|
"id", str(user_id)
|
|
).execute()
|
|
)
|
|
logger.info(f"Updated profile: {user_id}")
|
|
profile = Profile(**result.data[0])
|
|
await cache.invalidate_profile(str(user_id))
|
|
return profile
|
|
|
|
# ==================== LINKEDIN ACCOUNTS ====================
|
|
|
|
async def get_linkedin_account(self, user_id: UUID) -> Optional['LinkedInAccount']:
|
|
"""Get LinkedIn account for user."""
|
|
from src.database.models import LinkedInAccount
|
|
key = cache.linkedin_account_key(str(user_id))
|
|
if (hit := await cache.get(key)) is not None:
|
|
return LinkedInAccount(**hit)
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("linkedin_accounts").select("*")
|
|
.eq("user_id", str(user_id)).eq("is_active", True).execute()
|
|
)
|
|
if result.data:
|
|
account = LinkedInAccount(**result.data[0])
|
|
await cache.set(key, account.model_dump(mode="json"), LINKEDIN_ACCOUNT_TTL)
|
|
return account
|
|
return None
|
|
|
|
async def get_linkedin_account_by_id(self, account_id: UUID) -> Optional['LinkedInAccount']:
|
|
"""Get LinkedIn account by ID."""
|
|
from src.database.models import LinkedInAccount
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("linkedin_accounts").select("*")
|
|
.eq("id", str(account_id)).execute()
|
|
)
|
|
if result.data:
|
|
return LinkedInAccount(**result.data[0])
|
|
return None
|
|
|
|
async def list_linkedin_accounts(self, active_only: bool = True) -> list['LinkedInAccount']:
|
|
"""List LinkedIn accounts (optionally only active ones)."""
|
|
from src.database.models import LinkedInAccount
|
|
|
|
def _query():
|
|
q = self.client.table("linkedin_accounts").select("*")
|
|
if active_only:
|
|
q = q.eq("is_active", True)
|
|
return q.execute()
|
|
|
|
result = await asyncio.to_thread(_query)
|
|
return [LinkedInAccount(**item) for item in result.data]
|
|
|
|
async def create_linkedin_account(self, account: 'LinkedInAccount') -> 'LinkedInAccount':
|
|
"""Create LinkedIn account connection."""
|
|
from src.database.models import LinkedInAccount
|
|
data = account.model_dump(exclude={'id', 'created_at', 'updated_at'})
|
|
data['user_id'] = str(data['user_id'])
|
|
data['token_expires_at'] = data['token_expires_at'].isoformat()
|
|
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("linkedin_accounts").insert(data).execute()
|
|
)
|
|
logger.info(f"Created LinkedIn account for user: {account.user_id}")
|
|
created = LinkedInAccount(**result.data[0])
|
|
await cache.invalidate_linkedin_account(str(created.user_id))
|
|
return created
|
|
|
|
async def update_linkedin_account(self, account_id: UUID, updates: Dict) -> 'LinkedInAccount':
|
|
"""Update LinkedIn account."""
|
|
from src.database.models import LinkedInAccount
|
|
# Convert all datetime fields to isoformat strings
|
|
datetime_fields = ['token_expires_at', 'last_used_at', 'last_error_at']
|
|
for field in datetime_fields:
|
|
if field in updates and isinstance(updates[field], datetime):
|
|
updates[field] = updates[field].isoformat()
|
|
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("linkedin_accounts").update(updates)
|
|
.eq("id", str(account_id)).execute()
|
|
)
|
|
logger.info(f"Updated LinkedIn account: {account_id}")
|
|
updated = LinkedInAccount(**result.data[0])
|
|
await cache.invalidate_linkedin_account(str(updated.user_id))
|
|
return updated
|
|
|
|
async def delete_linkedin_account(self, account_id: UUID) -> None:
|
|
"""Delete LinkedIn account connection."""
|
|
# Fetch user_id before delete for cache invalidation
|
|
lookup = await asyncio.to_thread(
|
|
lambda: self.client.table("linkedin_accounts").select("user_id")
|
|
.eq("id", str(account_id)).execute()
|
|
)
|
|
user_id = lookup.data[0]["user_id"] if lookup.data else None
|
|
await asyncio.to_thread(
|
|
lambda: self.client.table("linkedin_accounts").delete()
|
|
.eq("id", str(account_id)).execute()
|
|
)
|
|
if user_id:
|
|
await cache.invalidate_linkedin_account(user_id)
|
|
logger.info(f"Deleted LinkedIn account: {account_id}")
|
|
|
|
# ==================== TELEGRAM ACCOUNTS ====================
|
|
|
|
async def get_telegram_account(self, user_id: UUID) -> Optional['TelegramAccount']:
|
|
"""Get Telegram account for user."""
|
|
from src.database.models import TelegramAccount
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("telegram_accounts").select("*")
|
|
.eq("user_id", str(user_id)).eq("is_active", True).execute()
|
|
)
|
|
if result.data:
|
|
return TelegramAccount(**result.data[0])
|
|
return None
|
|
|
|
async def get_telegram_account_by_chat_id(self, chat_id: str) -> Optional['TelegramAccount']:
|
|
"""Get Telegram account by chat_id."""
|
|
from src.database.models import TelegramAccount
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("telegram_accounts").select("*")
|
|
.eq("telegram_chat_id", chat_id).eq("is_active", True).execute()
|
|
)
|
|
if result.data:
|
|
return TelegramAccount(**result.data[0])
|
|
return None
|
|
|
|
async def save_telegram_account(self, account: 'TelegramAccount') -> 'TelegramAccount':
|
|
"""Create or update a Telegram account connection."""
|
|
from src.database.models import TelegramAccount
|
|
data = account.model_dump(exclude={'id', 'created_at', 'updated_at'}, exclude_none=True)
|
|
data['user_id'] = str(data['user_id'])
|
|
|
|
existing = await asyncio.to_thread(
|
|
lambda: self.client.table("telegram_accounts").select("id")
|
|
.eq("user_id", str(account.user_id)).execute()
|
|
)
|
|
|
|
if existing.data:
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("telegram_accounts").update(data)
|
|
.eq("user_id", str(account.user_id)).execute()
|
|
)
|
|
else:
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("telegram_accounts").insert(data).execute()
|
|
)
|
|
|
|
logger.info(f"Saved Telegram account for user: {account.user_id}")
|
|
return TelegramAccount(**result.data[0])
|
|
|
|
async def delete_telegram_account(self, user_id: UUID) -> bool:
|
|
"""Delete Telegram account connection for user."""
|
|
await asyncio.to_thread(
|
|
lambda: self.client.table("telegram_accounts").delete()
|
|
.eq("user_id", str(user_id)).execute()
|
|
)
|
|
logger.info(f"Deleted Telegram account for user: {user_id}")
|
|
return True
|
|
|
|
# ==================== TEAMS ACCOUNTS ====================
|
|
|
|
async def get_teams_account(self, user_id: UUID) -> Optional['TeamsAccount']:
|
|
"""Get Teams account for user."""
|
|
from src.database.models import TeamsAccount
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("teams_accounts").select("*")
|
|
.eq("user_id", str(user_id)).eq("is_active", True).execute()
|
|
)
|
|
if result.data:
|
|
return TeamsAccount(**result.data[0])
|
|
return None
|
|
|
|
async def save_teams_account(self, account: 'TeamsAccount') -> 'TeamsAccount':
|
|
"""Create or update a Teams account connection."""
|
|
from src.database.models import TeamsAccount
|
|
data = account.model_dump(exclude={'id', 'created_at', 'updated_at'}, exclude_none=True)
|
|
data['user_id'] = str(data['user_id'])
|
|
|
|
existing = await asyncio.to_thread(
|
|
lambda: self.client.table("teams_accounts").select("id")
|
|
.eq("user_id", str(account.user_id)).execute()
|
|
)
|
|
|
|
if existing.data:
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("teams_accounts").update(data)
|
|
.eq("user_id", str(account.user_id)).execute()
|
|
)
|
|
else:
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("teams_accounts").insert(data).execute()
|
|
)
|
|
|
|
logger.info(f"Saved Teams account for user: {account.user_id}")
|
|
return TeamsAccount(**result.data[0])
|
|
|
|
async def delete_teams_account(self, user_id: UUID) -> bool:
|
|
"""Delete Teams account connection for user."""
|
|
await asyncio.to_thread(
|
|
lambda: self.client.table("teams_accounts").delete()
|
|
.eq("user_id", str(user_id)).execute()
|
|
)
|
|
logger.info(f"Deleted Teams account for user: {user_id}")
|
|
return True
|
|
|
|
# ==================== USERS ====================
|
|
|
|
async def get_user(self, user_id: UUID) -> Optional[User]:
|
|
"""Get user by ID (from users view)."""
|
|
key = cache.user_key(str(user_id))
|
|
if (hit := await cache.get(key)) is not None:
|
|
return User(**hit)
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("users").select("*").eq("id", str(user_id)).execute()
|
|
)
|
|
if result.data:
|
|
user = User(**result.data[0])
|
|
await cache.set(key, user.model_dump(mode="json"), USER_TTL)
|
|
return user
|
|
return None
|
|
|
|
async def get_user_by_email(self, email: str) -> Optional[User]:
|
|
"""Get user by email (from users view)."""
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("users").select("*").eq("email", email.lower()).execute()
|
|
)
|
|
if result.data:
|
|
return User(**result.data[0])
|
|
return None
|
|
|
|
async def get_user_by_linkedin_sub(self, linkedin_sub: str) -> Optional[User]:
|
|
"""Get user by LinkedIn sub (OAuth identifier)."""
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("users").select("*").eq("linkedin_sub", linkedin_sub).execute()
|
|
)
|
|
if result.data:
|
|
return User(**result.data[0])
|
|
return None
|
|
|
|
async def update_user(self, user_id: UUID, updates: Dict[str, Any]) -> User:
|
|
"""Update user/profile fields."""
|
|
profile_fields = ["account_type", "display_name", "onboarding_status",
|
|
"onboarding_data", "company_id",
|
|
"linkedin_url", "writing_style_notes", "metadata",
|
|
"profile_picture", "creator_email", "customer_email", "is_active"]
|
|
profile_updates = {k: v for k, v in updates.items() if k in profile_fields}
|
|
|
|
if profile_updates:
|
|
await self.update_profile(user_id, profile_updates)
|
|
# update_profile already calls cache.invalidate_profile which also kills user_key
|
|
|
|
# Invalidate user view separately (in case it wasn't covered above)
|
|
await cache.delete(cache.user_key(str(user_id)))
|
|
return await self.get_user(user_id)
|
|
|
|
async def list_users(self, account_type: Optional[str] = None, company_id: Optional[UUID] = None) -> List[User]:
|
|
"""List users with optional filters (from users view)."""
|
|
def _query():
|
|
query = self.client.table("users").select("*")
|
|
if account_type:
|
|
query = query.eq("account_type", account_type)
|
|
if company_id:
|
|
query = query.eq("company_id", str(company_id))
|
|
return query.order("created_at", desc=True).execute()
|
|
|
|
result = await asyncio.to_thread(_query)
|
|
return [User(**item) for item in result.data]
|
|
|
|
async def delete_user_completely(self, user_id: UUID) -> bool:
|
|
"""Delete a user and all related data completely."""
|
|
try:
|
|
user_id_str = str(user_id)
|
|
|
|
# Delete all content data directly via user_id
|
|
for table in ["generated_posts", "linkedin_posts", "example_posts",
|
|
"reference_profiles", "profile_analyses", "linkedin_profiles",
|
|
"post_types", "research_results", "topics"]:
|
|
await asyncio.to_thread(
|
|
lambda t=table: self.client.table(t).delete().eq(
|
|
"user_id", user_id_str
|
|
).execute()
|
|
)
|
|
|
|
# Get user email for invitation deletion
|
|
user = await self.get_user(user_id)
|
|
if user and user.email:
|
|
await asyncio.to_thread(
|
|
lambda: self.client.table("invitations").delete().eq(
|
|
"email", user.email
|
|
).execute()
|
|
)
|
|
|
|
# Delete profile
|
|
await asyncio.to_thread(
|
|
lambda: self.client.table("profiles").delete().eq(
|
|
"id", user_id_str
|
|
).execute()
|
|
)
|
|
|
|
# Delete from auth.users
|
|
if self.admin_client:
|
|
await asyncio.to_thread(
|
|
lambda: self.admin_client.auth.admin.delete_user(user_id_str)
|
|
)
|
|
logger.info(f"Deleted auth user {user_id}")
|
|
else:
|
|
logger.warning(f"Cannot delete auth user {user_id} - no service role key configured")
|
|
|
|
logger.info(f"Completely deleted user {user_id} and all related data")
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error deleting user {user_id}: {e}")
|
|
raise
|
|
|
|
# ==================== COMPANIES ====================
|
|
|
|
async def create_company(self, company: Company) -> Company:
|
|
"""Create a new company."""
|
|
data = company.model_dump(exclude={"id", "created_at", "updated_at"}, exclude_none=True)
|
|
if "owner_user_id" in data:
|
|
data["owner_user_id"] = str(data["owner_user_id"])
|
|
if "license_key_id" in data:
|
|
data["license_key_id"] = str(data["license_key_id"])
|
|
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("companies").insert(data).execute()
|
|
)
|
|
logger.info(f"Created company: {result.data[0]['id']}")
|
|
return Company(**result.data[0])
|
|
|
|
async def get_company(self, company_id: UUID) -> Optional[Company]:
|
|
"""Get company by ID."""
|
|
key = cache.company_key(str(company_id))
|
|
if (hit := await cache.get(key)) is not None:
|
|
return Company(**hit)
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("companies").select("*").eq("id", str(company_id)).execute()
|
|
)
|
|
if result.data:
|
|
company = Company(**result.data[0])
|
|
await cache.set(key, company.model_dump(mode="json"), COMPANY_TTL)
|
|
return company
|
|
return None
|
|
|
|
async def get_company_by_owner(self, owner_user_id: UUID) -> Optional[Company]:
|
|
"""Get company by owner user ID."""
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("companies").select("*").eq("owner_user_id", str(owner_user_id)).execute()
|
|
)
|
|
if result.data:
|
|
return Company(**result.data[0])
|
|
return None
|
|
|
|
async def get_company_employees(self, company_id: UUID) -> List[User]:
|
|
"""Get all employees of a company."""
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("users").select("*").eq(
|
|
"company_id", str(company_id)
|
|
).eq("account_type", "employee").order("created_at", desc=True).execute()
|
|
)
|
|
return [User(**item) for item in result.data]
|
|
|
|
async def update_company(self, company_id: UUID, updates: Dict[str, Any]) -> Company:
|
|
"""Update company fields."""
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("companies").update(updates).eq(
|
|
"id", str(company_id)
|
|
).execute()
|
|
)
|
|
logger.info(f"Updated company: {company_id}")
|
|
company = Company(**result.data[0])
|
|
await cache.invalidate_company(str(company_id))
|
|
return company
|
|
|
|
async def list_companies(self) -> List[Company]:
|
|
"""List all companies."""
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("companies").select("*").order("created_at", desc=True).execute()
|
|
)
|
|
return [Company(**item) for item in result.data]
|
|
|
|
# ==================== INVITATIONS ====================
|
|
|
|
async def create_invitation(self, invitation: Invitation) -> Invitation:
|
|
"""Create a new invitation."""
|
|
from datetime import datetime
|
|
data = invitation.model_dump(exclude={"id", "created_at"}, exclude_none=True)
|
|
for key in ["company_id", "invited_by_user_id", "accepted_by_user_id"]:
|
|
if key in data and data[key]:
|
|
data[key] = str(data[key])
|
|
if "expires_at" in data and isinstance(data["expires_at"], datetime):
|
|
data["expires_at"] = data["expires_at"].isoformat()
|
|
if "accepted_at" in data and isinstance(data["accepted_at"], datetime):
|
|
data["accepted_at"] = data["accepted_at"].isoformat()
|
|
if "status" in data and hasattr(data["status"], "value"):
|
|
data["status"] = data["status"].value
|
|
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("invitations").insert(data).execute()
|
|
)
|
|
logger.info(f"Created invitation: {result.data[0]['id']}")
|
|
return Invitation(**result.data[0])
|
|
|
|
async def get_invitation_by_token(self, token: str) -> Optional[Invitation]:
|
|
"""Get invitation by token."""
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("invitations").select("*").eq("token", token).execute()
|
|
)
|
|
if result.data:
|
|
return Invitation(**result.data[0])
|
|
return None
|
|
|
|
async def get_invitation(self, invitation_id: UUID) -> Optional[Invitation]:
|
|
"""Get invitation by ID."""
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("invitations").select("*").eq("id", str(invitation_id)).execute()
|
|
)
|
|
if result.data:
|
|
return Invitation(**result.data[0])
|
|
return None
|
|
|
|
async def update_invitation(self, invitation_id: UUID, updates: Dict[str, Any]) -> Invitation:
|
|
"""Update invitation fields."""
|
|
from datetime import datetime
|
|
if "accepted_by_user_id" in updates and updates["accepted_by_user_id"]:
|
|
updates["accepted_by_user_id"] = str(updates["accepted_by_user_id"])
|
|
if "accepted_at" in updates and isinstance(updates["accepted_at"], datetime):
|
|
updates["accepted_at"] = updates["accepted_at"].isoformat()
|
|
if "status" in updates and hasattr(updates["status"], "value"):
|
|
updates["status"] = updates["status"].value
|
|
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("invitations").update(updates).eq(
|
|
"id", str(invitation_id)
|
|
).execute()
|
|
)
|
|
logger.info(f"Updated invitation: {invitation_id}")
|
|
return Invitation(**result.data[0])
|
|
|
|
async def get_pending_invitations(self, company_id: UUID) -> List[Invitation]:
|
|
"""Get all pending invitations for a company."""
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("invitations").select("*").eq(
|
|
"company_id", str(company_id)
|
|
).eq("status", "pending").order("created_at", desc=True).execute()
|
|
)
|
|
return [Invitation(**item) for item in result.data]
|
|
|
|
async def get_invitations_by_email(self, email: str) -> List[Invitation]:
|
|
"""Get all invitations for an email address."""
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("invitations").select("*").eq(
|
|
"email", email.lower()
|
|
).order("created_at", desc=True).execute()
|
|
)
|
|
return [Invitation(**item) for item in result.data]
|
|
|
|
# ==================== EXAMPLE POSTS ====================
|
|
|
|
async def save_example_posts(self, posts: List[ExamplePost]) -> List[ExamplePost]:
|
|
"""Save example posts (bulk)."""
|
|
if not posts:
|
|
return []
|
|
|
|
data = []
|
|
for p in posts:
|
|
post_dict = p.model_dump(exclude={"id", "created_at"}, exclude_none=True)
|
|
if "user_id" in post_dict:
|
|
post_dict["user_id"] = str(post_dict["user_id"])
|
|
if "post_type_id" in post_dict and post_dict["post_type_id"]:
|
|
post_dict["post_type_id"] = str(post_dict["post_type_id"])
|
|
data.append(post_dict)
|
|
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("example_posts").insert(data).execute()
|
|
)
|
|
logger.info(f"Saved {len(result.data)} example posts")
|
|
return [ExamplePost(**item) for item in result.data]
|
|
|
|
async def save_example_post(self, post: ExamplePost) -> ExamplePost:
|
|
"""Save a single example post."""
|
|
posts = await self.save_example_posts([post])
|
|
return posts[0] if posts else post
|
|
|
|
async def get_example_posts(self, user_id: UUID) -> List[ExamplePost]:
|
|
"""Get all example posts for a user."""
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("example_posts").select("*").eq(
|
|
"user_id", str(user_id)
|
|
).order("created_at", desc=True).execute()
|
|
)
|
|
return [ExamplePost(**item) for item in result.data]
|
|
|
|
async def delete_example_post(self, post_id: UUID) -> None:
|
|
"""Delete an example post."""
|
|
await asyncio.to_thread(
|
|
lambda: self.client.table("example_posts").delete().eq("id", str(post_id)).execute()
|
|
)
|
|
logger.info(f"Deleted example post: {post_id}")
|
|
|
|
# ==================== REFERENCE PROFILES ====================
|
|
|
|
async def create_reference_profile(self, profile: ReferenceProfile) -> ReferenceProfile:
|
|
"""Create a new reference profile."""
|
|
data = profile.model_dump(exclude={"id", "created_at"}, exclude_none=True)
|
|
if "user_id" in data:
|
|
data["user_id"] = str(data["user_id"])
|
|
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("reference_profiles").insert(data).execute()
|
|
)
|
|
logger.info(f"Created reference profile: {result.data[0]['id']}")
|
|
return ReferenceProfile(**result.data[0])
|
|
|
|
async def get_reference_profiles(self, user_id: UUID) -> List[ReferenceProfile]:
|
|
"""Get all reference profiles for a user."""
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("reference_profiles").select("*").eq(
|
|
"user_id", str(user_id)
|
|
).order("created_at", desc=True).execute()
|
|
)
|
|
return [ReferenceProfile(**item) for item in result.data]
|
|
|
|
async def update_reference_profile(self, profile_id: UUID, updates: Dict[str, Any]) -> ReferenceProfile:
|
|
"""Update reference profile fields."""
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("reference_profiles").update(updates).eq(
|
|
"id", str(profile_id)
|
|
).execute()
|
|
)
|
|
logger.info(f"Updated reference profile: {profile_id}")
|
|
return ReferenceProfile(**result.data[0])
|
|
|
|
async def delete_reference_profile(self, profile_id: UUID) -> None:
|
|
"""Delete a reference profile."""
|
|
await asyncio.to_thread(
|
|
lambda: self.client.table("reference_profiles").delete().eq("id", str(profile_id)).execute()
|
|
)
|
|
logger.info(f"Deleted reference profile: {profile_id}")
|
|
|
|
# ==================== API USAGE LOGS ====================
|
|
|
|
async def log_api_usage(
|
|
self,
|
|
provider: str,
|
|
model: str,
|
|
operation: str,
|
|
prompt_tokens: int,
|
|
completion_tokens: int,
|
|
total_tokens: int,
|
|
estimated_cost_usd: float,
|
|
user_id: Optional[UUID] = None,
|
|
company_id: Optional[UUID] = None
|
|
) -> None:
|
|
"""Log an API usage event."""
|
|
data = {
|
|
"provider": provider,
|
|
"model": model,
|
|
"operation": operation,
|
|
"prompt_tokens": prompt_tokens,
|
|
"completion_tokens": completion_tokens,
|
|
"total_tokens": total_tokens,
|
|
"estimated_cost_usd": estimated_cost_usd
|
|
}
|
|
if user_id:
|
|
data["user_id"] = str(user_id)
|
|
if company_id:
|
|
data["company_id"] = str(company_id)
|
|
|
|
try:
|
|
await asyncio.to_thread(
|
|
lambda: self.client.table("api_usage_logs").insert(data).execute()
|
|
)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to log API usage: {e}")
|
|
|
|
async def get_usage_stats(
|
|
self,
|
|
start_date: Optional[str] = None,
|
|
end_date: Optional[str] = None,
|
|
user_id: Optional[UUID] = None,
|
|
company_id: Optional[UUID] = None
|
|
) -> List[Dict[str, Any]]:
|
|
"""Get aggregated usage stats."""
|
|
def _query():
|
|
query = self.client.table("api_usage_logs").select("*")
|
|
if user_id:
|
|
query = query.eq("user_id", str(user_id))
|
|
if company_id:
|
|
query = query.eq("company_id", str(company_id))
|
|
if start_date:
|
|
query = query.gte("created_at", start_date)
|
|
if end_date:
|
|
query = query.lte("created_at", end_date)
|
|
return query.order("created_at", desc=True).execute()
|
|
|
|
result = await asyncio.to_thread(_query)
|
|
return result.data
|
|
|
|
async def get_usage_by_day(
|
|
self,
|
|
start_date: Optional[str] = None,
|
|
end_date: Optional[str] = None,
|
|
user_id: Optional[UUID] = None,
|
|
company_id: Optional[UUID] = None
|
|
) -> List[Dict[str, Any]]:
|
|
"""Get daily usage breakdown."""
|
|
logs = await self.get_usage_stats(start_date, end_date, user_id, company_id)
|
|
|
|
from collections import defaultdict
|
|
daily = defaultdict(lambda: {"total_tokens": 0, "estimated_cost_usd": 0.0, "count": 0})
|
|
|
|
for log in logs:
|
|
day = log["created_at"][:10]
|
|
daily[day]["total_tokens"] += log.get("total_tokens", 0)
|
|
daily[day]["estimated_cost_usd"] += float(log.get("estimated_cost_usd", 0))
|
|
daily[day]["count"] += 1
|
|
|
|
return [{"date": day, **data} for day, data in sorted(daily.items())]
|
|
|
|
async def get_post_stats(
|
|
self,
|
|
start_date: Optional[str] = None,
|
|
user_id: Optional[UUID] = None,
|
|
company_id: Optional[UUID] = None
|
|
) -> List[Dict[str, Any]]:
|
|
"""Get generated posts with select fields for statistics."""
|
|
# Resolve user_ids for filtering
|
|
target_user_ids = None
|
|
if user_id:
|
|
target_user_ids = [str(user_id)]
|
|
if company_id:
|
|
# Get all users belonging to this company
|
|
profiles_result = await asyncio.to_thread(
|
|
lambda: self.client.table("profiles").select("id").eq(
|
|
"company_id", str(company_id)
|
|
).execute()
|
|
)
|
|
company_user_ids = [p["id"] for p in profiles_result.data]
|
|
if target_user_ids is not None:
|
|
target_user_ids = [u for u in target_user_ids if u in company_user_ids]
|
|
else:
|
|
target_user_ids = company_user_ids
|
|
if not target_user_ids:
|
|
return []
|
|
|
|
def _query():
|
|
q = self.client.table("generated_posts").select(
|
|
"id, created_at, status, approved_at, published_at, user_id"
|
|
)
|
|
if target_user_ids is not None:
|
|
q = q.in_("user_id", target_user_ids)
|
|
if start_date:
|
|
q = q.gte("created_at", start_date)
|
|
return q.order("created_at", desc=True).execute()
|
|
|
|
result = await asyncio.to_thread(_query)
|
|
return result.data
|
|
|
|
# ==================== LICENSE KEYS ====================
|
|
|
|
async def list_license_keys(self) -> List[LicenseKey]:
|
|
"""Get all license keys (admin only)."""
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("license_keys")
|
|
.select("*")
|
|
.order("created_at", desc=True)
|
|
.execute()
|
|
)
|
|
return [LicenseKey(**row) for row in result.data]
|
|
|
|
async def get_license_key(self, key: str) -> Optional[LicenseKey]:
|
|
"""Get license key by key string."""
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("license_keys")
|
|
.select("*")
|
|
.eq("key", key)
|
|
.execute()
|
|
)
|
|
return LicenseKey(**result.data[0]) if result.data else None
|
|
|
|
async def create_license_key(
|
|
self,
|
|
key: str,
|
|
max_employees: int,
|
|
daily_token_limit: Optional[int] = None,
|
|
description: Optional[str] = None
|
|
) -> LicenseKey:
|
|
"""Create new license key."""
|
|
data = {
|
|
"key": key,
|
|
"max_employees": max_employees,
|
|
"description": description,
|
|
}
|
|
if daily_token_limit is not None:
|
|
data["daily_token_limit"] = daily_token_limit
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("license_keys").insert(data).execute()
|
|
)
|
|
logger.info(f"Created license key: {key}")
|
|
return LicenseKey(**result.data[0])
|
|
|
|
async def mark_license_key_used(
|
|
self,
|
|
key: str,
|
|
company_id: UUID
|
|
) -> LicenseKey:
|
|
"""Mark license key as used and link to company."""
|
|
data = {
|
|
"used": True,
|
|
"company_id": str(company_id),
|
|
"used_at": datetime.now(timezone.utc).isoformat()
|
|
}
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("license_keys")
|
|
.update(data)
|
|
.eq("key", key)
|
|
.execute()
|
|
)
|
|
logger.info(f"Marked license key as used: {key}")
|
|
return LicenseKey(**result.data[0])
|
|
|
|
async def get_license_key_by_id(self, key_id: UUID) -> Optional[LicenseKey]:
|
|
"""Get license key by UUID."""
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("license_keys")
|
|
.select("*")
|
|
.eq("id", str(key_id))
|
|
.execute()
|
|
)
|
|
return LicenseKey(**result.data[0]) if result.data else None
|
|
|
|
async def update_license_key(self, key_id: UUID, updates: Dict[str, Any]) -> LicenseKey:
|
|
"""Update license key limits (admin only)."""
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("license_keys")
|
|
.update(updates)
|
|
.eq("id", str(key_id))
|
|
.execute()
|
|
)
|
|
logger.info(f"Updated license key: {key_id}")
|
|
return LicenseKey(**result.data[0]) if result.data else None
|
|
|
|
async def delete_license_key(self, key_id: UUID) -> None:
|
|
"""Delete license key (admin only)."""
|
|
await asyncio.to_thread(
|
|
lambda: self.client.table("license_keys")
|
|
.delete()
|
|
.eq("id", str(key_id))
|
|
.execute()
|
|
)
|
|
logger.info(f"Deleted license key: {key_id}")
|
|
|
|
# ==================== LICENSE KEY OFFERS ====================
|
|
|
|
async def create_license_key_offer(
|
|
self,
|
|
license_key_id: UUID,
|
|
moco_offer_id: int,
|
|
moco_offer_identifier: Optional[str],
|
|
moco_offer_url: Optional[str],
|
|
offer_title: Optional[str],
|
|
company_name: Optional[str],
|
|
price: Optional[float],
|
|
payment_frequency: Optional[str],
|
|
) -> "LicenseKeyOffer":
|
|
"""Save a MOCO offer linked to a license key."""
|
|
data = {
|
|
"license_key_id": str(license_key_id),
|
|
"moco_offer_id": moco_offer_id,
|
|
"moco_offer_identifier": moco_offer_identifier,
|
|
"moco_offer_url": moco_offer_url,
|
|
"offer_title": offer_title,
|
|
"company_name": company_name,
|
|
"price": price,
|
|
"payment_frequency": payment_frequency,
|
|
"status": "draft",
|
|
}
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("license_key_offers").insert(data).execute()
|
|
)
|
|
return LicenseKeyOffer(**result.data[0])
|
|
|
|
async def list_license_key_offers(self, license_key_id: UUID) -> list["LicenseKeyOffer"]:
|
|
"""List all MOCO offers for a license key."""
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("license_key_offers")
|
|
.select("*")
|
|
.eq("license_key_id", str(license_key_id))
|
|
.order("created_at", desc=True)
|
|
.execute()
|
|
)
|
|
return [LicenseKeyOffer(**row) for row in result.data]
|
|
|
|
async def update_license_key_offer_status(
|
|
self, offer_id: UUID, status: str
|
|
) -> "LicenseKeyOffer":
|
|
"""Update the status of a stored offer."""
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("license_key_offers")
|
|
.update({"status": status})
|
|
.eq("id", str(offer_id))
|
|
.execute()
|
|
)
|
|
return LicenseKeyOffer(**result.data[0])
|
|
|
|
# ==================== COMPANY QUOTAS ====================
|
|
|
|
async def get_company_daily_quota(
|
|
self,
|
|
company_id: UUID,
|
|
date_: Optional[date] = None
|
|
) -> CompanyDailyQuota:
|
|
"""Get or create daily quota for company."""
|
|
if date_ is None:
|
|
date_ = date.today()
|
|
|
|
# Try to get existing
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("company_daily_quotas")
|
|
.select("*")
|
|
.eq("company_id", str(company_id))
|
|
.eq("date", date_.isoformat())
|
|
.execute()
|
|
)
|
|
|
|
if result.data:
|
|
return CompanyDailyQuota(**result.data[0])
|
|
|
|
# Create new quota for today
|
|
data = {
|
|
"company_id": str(company_id),
|
|
"date": date_.isoformat(),
|
|
"tokens_used": 0
|
|
}
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("company_daily_quotas")
|
|
.insert(data)
|
|
.execute()
|
|
)
|
|
return CompanyDailyQuota(**result.data[0])
|
|
|
|
async def increment_company_tokens(self, company_id: UUID, tokens: int) -> None:
|
|
"""Increment daily token usage for company."""
|
|
try:
|
|
quota = await self.get_company_daily_quota(company_id)
|
|
new_count = quota.tokens_used + tokens
|
|
|
|
await asyncio.to_thread(
|
|
lambda: self.client.table("company_daily_quotas")
|
|
.update({"tokens_used": new_count})
|
|
.eq("id", str(quota.id))
|
|
.execute()
|
|
)
|
|
|
|
logger.info(f"Incremented token quota for company {company_id}: {quota.tokens_used} -> {new_count}")
|
|
except Exception as e:
|
|
logger.warning(f"Error incrementing token quota for company {company_id}: {e}")
|
|
|
|
async def get_company_limits(self, company_id: UUID) -> Optional[LicenseKey]:
|
|
"""Get company limits from associated license key.
|
|
|
|
Returns None if no license key is associated (uses defaults).
|
|
"""
|
|
company = await self.get_company(company_id)
|
|
if not company or not company.license_key_id:
|
|
return None
|
|
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("license_keys")
|
|
.select("*")
|
|
.eq("id", str(company.license_key_id))
|
|
.execute()
|
|
)
|
|
|
|
return LicenseKey(**result.data[0]) if result.data else None
|
|
|
|
async def check_company_token_limit(self, company_id: UUID) -> tuple[bool, str, int, int]:
|
|
"""Check if company has token budget remaining today.
|
|
|
|
Returns (can_proceed: bool, error_message: str, tokens_used: int, daily_limit: int)
|
|
"""
|
|
license_key = await self.get_company_limits(company_id)
|
|
if not license_key or license_key.daily_token_limit is None:
|
|
return True, "", 0, 0 # no limit = unlimited
|
|
|
|
quota = await self.get_company_daily_quota(company_id)
|
|
|
|
if quota.tokens_used >= license_key.daily_token_limit:
|
|
return False, f"Tageslimit erreicht ({license_key.daily_token_limit:,} Tokens/Tag). Morgen wieder verfügbar.", quota.tokens_used, license_key.daily_token_limit
|
|
|
|
return True, "", quota.tokens_used, license_key.daily_token_limit
|
|
|
|
async def check_company_employee_limit(self, company_id: UUID) -> tuple[bool, str]:
|
|
"""Check if company can add more employees.
|
|
|
|
Returns (can_add: bool, error_message: str)
|
|
"""
|
|
license_key = await self.get_company_limits(company_id)
|
|
if not license_key:
|
|
# No license key, use defaults (unlimited)
|
|
return True, ""
|
|
|
|
employees = await self.list_users(account_type="employee", company_id=company_id)
|
|
|
|
if len(employees) >= license_key.max_employees:
|
|
return False, f"Maximale Mitarbeiteranzahl erreicht ({license_key.max_employees}). Bitte Lizenz upgraden."
|
|
|
|
return True, ""
|
|
|
|
# ==================== EMAIL ACTION TOKENS ====================
|
|
|
|
async def create_email_token(self, token: str, post_id: UUID, action: str, expires_hours: int = 72) -> None:
|
|
"""Store an email action token in the database."""
|
|
from datetime import timedelta
|
|
expires_at = datetime.now(timezone.utc) + timedelta(hours=expires_hours)
|
|
data = {
|
|
"token": token,
|
|
"post_id": str(post_id),
|
|
"action": action,
|
|
"expires_at": expires_at.isoformat(),
|
|
"used": False,
|
|
}
|
|
await asyncio.to_thread(
|
|
lambda: self.client.table("email_action_tokens").insert(data).execute()
|
|
)
|
|
logger.debug(f"Created email token for post {post_id} action={action}")
|
|
|
|
async def get_email_token(self, token: str) -> Optional[Dict[str, Any]]:
|
|
"""Retrieve email token data; returns None if not found."""
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("email_action_tokens").select("*").eq("token", token).execute()
|
|
)
|
|
if not result.data:
|
|
return None
|
|
return result.data[0]
|
|
|
|
async def mark_email_token_used(self, token: str) -> None:
|
|
"""Mark an email token as used."""
|
|
await asyncio.to_thread(
|
|
lambda: self.client.table("email_action_tokens").update({"used": True}).eq("token", token).execute()
|
|
)
|
|
|
|
async def mark_all_post_tokens_used(self, post_id: UUID) -> None:
|
|
"""Mark all email action tokens for a post as used (invalidate approve + reject together)."""
|
|
await asyncio.to_thread(
|
|
lambda: self.client.table("email_action_tokens").update({"used": True}).eq("post_id", str(post_id)).execute()
|
|
)
|
|
|
|
async def cleanup_expired_email_tokens(self) -> None:
|
|
"""Delete expired email tokens from the database."""
|
|
now = datetime.now(timezone.utc).isoformat()
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("email_action_tokens").delete().lt("expires_at", now).execute()
|
|
)
|
|
count = len(result.data) if result.data else 0
|
|
if count:
|
|
logger.info(f"Cleaned up {count} expired email tokens")
|
|
|
|
# ==================== LINKEDIN TOKEN REFRESH ====================
|
|
|
|
async def get_expiring_linkedin_accounts(self, within_days: int = 7) -> list:
|
|
"""Return active LinkedIn accounts whose tokens expire within within_days and have a refresh_token."""
|
|
from src.database.models import LinkedInAccount
|
|
from datetime import timedelta
|
|
cutoff = (datetime.now(timezone.utc) + timedelta(days=within_days)).isoformat()
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("linkedin_accounts")
|
|
.select("*")
|
|
.eq("is_active", True)
|
|
.lt("token_expires_at", cutoff)
|
|
.not_.is_("refresh_token", "null")
|
|
.execute()
|
|
)
|
|
return [LinkedInAccount(**row) for row in result.data]
|
|
|
|
# ==================== EMPLOYEE COMPANY PERMISSIONS ====================
|
|
|
|
async def get_employee_permissions(self, user_id: UUID, company_id: UUID):
|
|
"""Get permissions for an employee in a company. Returns None if no row exists (treat as all-true defaults)."""
|
|
from src.database.models import EmployeeCompanyPermissions
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("employee_company_permissions").select("*").eq(
|
|
"user_id", str(user_id)
|
|
).eq("company_id", str(company_id)).execute()
|
|
)
|
|
if result.data:
|
|
return EmployeeCompanyPermissions(**result.data[0])
|
|
return None
|
|
|
|
async def upsert_employee_permissions(self, user_id: UUID, company_id: UUID, updates: Dict[str, Any]) -> None:
|
|
"""Insert or update employee-company permissions."""
|
|
data = {
|
|
"user_id": str(user_id),
|
|
"company_id": str(company_id),
|
|
**updates
|
|
}
|
|
await asyncio.to_thread(
|
|
lambda: self.client.table("employee_company_permissions").upsert(
|
|
data, on_conflict="user_id,company_id"
|
|
).execute()
|
|
)
|
|
logger.info(f"Upserted permissions for user {user_id} in company {company_id}")
|
|
|
|
async def get_scheduled_posts_for_user(self, user_id: UUID) -> List[GeneratedPost]:
|
|
"""Get scheduled/approved/published posts for an employee (for their calendar)."""
|
|
result = await asyncio.to_thread(
|
|
lambda: self.client.table("generated_posts").select("*").eq(
|
|
"user_id", str(user_id)
|
|
).in_(
|
|
"status", ["approved", "ready", "scheduled", "published"]
|
|
).order("scheduled_at", desc=False, nullsfirst=False).execute()
|
|
)
|
|
return [GeneratedPost(**item) for item in result.data]
|
|
|
|
|
|
# Global database client instance
|
|
db = DatabaseClient()
|