added scalability and performance improvements (redis, http caching etc)

This commit is contained in:
2026-02-19 17:19:41 +01:00
parent d8d054c9a8
commit 4b15b552d6
12 changed files with 763 additions and 88 deletions

View File

@@ -13,6 +13,12 @@ from src.database.models import (
User, Profile, Company, Invitation, ExamplePost, ReferenceProfile,
ApiUsageLog, LicenseKey, CompanyDailyQuota, LicenseKeyOffer
)
from src.services.cache_service import (
cache,
PROFILE_TTL, USER_TTL, LINKEDIN_ACCOUNT_TTL, PROFILE_ANALYSIS_TTL,
COMPANY_TTL, LINKEDIN_POSTS_TTL, POST_TYPES_TTL, POST_TYPE_TTL,
GEN_POST_TTL, GEN_POSTS_TTL,
)
class DatabaseClient:
@@ -111,16 +117,26 @@ class DatabaseClient:
).execute()
)
logger.info(f"Saved {len(result.data)} LinkedIn posts")
return [LinkedInPost(**item) for item in result.data]
saved = [LinkedInPost(**item) for item in result.data]
# Invalidate cache for all affected users
affected_user_ids = {str(p.user_id) for p in saved}
for uid in affected_user_ids:
await cache.invalidate_linkedin_posts(uid)
return saved
async def get_linkedin_posts(self, user_id: UUID) -> List[LinkedInPost]:
"""Get all LinkedIn posts for user."""
key = cache.linkedin_posts_key(str(user_id))
if (hit := await cache.get(key)) is not None:
return [LinkedInPost(**item) for item in hit]
result = await asyncio.to_thread(
lambda: self.client.table("linkedin_posts").select("*").eq(
"user_id", str(user_id)
).order("post_date", desc=True).execute()
)
return [LinkedInPost(**item) for item in result.data]
posts = [LinkedInPost(**item) for item in result.data]
await cache.set(key, [p.model_dump(mode="json") for p in posts], LINKEDIN_POSTS_TTL)
return posts
async def copy_posts_to_user(self, source_user_id: UUID, target_user_id: UUID) -> List[LinkedInPost]:
"""Copy all LinkedIn posts from one user to another."""
@@ -148,9 +164,16 @@ class DatabaseClient:
async def delete_linkedin_post(self, post_id: UUID) -> None:
"""Delete a LinkedIn post."""
# Fetch user_id before deleting so we can invalidate the cache
lookup = await asyncio.to_thread(
lambda: self.client.table("linkedin_posts").select("user_id").eq("id", str(post_id)).execute()
)
user_id = lookup.data[0]["user_id"] if lookup.data else None
await asyncio.to_thread(
lambda: self.client.table("linkedin_posts").delete().eq("id", str(post_id)).execute()
)
if user_id:
await cache.invalidate_linkedin_posts(user_id)
logger.info(f"Deleted LinkedIn post: {post_id}")
async def get_unclassified_posts(self, user_id: UUID) -> List[LinkedInPost]:
@@ -195,9 +218,11 @@ class DatabaseClient:
if "user_id" in updates and updates["user_id"]:
updates["user_id"] = str(updates["user_id"])
await asyncio.to_thread(
result = await asyncio.to_thread(
lambda: self.client.table("linkedin_posts").update(updates).eq("id", str(post_id)).execute()
)
if result.data:
await cache.invalidate_linkedin_posts(result.data[0]["user_id"])
logger.debug(f"Updated LinkedIn post {post_id}")
async def update_posts_classification_bulk(
@@ -233,7 +258,9 @@ class DatabaseClient:
lambda: self.client.table("post_types").insert(data).execute()
)
logger.info(f"Created post type: {result.data[0]['name']}")
return PostType(**result.data[0])
created = PostType(**result.data[0])
await cache.invalidate_post_types(str(created.user_id))
return created
async def create_post_types_bulk(self, post_types: List[PostType]) -> List[PostType]:
"""Create multiple post types at once."""
@@ -251,10 +278,18 @@ class DatabaseClient:
lambda: self.client.table("post_types").insert(data).execute()
)
logger.info(f"Created {len(result.data)} post types")
return [PostType(**item) for item in result.data]
created = [PostType(**item) for item in result.data]
affected_user_ids = {str(pt.user_id) for pt in created}
for uid in affected_user_ids:
await cache.invalidate_post_types(uid)
return created
async def get_post_types(self, user_id: UUID, active_only: bool = True) -> List[PostType]:
"""Get all post types for a user."""
key = cache.post_types_key(str(user_id), active_only)
if (hit := await cache.get(key)) is not None:
return [PostType(**item) for item in hit]
def _query():
query = self.client.table("post_types").select("*").eq("user_id", str(user_id))
if active_only:
@@ -262,7 +297,9 @@ class DatabaseClient:
return query.order("name").execute()
result = await asyncio.to_thread(_query)
return [PostType(**item) for item in result.data]
post_types = [PostType(**item) for item in result.data]
await cache.set(key, [pt.model_dump(mode="json") for pt in post_types], POST_TYPES_TTL)
return post_types
# Alias for get_post_types
async def get_customer_post_types(self, user_id: UUID, active_only: bool = True) -> List[PostType]:
@@ -271,13 +308,18 @@ class DatabaseClient:
async def get_post_type(self, post_type_id: UUID) -> Optional[PostType]:
"""Get a single post type by ID."""
key = cache.post_type_key(str(post_type_id))
if (hit := await cache.get(key)) is not None:
return PostType(**hit)
result = await asyncio.to_thread(
lambda: self.client.table("post_types").select("*").eq(
"id", str(post_type_id)
).execute()
)
if result.data:
return PostType(**result.data[0])
pt = PostType(**result.data[0])
await cache.set(key, pt.model_dump(mode="json"), POST_TYPE_TTL)
return pt
return None
async def update_post_type(self, post_type_id: UUID, updates: Dict[str, Any]) -> PostType:
@@ -288,7 +330,9 @@ class DatabaseClient:
).execute()
)
logger.info(f"Updated post type: {post_type_id}")
return PostType(**result.data[0])
pt = PostType(**result.data[0])
await cache.invalidate_post_type(str(post_type_id), str(pt.user_id))
return pt
async def update_post_type_analysis(
self,
@@ -306,23 +350,36 @@ class DatabaseClient:
}).eq("id", str(post_type_id)).execute()
)
logger.info(f"Updated analysis for post type: {post_type_id}")
return PostType(**result.data[0])
pt = PostType(**result.data[0])
await cache.invalidate_post_type(str(post_type_id), str(pt.user_id))
return pt
async def delete_post_type(self, post_type_id: UUID, soft: bool = True) -> None:
"""Delete a post type (soft delete by default)."""
if soft:
await asyncio.to_thread(
result = await asyncio.to_thread(
lambda: self.client.table("post_types").update({
"is_active": False
}).eq("id", str(post_type_id)).execute()
)
logger.info(f"Soft deleted post type: {post_type_id}")
if result.data:
await cache.invalidate_post_type(str(post_type_id), result.data[0]["user_id"])
else:
# Fetch user_id before hard delete for cache invalidation
lookup = await asyncio.to_thread(
lambda: self.client.table("post_types").select("user_id").eq(
"id", str(post_type_id)
).execute()
)
user_id = lookup.data[0]["user_id"] if lookup.data else None
await asyncio.to_thread(
lambda: self.client.table("post_types").delete().eq(
"id", str(post_type_id)
).execute()
)
if user_id:
await cache.invalidate_post_type(str(post_type_id), user_id)
logger.info(f"Hard deleted post type: {post_type_id}")
# ==================== TOPICS ====================
@@ -418,17 +475,24 @@ class DatabaseClient:
)
logger.info(f"Saved profile analysis for user: {analysis.user_id}")
return ProfileAnalysis(**result.data[0])
saved = ProfileAnalysis(**result.data[0])
await cache.delete(cache.profile_analysis_key(str(analysis.user_id)))
return saved
async def get_profile_analysis(self, user_id: UUID) -> Optional[ProfileAnalysis]:
"""Get profile analysis for user."""
key = cache.profile_analysis_key(str(user_id))
if (hit := await cache.get(key)) is not None:
return ProfileAnalysis(**hit)
result = await asyncio.to_thread(
lambda: self.client.table("profile_analyses").select("*").eq(
"user_id", str(user_id)
).execute()
)
if result.data:
return ProfileAnalysis(**result.data[0])
pa = ProfileAnalysis(**result.data[0])
await cache.set(key, pa.model_dump(mode="json"), PROFILE_ANALYSIS_TTL)
return pa
return None
# ==================== RESEARCH RESULTS ====================
@@ -491,7 +555,9 @@ class DatabaseClient:
lambda: self.client.table("generated_posts").insert(data).execute()
)
logger.info(f"Saved generated post: {result.data[0]['id']}")
return GeneratedPost(**result.data[0])
saved = GeneratedPost(**result.data[0])
await cache.invalidate_gen_posts(str(saved.user_id))
return saved
async def update_generated_post(self, post_id: UUID, updates: Dict[str, Any]) -> GeneratedPost:
"""Update generated post."""
@@ -504,11 +570,8 @@ class DatabaseClient:
# Handle metadata dict - ensure all nested datetime values are serialized
if 'metadata' in updates and isinstance(updates['metadata'], dict):
serialized_metadata = {}
for key, value in updates['metadata'].items():
if isinstance(value, datetime):
serialized_metadata[key] = value.isoformat()
else:
serialized_metadata[key] = value
for k, value in updates['metadata'].items():
serialized_metadata[k] = value.isoformat() if isinstance(value, datetime) else value
updates['metadata'] = serialized_metadata
result = await asyncio.to_thread(
@@ -517,26 +580,38 @@ class DatabaseClient:
).execute()
)
logger.info(f"Updated generated post: {post_id}")
return GeneratedPost(**result.data[0])
updated = GeneratedPost(**result.data[0])
await cache.invalidate_gen_post(str(post_id), str(updated.user_id))
return updated
async def get_generated_posts(self, user_id: UUID) -> List[GeneratedPost]:
"""Get all generated posts for user."""
key = cache.gen_posts_key(str(user_id))
if (hit := await cache.get(key)) is not None:
return [GeneratedPost(**item) for item in hit]
result = await asyncio.to_thread(
lambda: self.client.table("generated_posts").select("*").eq(
"user_id", str(user_id)
).order("created_at", desc=True).execute()
)
return [GeneratedPost(**item) for item in result.data]
posts = [GeneratedPost(**item) for item in result.data]
await cache.set(key, [p.model_dump(mode="json") for p in posts], GEN_POSTS_TTL)
return posts
async def get_generated_post(self, post_id: UUID) -> Optional[GeneratedPost]:
"""Get a single generated post by ID."""
key = cache.gen_post_key(str(post_id))
if (hit := await cache.get(key)) is not None:
return GeneratedPost(**hit)
result = await asyncio.to_thread(
lambda: self.client.table("generated_posts").select("*").eq(
"id", str(post_id)
).execute()
)
if result.data:
return GeneratedPost(**result.data[0])
post = GeneratedPost(**result.data[0])
await cache.set(key, post.model_dump(mode="json"), GEN_POST_TTL)
return post
return None
async def get_scheduled_posts_due(self) -> List[GeneratedPost]:
@@ -627,11 +702,16 @@ class DatabaseClient:
async def get_profile(self, user_id: UUID) -> Optional[Profile]:
"""Get profile by user ID."""
key = cache.profile_key(str(user_id))
if (hit := await cache.get(key)) is not None:
return Profile(**hit)
result = await asyncio.to_thread(
lambda: self.client.table("profiles").select("*").eq("id", str(user_id)).execute()
)
if result.data:
return Profile(**result.data[0])
profile = Profile(**result.data[0])
await cache.set(key, profile.model_dump(mode="json"), PROFILE_TTL)
return profile
return None
async def get_profiles_by_linkedin_url(self, linkedin_url: str) -> List[Profile]:
@@ -645,9 +725,9 @@ class DatabaseClient:
"""Update profile fields."""
if "company_id" in updates and updates["company_id"]:
updates["company_id"] = str(updates["company_id"])
for key in ["account_type", "onboarding_status"]:
if key in updates and hasattr(updates[key], "value"):
updates[key] = updates[key].value
for k in ["account_type", "onboarding_status"]:
if k in updates and hasattr(updates[k], "value"):
updates[k] = updates[k].value
result = await asyncio.to_thread(
lambda: self.client.table("profiles").update(updates).eq(
@@ -655,19 +735,26 @@ class DatabaseClient:
).execute()
)
logger.info(f"Updated profile: {user_id}")
return Profile(**result.data[0])
profile = Profile(**result.data[0])
await cache.invalidate_profile(str(user_id))
return profile
# ==================== LINKEDIN ACCOUNTS ====================
async def get_linkedin_account(self, user_id: UUID) -> Optional['LinkedInAccount']:
"""Get LinkedIn account for user."""
from src.database.models import LinkedInAccount
key = cache.linkedin_account_key(str(user_id))
if (hit := await cache.get(key)) is not None:
return LinkedInAccount(**hit)
result = await asyncio.to_thread(
lambda: self.client.table("linkedin_accounts").select("*")
.eq("user_id", str(user_id)).eq("is_active", True).execute()
)
if result.data:
return LinkedInAccount(**result.data[0])
account = LinkedInAccount(**result.data[0])
await cache.set(key, account.model_dump(mode="json"), LINKEDIN_ACCOUNT_TTL)
return account
return None
async def get_linkedin_account_by_id(self, account_id: UUID) -> Optional['LinkedInAccount']:
@@ -692,7 +779,9 @@ class DatabaseClient:
lambda: self.client.table("linkedin_accounts").insert(data).execute()
)
logger.info(f"Created LinkedIn account for user: {account.user_id}")
return LinkedInAccount(**result.data[0])
created = LinkedInAccount(**result.data[0])
await cache.invalidate_linkedin_account(str(created.user_id))
return created
async def update_linkedin_account(self, account_id: UUID, updates: Dict) -> 'LinkedInAccount':
"""Update LinkedIn account."""
@@ -708,25 +797,40 @@ class DatabaseClient:
.eq("id", str(account_id)).execute()
)
logger.info(f"Updated LinkedIn account: {account_id}")
return LinkedInAccount(**result.data[0])
updated = LinkedInAccount(**result.data[0])
await cache.invalidate_linkedin_account(str(updated.user_id))
return updated
async def delete_linkedin_account(self, account_id: UUID) -> None:
"""Delete LinkedIn account connection."""
# Fetch user_id before delete for cache invalidation
lookup = await asyncio.to_thread(
lambda: self.client.table("linkedin_accounts").select("user_id")
.eq("id", str(account_id)).execute()
)
user_id = lookup.data[0]["user_id"] if lookup.data else None
await asyncio.to_thread(
lambda: self.client.table("linkedin_accounts").delete()
.eq("id", str(account_id)).execute()
)
if user_id:
await cache.invalidate_linkedin_account(user_id)
logger.info(f"Deleted LinkedIn account: {account_id}")
# ==================== USERS ====================
async def get_user(self, user_id: UUID) -> Optional[User]:
"""Get user by ID (from users view)."""
key = cache.user_key(str(user_id))
if (hit := await cache.get(key)) is not None:
return User(**hit)
result = await asyncio.to_thread(
lambda: self.client.table("users").select("*").eq("id", str(user_id)).execute()
)
if result.data:
return User(**result.data[0])
user = User(**result.data[0])
await cache.set(key, user.model_dump(mode="json"), USER_TTL)
return user
return None
async def get_user_by_email(self, email: str) -> Optional[User]:
@@ -757,7 +861,10 @@ class DatabaseClient:
if profile_updates:
await self.update_profile(user_id, profile_updates)
# update_profile already calls cache.invalidate_profile which also kills user_key
# Invalidate user view separately (in case it wasn't covered above)
await cache.delete(cache.user_key(str(user_id)))
return await self.get_user(user_id)
async def list_users(self, account_type: Optional[str] = None, company_id: Optional[UUID] = None) -> List[User]:
@@ -838,11 +945,16 @@ class DatabaseClient:
async def get_company(self, company_id: UUID) -> Optional[Company]:
"""Get company by ID."""
key = cache.company_key(str(company_id))
if (hit := await cache.get(key)) is not None:
return Company(**hit)
result = await asyncio.to_thread(
lambda: self.client.table("companies").select("*").eq("id", str(company_id)).execute()
)
if result.data:
return Company(**result.data[0])
company = Company(**result.data[0])
await cache.set(key, company.model_dump(mode="json"), COMPANY_TTL)
return company
return None
async def get_company_by_owner(self, owner_user_id: UUID) -> Optional[Company]:
@@ -871,7 +983,9 @@ class DatabaseClient:
).execute()
)
logger.info(f"Updated company: {company_id}")
return Company(**result.data[0])
company = Company(**result.data[0])
await cache.invalidate_company(str(company_id))
return company
async def list_companies(self) -> List[Company]:
"""List all companies."""