added scalability and performance improvements (redis, http caching etc)
This commit is contained in:
@@ -13,6 +13,12 @@ from src.database.models import (
|
||||
User, Profile, Company, Invitation, ExamplePost, ReferenceProfile,
|
||||
ApiUsageLog, LicenseKey, CompanyDailyQuota, LicenseKeyOffer
|
||||
)
|
||||
from src.services.cache_service import (
|
||||
cache,
|
||||
PROFILE_TTL, USER_TTL, LINKEDIN_ACCOUNT_TTL, PROFILE_ANALYSIS_TTL,
|
||||
COMPANY_TTL, LINKEDIN_POSTS_TTL, POST_TYPES_TTL, POST_TYPE_TTL,
|
||||
GEN_POST_TTL, GEN_POSTS_TTL,
|
||||
)
|
||||
|
||||
|
||||
class DatabaseClient:
|
||||
@@ -111,16 +117,26 @@ class DatabaseClient:
|
||||
).execute()
|
||||
)
|
||||
logger.info(f"Saved {len(result.data)} LinkedIn posts")
|
||||
return [LinkedInPost(**item) for item in result.data]
|
||||
saved = [LinkedInPost(**item) for item in result.data]
|
||||
# Invalidate cache for all affected users
|
||||
affected_user_ids = {str(p.user_id) for p in saved}
|
||||
for uid in affected_user_ids:
|
||||
await cache.invalidate_linkedin_posts(uid)
|
||||
return saved
|
||||
|
||||
async def get_linkedin_posts(self, user_id: UUID) -> List[LinkedInPost]:
|
||||
"""Get all LinkedIn posts for user."""
|
||||
key = cache.linkedin_posts_key(str(user_id))
|
||||
if (hit := await cache.get(key)) is not None:
|
||||
return [LinkedInPost(**item) for item in hit]
|
||||
result = await asyncio.to_thread(
|
||||
lambda: self.client.table("linkedin_posts").select("*").eq(
|
||||
"user_id", str(user_id)
|
||||
).order("post_date", desc=True).execute()
|
||||
)
|
||||
return [LinkedInPost(**item) for item in result.data]
|
||||
posts = [LinkedInPost(**item) for item in result.data]
|
||||
await cache.set(key, [p.model_dump(mode="json") for p in posts], LINKEDIN_POSTS_TTL)
|
||||
return posts
|
||||
|
||||
async def copy_posts_to_user(self, source_user_id: UUID, target_user_id: UUID) -> List[LinkedInPost]:
|
||||
"""Copy all LinkedIn posts from one user to another."""
|
||||
@@ -148,9 +164,16 @@ class DatabaseClient:
|
||||
|
||||
async def delete_linkedin_post(self, post_id: UUID) -> None:
|
||||
"""Delete a LinkedIn post."""
|
||||
# Fetch user_id before deleting so we can invalidate the cache
|
||||
lookup = await asyncio.to_thread(
|
||||
lambda: self.client.table("linkedin_posts").select("user_id").eq("id", str(post_id)).execute()
|
||||
)
|
||||
user_id = lookup.data[0]["user_id"] if lookup.data else None
|
||||
await asyncio.to_thread(
|
||||
lambda: self.client.table("linkedin_posts").delete().eq("id", str(post_id)).execute()
|
||||
)
|
||||
if user_id:
|
||||
await cache.invalidate_linkedin_posts(user_id)
|
||||
logger.info(f"Deleted LinkedIn post: {post_id}")
|
||||
|
||||
async def get_unclassified_posts(self, user_id: UUID) -> List[LinkedInPost]:
|
||||
@@ -195,9 +218,11 @@ class DatabaseClient:
|
||||
if "user_id" in updates and updates["user_id"]:
|
||||
updates["user_id"] = str(updates["user_id"])
|
||||
|
||||
await asyncio.to_thread(
|
||||
result = await asyncio.to_thread(
|
||||
lambda: self.client.table("linkedin_posts").update(updates).eq("id", str(post_id)).execute()
|
||||
)
|
||||
if result.data:
|
||||
await cache.invalidate_linkedin_posts(result.data[0]["user_id"])
|
||||
logger.debug(f"Updated LinkedIn post {post_id}")
|
||||
|
||||
async def update_posts_classification_bulk(
|
||||
@@ -233,7 +258,9 @@ class DatabaseClient:
|
||||
lambda: self.client.table("post_types").insert(data).execute()
|
||||
)
|
||||
logger.info(f"Created post type: {result.data[0]['name']}")
|
||||
return PostType(**result.data[0])
|
||||
created = PostType(**result.data[0])
|
||||
await cache.invalidate_post_types(str(created.user_id))
|
||||
return created
|
||||
|
||||
async def create_post_types_bulk(self, post_types: List[PostType]) -> List[PostType]:
|
||||
"""Create multiple post types at once."""
|
||||
@@ -251,10 +278,18 @@ class DatabaseClient:
|
||||
lambda: self.client.table("post_types").insert(data).execute()
|
||||
)
|
||||
logger.info(f"Created {len(result.data)} post types")
|
||||
return [PostType(**item) for item in result.data]
|
||||
created = [PostType(**item) for item in result.data]
|
||||
affected_user_ids = {str(pt.user_id) for pt in created}
|
||||
for uid in affected_user_ids:
|
||||
await cache.invalidate_post_types(uid)
|
||||
return created
|
||||
|
||||
async def get_post_types(self, user_id: UUID, active_only: bool = True) -> List[PostType]:
|
||||
"""Get all post types for a user."""
|
||||
key = cache.post_types_key(str(user_id), active_only)
|
||||
if (hit := await cache.get(key)) is not None:
|
||||
return [PostType(**item) for item in hit]
|
||||
|
||||
def _query():
|
||||
query = self.client.table("post_types").select("*").eq("user_id", str(user_id))
|
||||
if active_only:
|
||||
@@ -262,7 +297,9 @@ class DatabaseClient:
|
||||
return query.order("name").execute()
|
||||
|
||||
result = await asyncio.to_thread(_query)
|
||||
return [PostType(**item) for item in result.data]
|
||||
post_types = [PostType(**item) for item in result.data]
|
||||
await cache.set(key, [pt.model_dump(mode="json") for pt in post_types], POST_TYPES_TTL)
|
||||
return post_types
|
||||
|
||||
# Alias for get_post_types
|
||||
async def get_customer_post_types(self, user_id: UUID, active_only: bool = True) -> List[PostType]:
|
||||
@@ -271,13 +308,18 @@ class DatabaseClient:
|
||||
|
||||
async def get_post_type(self, post_type_id: UUID) -> Optional[PostType]:
|
||||
"""Get a single post type by ID."""
|
||||
key = cache.post_type_key(str(post_type_id))
|
||||
if (hit := await cache.get(key)) is not None:
|
||||
return PostType(**hit)
|
||||
result = await asyncio.to_thread(
|
||||
lambda: self.client.table("post_types").select("*").eq(
|
||||
"id", str(post_type_id)
|
||||
).execute()
|
||||
)
|
||||
if result.data:
|
||||
return PostType(**result.data[0])
|
||||
pt = PostType(**result.data[0])
|
||||
await cache.set(key, pt.model_dump(mode="json"), POST_TYPE_TTL)
|
||||
return pt
|
||||
return None
|
||||
|
||||
async def update_post_type(self, post_type_id: UUID, updates: Dict[str, Any]) -> PostType:
|
||||
@@ -288,7 +330,9 @@ class DatabaseClient:
|
||||
).execute()
|
||||
)
|
||||
logger.info(f"Updated post type: {post_type_id}")
|
||||
return PostType(**result.data[0])
|
||||
pt = PostType(**result.data[0])
|
||||
await cache.invalidate_post_type(str(post_type_id), str(pt.user_id))
|
||||
return pt
|
||||
|
||||
async def update_post_type_analysis(
|
||||
self,
|
||||
@@ -306,23 +350,36 @@ class DatabaseClient:
|
||||
}).eq("id", str(post_type_id)).execute()
|
||||
)
|
||||
logger.info(f"Updated analysis for post type: {post_type_id}")
|
||||
return PostType(**result.data[0])
|
||||
pt = PostType(**result.data[0])
|
||||
await cache.invalidate_post_type(str(post_type_id), str(pt.user_id))
|
||||
return pt
|
||||
|
||||
async def delete_post_type(self, post_type_id: UUID, soft: bool = True) -> None:
|
||||
"""Delete a post type (soft delete by default)."""
|
||||
if soft:
|
||||
await asyncio.to_thread(
|
||||
result = await asyncio.to_thread(
|
||||
lambda: self.client.table("post_types").update({
|
||||
"is_active": False
|
||||
}).eq("id", str(post_type_id)).execute()
|
||||
)
|
||||
logger.info(f"Soft deleted post type: {post_type_id}")
|
||||
if result.data:
|
||||
await cache.invalidate_post_type(str(post_type_id), result.data[0]["user_id"])
|
||||
else:
|
||||
# Fetch user_id before hard delete for cache invalidation
|
||||
lookup = await asyncio.to_thread(
|
||||
lambda: self.client.table("post_types").select("user_id").eq(
|
||||
"id", str(post_type_id)
|
||||
).execute()
|
||||
)
|
||||
user_id = lookup.data[0]["user_id"] if lookup.data else None
|
||||
await asyncio.to_thread(
|
||||
lambda: self.client.table("post_types").delete().eq(
|
||||
"id", str(post_type_id)
|
||||
).execute()
|
||||
)
|
||||
if user_id:
|
||||
await cache.invalidate_post_type(str(post_type_id), user_id)
|
||||
logger.info(f"Hard deleted post type: {post_type_id}")
|
||||
|
||||
# ==================== TOPICS ====================
|
||||
@@ -418,17 +475,24 @@ class DatabaseClient:
|
||||
)
|
||||
|
||||
logger.info(f"Saved profile analysis for user: {analysis.user_id}")
|
||||
return ProfileAnalysis(**result.data[0])
|
||||
saved = ProfileAnalysis(**result.data[0])
|
||||
await cache.delete(cache.profile_analysis_key(str(analysis.user_id)))
|
||||
return saved
|
||||
|
||||
async def get_profile_analysis(self, user_id: UUID) -> Optional[ProfileAnalysis]:
|
||||
"""Get profile analysis for user."""
|
||||
key = cache.profile_analysis_key(str(user_id))
|
||||
if (hit := await cache.get(key)) is not None:
|
||||
return ProfileAnalysis(**hit)
|
||||
result = await asyncio.to_thread(
|
||||
lambda: self.client.table("profile_analyses").select("*").eq(
|
||||
"user_id", str(user_id)
|
||||
).execute()
|
||||
)
|
||||
if result.data:
|
||||
return ProfileAnalysis(**result.data[0])
|
||||
pa = ProfileAnalysis(**result.data[0])
|
||||
await cache.set(key, pa.model_dump(mode="json"), PROFILE_ANALYSIS_TTL)
|
||||
return pa
|
||||
return None
|
||||
|
||||
# ==================== RESEARCH RESULTS ====================
|
||||
@@ -491,7 +555,9 @@ class DatabaseClient:
|
||||
lambda: self.client.table("generated_posts").insert(data).execute()
|
||||
)
|
||||
logger.info(f"Saved generated post: {result.data[0]['id']}")
|
||||
return GeneratedPost(**result.data[0])
|
||||
saved = GeneratedPost(**result.data[0])
|
||||
await cache.invalidate_gen_posts(str(saved.user_id))
|
||||
return saved
|
||||
|
||||
async def update_generated_post(self, post_id: UUID, updates: Dict[str, Any]) -> GeneratedPost:
|
||||
"""Update generated post."""
|
||||
@@ -504,11 +570,8 @@ class DatabaseClient:
|
||||
# Handle metadata dict - ensure all nested datetime values are serialized
|
||||
if 'metadata' in updates and isinstance(updates['metadata'], dict):
|
||||
serialized_metadata = {}
|
||||
for key, value in updates['metadata'].items():
|
||||
if isinstance(value, datetime):
|
||||
serialized_metadata[key] = value.isoformat()
|
||||
else:
|
||||
serialized_metadata[key] = value
|
||||
for k, value in updates['metadata'].items():
|
||||
serialized_metadata[k] = value.isoformat() if isinstance(value, datetime) else value
|
||||
updates['metadata'] = serialized_metadata
|
||||
|
||||
result = await asyncio.to_thread(
|
||||
@@ -517,26 +580,38 @@ class DatabaseClient:
|
||||
).execute()
|
||||
)
|
||||
logger.info(f"Updated generated post: {post_id}")
|
||||
return GeneratedPost(**result.data[0])
|
||||
updated = GeneratedPost(**result.data[0])
|
||||
await cache.invalidate_gen_post(str(post_id), str(updated.user_id))
|
||||
return updated
|
||||
|
||||
async def get_generated_posts(self, user_id: UUID) -> List[GeneratedPost]:
|
||||
"""Get all generated posts for user."""
|
||||
key = cache.gen_posts_key(str(user_id))
|
||||
if (hit := await cache.get(key)) is not None:
|
||||
return [GeneratedPost(**item) for item in hit]
|
||||
result = await asyncio.to_thread(
|
||||
lambda: self.client.table("generated_posts").select("*").eq(
|
||||
"user_id", str(user_id)
|
||||
).order("created_at", desc=True).execute()
|
||||
)
|
||||
return [GeneratedPost(**item) for item in result.data]
|
||||
posts = [GeneratedPost(**item) for item in result.data]
|
||||
await cache.set(key, [p.model_dump(mode="json") for p in posts], GEN_POSTS_TTL)
|
||||
return posts
|
||||
|
||||
async def get_generated_post(self, post_id: UUID) -> Optional[GeneratedPost]:
|
||||
"""Get a single generated post by ID."""
|
||||
key = cache.gen_post_key(str(post_id))
|
||||
if (hit := await cache.get(key)) is not None:
|
||||
return GeneratedPost(**hit)
|
||||
result = await asyncio.to_thread(
|
||||
lambda: self.client.table("generated_posts").select("*").eq(
|
||||
"id", str(post_id)
|
||||
).execute()
|
||||
)
|
||||
if result.data:
|
||||
return GeneratedPost(**result.data[0])
|
||||
post = GeneratedPost(**result.data[0])
|
||||
await cache.set(key, post.model_dump(mode="json"), GEN_POST_TTL)
|
||||
return post
|
||||
return None
|
||||
|
||||
async def get_scheduled_posts_due(self) -> List[GeneratedPost]:
|
||||
@@ -627,11 +702,16 @@ class DatabaseClient:
|
||||
|
||||
async def get_profile(self, user_id: UUID) -> Optional[Profile]:
|
||||
"""Get profile by user ID."""
|
||||
key = cache.profile_key(str(user_id))
|
||||
if (hit := await cache.get(key)) is not None:
|
||||
return Profile(**hit)
|
||||
result = await asyncio.to_thread(
|
||||
lambda: self.client.table("profiles").select("*").eq("id", str(user_id)).execute()
|
||||
)
|
||||
if result.data:
|
||||
return Profile(**result.data[0])
|
||||
profile = Profile(**result.data[0])
|
||||
await cache.set(key, profile.model_dump(mode="json"), PROFILE_TTL)
|
||||
return profile
|
||||
return None
|
||||
|
||||
async def get_profiles_by_linkedin_url(self, linkedin_url: str) -> List[Profile]:
|
||||
@@ -645,9 +725,9 @@ class DatabaseClient:
|
||||
"""Update profile fields."""
|
||||
if "company_id" in updates and updates["company_id"]:
|
||||
updates["company_id"] = str(updates["company_id"])
|
||||
for key in ["account_type", "onboarding_status"]:
|
||||
if key in updates and hasattr(updates[key], "value"):
|
||||
updates[key] = updates[key].value
|
||||
for k in ["account_type", "onboarding_status"]:
|
||||
if k in updates and hasattr(updates[k], "value"):
|
||||
updates[k] = updates[k].value
|
||||
|
||||
result = await asyncio.to_thread(
|
||||
lambda: self.client.table("profiles").update(updates).eq(
|
||||
@@ -655,19 +735,26 @@ class DatabaseClient:
|
||||
).execute()
|
||||
)
|
||||
logger.info(f"Updated profile: {user_id}")
|
||||
return Profile(**result.data[0])
|
||||
profile = Profile(**result.data[0])
|
||||
await cache.invalidate_profile(str(user_id))
|
||||
return profile
|
||||
|
||||
# ==================== LINKEDIN ACCOUNTS ====================
|
||||
|
||||
async def get_linkedin_account(self, user_id: UUID) -> Optional['LinkedInAccount']:
|
||||
"""Get LinkedIn account for user."""
|
||||
from src.database.models import LinkedInAccount
|
||||
key = cache.linkedin_account_key(str(user_id))
|
||||
if (hit := await cache.get(key)) is not None:
|
||||
return LinkedInAccount(**hit)
|
||||
result = await asyncio.to_thread(
|
||||
lambda: self.client.table("linkedin_accounts").select("*")
|
||||
.eq("user_id", str(user_id)).eq("is_active", True).execute()
|
||||
)
|
||||
if result.data:
|
||||
return LinkedInAccount(**result.data[0])
|
||||
account = LinkedInAccount(**result.data[0])
|
||||
await cache.set(key, account.model_dump(mode="json"), LINKEDIN_ACCOUNT_TTL)
|
||||
return account
|
||||
return None
|
||||
|
||||
async def get_linkedin_account_by_id(self, account_id: UUID) -> Optional['LinkedInAccount']:
|
||||
@@ -692,7 +779,9 @@ class DatabaseClient:
|
||||
lambda: self.client.table("linkedin_accounts").insert(data).execute()
|
||||
)
|
||||
logger.info(f"Created LinkedIn account for user: {account.user_id}")
|
||||
return LinkedInAccount(**result.data[0])
|
||||
created = LinkedInAccount(**result.data[0])
|
||||
await cache.invalidate_linkedin_account(str(created.user_id))
|
||||
return created
|
||||
|
||||
async def update_linkedin_account(self, account_id: UUID, updates: Dict) -> 'LinkedInAccount':
|
||||
"""Update LinkedIn account."""
|
||||
@@ -708,25 +797,40 @@ class DatabaseClient:
|
||||
.eq("id", str(account_id)).execute()
|
||||
)
|
||||
logger.info(f"Updated LinkedIn account: {account_id}")
|
||||
return LinkedInAccount(**result.data[0])
|
||||
updated = LinkedInAccount(**result.data[0])
|
||||
await cache.invalidate_linkedin_account(str(updated.user_id))
|
||||
return updated
|
||||
|
||||
async def delete_linkedin_account(self, account_id: UUID) -> None:
|
||||
"""Delete LinkedIn account connection."""
|
||||
# Fetch user_id before delete for cache invalidation
|
||||
lookup = await asyncio.to_thread(
|
||||
lambda: self.client.table("linkedin_accounts").select("user_id")
|
||||
.eq("id", str(account_id)).execute()
|
||||
)
|
||||
user_id = lookup.data[0]["user_id"] if lookup.data else None
|
||||
await asyncio.to_thread(
|
||||
lambda: self.client.table("linkedin_accounts").delete()
|
||||
.eq("id", str(account_id)).execute()
|
||||
)
|
||||
if user_id:
|
||||
await cache.invalidate_linkedin_account(user_id)
|
||||
logger.info(f"Deleted LinkedIn account: {account_id}")
|
||||
|
||||
# ==================== USERS ====================
|
||||
|
||||
async def get_user(self, user_id: UUID) -> Optional[User]:
|
||||
"""Get user by ID (from users view)."""
|
||||
key = cache.user_key(str(user_id))
|
||||
if (hit := await cache.get(key)) is not None:
|
||||
return User(**hit)
|
||||
result = await asyncio.to_thread(
|
||||
lambda: self.client.table("users").select("*").eq("id", str(user_id)).execute()
|
||||
)
|
||||
if result.data:
|
||||
return User(**result.data[0])
|
||||
user = User(**result.data[0])
|
||||
await cache.set(key, user.model_dump(mode="json"), USER_TTL)
|
||||
return user
|
||||
return None
|
||||
|
||||
async def get_user_by_email(self, email: str) -> Optional[User]:
|
||||
@@ -757,7 +861,10 @@ class DatabaseClient:
|
||||
|
||||
if profile_updates:
|
||||
await self.update_profile(user_id, profile_updates)
|
||||
# update_profile already calls cache.invalidate_profile which also kills user_key
|
||||
|
||||
# Invalidate user view separately (in case it wasn't covered above)
|
||||
await cache.delete(cache.user_key(str(user_id)))
|
||||
return await self.get_user(user_id)
|
||||
|
||||
async def list_users(self, account_type: Optional[str] = None, company_id: Optional[UUID] = None) -> List[User]:
|
||||
@@ -838,11 +945,16 @@ class DatabaseClient:
|
||||
|
||||
async def get_company(self, company_id: UUID) -> Optional[Company]:
|
||||
"""Get company by ID."""
|
||||
key = cache.company_key(str(company_id))
|
||||
if (hit := await cache.get(key)) is not None:
|
||||
return Company(**hit)
|
||||
result = await asyncio.to_thread(
|
||||
lambda: self.client.table("companies").select("*").eq("id", str(company_id)).execute()
|
||||
)
|
||||
if result.data:
|
||||
return Company(**result.data[0])
|
||||
company = Company(**result.data[0])
|
||||
await cache.set(key, company.model_dump(mode="json"), COMPANY_TTL)
|
||||
return company
|
||||
return None
|
||||
|
||||
async def get_company_by_owner(self, owner_user_id: UUID) -> Optional[Company]:
|
||||
@@ -871,7 +983,9 @@ class DatabaseClient:
|
||||
).execute()
|
||||
)
|
||||
logger.info(f"Updated company: {company_id}")
|
||||
return Company(**result.data[0])
|
||||
company = Company(**result.data[0])
|
||||
await cache.invalidate_company(str(company_id))
|
||||
return company
|
||||
|
||||
async def list_companies(self) -> List[Company]:
|
||||
"""List all companies."""
|
||||
|
||||
Reference in New Issue
Block a user