added scalability and performance improvements (redis, http caching etc)

This commit is contained in:
2026-02-19 17:19:41 +01:00
parent d8d054c9a8
commit 4b15b552d6
12 changed files with 763 additions and 88 deletions

View File

@@ -65,6 +65,10 @@ class Settings(BaseSettings):
moco_api_key: str = "" # Token für Authorization-Header
moco_domain: str = "" # Subdomain: {domain}.mocoapp.com
# Redis
redis_url: str = "redis://redis:6379/0"
scheduler_enabled: bool = False # True only on dedicated scheduler container
model_config = SettingsConfigDict(
env_file=".env",
env_file_encoding="utf-8",

View File

@@ -13,6 +13,12 @@ from src.database.models import (
User, Profile, Company, Invitation, ExamplePost, ReferenceProfile,
ApiUsageLog, LicenseKey, CompanyDailyQuota, LicenseKeyOffer
)
from src.services.cache_service import (
cache,
PROFILE_TTL, USER_TTL, LINKEDIN_ACCOUNT_TTL, PROFILE_ANALYSIS_TTL,
COMPANY_TTL, LINKEDIN_POSTS_TTL, POST_TYPES_TTL, POST_TYPE_TTL,
GEN_POST_TTL, GEN_POSTS_TTL,
)
class DatabaseClient:
@@ -111,16 +117,26 @@ class DatabaseClient:
).execute()
)
logger.info(f"Saved {len(result.data)} LinkedIn posts")
return [LinkedInPost(**item) for item in result.data]
saved = [LinkedInPost(**item) for item in result.data]
# Invalidate cache for all affected users
affected_user_ids = {str(p.user_id) for p in saved}
for uid in affected_user_ids:
await cache.invalidate_linkedin_posts(uid)
return saved
async def get_linkedin_posts(self, user_id: UUID) -> List[LinkedInPost]:
"""Get all LinkedIn posts for user."""
key = cache.linkedin_posts_key(str(user_id))
if (hit := await cache.get(key)) is not None:
return [LinkedInPost(**item) for item in hit]
result = await asyncio.to_thread(
lambda: self.client.table("linkedin_posts").select("*").eq(
"user_id", str(user_id)
).order("post_date", desc=True).execute()
)
return [LinkedInPost(**item) for item in result.data]
posts = [LinkedInPost(**item) for item in result.data]
await cache.set(key, [p.model_dump(mode="json") for p in posts], LINKEDIN_POSTS_TTL)
return posts
async def copy_posts_to_user(self, source_user_id: UUID, target_user_id: UUID) -> List[LinkedInPost]:
"""Copy all LinkedIn posts from one user to another."""
@@ -148,9 +164,16 @@ class DatabaseClient:
async def delete_linkedin_post(self, post_id: UUID) -> None:
"""Delete a LinkedIn post."""
# Fetch user_id before deleting so we can invalidate the cache
lookup = await asyncio.to_thread(
lambda: self.client.table("linkedin_posts").select("user_id").eq("id", str(post_id)).execute()
)
user_id = lookup.data[0]["user_id"] if lookup.data else None
await asyncio.to_thread(
lambda: self.client.table("linkedin_posts").delete().eq("id", str(post_id)).execute()
)
if user_id:
await cache.invalidate_linkedin_posts(user_id)
logger.info(f"Deleted LinkedIn post: {post_id}")
async def get_unclassified_posts(self, user_id: UUID) -> List[LinkedInPost]:
@@ -195,9 +218,11 @@ class DatabaseClient:
if "user_id" in updates and updates["user_id"]:
updates["user_id"] = str(updates["user_id"])
await asyncio.to_thread(
result = await asyncio.to_thread(
lambda: self.client.table("linkedin_posts").update(updates).eq("id", str(post_id)).execute()
)
if result.data:
await cache.invalidate_linkedin_posts(result.data[0]["user_id"])
logger.debug(f"Updated LinkedIn post {post_id}")
async def update_posts_classification_bulk(
@@ -233,7 +258,9 @@ class DatabaseClient:
lambda: self.client.table("post_types").insert(data).execute()
)
logger.info(f"Created post type: {result.data[0]['name']}")
return PostType(**result.data[0])
created = PostType(**result.data[0])
await cache.invalidate_post_types(str(created.user_id))
return created
async def create_post_types_bulk(self, post_types: List[PostType]) -> List[PostType]:
"""Create multiple post types at once."""
@@ -251,10 +278,18 @@ class DatabaseClient:
lambda: self.client.table("post_types").insert(data).execute()
)
logger.info(f"Created {len(result.data)} post types")
return [PostType(**item) for item in result.data]
created = [PostType(**item) for item in result.data]
affected_user_ids = {str(pt.user_id) for pt in created}
for uid in affected_user_ids:
await cache.invalidate_post_types(uid)
return created
async def get_post_types(self, user_id: UUID, active_only: bool = True) -> List[PostType]:
"""Get all post types for a user."""
key = cache.post_types_key(str(user_id), active_only)
if (hit := await cache.get(key)) is not None:
return [PostType(**item) for item in hit]
def _query():
query = self.client.table("post_types").select("*").eq("user_id", str(user_id))
if active_only:
@@ -262,7 +297,9 @@ class DatabaseClient:
return query.order("name").execute()
result = await asyncio.to_thread(_query)
return [PostType(**item) for item in result.data]
post_types = [PostType(**item) for item in result.data]
await cache.set(key, [pt.model_dump(mode="json") for pt in post_types], POST_TYPES_TTL)
return post_types
# Alias for get_post_types
async def get_customer_post_types(self, user_id: UUID, active_only: bool = True) -> List[PostType]:
@@ -271,13 +308,18 @@ class DatabaseClient:
async def get_post_type(self, post_type_id: UUID) -> Optional[PostType]:
"""Get a single post type by ID."""
key = cache.post_type_key(str(post_type_id))
if (hit := await cache.get(key)) is not None:
return PostType(**hit)
result = await asyncio.to_thread(
lambda: self.client.table("post_types").select("*").eq(
"id", str(post_type_id)
).execute()
)
if result.data:
return PostType(**result.data[0])
pt = PostType(**result.data[0])
await cache.set(key, pt.model_dump(mode="json"), POST_TYPE_TTL)
return pt
return None
async def update_post_type(self, post_type_id: UUID, updates: Dict[str, Any]) -> PostType:
@@ -288,7 +330,9 @@ class DatabaseClient:
).execute()
)
logger.info(f"Updated post type: {post_type_id}")
return PostType(**result.data[0])
pt = PostType(**result.data[0])
await cache.invalidate_post_type(str(post_type_id), str(pt.user_id))
return pt
async def update_post_type_analysis(
self,
@@ -306,23 +350,36 @@ class DatabaseClient:
}).eq("id", str(post_type_id)).execute()
)
logger.info(f"Updated analysis for post type: {post_type_id}")
return PostType(**result.data[0])
pt = PostType(**result.data[0])
await cache.invalidate_post_type(str(post_type_id), str(pt.user_id))
return pt
async def delete_post_type(self, post_type_id: UUID, soft: bool = True) -> None:
"""Delete a post type (soft delete by default)."""
if soft:
await asyncio.to_thread(
result = await asyncio.to_thread(
lambda: self.client.table("post_types").update({
"is_active": False
}).eq("id", str(post_type_id)).execute()
)
logger.info(f"Soft deleted post type: {post_type_id}")
if result.data:
await cache.invalidate_post_type(str(post_type_id), result.data[0]["user_id"])
else:
# Fetch user_id before hard delete for cache invalidation
lookup = await asyncio.to_thread(
lambda: self.client.table("post_types").select("user_id").eq(
"id", str(post_type_id)
).execute()
)
user_id = lookup.data[0]["user_id"] if lookup.data else None
await asyncio.to_thread(
lambda: self.client.table("post_types").delete().eq(
"id", str(post_type_id)
).execute()
)
if user_id:
await cache.invalidate_post_type(str(post_type_id), user_id)
logger.info(f"Hard deleted post type: {post_type_id}")
# ==================== TOPICS ====================
@@ -418,17 +475,24 @@ class DatabaseClient:
)
logger.info(f"Saved profile analysis for user: {analysis.user_id}")
return ProfileAnalysis(**result.data[0])
saved = ProfileAnalysis(**result.data[0])
await cache.delete(cache.profile_analysis_key(str(analysis.user_id)))
return saved
async def get_profile_analysis(self, user_id: UUID) -> Optional[ProfileAnalysis]:
"""Get profile analysis for user."""
key = cache.profile_analysis_key(str(user_id))
if (hit := await cache.get(key)) is not None:
return ProfileAnalysis(**hit)
result = await asyncio.to_thread(
lambda: self.client.table("profile_analyses").select("*").eq(
"user_id", str(user_id)
).execute()
)
if result.data:
return ProfileAnalysis(**result.data[0])
pa = ProfileAnalysis(**result.data[0])
await cache.set(key, pa.model_dump(mode="json"), PROFILE_ANALYSIS_TTL)
return pa
return None
# ==================== RESEARCH RESULTS ====================
@@ -491,7 +555,9 @@ class DatabaseClient:
lambda: self.client.table("generated_posts").insert(data).execute()
)
logger.info(f"Saved generated post: {result.data[0]['id']}")
return GeneratedPost(**result.data[0])
saved = GeneratedPost(**result.data[0])
await cache.invalidate_gen_posts(str(saved.user_id))
return saved
async def update_generated_post(self, post_id: UUID, updates: Dict[str, Any]) -> GeneratedPost:
"""Update generated post."""
@@ -504,11 +570,8 @@ class DatabaseClient:
# Handle metadata dict - ensure all nested datetime values are serialized
if 'metadata' in updates and isinstance(updates['metadata'], dict):
serialized_metadata = {}
for key, value in updates['metadata'].items():
if isinstance(value, datetime):
serialized_metadata[key] = value.isoformat()
else:
serialized_metadata[key] = value
for k, value in updates['metadata'].items():
serialized_metadata[k] = value.isoformat() if isinstance(value, datetime) else value
updates['metadata'] = serialized_metadata
result = await asyncio.to_thread(
@@ -517,26 +580,38 @@ class DatabaseClient:
).execute()
)
logger.info(f"Updated generated post: {post_id}")
return GeneratedPost(**result.data[0])
updated = GeneratedPost(**result.data[0])
await cache.invalidate_gen_post(str(post_id), str(updated.user_id))
return updated
async def get_generated_posts(self, user_id: UUID) -> List[GeneratedPost]:
"""Get all generated posts for user."""
key = cache.gen_posts_key(str(user_id))
if (hit := await cache.get(key)) is not None:
return [GeneratedPost(**item) for item in hit]
result = await asyncio.to_thread(
lambda: self.client.table("generated_posts").select("*").eq(
"user_id", str(user_id)
).order("created_at", desc=True).execute()
)
return [GeneratedPost(**item) for item in result.data]
posts = [GeneratedPost(**item) for item in result.data]
await cache.set(key, [p.model_dump(mode="json") for p in posts], GEN_POSTS_TTL)
return posts
async def get_generated_post(self, post_id: UUID) -> Optional[GeneratedPost]:
"""Get a single generated post by ID."""
key = cache.gen_post_key(str(post_id))
if (hit := await cache.get(key)) is not None:
return GeneratedPost(**hit)
result = await asyncio.to_thread(
lambda: self.client.table("generated_posts").select("*").eq(
"id", str(post_id)
).execute()
)
if result.data:
return GeneratedPost(**result.data[0])
post = GeneratedPost(**result.data[0])
await cache.set(key, post.model_dump(mode="json"), GEN_POST_TTL)
return post
return None
async def get_scheduled_posts_due(self) -> List[GeneratedPost]:
@@ -627,11 +702,16 @@ class DatabaseClient:
async def get_profile(self, user_id: UUID) -> Optional[Profile]:
"""Get profile by user ID."""
key = cache.profile_key(str(user_id))
if (hit := await cache.get(key)) is not None:
return Profile(**hit)
result = await asyncio.to_thread(
lambda: self.client.table("profiles").select("*").eq("id", str(user_id)).execute()
)
if result.data:
return Profile(**result.data[0])
profile = Profile(**result.data[0])
await cache.set(key, profile.model_dump(mode="json"), PROFILE_TTL)
return profile
return None
async def get_profiles_by_linkedin_url(self, linkedin_url: str) -> List[Profile]:
@@ -645,9 +725,9 @@ class DatabaseClient:
"""Update profile fields."""
if "company_id" in updates and updates["company_id"]:
updates["company_id"] = str(updates["company_id"])
for key in ["account_type", "onboarding_status"]:
if key in updates and hasattr(updates[key], "value"):
updates[key] = updates[key].value
for k in ["account_type", "onboarding_status"]:
if k in updates and hasattr(updates[k], "value"):
updates[k] = updates[k].value
result = await asyncio.to_thread(
lambda: self.client.table("profiles").update(updates).eq(
@@ -655,19 +735,26 @@ class DatabaseClient:
).execute()
)
logger.info(f"Updated profile: {user_id}")
return Profile(**result.data[0])
profile = Profile(**result.data[0])
await cache.invalidate_profile(str(user_id))
return profile
# ==================== LINKEDIN ACCOUNTS ====================
async def get_linkedin_account(self, user_id: UUID) -> Optional['LinkedInAccount']:
"""Get LinkedIn account for user."""
from src.database.models import LinkedInAccount
key = cache.linkedin_account_key(str(user_id))
if (hit := await cache.get(key)) is not None:
return LinkedInAccount(**hit)
result = await asyncio.to_thread(
lambda: self.client.table("linkedin_accounts").select("*")
.eq("user_id", str(user_id)).eq("is_active", True).execute()
)
if result.data:
return LinkedInAccount(**result.data[0])
account = LinkedInAccount(**result.data[0])
await cache.set(key, account.model_dump(mode="json"), LINKEDIN_ACCOUNT_TTL)
return account
return None
async def get_linkedin_account_by_id(self, account_id: UUID) -> Optional['LinkedInAccount']:
@@ -692,7 +779,9 @@ class DatabaseClient:
lambda: self.client.table("linkedin_accounts").insert(data).execute()
)
logger.info(f"Created LinkedIn account for user: {account.user_id}")
return LinkedInAccount(**result.data[0])
created = LinkedInAccount(**result.data[0])
await cache.invalidate_linkedin_account(str(created.user_id))
return created
async def update_linkedin_account(self, account_id: UUID, updates: Dict) -> 'LinkedInAccount':
"""Update LinkedIn account."""
@@ -708,25 +797,40 @@ class DatabaseClient:
.eq("id", str(account_id)).execute()
)
logger.info(f"Updated LinkedIn account: {account_id}")
return LinkedInAccount(**result.data[0])
updated = LinkedInAccount(**result.data[0])
await cache.invalidate_linkedin_account(str(updated.user_id))
return updated
async def delete_linkedin_account(self, account_id: UUID) -> None:
"""Delete LinkedIn account connection."""
# Fetch user_id before delete for cache invalidation
lookup = await asyncio.to_thread(
lambda: self.client.table("linkedin_accounts").select("user_id")
.eq("id", str(account_id)).execute()
)
user_id = lookup.data[0]["user_id"] if lookup.data else None
await asyncio.to_thread(
lambda: self.client.table("linkedin_accounts").delete()
.eq("id", str(account_id)).execute()
)
if user_id:
await cache.invalidate_linkedin_account(user_id)
logger.info(f"Deleted LinkedIn account: {account_id}")
# ==================== USERS ====================
async def get_user(self, user_id: UUID) -> Optional[User]:
"""Get user by ID (from users view)."""
key = cache.user_key(str(user_id))
if (hit := await cache.get(key)) is not None:
return User(**hit)
result = await asyncio.to_thread(
lambda: self.client.table("users").select("*").eq("id", str(user_id)).execute()
)
if result.data:
return User(**result.data[0])
user = User(**result.data[0])
await cache.set(key, user.model_dump(mode="json"), USER_TTL)
return user
return None
async def get_user_by_email(self, email: str) -> Optional[User]:
@@ -757,7 +861,10 @@ class DatabaseClient:
if profile_updates:
await self.update_profile(user_id, profile_updates)
# update_profile already calls cache.invalidate_profile which also kills user_key
# Invalidate user view separately (in case it wasn't covered above)
await cache.delete(cache.user_key(str(user_id)))
return await self.get_user(user_id)
async def list_users(self, account_type: Optional[str] = None, company_id: Optional[UUID] = None) -> List[User]:
@@ -838,11 +945,16 @@ class DatabaseClient:
async def get_company(self, company_id: UUID) -> Optional[Company]:
"""Get company by ID."""
key = cache.company_key(str(company_id))
if (hit := await cache.get(key)) is not None:
return Company(**hit)
result = await asyncio.to_thread(
lambda: self.client.table("companies").select("*").eq("id", str(company_id)).execute()
)
if result.data:
return Company(**result.data[0])
company = Company(**result.data[0])
await cache.set(key, company.model_dump(mode="json"), COMPANY_TTL)
return company
return None
async def get_company_by_owner(self, owner_user_id: UUID) -> Optional[Company]:
@@ -871,7 +983,9 @@ class DatabaseClient:
).execute()
)
logger.info(f"Updated company: {company_id}")
return Company(**result.data[0])
company = Company(**result.data[0])
await cache.invalidate_company(str(company_id))
return company
async def list_companies(self) -> List[Company]:
"""List all companies."""

View File

@@ -156,8 +156,10 @@ class BackgroundJobManager:
logger.info(f"Cleaned up {len(to_remove)} old background jobs")
# Global instance
job_manager = BackgroundJobManager()
# Global instance — backed by Supabase DB + Redis pub/sub for multi-worker safety.
# db_job_manager imports BackgroundJob/JobType/JobStatus from this module, so
# this import must stay at the bottom to avoid a circular-import issue.
from src.services.db_job_manager import job_manager # noqa: F401
async def run_post_scraping(user_id: UUID, linkedin_url: str, job_id: str):
@@ -385,6 +387,11 @@ async def run_post_categorization(user_id: UUID, job_id: str):
message=f"{len(classifications)} Posts kategorisiert!"
)
# Invalidate cached LinkedIn posts — classifications changed but bulk update
# doesn't have user_id per-row, so we invalidate explicitly here.
from src.services.cache_service import cache as _cache
await _cache.invalidate_linkedin_posts(str(user_id))
logger.info(f"Post categorization completed for user {user_id}: {len(classifications)} posts")
except Exception as e:
@@ -480,6 +487,9 @@ async def run_post_recategorization(user_id: UUID, job_id: str):
message=f"{len(classifications)} Posts re-kategorisiert!"
)
from src.services.cache_service import cache as _cache
await _cache.invalidate_linkedin_posts(str(user_id))
logger.info(f"Post re-categorization completed for user {user_id}: {len(classifications)} posts")
except Exception as e:
@@ -556,21 +566,15 @@ async def run_full_analysis_pipeline(user_id: UUID):
logger.info(f"Starting full analysis pipeline for user {user_id}")
# 1. Profile Analysis
job1 = job_manager.create_job(JobType.PROFILE_ANALYSIS, str(user_id))
job1 = await job_manager.create_job(JobType.PROFILE_ANALYSIS, str(user_id))
await run_profile_analysis(user_id, job1.id)
if job1.status == JobStatus.FAILED:
logger.warning(f"Profile analysis failed, continuing with categorization")
# 2. Post Categorization
job2 = job_manager.create_job(JobType.POST_CATEGORIZATION, str(user_id))
# 2. Post Categorization (always continue regardless of previous step outcome)
job2 = await job_manager.create_job(JobType.POST_CATEGORIZATION, str(user_id))
await run_post_categorization(user_id, job2.id)
if job2.status == JobStatus.FAILED:
logger.warning(f"Post categorization failed, continuing with post type analysis")
# 3. Post Type Analysis
job3 = job_manager.create_job(JobType.POST_TYPE_ANALYSIS, str(user_id))
job3 = await job_manager.create_job(JobType.POST_TYPE_ANALYSIS, str(user_id))
await run_post_type_analysis(user_id, job3.id)
logger.info(f"Full analysis pipeline completed for user {user_id}")

View File

@@ -0,0 +1,157 @@
"""Typed cache helpers backed by Redis.
All failures are silent (logged as warnings) so Redis being down never causes an outage.
Key design:
profile:{user_id} — Profile row
user:{user_id} — User view row
linkedin_account:{user_id} — Active LinkedInAccount row
profile_analysis:{user_id} — ProfileAnalysis row
company:{company_id} — Company row
linkedin_posts:{user_id} — List[LinkedInPost] (scraped reference posts)
post_types:{user_id}:1 or :0 — List[PostType] (active_only=True/False)
post_type:{post_type_id} — Single PostType row
gen_post:{post_id} — Single GeneratedPost row
gen_posts:{user_id} — List[GeneratedPost] for user
"""
import json
from typing import Any, Optional
from loguru import logger
# TTL constants (seconds)
PROFILE_TTL = 300 # 5 min — updated on settings / onboarding changes
USER_TTL = 300 # 5 min
LINKEDIN_ACCOUNT_TTL = 300 # 5 min — updated only on OAuth connect/disconnect
PROFILE_ANALYSIS_TTL = 600 # 10 min — computed infrequently by background job
COMPANY_TTL = 300 # 5 min — company settings
LINKEDIN_POSTS_TTL = 600 # 10 min — scraped reference data, rarely changes
POST_TYPES_TTL = 600 # 10 min — strategy config, rarely changes
POST_TYPE_TTL = 600 # 10 min
GEN_POST_TTL = 120 # 2 min — status/content changes frequently
GEN_POSTS_TTL = 120 # 2 min
class CacheService:
"""Redis-backed cache with typed key helpers and silent failure semantics."""
# ------------------------------------------------------------------
# Key helpers
# ------------------------------------------------------------------
def profile_key(self, user_id: str) -> str:
return f"profile:{user_id}"
def user_key(self, user_id: str) -> str:
return f"user:{user_id}"
def linkedin_account_key(self, user_id: str) -> str:
return f"linkedin_account:{user_id}"
def profile_analysis_key(self, user_id: str) -> str:
return f"profile_analysis:{user_id}"
def company_key(self, company_id: str) -> str:
return f"company:{company_id}"
def linkedin_posts_key(self, user_id: str) -> str:
return f"linkedin_posts:{user_id}"
def post_types_key(self, user_id: str, active_only: bool = True) -> str:
return f"post_types:{user_id}:{'1' if active_only else '0'}"
def post_type_key(self, post_type_id: str) -> str:
return f"post_type:{post_type_id}"
def gen_post_key(self, post_id: str) -> str:
return f"gen_post:{post_id}"
def gen_posts_key(self, user_id: str) -> str:
return f"gen_posts:{user_id}"
# ------------------------------------------------------------------
# Core operations
# ------------------------------------------------------------------
async def get(self, key: str) -> Optional[Any]:
try:
from src.services.redis_client import get_redis
r = await get_redis()
value = await r.get(key)
if value is not None:
return json.loads(value)
return None
except Exception as e:
logger.warning(f"Cache get failed for {key}: {e}")
return None
async def set(self, key: str, value: Any, ttl: int):
try:
from src.services.redis_client import get_redis
r = await get_redis()
await r.setex(key, ttl, json.dumps(value, default=str))
except Exception as e:
logger.warning(f"Cache set failed for {key}: {e}")
async def delete(self, *keys: str):
if not keys:
return
try:
from src.services.redis_client import get_redis
r = await get_redis()
await r.delete(*keys)
except Exception as e:
logger.warning(f"Cache delete failed for {keys}: {e}")
# ------------------------------------------------------------------
# Compound invalidation helpers
# ------------------------------------------------------------------
async def invalidate_profile(self, user_id: str):
"""Invalidate profile, user view, profile_analysis, and linkedin_account."""
await self.delete(
self.profile_key(user_id),
self.user_key(user_id),
self.profile_analysis_key(user_id),
self.linkedin_account_key(user_id),
)
async def invalidate_company(self, company_id: str):
await self.delete(self.company_key(company_id))
async def invalidate_linkedin_posts(self, user_id: str):
"""Invalidate the scraped LinkedIn posts list for a user."""
await self.delete(self.linkedin_posts_key(user_id))
async def invalidate_post_types(self, user_id: str):
"""Invalidate both active_only variants of the post types list."""
await self.delete(
self.post_types_key(user_id, True),
self.post_types_key(user_id, False),
)
async def invalidate_post_type(self, post_type_id: str, user_id: str):
"""Invalidate a single PostType row + both list variants."""
await self.delete(
self.post_type_key(post_type_id),
self.post_types_key(user_id, True),
self.post_types_key(user_id, False),
)
async def invalidate_gen_post(self, post_id: str, user_id: str):
"""Invalidate a single GeneratedPost + the user's post list."""
await self.delete(
self.gen_post_key(post_id),
self.gen_posts_key(user_id),
)
async def invalidate_gen_posts(self, user_id: str):
"""Invalidate only the user's GeneratedPost list (not a single entry)."""
await self.delete(self.gen_posts_key(user_id))
async def invalidate_linkedin_account(self, user_id: str):
"""Invalidate the LinkedIn account cache for a user."""
await self.delete(self.linkedin_account_key(user_id))
# Global singleton
cache = CacheService()

View File

@@ -0,0 +1,208 @@
"""Database-backed job manager using Supabase + Redis pub/sub for cross-worker job updates."""
import asyncio
import json
from datetime import datetime, timezone, timedelta
from typing import Dict, Any, Optional
from loguru import logger
# BackgroundJob, JobType, JobStatus are defined in background_jobs.py.
# We import them here; background_jobs.py imports job_manager from this module at
# its bottom — Python handles this circular reference safely because background_jobs.py
# defines these symbols *before* it reaches the import-from-db_job_manager line.
from src.services.background_jobs import BackgroundJob, JobType, JobStatus
def _parse_ts(value: Optional[str]) -> Optional[datetime]:
"""Parse an ISO-8601 / Supabase timestamp string to a timezone-aware datetime."""
if not value:
return None
try:
# Supabase returns strings like "2024-01-01T12:00:00+00:00" or ending in "Z"
return datetime.fromisoformat(value.replace("Z", "+00:00"))
except ValueError:
return None
class DBJobManager:
"""Manages background jobs backed by the Supabase `background_jobs` table.
Publishes job-update payloads to Redis channel ``job_updates:{user_id}`` so
every worker process can push SSE events to its own connected clients.
"""
def _db(self):
"""Lazy DatabaseClient — avoids import at module load time."""
from src.database.client import DatabaseClient
if not hasattr(self, "_db_client"):
self._db_client = DatabaseClient()
return self._db_client
def _row_to_job(self, row: dict) -> BackgroundJob:
return BackgroundJob(
id=row["id"],
job_type=JobType(row["job_type"]),
user_id=row["user_id"],
status=JobStatus(row["status"]),
progress=row.get("progress") or 0,
message=row.get("message") or "",
error=row.get("error"),
result=row.get("result"),
created_at=_parse_ts(row.get("created_at")) or datetime.now(timezone.utc),
started_at=_parse_ts(row.get("started_at")),
completed_at=_parse_ts(row.get("completed_at")),
)
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
async def create_job(self, job_type: JobType, user_id: str) -> BackgroundJob:
"""Create a new background job row in the database."""
db = self._db()
try:
resp = await asyncio.to_thread(
lambda: db.client.table("background_jobs").insert({
"job_type": job_type.value,
"user_id": user_id,
"status": JobStatus.PENDING.value,
"progress": 0,
"message": "",
}).execute()
)
job = self._row_to_job(resp.data[0])
logger.info(f"Created background job {job.id} of type {job_type} for user {user_id}")
return job
except Exception as e:
logger.error(f"Failed to create job in DB: {e}")
raise
async def get_job(self, job_id: str) -> Optional[BackgroundJob]:
"""Fetch a single job by ID from the database."""
db = self._db()
try:
resp = await asyncio.to_thread(
lambda: db.client.table("background_jobs").select("*").eq("id", job_id).execute()
)
if resp.data:
return self._row_to_job(resp.data[0])
return None
except Exception as e:
logger.warning(f"Failed to get job {job_id}: {e}")
return None
async def get_user_jobs(self, user_id: str) -> list[BackgroundJob]:
"""Fetch the 50 most-recent jobs for a user from the database."""
db = self._db()
try:
resp = await asyncio.to_thread(
lambda: db.client.table("background_jobs").select("*")
.eq("user_id", user_id)
.order("created_at", desc=True)
.limit(50)
.execute()
)
return [self._row_to_job(r) for r in resp.data]
except Exception as e:
logger.warning(f"Failed to get jobs for user {user_id}: {e}")
return []
async def get_active_jobs(self, user_id: str) -> list[BackgroundJob]:
"""Fetch pending/running jobs for a user from the database."""
db = self._db()
try:
resp = await asyncio.to_thread(
lambda: db.client.table("background_jobs").select("*")
.eq("user_id", user_id)
.in_("status", [JobStatus.PENDING.value, JobStatus.RUNNING.value])
.order("created_at", desc=True)
.execute()
)
return [self._row_to_job(r) for r in resp.data]
except Exception as e:
logger.warning(f"Failed to get active jobs for user {user_id}: {e}")
return []
async def update_job(
self,
job_id: str,
status: Optional[JobStatus] = None,
progress: Optional[int] = None,
message: Optional[str] = None,
error: Optional[str] = None,
result: Optional[Dict[str, Any]] = None,
):
"""Update a job row in the database, then publish the new state to Redis."""
db = self._db()
update_data: dict = {}
if status is not None:
update_data["status"] = status.value
if status == JobStatus.RUNNING:
update_data["started_at"] = datetime.now(timezone.utc).isoformat()
elif status in (JobStatus.COMPLETED, JobStatus.FAILED):
update_data["completed_at"] = datetime.now(timezone.utc).isoformat()
if progress is not None:
update_data["progress"] = progress
if message is not None:
update_data["message"] = message
if error is not None:
update_data["error"] = error
if result is not None:
update_data["result"] = result
if not update_data:
return
try:
resp = await asyncio.to_thread(
lambda: db.client.table("background_jobs")
.update(update_data)
.eq("id", job_id)
.execute()
)
if resp.data:
job = self._row_to_job(resp.data[0])
await self._publish_update(job)
except Exception as e:
logger.error(f"Failed to update job {job_id} in DB: {e}")
async def cleanup_old_jobs(self, max_age_hours: int = 24):
"""Delete completed/failed jobs older than *max_age_hours* from the database."""
db = self._db()
try:
cutoff = (datetime.now(timezone.utc) - timedelta(hours=max_age_hours)).isoformat()
await asyncio.to_thread(
lambda: db.client.table("background_jobs")
.delete()
.in_("status", [JobStatus.COMPLETED.value, JobStatus.FAILED.value])
.lt("completed_at", cutoff)
.execute()
)
logger.info(f"Cleaned up background jobs older than {max_age_hours}h")
except Exception as e:
logger.warning(f"Failed to cleanup old jobs: {e}")
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
async def _publish_update(self, job: BackgroundJob):
"""Publish a job-update payload to Redis pub/sub so all workers can forward it."""
try:
from src.services.redis_client import get_redis
r = await get_redis()
payload = json.dumps({
"id": job.id,
"job_type": job.job_type.value,
"status": job.status.value,
"progress": job.progress,
"message": job.message,
"error": job.error,
})
await r.publish(f"job_updates:{job.user_id}", payload)
except Exception as e:
logger.warning(f"Failed to publish job update to Redis: {e}")
# Global singleton — replaces the old BackgroundJobManager instance
job_manager = DBJobManager()

View File

@@ -0,0 +1,27 @@
"""Single async Redis connection pool for the whole app."""
import redis.asyncio as aioredis
from loguru import logger
from src.config import settings
_pool: aioredis.Redis | None = None
async def get_redis() -> aioredis.Redis:
global _pool
if _pool is None:
_pool = aioredis.from_url(
settings.redis_url,
encoding="utf-8",
decode_responses=True,
max_connections=20,
)
return _pool
async def close_redis():
global _pool
if _pool:
await _pool.aclose()
_pool = None
logger.info("Redis connection pool closed")

View File

@@ -0,0 +1,42 @@
"""Standalone scheduler process entry point.
Run with:
python -m src.services.scheduler_runner
This module intentionally does NOT import the FastAPI app — it only starts the
SchedulerService so it can run in its own container without duplicating work in
the main web-worker containers (which set SCHEDULER_ENABLED=false).
"""
import asyncio
import signal
from loguru import logger
from src.database.client import DatabaseClient
from src.services.scheduler_service import init_scheduler
async def main():
db = DatabaseClient()
scheduler = init_scheduler(db, check_interval=60)
stop_event = asyncio.Event()
def handle_signal():
logger.info("Scheduler received shutdown signal — stopping…")
stop_event.set()
loop = asyncio.get_running_loop()
for sig in (signal.SIGINT, signal.SIGTERM):
loop.add_signal_handler(sig, handle_signal)
await scheduler.start()
logger.info("Scheduler started (dedicated process)")
await stop_event.wait()
await scheduler.stop()
logger.info("Scheduler stopped")
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -5,6 +5,7 @@ from pathlib import Path
from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
from fastapi.responses import RedirectResponse
from starlette.middleware.base import BaseHTTPMiddleware
from loguru import logger
from src.config import settings
@@ -14,35 +15,57 @@ from src.web.admin import admin_router
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Manage application lifecycle - startup and shutdown."""
# Startup
logger.info("Starting LinkedIn Post Creation System...")
# Initialize and start scheduler if enabled
# Warm up Redis connection pool
from src.services.redis_client import get_redis, close_redis
await get_redis()
# Start scheduler only when this process is the dedicated scheduler container
scheduler = None
if settings.user_frontend_enabled:
if settings.scheduler_enabled:
try:
from src.database.client import DatabaseClient
from src.services.scheduler_service import init_scheduler
db = DatabaseClient()
scheduler = init_scheduler(db, check_interval=60) # Check every 60 seconds
scheduler = init_scheduler(db, check_interval=60)
await scheduler.start()
logger.info("Scheduler service started")
logger.info("Scheduler started (dedicated process)")
except Exception as e:
logger.error(f"Failed to start scheduler: {e}")
yield # Application runs here
# Shutdown
logger.info("Shutting down LinkedIn Post Creation System...")
if scheduler:
await scheduler.stop()
logger.info("Scheduler service stopped")
await close_redis()
# Setup
app = FastAPI(title="LinkedIn Post Creation System", lifespan=lifespan)
class StaticCacheMiddleware(BaseHTTPMiddleware):
"""Set long-lived Cache-Control headers on static assets."""
async def dispatch(self, request, call_next):
response = await call_next(request)
if request.url.path.startswith("/static/"):
if request.url.path.endswith((".css", ".js")):
response.headers["Cache-Control"] = "public, max-age=86400, stale-while-revalidate=3600"
elif request.url.path.endswith((".png", ".jpg", ".jpeg", ".svg", ".ico", ".webp")):
response.headers["Cache-Control"] = "public, max-age=604800, immutable"
else:
response.headers["Cache-Control"] = "public, max-age=3600"
return response
app.add_middleware(StaticCacheMiddleware)
# Static files
app.mount("/static", StaticFiles(directory=Path(__file__).parent / "static"), name="static")

View File

@@ -32,10 +32,11 @@ from src.services.email_service import (
mark_token_used,
)
from src.services.background_jobs import (
job_manager, JobType, JobStatus,
JobType, JobStatus,
run_post_scraping, run_profile_analysis, run_post_categorization, run_post_type_analysis,
run_full_analysis_pipeline, run_post_recategorization
)
from src.services.db_job_manager import job_manager
from src.services.storage_service import storage
# Router for user frontend
@@ -93,6 +94,7 @@ async def get_user_avatar(session: UserSession, user_id: UUID) -> Optional[str]:
return None
def require_user_session(request: Request) -> Optional[UserSession]:
"""Check if user is authenticated, redirect to login if not."""
session = get_user_session(request)
@@ -676,7 +678,7 @@ async def onboarding_profile_submit(
logger.info(f"Skipping scraping - {len(existing_posts)} posts already exist for user {user_id}")
if should_scrape:
job = job_manager.create_job(JobType.POST_SCRAPING, str(user_id))
job = await job_manager.create_job(JobType.POST_SCRAPING, str(user_id))
background_tasks.add_task(run_post_scraping, user_id, linkedin_url, job.id)
logger.info(f"Started background scraping for user {user_id}")
@@ -829,7 +831,7 @@ async def api_rescrape(request: Request, background_tasks: BackgroundTasks):
return JSONResponse({"error": "No LinkedIn URL found"}, status_code=400)
# Create job and start scraping
job = job_manager.create_job(JobType.POST_SCRAPING, session.user_id)
job = await job_manager.create_job(JobType.POST_SCRAPING, session.user_id)
background_tasks.add_task(run_post_scraping, user_id, profile.linkedin_url, job.id)
return JSONResponse({"success": True, "job_id": job.id})
@@ -1451,53 +1453,45 @@ async def api_categorize_post(request: Request):
@user_router.get("/api/job-updates")
async def job_updates_sse(request: Request):
"""Server-Sent Events endpoint for job updates."""
"""Server-Sent Events endpoint for job updates (Redis pub/sub — works across workers)."""
session = require_user_session(request)
tracking_id = getattr(session, 'user_id', None) or getattr(session, 'company_id', None)
if not session or not tracking_id:
return JSONResponse({"error": "Not authenticated"}, status_code=401)
async def event_generator():
queue = asyncio.Queue()
async def on_job_update(job):
await queue.put(job)
# Register listener
job_manager.add_listener(tracking_id, on_job_update)
from src.services.redis_client import get_redis
r = await get_redis()
pubsub = r.pubsub()
await pubsub.subscribe(f"job_updates:{tracking_id}")
try:
# Send initial active jobs
active_jobs = job_manager.get_active_jobs(tracking_id)
for job in active_jobs:
# Send any currently active jobs as the initial state
for job in await job_manager.get_active_jobs(tracking_id):
data = {
"id": job.id,
"job_type": job.job_type.value,
"status": job.status.value,
"progress": job.progress,
"message": job.message,
"error": job.error
"error": job.error,
}
yield f"data: {json.dumps(data)}\n\n"
# Stream updates
# Stream pub/sub messages, keepalive on timeout
while True:
try:
job = await asyncio.wait_for(queue.get(), timeout=30)
data = {
"id": job.id,
"job_type": job.job_type.value,
"status": job.status.value,
"progress": job.progress,
"message": job.message,
"error": job.error
}
yield f"data: {json.dumps(data)}\n\n"
msg = await asyncio.wait_for(
pubsub.get_message(ignore_subscribe_messages=True), timeout=30
)
if msg and msg.get("type") == "message":
yield f"data: {msg['data']}\n\n"
else:
yield ": keepalive\n\n"
except asyncio.TimeoutError:
# Send keepalive
yield ": keepalive\n\n"
finally:
job_manager.remove_listener(tracking_id, on_job_update)
await pubsub.unsubscribe(f"job_updates:{tracking_id}")
await pubsub.aclose()
return StreamingResponse(
event_generator(),
@@ -1505,8 +1499,8 @@ async def job_updates_sse(request: Request):
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no"
}
"X-Accel-Buffering": "no",
},
)
@@ -1521,7 +1515,7 @@ async def api_run_post_type_analysis(request: Request, background_tasks: Backgro
user_id = UUID(session.user_id)
# Create job
job = job_manager.create_job(JobType.POST_TYPE_ANALYSIS, session.user_id)
job = await job_manager.create_job(JobType.POST_TYPE_ANALYSIS, session.user_id)
# Run in background
background_tasks.add_task(run_post_type_analysis, user_id, job.id)
@@ -3278,13 +3272,13 @@ async def save_all_and_reanalyze(request: Request, background_tasks: BackgroundT
# Only trigger re-categorization and analysis if there were structural changes
if has_structural_changes:
# Create background job for post re-categorization (ALL posts)
categorization_job = job_manager.create_job(
categorization_job = await job_manager.create_job(
job_type=JobType.POST_CATEGORIZATION,
user_id=user_id_str
)
# Create background job for post type analysis
analysis_job = job_manager.create_job(
analysis_job = await job_manager.create_job(
job_type=JobType.POST_TYPE_ANALYSIS,
user_id=user_id_str
)