added scalability and performance improvements (redis, http caching etc)
This commit is contained in:
@@ -65,6 +65,10 @@ class Settings(BaseSettings):
|
||||
moco_api_key: str = "" # Token für Authorization-Header
|
||||
moco_domain: str = "" # Subdomain: {domain}.mocoapp.com
|
||||
|
||||
# Redis
|
||||
redis_url: str = "redis://redis:6379/0"
|
||||
scheduler_enabled: bool = False # True only on dedicated scheduler container
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=".env",
|
||||
env_file_encoding="utf-8",
|
||||
|
||||
@@ -13,6 +13,12 @@ from src.database.models import (
|
||||
User, Profile, Company, Invitation, ExamplePost, ReferenceProfile,
|
||||
ApiUsageLog, LicenseKey, CompanyDailyQuota, LicenseKeyOffer
|
||||
)
|
||||
from src.services.cache_service import (
|
||||
cache,
|
||||
PROFILE_TTL, USER_TTL, LINKEDIN_ACCOUNT_TTL, PROFILE_ANALYSIS_TTL,
|
||||
COMPANY_TTL, LINKEDIN_POSTS_TTL, POST_TYPES_TTL, POST_TYPE_TTL,
|
||||
GEN_POST_TTL, GEN_POSTS_TTL,
|
||||
)
|
||||
|
||||
|
||||
class DatabaseClient:
|
||||
@@ -111,16 +117,26 @@ class DatabaseClient:
|
||||
).execute()
|
||||
)
|
||||
logger.info(f"Saved {len(result.data)} LinkedIn posts")
|
||||
return [LinkedInPost(**item) for item in result.data]
|
||||
saved = [LinkedInPost(**item) for item in result.data]
|
||||
# Invalidate cache for all affected users
|
||||
affected_user_ids = {str(p.user_id) for p in saved}
|
||||
for uid in affected_user_ids:
|
||||
await cache.invalidate_linkedin_posts(uid)
|
||||
return saved
|
||||
|
||||
async def get_linkedin_posts(self, user_id: UUID) -> List[LinkedInPost]:
|
||||
"""Get all LinkedIn posts for user."""
|
||||
key = cache.linkedin_posts_key(str(user_id))
|
||||
if (hit := await cache.get(key)) is not None:
|
||||
return [LinkedInPost(**item) for item in hit]
|
||||
result = await asyncio.to_thread(
|
||||
lambda: self.client.table("linkedin_posts").select("*").eq(
|
||||
"user_id", str(user_id)
|
||||
).order("post_date", desc=True).execute()
|
||||
)
|
||||
return [LinkedInPost(**item) for item in result.data]
|
||||
posts = [LinkedInPost(**item) for item in result.data]
|
||||
await cache.set(key, [p.model_dump(mode="json") for p in posts], LINKEDIN_POSTS_TTL)
|
||||
return posts
|
||||
|
||||
async def copy_posts_to_user(self, source_user_id: UUID, target_user_id: UUID) -> List[LinkedInPost]:
|
||||
"""Copy all LinkedIn posts from one user to another."""
|
||||
@@ -148,9 +164,16 @@ class DatabaseClient:
|
||||
|
||||
async def delete_linkedin_post(self, post_id: UUID) -> None:
|
||||
"""Delete a LinkedIn post."""
|
||||
# Fetch user_id before deleting so we can invalidate the cache
|
||||
lookup = await asyncio.to_thread(
|
||||
lambda: self.client.table("linkedin_posts").select("user_id").eq("id", str(post_id)).execute()
|
||||
)
|
||||
user_id = lookup.data[0]["user_id"] if lookup.data else None
|
||||
await asyncio.to_thread(
|
||||
lambda: self.client.table("linkedin_posts").delete().eq("id", str(post_id)).execute()
|
||||
)
|
||||
if user_id:
|
||||
await cache.invalidate_linkedin_posts(user_id)
|
||||
logger.info(f"Deleted LinkedIn post: {post_id}")
|
||||
|
||||
async def get_unclassified_posts(self, user_id: UUID) -> List[LinkedInPost]:
|
||||
@@ -195,9 +218,11 @@ class DatabaseClient:
|
||||
if "user_id" in updates and updates["user_id"]:
|
||||
updates["user_id"] = str(updates["user_id"])
|
||||
|
||||
await asyncio.to_thread(
|
||||
result = await asyncio.to_thread(
|
||||
lambda: self.client.table("linkedin_posts").update(updates).eq("id", str(post_id)).execute()
|
||||
)
|
||||
if result.data:
|
||||
await cache.invalidate_linkedin_posts(result.data[0]["user_id"])
|
||||
logger.debug(f"Updated LinkedIn post {post_id}")
|
||||
|
||||
async def update_posts_classification_bulk(
|
||||
@@ -233,7 +258,9 @@ class DatabaseClient:
|
||||
lambda: self.client.table("post_types").insert(data).execute()
|
||||
)
|
||||
logger.info(f"Created post type: {result.data[0]['name']}")
|
||||
return PostType(**result.data[0])
|
||||
created = PostType(**result.data[0])
|
||||
await cache.invalidate_post_types(str(created.user_id))
|
||||
return created
|
||||
|
||||
async def create_post_types_bulk(self, post_types: List[PostType]) -> List[PostType]:
|
||||
"""Create multiple post types at once."""
|
||||
@@ -251,10 +278,18 @@ class DatabaseClient:
|
||||
lambda: self.client.table("post_types").insert(data).execute()
|
||||
)
|
||||
logger.info(f"Created {len(result.data)} post types")
|
||||
return [PostType(**item) for item in result.data]
|
||||
created = [PostType(**item) for item in result.data]
|
||||
affected_user_ids = {str(pt.user_id) for pt in created}
|
||||
for uid in affected_user_ids:
|
||||
await cache.invalidate_post_types(uid)
|
||||
return created
|
||||
|
||||
async def get_post_types(self, user_id: UUID, active_only: bool = True) -> List[PostType]:
|
||||
"""Get all post types for a user."""
|
||||
key = cache.post_types_key(str(user_id), active_only)
|
||||
if (hit := await cache.get(key)) is not None:
|
||||
return [PostType(**item) for item in hit]
|
||||
|
||||
def _query():
|
||||
query = self.client.table("post_types").select("*").eq("user_id", str(user_id))
|
||||
if active_only:
|
||||
@@ -262,7 +297,9 @@ class DatabaseClient:
|
||||
return query.order("name").execute()
|
||||
|
||||
result = await asyncio.to_thread(_query)
|
||||
return [PostType(**item) for item in result.data]
|
||||
post_types = [PostType(**item) for item in result.data]
|
||||
await cache.set(key, [pt.model_dump(mode="json") for pt in post_types], POST_TYPES_TTL)
|
||||
return post_types
|
||||
|
||||
# Alias for get_post_types
|
||||
async def get_customer_post_types(self, user_id: UUID, active_only: bool = True) -> List[PostType]:
|
||||
@@ -271,13 +308,18 @@ class DatabaseClient:
|
||||
|
||||
async def get_post_type(self, post_type_id: UUID) -> Optional[PostType]:
|
||||
"""Get a single post type by ID."""
|
||||
key = cache.post_type_key(str(post_type_id))
|
||||
if (hit := await cache.get(key)) is not None:
|
||||
return PostType(**hit)
|
||||
result = await asyncio.to_thread(
|
||||
lambda: self.client.table("post_types").select("*").eq(
|
||||
"id", str(post_type_id)
|
||||
).execute()
|
||||
)
|
||||
if result.data:
|
||||
return PostType(**result.data[0])
|
||||
pt = PostType(**result.data[0])
|
||||
await cache.set(key, pt.model_dump(mode="json"), POST_TYPE_TTL)
|
||||
return pt
|
||||
return None
|
||||
|
||||
async def update_post_type(self, post_type_id: UUID, updates: Dict[str, Any]) -> PostType:
|
||||
@@ -288,7 +330,9 @@ class DatabaseClient:
|
||||
).execute()
|
||||
)
|
||||
logger.info(f"Updated post type: {post_type_id}")
|
||||
return PostType(**result.data[0])
|
||||
pt = PostType(**result.data[0])
|
||||
await cache.invalidate_post_type(str(post_type_id), str(pt.user_id))
|
||||
return pt
|
||||
|
||||
async def update_post_type_analysis(
|
||||
self,
|
||||
@@ -306,23 +350,36 @@ class DatabaseClient:
|
||||
}).eq("id", str(post_type_id)).execute()
|
||||
)
|
||||
logger.info(f"Updated analysis for post type: {post_type_id}")
|
||||
return PostType(**result.data[0])
|
||||
pt = PostType(**result.data[0])
|
||||
await cache.invalidate_post_type(str(post_type_id), str(pt.user_id))
|
||||
return pt
|
||||
|
||||
async def delete_post_type(self, post_type_id: UUID, soft: bool = True) -> None:
|
||||
"""Delete a post type (soft delete by default)."""
|
||||
if soft:
|
||||
await asyncio.to_thread(
|
||||
result = await asyncio.to_thread(
|
||||
lambda: self.client.table("post_types").update({
|
||||
"is_active": False
|
||||
}).eq("id", str(post_type_id)).execute()
|
||||
)
|
||||
logger.info(f"Soft deleted post type: {post_type_id}")
|
||||
if result.data:
|
||||
await cache.invalidate_post_type(str(post_type_id), result.data[0]["user_id"])
|
||||
else:
|
||||
# Fetch user_id before hard delete for cache invalidation
|
||||
lookup = await asyncio.to_thread(
|
||||
lambda: self.client.table("post_types").select("user_id").eq(
|
||||
"id", str(post_type_id)
|
||||
).execute()
|
||||
)
|
||||
user_id = lookup.data[0]["user_id"] if lookup.data else None
|
||||
await asyncio.to_thread(
|
||||
lambda: self.client.table("post_types").delete().eq(
|
||||
"id", str(post_type_id)
|
||||
).execute()
|
||||
)
|
||||
if user_id:
|
||||
await cache.invalidate_post_type(str(post_type_id), user_id)
|
||||
logger.info(f"Hard deleted post type: {post_type_id}")
|
||||
|
||||
# ==================== TOPICS ====================
|
||||
@@ -418,17 +475,24 @@ class DatabaseClient:
|
||||
)
|
||||
|
||||
logger.info(f"Saved profile analysis for user: {analysis.user_id}")
|
||||
return ProfileAnalysis(**result.data[0])
|
||||
saved = ProfileAnalysis(**result.data[0])
|
||||
await cache.delete(cache.profile_analysis_key(str(analysis.user_id)))
|
||||
return saved
|
||||
|
||||
async def get_profile_analysis(self, user_id: UUID) -> Optional[ProfileAnalysis]:
|
||||
"""Get profile analysis for user."""
|
||||
key = cache.profile_analysis_key(str(user_id))
|
||||
if (hit := await cache.get(key)) is not None:
|
||||
return ProfileAnalysis(**hit)
|
||||
result = await asyncio.to_thread(
|
||||
lambda: self.client.table("profile_analyses").select("*").eq(
|
||||
"user_id", str(user_id)
|
||||
).execute()
|
||||
)
|
||||
if result.data:
|
||||
return ProfileAnalysis(**result.data[0])
|
||||
pa = ProfileAnalysis(**result.data[0])
|
||||
await cache.set(key, pa.model_dump(mode="json"), PROFILE_ANALYSIS_TTL)
|
||||
return pa
|
||||
return None
|
||||
|
||||
# ==================== RESEARCH RESULTS ====================
|
||||
@@ -491,7 +555,9 @@ class DatabaseClient:
|
||||
lambda: self.client.table("generated_posts").insert(data).execute()
|
||||
)
|
||||
logger.info(f"Saved generated post: {result.data[0]['id']}")
|
||||
return GeneratedPost(**result.data[0])
|
||||
saved = GeneratedPost(**result.data[0])
|
||||
await cache.invalidate_gen_posts(str(saved.user_id))
|
||||
return saved
|
||||
|
||||
async def update_generated_post(self, post_id: UUID, updates: Dict[str, Any]) -> GeneratedPost:
|
||||
"""Update generated post."""
|
||||
@@ -504,11 +570,8 @@ class DatabaseClient:
|
||||
# Handle metadata dict - ensure all nested datetime values are serialized
|
||||
if 'metadata' in updates and isinstance(updates['metadata'], dict):
|
||||
serialized_metadata = {}
|
||||
for key, value in updates['metadata'].items():
|
||||
if isinstance(value, datetime):
|
||||
serialized_metadata[key] = value.isoformat()
|
||||
else:
|
||||
serialized_metadata[key] = value
|
||||
for k, value in updates['metadata'].items():
|
||||
serialized_metadata[k] = value.isoformat() if isinstance(value, datetime) else value
|
||||
updates['metadata'] = serialized_metadata
|
||||
|
||||
result = await asyncio.to_thread(
|
||||
@@ -517,26 +580,38 @@ class DatabaseClient:
|
||||
).execute()
|
||||
)
|
||||
logger.info(f"Updated generated post: {post_id}")
|
||||
return GeneratedPost(**result.data[0])
|
||||
updated = GeneratedPost(**result.data[0])
|
||||
await cache.invalidate_gen_post(str(post_id), str(updated.user_id))
|
||||
return updated
|
||||
|
||||
async def get_generated_posts(self, user_id: UUID) -> List[GeneratedPost]:
|
||||
"""Get all generated posts for user."""
|
||||
key = cache.gen_posts_key(str(user_id))
|
||||
if (hit := await cache.get(key)) is not None:
|
||||
return [GeneratedPost(**item) for item in hit]
|
||||
result = await asyncio.to_thread(
|
||||
lambda: self.client.table("generated_posts").select("*").eq(
|
||||
"user_id", str(user_id)
|
||||
).order("created_at", desc=True).execute()
|
||||
)
|
||||
return [GeneratedPost(**item) for item in result.data]
|
||||
posts = [GeneratedPost(**item) for item in result.data]
|
||||
await cache.set(key, [p.model_dump(mode="json") for p in posts], GEN_POSTS_TTL)
|
||||
return posts
|
||||
|
||||
async def get_generated_post(self, post_id: UUID) -> Optional[GeneratedPost]:
|
||||
"""Get a single generated post by ID."""
|
||||
key = cache.gen_post_key(str(post_id))
|
||||
if (hit := await cache.get(key)) is not None:
|
||||
return GeneratedPost(**hit)
|
||||
result = await asyncio.to_thread(
|
||||
lambda: self.client.table("generated_posts").select("*").eq(
|
||||
"id", str(post_id)
|
||||
).execute()
|
||||
)
|
||||
if result.data:
|
||||
return GeneratedPost(**result.data[0])
|
||||
post = GeneratedPost(**result.data[0])
|
||||
await cache.set(key, post.model_dump(mode="json"), GEN_POST_TTL)
|
||||
return post
|
||||
return None
|
||||
|
||||
async def get_scheduled_posts_due(self) -> List[GeneratedPost]:
|
||||
@@ -627,11 +702,16 @@ class DatabaseClient:
|
||||
|
||||
async def get_profile(self, user_id: UUID) -> Optional[Profile]:
|
||||
"""Get profile by user ID."""
|
||||
key = cache.profile_key(str(user_id))
|
||||
if (hit := await cache.get(key)) is not None:
|
||||
return Profile(**hit)
|
||||
result = await asyncio.to_thread(
|
||||
lambda: self.client.table("profiles").select("*").eq("id", str(user_id)).execute()
|
||||
)
|
||||
if result.data:
|
||||
return Profile(**result.data[0])
|
||||
profile = Profile(**result.data[0])
|
||||
await cache.set(key, profile.model_dump(mode="json"), PROFILE_TTL)
|
||||
return profile
|
||||
return None
|
||||
|
||||
async def get_profiles_by_linkedin_url(self, linkedin_url: str) -> List[Profile]:
|
||||
@@ -645,9 +725,9 @@ class DatabaseClient:
|
||||
"""Update profile fields."""
|
||||
if "company_id" in updates and updates["company_id"]:
|
||||
updates["company_id"] = str(updates["company_id"])
|
||||
for key in ["account_type", "onboarding_status"]:
|
||||
if key in updates and hasattr(updates[key], "value"):
|
||||
updates[key] = updates[key].value
|
||||
for k in ["account_type", "onboarding_status"]:
|
||||
if k in updates and hasattr(updates[k], "value"):
|
||||
updates[k] = updates[k].value
|
||||
|
||||
result = await asyncio.to_thread(
|
||||
lambda: self.client.table("profiles").update(updates).eq(
|
||||
@@ -655,19 +735,26 @@ class DatabaseClient:
|
||||
).execute()
|
||||
)
|
||||
logger.info(f"Updated profile: {user_id}")
|
||||
return Profile(**result.data[0])
|
||||
profile = Profile(**result.data[0])
|
||||
await cache.invalidate_profile(str(user_id))
|
||||
return profile
|
||||
|
||||
# ==================== LINKEDIN ACCOUNTS ====================
|
||||
|
||||
async def get_linkedin_account(self, user_id: UUID) -> Optional['LinkedInAccount']:
|
||||
"""Get LinkedIn account for user."""
|
||||
from src.database.models import LinkedInAccount
|
||||
key = cache.linkedin_account_key(str(user_id))
|
||||
if (hit := await cache.get(key)) is not None:
|
||||
return LinkedInAccount(**hit)
|
||||
result = await asyncio.to_thread(
|
||||
lambda: self.client.table("linkedin_accounts").select("*")
|
||||
.eq("user_id", str(user_id)).eq("is_active", True).execute()
|
||||
)
|
||||
if result.data:
|
||||
return LinkedInAccount(**result.data[0])
|
||||
account = LinkedInAccount(**result.data[0])
|
||||
await cache.set(key, account.model_dump(mode="json"), LINKEDIN_ACCOUNT_TTL)
|
||||
return account
|
||||
return None
|
||||
|
||||
async def get_linkedin_account_by_id(self, account_id: UUID) -> Optional['LinkedInAccount']:
|
||||
@@ -692,7 +779,9 @@ class DatabaseClient:
|
||||
lambda: self.client.table("linkedin_accounts").insert(data).execute()
|
||||
)
|
||||
logger.info(f"Created LinkedIn account for user: {account.user_id}")
|
||||
return LinkedInAccount(**result.data[0])
|
||||
created = LinkedInAccount(**result.data[0])
|
||||
await cache.invalidate_linkedin_account(str(created.user_id))
|
||||
return created
|
||||
|
||||
async def update_linkedin_account(self, account_id: UUID, updates: Dict) -> 'LinkedInAccount':
|
||||
"""Update LinkedIn account."""
|
||||
@@ -708,25 +797,40 @@ class DatabaseClient:
|
||||
.eq("id", str(account_id)).execute()
|
||||
)
|
||||
logger.info(f"Updated LinkedIn account: {account_id}")
|
||||
return LinkedInAccount(**result.data[0])
|
||||
updated = LinkedInAccount(**result.data[0])
|
||||
await cache.invalidate_linkedin_account(str(updated.user_id))
|
||||
return updated
|
||||
|
||||
async def delete_linkedin_account(self, account_id: UUID) -> None:
|
||||
"""Delete LinkedIn account connection."""
|
||||
# Fetch user_id before delete for cache invalidation
|
||||
lookup = await asyncio.to_thread(
|
||||
lambda: self.client.table("linkedin_accounts").select("user_id")
|
||||
.eq("id", str(account_id)).execute()
|
||||
)
|
||||
user_id = lookup.data[0]["user_id"] if lookup.data else None
|
||||
await asyncio.to_thread(
|
||||
lambda: self.client.table("linkedin_accounts").delete()
|
||||
.eq("id", str(account_id)).execute()
|
||||
)
|
||||
if user_id:
|
||||
await cache.invalidate_linkedin_account(user_id)
|
||||
logger.info(f"Deleted LinkedIn account: {account_id}")
|
||||
|
||||
# ==================== USERS ====================
|
||||
|
||||
async def get_user(self, user_id: UUID) -> Optional[User]:
|
||||
"""Get user by ID (from users view)."""
|
||||
key = cache.user_key(str(user_id))
|
||||
if (hit := await cache.get(key)) is not None:
|
||||
return User(**hit)
|
||||
result = await asyncio.to_thread(
|
||||
lambda: self.client.table("users").select("*").eq("id", str(user_id)).execute()
|
||||
)
|
||||
if result.data:
|
||||
return User(**result.data[0])
|
||||
user = User(**result.data[0])
|
||||
await cache.set(key, user.model_dump(mode="json"), USER_TTL)
|
||||
return user
|
||||
return None
|
||||
|
||||
async def get_user_by_email(self, email: str) -> Optional[User]:
|
||||
@@ -757,7 +861,10 @@ class DatabaseClient:
|
||||
|
||||
if profile_updates:
|
||||
await self.update_profile(user_id, profile_updates)
|
||||
# update_profile already calls cache.invalidate_profile which also kills user_key
|
||||
|
||||
# Invalidate user view separately (in case it wasn't covered above)
|
||||
await cache.delete(cache.user_key(str(user_id)))
|
||||
return await self.get_user(user_id)
|
||||
|
||||
async def list_users(self, account_type: Optional[str] = None, company_id: Optional[UUID] = None) -> List[User]:
|
||||
@@ -838,11 +945,16 @@ class DatabaseClient:
|
||||
|
||||
async def get_company(self, company_id: UUID) -> Optional[Company]:
|
||||
"""Get company by ID."""
|
||||
key = cache.company_key(str(company_id))
|
||||
if (hit := await cache.get(key)) is not None:
|
||||
return Company(**hit)
|
||||
result = await asyncio.to_thread(
|
||||
lambda: self.client.table("companies").select("*").eq("id", str(company_id)).execute()
|
||||
)
|
||||
if result.data:
|
||||
return Company(**result.data[0])
|
||||
company = Company(**result.data[0])
|
||||
await cache.set(key, company.model_dump(mode="json"), COMPANY_TTL)
|
||||
return company
|
||||
return None
|
||||
|
||||
async def get_company_by_owner(self, owner_user_id: UUID) -> Optional[Company]:
|
||||
@@ -871,7 +983,9 @@ class DatabaseClient:
|
||||
).execute()
|
||||
)
|
||||
logger.info(f"Updated company: {company_id}")
|
||||
return Company(**result.data[0])
|
||||
company = Company(**result.data[0])
|
||||
await cache.invalidate_company(str(company_id))
|
||||
return company
|
||||
|
||||
async def list_companies(self) -> List[Company]:
|
||||
"""List all companies."""
|
||||
|
||||
@@ -156,8 +156,10 @@ class BackgroundJobManager:
|
||||
logger.info(f"Cleaned up {len(to_remove)} old background jobs")
|
||||
|
||||
|
||||
# Global instance
|
||||
job_manager = BackgroundJobManager()
|
||||
# Global instance — backed by Supabase DB + Redis pub/sub for multi-worker safety.
|
||||
# db_job_manager imports BackgroundJob/JobType/JobStatus from this module, so
|
||||
# this import must stay at the bottom to avoid a circular-import issue.
|
||||
from src.services.db_job_manager import job_manager # noqa: F401
|
||||
|
||||
|
||||
async def run_post_scraping(user_id: UUID, linkedin_url: str, job_id: str):
|
||||
@@ -385,6 +387,11 @@ async def run_post_categorization(user_id: UUID, job_id: str):
|
||||
message=f"{len(classifications)} Posts kategorisiert!"
|
||||
)
|
||||
|
||||
# Invalidate cached LinkedIn posts — classifications changed but bulk update
|
||||
# doesn't have user_id per-row, so we invalidate explicitly here.
|
||||
from src.services.cache_service import cache as _cache
|
||||
await _cache.invalidate_linkedin_posts(str(user_id))
|
||||
|
||||
logger.info(f"Post categorization completed for user {user_id}: {len(classifications)} posts")
|
||||
|
||||
except Exception as e:
|
||||
@@ -480,6 +487,9 @@ async def run_post_recategorization(user_id: UUID, job_id: str):
|
||||
message=f"{len(classifications)} Posts re-kategorisiert!"
|
||||
)
|
||||
|
||||
from src.services.cache_service import cache as _cache
|
||||
await _cache.invalidate_linkedin_posts(str(user_id))
|
||||
|
||||
logger.info(f"Post re-categorization completed for user {user_id}: {len(classifications)} posts")
|
||||
|
||||
except Exception as e:
|
||||
@@ -556,21 +566,15 @@ async def run_full_analysis_pipeline(user_id: UUID):
|
||||
logger.info(f"Starting full analysis pipeline for user {user_id}")
|
||||
|
||||
# 1. Profile Analysis
|
||||
job1 = job_manager.create_job(JobType.PROFILE_ANALYSIS, str(user_id))
|
||||
job1 = await job_manager.create_job(JobType.PROFILE_ANALYSIS, str(user_id))
|
||||
await run_profile_analysis(user_id, job1.id)
|
||||
|
||||
if job1.status == JobStatus.FAILED:
|
||||
logger.warning(f"Profile analysis failed, continuing with categorization")
|
||||
|
||||
# 2. Post Categorization
|
||||
job2 = job_manager.create_job(JobType.POST_CATEGORIZATION, str(user_id))
|
||||
# 2. Post Categorization (always continue regardless of previous step outcome)
|
||||
job2 = await job_manager.create_job(JobType.POST_CATEGORIZATION, str(user_id))
|
||||
await run_post_categorization(user_id, job2.id)
|
||||
|
||||
if job2.status == JobStatus.FAILED:
|
||||
logger.warning(f"Post categorization failed, continuing with post type analysis")
|
||||
|
||||
# 3. Post Type Analysis
|
||||
job3 = job_manager.create_job(JobType.POST_TYPE_ANALYSIS, str(user_id))
|
||||
job3 = await job_manager.create_job(JobType.POST_TYPE_ANALYSIS, str(user_id))
|
||||
await run_post_type_analysis(user_id, job3.id)
|
||||
|
||||
logger.info(f"Full analysis pipeline completed for user {user_id}")
|
||||
|
||||
157
src/services/cache_service.py
Normal file
157
src/services/cache_service.py
Normal file
@@ -0,0 +1,157 @@
|
||||
"""Typed cache helpers backed by Redis.
|
||||
|
||||
All failures are silent (logged as warnings) so Redis being down never causes an outage.
|
||||
|
||||
Key design:
|
||||
profile:{user_id} — Profile row
|
||||
user:{user_id} — User view row
|
||||
linkedin_account:{user_id} — Active LinkedInAccount row
|
||||
profile_analysis:{user_id} — ProfileAnalysis row
|
||||
company:{company_id} — Company row
|
||||
linkedin_posts:{user_id} — List[LinkedInPost] (scraped reference posts)
|
||||
post_types:{user_id}:1 or :0 — List[PostType] (active_only=True/False)
|
||||
post_type:{post_type_id} — Single PostType row
|
||||
gen_post:{post_id} — Single GeneratedPost row
|
||||
gen_posts:{user_id} — List[GeneratedPost] for user
|
||||
"""
|
||||
import json
|
||||
from typing import Any, Optional
|
||||
from loguru import logger
|
||||
|
||||
# TTL constants (seconds)
|
||||
PROFILE_TTL = 300 # 5 min — updated on settings / onboarding changes
|
||||
USER_TTL = 300 # 5 min
|
||||
LINKEDIN_ACCOUNT_TTL = 300 # 5 min — updated only on OAuth connect/disconnect
|
||||
PROFILE_ANALYSIS_TTL = 600 # 10 min — computed infrequently by background job
|
||||
COMPANY_TTL = 300 # 5 min — company settings
|
||||
LINKEDIN_POSTS_TTL = 600 # 10 min — scraped reference data, rarely changes
|
||||
POST_TYPES_TTL = 600 # 10 min — strategy config, rarely changes
|
||||
POST_TYPE_TTL = 600 # 10 min
|
||||
GEN_POST_TTL = 120 # 2 min — status/content changes frequently
|
||||
GEN_POSTS_TTL = 120 # 2 min
|
||||
|
||||
|
||||
class CacheService:
|
||||
"""Redis-backed cache with typed key helpers and silent failure semantics."""
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Key helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def profile_key(self, user_id: str) -> str:
|
||||
return f"profile:{user_id}"
|
||||
|
||||
def user_key(self, user_id: str) -> str:
|
||||
return f"user:{user_id}"
|
||||
|
||||
def linkedin_account_key(self, user_id: str) -> str:
|
||||
return f"linkedin_account:{user_id}"
|
||||
|
||||
def profile_analysis_key(self, user_id: str) -> str:
|
||||
return f"profile_analysis:{user_id}"
|
||||
|
||||
def company_key(self, company_id: str) -> str:
|
||||
return f"company:{company_id}"
|
||||
|
||||
def linkedin_posts_key(self, user_id: str) -> str:
|
||||
return f"linkedin_posts:{user_id}"
|
||||
|
||||
def post_types_key(self, user_id: str, active_only: bool = True) -> str:
|
||||
return f"post_types:{user_id}:{'1' if active_only else '0'}"
|
||||
|
||||
def post_type_key(self, post_type_id: str) -> str:
|
||||
return f"post_type:{post_type_id}"
|
||||
|
||||
def gen_post_key(self, post_id: str) -> str:
|
||||
return f"gen_post:{post_id}"
|
||||
|
||||
def gen_posts_key(self, user_id: str) -> str:
|
||||
return f"gen_posts:{user_id}"
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Core operations
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def get(self, key: str) -> Optional[Any]:
|
||||
try:
|
||||
from src.services.redis_client import get_redis
|
||||
r = await get_redis()
|
||||
value = await r.get(key)
|
||||
if value is not None:
|
||||
return json.loads(value)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning(f"Cache get failed for {key}: {e}")
|
||||
return None
|
||||
|
||||
async def set(self, key: str, value: Any, ttl: int):
|
||||
try:
|
||||
from src.services.redis_client import get_redis
|
||||
r = await get_redis()
|
||||
await r.setex(key, ttl, json.dumps(value, default=str))
|
||||
except Exception as e:
|
||||
logger.warning(f"Cache set failed for {key}: {e}")
|
||||
|
||||
async def delete(self, *keys: str):
|
||||
if not keys:
|
||||
return
|
||||
try:
|
||||
from src.services.redis_client import get_redis
|
||||
r = await get_redis()
|
||||
await r.delete(*keys)
|
||||
except Exception as e:
|
||||
logger.warning(f"Cache delete failed for {keys}: {e}")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Compound invalidation helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def invalidate_profile(self, user_id: str):
|
||||
"""Invalidate profile, user view, profile_analysis, and linkedin_account."""
|
||||
await self.delete(
|
||||
self.profile_key(user_id),
|
||||
self.user_key(user_id),
|
||||
self.profile_analysis_key(user_id),
|
||||
self.linkedin_account_key(user_id),
|
||||
)
|
||||
|
||||
async def invalidate_company(self, company_id: str):
|
||||
await self.delete(self.company_key(company_id))
|
||||
|
||||
async def invalidate_linkedin_posts(self, user_id: str):
|
||||
"""Invalidate the scraped LinkedIn posts list for a user."""
|
||||
await self.delete(self.linkedin_posts_key(user_id))
|
||||
|
||||
async def invalidate_post_types(self, user_id: str):
|
||||
"""Invalidate both active_only variants of the post types list."""
|
||||
await self.delete(
|
||||
self.post_types_key(user_id, True),
|
||||
self.post_types_key(user_id, False),
|
||||
)
|
||||
|
||||
async def invalidate_post_type(self, post_type_id: str, user_id: str):
|
||||
"""Invalidate a single PostType row + both list variants."""
|
||||
await self.delete(
|
||||
self.post_type_key(post_type_id),
|
||||
self.post_types_key(user_id, True),
|
||||
self.post_types_key(user_id, False),
|
||||
)
|
||||
|
||||
async def invalidate_gen_post(self, post_id: str, user_id: str):
|
||||
"""Invalidate a single GeneratedPost + the user's post list."""
|
||||
await self.delete(
|
||||
self.gen_post_key(post_id),
|
||||
self.gen_posts_key(user_id),
|
||||
)
|
||||
|
||||
async def invalidate_gen_posts(self, user_id: str):
|
||||
"""Invalidate only the user's GeneratedPost list (not a single entry)."""
|
||||
await self.delete(self.gen_posts_key(user_id))
|
||||
|
||||
async def invalidate_linkedin_account(self, user_id: str):
|
||||
"""Invalidate the LinkedIn account cache for a user."""
|
||||
await self.delete(self.linkedin_account_key(user_id))
|
||||
|
||||
|
||||
# Global singleton
|
||||
cache = CacheService()
|
||||
208
src/services/db_job_manager.py
Normal file
208
src/services/db_job_manager.py
Normal file
@@ -0,0 +1,208 @@
|
||||
"""Database-backed job manager using Supabase + Redis pub/sub for cross-worker job updates."""
|
||||
import asyncio
|
||||
import json
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from typing import Dict, Any, Optional
|
||||
from loguru import logger
|
||||
|
||||
# BackgroundJob, JobType, JobStatus are defined in background_jobs.py.
|
||||
# We import them here; background_jobs.py imports job_manager from this module at
|
||||
# its bottom — Python handles this circular reference safely because background_jobs.py
|
||||
# defines these symbols *before* it reaches the import-from-db_job_manager line.
|
||||
from src.services.background_jobs import BackgroundJob, JobType, JobStatus
|
||||
|
||||
|
||||
def _parse_ts(value: Optional[str]) -> Optional[datetime]:
|
||||
"""Parse an ISO-8601 / Supabase timestamp string to a timezone-aware datetime."""
|
||||
if not value:
|
||||
return None
|
||||
try:
|
||||
# Supabase returns strings like "2024-01-01T12:00:00+00:00" or ending in "Z"
|
||||
return datetime.fromisoformat(value.replace("Z", "+00:00"))
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
class DBJobManager:
|
||||
"""Manages background jobs backed by the Supabase `background_jobs` table.
|
||||
|
||||
Publishes job-update payloads to Redis channel ``job_updates:{user_id}`` so
|
||||
every worker process can push SSE events to its own connected clients.
|
||||
"""
|
||||
|
||||
def _db(self):
|
||||
"""Lazy DatabaseClient — avoids import at module load time."""
|
||||
from src.database.client import DatabaseClient
|
||||
if not hasattr(self, "_db_client"):
|
||||
self._db_client = DatabaseClient()
|
||||
return self._db_client
|
||||
|
||||
def _row_to_job(self, row: dict) -> BackgroundJob:
|
||||
return BackgroundJob(
|
||||
id=row["id"],
|
||||
job_type=JobType(row["job_type"]),
|
||||
user_id=row["user_id"],
|
||||
status=JobStatus(row["status"]),
|
||||
progress=row.get("progress") or 0,
|
||||
message=row.get("message") or "",
|
||||
error=row.get("error"),
|
||||
result=row.get("result"),
|
||||
created_at=_parse_ts(row.get("created_at")) or datetime.now(timezone.utc),
|
||||
started_at=_parse_ts(row.get("started_at")),
|
||||
completed_at=_parse_ts(row.get("completed_at")),
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def create_job(self, job_type: JobType, user_id: str) -> BackgroundJob:
|
||||
"""Create a new background job row in the database."""
|
||||
db = self._db()
|
||||
try:
|
||||
resp = await asyncio.to_thread(
|
||||
lambda: db.client.table("background_jobs").insert({
|
||||
"job_type": job_type.value,
|
||||
"user_id": user_id,
|
||||
"status": JobStatus.PENDING.value,
|
||||
"progress": 0,
|
||||
"message": "",
|
||||
}).execute()
|
||||
)
|
||||
job = self._row_to_job(resp.data[0])
|
||||
logger.info(f"Created background job {job.id} of type {job_type} for user {user_id}")
|
||||
return job
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create job in DB: {e}")
|
||||
raise
|
||||
|
||||
async def get_job(self, job_id: str) -> Optional[BackgroundJob]:
|
||||
"""Fetch a single job by ID from the database."""
|
||||
db = self._db()
|
||||
try:
|
||||
resp = await asyncio.to_thread(
|
||||
lambda: db.client.table("background_jobs").select("*").eq("id", job_id).execute()
|
||||
)
|
||||
if resp.data:
|
||||
return self._row_to_job(resp.data[0])
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get job {job_id}: {e}")
|
||||
return None
|
||||
|
||||
async def get_user_jobs(self, user_id: str) -> list[BackgroundJob]:
|
||||
"""Fetch the 50 most-recent jobs for a user from the database."""
|
||||
db = self._db()
|
||||
try:
|
||||
resp = await asyncio.to_thread(
|
||||
lambda: db.client.table("background_jobs").select("*")
|
||||
.eq("user_id", user_id)
|
||||
.order("created_at", desc=True)
|
||||
.limit(50)
|
||||
.execute()
|
||||
)
|
||||
return [self._row_to_job(r) for r in resp.data]
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get jobs for user {user_id}: {e}")
|
||||
return []
|
||||
|
||||
async def get_active_jobs(self, user_id: str) -> list[BackgroundJob]:
|
||||
"""Fetch pending/running jobs for a user from the database."""
|
||||
db = self._db()
|
||||
try:
|
||||
resp = await asyncio.to_thread(
|
||||
lambda: db.client.table("background_jobs").select("*")
|
||||
.eq("user_id", user_id)
|
||||
.in_("status", [JobStatus.PENDING.value, JobStatus.RUNNING.value])
|
||||
.order("created_at", desc=True)
|
||||
.execute()
|
||||
)
|
||||
return [self._row_to_job(r) for r in resp.data]
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get active jobs for user {user_id}: {e}")
|
||||
return []
|
||||
|
||||
async def update_job(
|
||||
self,
|
||||
job_id: str,
|
||||
status: Optional[JobStatus] = None,
|
||||
progress: Optional[int] = None,
|
||||
message: Optional[str] = None,
|
||||
error: Optional[str] = None,
|
||||
result: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
"""Update a job row in the database, then publish the new state to Redis."""
|
||||
db = self._db()
|
||||
update_data: dict = {}
|
||||
|
||||
if status is not None:
|
||||
update_data["status"] = status.value
|
||||
if status == JobStatus.RUNNING:
|
||||
update_data["started_at"] = datetime.now(timezone.utc).isoformat()
|
||||
elif status in (JobStatus.COMPLETED, JobStatus.FAILED):
|
||||
update_data["completed_at"] = datetime.now(timezone.utc).isoformat()
|
||||
if progress is not None:
|
||||
update_data["progress"] = progress
|
||||
if message is not None:
|
||||
update_data["message"] = message
|
||||
if error is not None:
|
||||
update_data["error"] = error
|
||||
if result is not None:
|
||||
update_data["result"] = result
|
||||
|
||||
if not update_data:
|
||||
return
|
||||
|
||||
try:
|
||||
resp = await asyncio.to_thread(
|
||||
lambda: db.client.table("background_jobs")
|
||||
.update(update_data)
|
||||
.eq("id", job_id)
|
||||
.execute()
|
||||
)
|
||||
if resp.data:
|
||||
job = self._row_to_job(resp.data[0])
|
||||
await self._publish_update(job)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update job {job_id} in DB: {e}")
|
||||
|
||||
async def cleanup_old_jobs(self, max_age_hours: int = 24):
|
||||
"""Delete completed/failed jobs older than *max_age_hours* from the database."""
|
||||
db = self._db()
|
||||
try:
|
||||
cutoff = (datetime.now(timezone.utc) - timedelta(hours=max_age_hours)).isoformat()
|
||||
await asyncio.to_thread(
|
||||
lambda: db.client.table("background_jobs")
|
||||
.delete()
|
||||
.in_("status", [JobStatus.COMPLETED.value, JobStatus.FAILED.value])
|
||||
.lt("completed_at", cutoff)
|
||||
.execute()
|
||||
)
|
||||
logger.info(f"Cleaned up background jobs older than {max_age_hours}h")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to cleanup old jobs: {e}")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _publish_update(self, job: BackgroundJob):
|
||||
"""Publish a job-update payload to Redis pub/sub so all workers can forward it."""
|
||||
try:
|
||||
from src.services.redis_client import get_redis
|
||||
r = await get_redis()
|
||||
payload = json.dumps({
|
||||
"id": job.id,
|
||||
"job_type": job.job_type.value,
|
||||
"status": job.status.value,
|
||||
"progress": job.progress,
|
||||
"message": job.message,
|
||||
"error": job.error,
|
||||
})
|
||||
await r.publish(f"job_updates:{job.user_id}", payload)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to publish job update to Redis: {e}")
|
||||
|
||||
|
||||
# Global singleton — replaces the old BackgroundJobManager instance
|
||||
job_manager = DBJobManager()
|
||||
27
src/services/redis_client.py
Normal file
27
src/services/redis_client.py
Normal file
@@ -0,0 +1,27 @@
|
||||
"""Single async Redis connection pool for the whole app."""
|
||||
import redis.asyncio as aioredis
|
||||
from loguru import logger
|
||||
|
||||
from src.config import settings
|
||||
|
||||
_pool: aioredis.Redis | None = None
|
||||
|
||||
|
||||
async def get_redis() -> aioredis.Redis:
|
||||
global _pool
|
||||
if _pool is None:
|
||||
_pool = aioredis.from_url(
|
||||
settings.redis_url,
|
||||
encoding="utf-8",
|
||||
decode_responses=True,
|
||||
max_connections=20,
|
||||
)
|
||||
return _pool
|
||||
|
||||
|
||||
async def close_redis():
|
||||
global _pool
|
||||
if _pool:
|
||||
await _pool.aclose()
|
||||
_pool = None
|
||||
logger.info("Redis connection pool closed")
|
||||
42
src/services/scheduler_runner.py
Normal file
42
src/services/scheduler_runner.py
Normal file
@@ -0,0 +1,42 @@
|
||||
"""Standalone scheduler process entry point.
|
||||
|
||||
Run with:
|
||||
python -m src.services.scheduler_runner
|
||||
|
||||
This module intentionally does NOT import the FastAPI app — it only starts the
|
||||
SchedulerService so it can run in its own container without duplicating work in
|
||||
the main web-worker containers (which set SCHEDULER_ENABLED=false).
|
||||
"""
|
||||
import asyncio
|
||||
import signal
|
||||
from loguru import logger
|
||||
|
||||
from src.database.client import DatabaseClient
|
||||
from src.services.scheduler_service import init_scheduler
|
||||
|
||||
|
||||
async def main():
|
||||
db = DatabaseClient()
|
||||
scheduler = init_scheduler(db, check_interval=60)
|
||||
|
||||
stop_event = asyncio.Event()
|
||||
|
||||
def handle_signal():
|
||||
logger.info("Scheduler received shutdown signal — stopping…")
|
||||
stop_event.set()
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
for sig in (signal.SIGINT, signal.SIGTERM):
|
||||
loop.add_signal_handler(sig, handle_signal)
|
||||
|
||||
await scheduler.start()
|
||||
logger.info("Scheduler started (dedicated process)")
|
||||
|
||||
await stop_event.wait()
|
||||
|
||||
await scheduler.stop()
|
||||
logger.info("Scheduler stopped")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -5,6 +5,7 @@ from pathlib import Path
|
||||
from fastapi import FastAPI
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.responses import RedirectResponse
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from loguru import logger
|
||||
|
||||
from src.config import settings
|
||||
@@ -14,35 +15,57 @@ from src.web.admin import admin_router
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Manage application lifecycle - startup and shutdown."""
|
||||
# Startup
|
||||
logger.info("Starting LinkedIn Post Creation System...")
|
||||
|
||||
# Initialize and start scheduler if enabled
|
||||
# Warm up Redis connection pool
|
||||
from src.services.redis_client import get_redis, close_redis
|
||||
await get_redis()
|
||||
|
||||
# Start scheduler only when this process is the dedicated scheduler container
|
||||
scheduler = None
|
||||
if settings.user_frontend_enabled:
|
||||
if settings.scheduler_enabled:
|
||||
try:
|
||||
from src.database.client import DatabaseClient
|
||||
from src.services.scheduler_service import init_scheduler
|
||||
|
||||
db = DatabaseClient()
|
||||
scheduler = init_scheduler(db, check_interval=60) # Check every 60 seconds
|
||||
scheduler = init_scheduler(db, check_interval=60)
|
||||
await scheduler.start()
|
||||
logger.info("Scheduler service started")
|
||||
logger.info("Scheduler started (dedicated process)")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start scheduler: {e}")
|
||||
|
||||
yield # Application runs here
|
||||
|
||||
# Shutdown
|
||||
logger.info("Shutting down LinkedIn Post Creation System...")
|
||||
if scheduler:
|
||||
await scheduler.stop()
|
||||
logger.info("Scheduler service stopped")
|
||||
|
||||
await close_redis()
|
||||
|
||||
|
||||
# Setup
|
||||
app = FastAPI(title="LinkedIn Post Creation System", lifespan=lifespan)
|
||||
|
||||
|
||||
class StaticCacheMiddleware(BaseHTTPMiddleware):
|
||||
"""Set long-lived Cache-Control headers on static assets."""
|
||||
|
||||
async def dispatch(self, request, call_next):
|
||||
response = await call_next(request)
|
||||
if request.url.path.startswith("/static/"):
|
||||
if request.url.path.endswith((".css", ".js")):
|
||||
response.headers["Cache-Control"] = "public, max-age=86400, stale-while-revalidate=3600"
|
||||
elif request.url.path.endswith((".png", ".jpg", ".jpeg", ".svg", ".ico", ".webp")):
|
||||
response.headers["Cache-Control"] = "public, max-age=604800, immutable"
|
||||
else:
|
||||
response.headers["Cache-Control"] = "public, max-age=3600"
|
||||
return response
|
||||
|
||||
|
||||
app.add_middleware(StaticCacheMiddleware)
|
||||
|
||||
# Static files
|
||||
app.mount("/static", StaticFiles(directory=Path(__file__).parent / "static"), name="static")
|
||||
|
||||
|
||||
@@ -32,10 +32,11 @@ from src.services.email_service import (
|
||||
mark_token_used,
|
||||
)
|
||||
from src.services.background_jobs import (
|
||||
job_manager, JobType, JobStatus,
|
||||
JobType, JobStatus,
|
||||
run_post_scraping, run_profile_analysis, run_post_categorization, run_post_type_analysis,
|
||||
run_full_analysis_pipeline, run_post_recategorization
|
||||
)
|
||||
from src.services.db_job_manager import job_manager
|
||||
from src.services.storage_service import storage
|
||||
|
||||
# Router for user frontend
|
||||
@@ -93,6 +94,7 @@ async def get_user_avatar(session: UserSession, user_id: UUID) -> Optional[str]:
|
||||
return None
|
||||
|
||||
|
||||
|
||||
def require_user_session(request: Request) -> Optional[UserSession]:
|
||||
"""Check if user is authenticated, redirect to login if not."""
|
||||
session = get_user_session(request)
|
||||
@@ -676,7 +678,7 @@ async def onboarding_profile_submit(
|
||||
logger.info(f"Skipping scraping - {len(existing_posts)} posts already exist for user {user_id}")
|
||||
|
||||
if should_scrape:
|
||||
job = job_manager.create_job(JobType.POST_SCRAPING, str(user_id))
|
||||
job = await job_manager.create_job(JobType.POST_SCRAPING, str(user_id))
|
||||
background_tasks.add_task(run_post_scraping, user_id, linkedin_url, job.id)
|
||||
logger.info(f"Started background scraping for user {user_id}")
|
||||
|
||||
@@ -829,7 +831,7 @@ async def api_rescrape(request: Request, background_tasks: BackgroundTasks):
|
||||
return JSONResponse({"error": "No LinkedIn URL found"}, status_code=400)
|
||||
|
||||
# Create job and start scraping
|
||||
job = job_manager.create_job(JobType.POST_SCRAPING, session.user_id)
|
||||
job = await job_manager.create_job(JobType.POST_SCRAPING, session.user_id)
|
||||
background_tasks.add_task(run_post_scraping, user_id, profile.linkedin_url, job.id)
|
||||
|
||||
return JSONResponse({"success": True, "job_id": job.id})
|
||||
@@ -1451,53 +1453,45 @@ async def api_categorize_post(request: Request):
|
||||
|
||||
@user_router.get("/api/job-updates")
|
||||
async def job_updates_sse(request: Request):
|
||||
"""Server-Sent Events endpoint for job updates."""
|
||||
"""Server-Sent Events endpoint for job updates (Redis pub/sub — works across workers)."""
|
||||
session = require_user_session(request)
|
||||
tracking_id = getattr(session, 'user_id', None) or getattr(session, 'company_id', None)
|
||||
if not session or not tracking_id:
|
||||
return JSONResponse({"error": "Not authenticated"}, status_code=401)
|
||||
|
||||
async def event_generator():
|
||||
queue = asyncio.Queue()
|
||||
|
||||
async def on_job_update(job):
|
||||
await queue.put(job)
|
||||
|
||||
# Register listener
|
||||
job_manager.add_listener(tracking_id, on_job_update)
|
||||
|
||||
from src.services.redis_client import get_redis
|
||||
r = await get_redis()
|
||||
pubsub = r.pubsub()
|
||||
await pubsub.subscribe(f"job_updates:{tracking_id}")
|
||||
try:
|
||||
# Send initial active jobs
|
||||
active_jobs = job_manager.get_active_jobs(tracking_id)
|
||||
for job in active_jobs:
|
||||
# Send any currently active jobs as the initial state
|
||||
for job in await job_manager.get_active_jobs(tracking_id):
|
||||
data = {
|
||||
"id": job.id,
|
||||
"job_type": job.job_type.value,
|
||||
"status": job.status.value,
|
||||
"progress": job.progress,
|
||||
"message": job.message,
|
||||
"error": job.error
|
||||
"error": job.error,
|
||||
}
|
||||
yield f"data: {json.dumps(data)}\n\n"
|
||||
|
||||
# Stream updates
|
||||
# Stream pub/sub messages, keepalive on timeout
|
||||
while True:
|
||||
try:
|
||||
job = await asyncio.wait_for(queue.get(), timeout=30)
|
||||
data = {
|
||||
"id": job.id,
|
||||
"job_type": job.job_type.value,
|
||||
"status": job.status.value,
|
||||
"progress": job.progress,
|
||||
"message": job.message,
|
||||
"error": job.error
|
||||
}
|
||||
yield f"data: {json.dumps(data)}\n\n"
|
||||
msg = await asyncio.wait_for(
|
||||
pubsub.get_message(ignore_subscribe_messages=True), timeout=30
|
||||
)
|
||||
if msg and msg.get("type") == "message":
|
||||
yield f"data: {msg['data']}\n\n"
|
||||
else:
|
||||
yield ": keepalive\n\n"
|
||||
except asyncio.TimeoutError:
|
||||
# Send keepalive
|
||||
yield ": keepalive\n\n"
|
||||
finally:
|
||||
job_manager.remove_listener(tracking_id, on_job_update)
|
||||
await pubsub.unsubscribe(f"job_updates:{tracking_id}")
|
||||
await pubsub.aclose()
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
@@ -1505,8 +1499,8 @@ async def job_updates_sse(request: Request):
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no"
|
||||
}
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@@ -1521,7 +1515,7 @@ async def api_run_post_type_analysis(request: Request, background_tasks: Backgro
|
||||
user_id = UUID(session.user_id)
|
||||
|
||||
# Create job
|
||||
job = job_manager.create_job(JobType.POST_TYPE_ANALYSIS, session.user_id)
|
||||
job = await job_manager.create_job(JobType.POST_TYPE_ANALYSIS, session.user_id)
|
||||
|
||||
# Run in background
|
||||
background_tasks.add_task(run_post_type_analysis, user_id, job.id)
|
||||
@@ -3278,13 +3272,13 @@ async def save_all_and_reanalyze(request: Request, background_tasks: BackgroundT
|
||||
# Only trigger re-categorization and analysis if there were structural changes
|
||||
if has_structural_changes:
|
||||
# Create background job for post re-categorization (ALL posts)
|
||||
categorization_job = job_manager.create_job(
|
||||
categorization_job = await job_manager.create_job(
|
||||
job_type=JobType.POST_CATEGORIZATION,
|
||||
user_id=user_id_str
|
||||
)
|
||||
|
||||
# Create background job for post type analysis
|
||||
analysis_job = job_manager.create_job(
|
||||
analysis_job = await job_manager.create_job(
|
||||
job_type=JobType.POST_TYPE_ANALYSIS,
|
||||
user_id=user_id_str
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user