added scalability and performance improvements (redis, http caching etc)

This commit is contained in:
2026-02-19 17:19:41 +01:00
parent d8d054c9a8
commit 4b15b552d6
12 changed files with 763 additions and 88 deletions

View File

@@ -0,0 +1,67 @@
version: '3.8'
# Coolify-Variante von docker-compose.simple.yml
#
# Unterschiede zum lokalen File:
# 1. Kein "env_file: .env" — Coolify injiziert alle Env-Vars über sein UI
# 2. Kein "container_name" — Coolify verwaltet die Container-Namen selbst
# 3. Port ohne 127.0.0.1-Bindung — Coolify's Traefik-Proxy übernimmt das Routing
# 4. Named volume für Logs statt Host-Pfad
#
# Coolify-Einrichtung:
# - Resource Type: Docker Compose
# - Compose File: docker-compose.coolify.yml
# - Main Service: linkedin-posts
# - Port: 8001
# - Alle Env-Vars (API Keys, Supabase etc.) im Coolify UI setzen
services:
redis:
image: redis:7-alpine
restart: unless-stopped
command: redis-server --maxmemory 128mb --maxmemory-policy allkeys-lru --save ""
healthcheck:
test: ["CMD", "redis-cli", "ping"]
interval: 10s
timeout: 5s
retries: 3
linkedin-scheduler:
build: .
restart: unless-stopped
command: python -m src.services.scheduler_runner
environment:
- PYTHONPATH=/app
- SCHEDULER_ENABLED=true
- REDIS_URL=redis://redis:6379/0
volumes:
- logs:/app/logs
depends_on:
redis:
condition: service_healthy
linkedin-posts:
build: .
restart: unless-stopped
command: python -m uvicorn src.web.app:app --host 0.0.0.0 --port 8001 --workers 2
ports:
- "8001:8001"
environment:
- PYTHONPATH=/app
- PORT=8001
- SCHEDULER_ENABLED=false
- REDIS_URL=redis://redis:6379/0
volumes:
- logs:/app/logs
depends_on:
redis:
condition: service_healthy
healthcheck:
test: ["CMD", "python", "-c", "import httpx; httpx.get('http://localhost:8001/login', timeout=5)"]
interval: 30s
timeout: 10s
retries: 3
start_period: 15s
volumes:
logs:

View File

@@ -1,22 +1,54 @@
version: '3.8' version: '3.8'
services: services:
redis:
image: redis:7-alpine
container_name: linkedin-redis
restart: unless-stopped
command: redis-server --maxmemory 128mb --maxmemory-policy allkeys-lru --save ""
healthcheck:
test: ["CMD", "redis-cli", "ping"]
interval: 10s
timeout: 5s
retries: 3
linkedin-scheduler:
build: .
container_name: linkedin-scheduler
restart: unless-stopped
command: sh -c "python -m src.services.scheduler_runner"
env_file: .env
environment:
- PYTHONPATH=/app
- SCHEDULER_ENABLED=true
- REDIS_URL=redis://redis:6379/0
volumes:
- ./logs:/app/logs
depends_on:
redis:
condition: service_healthy
linkedin-posts: linkedin-posts:
build: . build: .
container_name: linkedin-posts container_name: linkedin-posts
restart: unless-stopped restart: unless-stopped
command: sh -c "python -m uvicorn src.web.app:app --host 0.0.0.0 --port ${PORT:-8001} --workers 2"
ports: ports:
- "127.0.0.1:8001:8001" # Nur lokal erreichbar - "127.0.0.1:8001:8001"
env_file: env_file: .env
- .env
environment: environment:
- PYTHONPATH=/app - PYTHONPATH=/app
- PORT=8001 - PORT=8001
- SCHEDULER_ENABLED=false
- REDIS_URL=redis://redis:6379/0
volumes: volumes:
- ./logs:/app/logs - ./logs:/app/logs
depends_on:
redis:
condition: service_healthy
healthcheck: healthcheck:
test: ["CMD", "python", "-c", "import httpx; httpx.get('http://localhost:8001/login', timeout=5)"] test: ["CMD", "python", "-c", "import httpx; httpx.get('http://localhost:8001/login', timeout=5)"]
interval: 30s interval: 30s
timeout: 10s timeout: 10s
retries: 3 retries: 3
start_period: 10s start_period: 15s

View File

@@ -32,6 +32,9 @@ setuptools>=65.0.0
# PDF Generation # PDF Generation
reportlab>=4.0.0 reportlab>=4.0.0
# Redis (async client for caching + pub/sub)
redis==5.2.1
# Web Frontend # Web Frontend
fastapi==0.115.0 fastapi==0.115.0
uvicorn==0.32.0 uvicorn==0.32.0

View File

@@ -65,6 +65,10 @@ class Settings(BaseSettings):
moco_api_key: str = "" # Token für Authorization-Header moco_api_key: str = "" # Token für Authorization-Header
moco_domain: str = "" # Subdomain: {domain}.mocoapp.com moco_domain: str = "" # Subdomain: {domain}.mocoapp.com
# Redis
redis_url: str = "redis://redis:6379/0"
scheduler_enabled: bool = False # True only on dedicated scheduler container
model_config = SettingsConfigDict( model_config = SettingsConfigDict(
env_file=".env", env_file=".env",
env_file_encoding="utf-8", env_file_encoding="utf-8",

View File

@@ -13,6 +13,12 @@ from src.database.models import (
User, Profile, Company, Invitation, ExamplePost, ReferenceProfile, User, Profile, Company, Invitation, ExamplePost, ReferenceProfile,
ApiUsageLog, LicenseKey, CompanyDailyQuota, LicenseKeyOffer ApiUsageLog, LicenseKey, CompanyDailyQuota, LicenseKeyOffer
) )
from src.services.cache_service import (
cache,
PROFILE_TTL, USER_TTL, LINKEDIN_ACCOUNT_TTL, PROFILE_ANALYSIS_TTL,
COMPANY_TTL, LINKEDIN_POSTS_TTL, POST_TYPES_TTL, POST_TYPE_TTL,
GEN_POST_TTL, GEN_POSTS_TTL,
)
class DatabaseClient: class DatabaseClient:
@@ -111,16 +117,26 @@ class DatabaseClient:
).execute() ).execute()
) )
logger.info(f"Saved {len(result.data)} LinkedIn posts") logger.info(f"Saved {len(result.data)} LinkedIn posts")
return [LinkedInPost(**item) for item in result.data] saved = [LinkedInPost(**item) for item in result.data]
# Invalidate cache for all affected users
affected_user_ids = {str(p.user_id) for p in saved}
for uid in affected_user_ids:
await cache.invalidate_linkedin_posts(uid)
return saved
async def get_linkedin_posts(self, user_id: UUID) -> List[LinkedInPost]: async def get_linkedin_posts(self, user_id: UUID) -> List[LinkedInPost]:
"""Get all LinkedIn posts for user.""" """Get all LinkedIn posts for user."""
key = cache.linkedin_posts_key(str(user_id))
if (hit := await cache.get(key)) is not None:
return [LinkedInPost(**item) for item in hit]
result = await asyncio.to_thread( result = await asyncio.to_thread(
lambda: self.client.table("linkedin_posts").select("*").eq( lambda: self.client.table("linkedin_posts").select("*").eq(
"user_id", str(user_id) "user_id", str(user_id)
).order("post_date", desc=True).execute() ).order("post_date", desc=True).execute()
) )
return [LinkedInPost(**item) for item in result.data] posts = [LinkedInPost(**item) for item in result.data]
await cache.set(key, [p.model_dump(mode="json") for p in posts], LINKEDIN_POSTS_TTL)
return posts
async def copy_posts_to_user(self, source_user_id: UUID, target_user_id: UUID) -> List[LinkedInPost]: async def copy_posts_to_user(self, source_user_id: UUID, target_user_id: UUID) -> List[LinkedInPost]:
"""Copy all LinkedIn posts from one user to another.""" """Copy all LinkedIn posts from one user to another."""
@@ -148,9 +164,16 @@ class DatabaseClient:
async def delete_linkedin_post(self, post_id: UUID) -> None: async def delete_linkedin_post(self, post_id: UUID) -> None:
"""Delete a LinkedIn post.""" """Delete a LinkedIn post."""
# Fetch user_id before deleting so we can invalidate the cache
lookup = await asyncio.to_thread(
lambda: self.client.table("linkedin_posts").select("user_id").eq("id", str(post_id)).execute()
)
user_id = lookup.data[0]["user_id"] if lookup.data else None
await asyncio.to_thread( await asyncio.to_thread(
lambda: self.client.table("linkedin_posts").delete().eq("id", str(post_id)).execute() lambda: self.client.table("linkedin_posts").delete().eq("id", str(post_id)).execute()
) )
if user_id:
await cache.invalidate_linkedin_posts(user_id)
logger.info(f"Deleted LinkedIn post: {post_id}") logger.info(f"Deleted LinkedIn post: {post_id}")
async def get_unclassified_posts(self, user_id: UUID) -> List[LinkedInPost]: async def get_unclassified_posts(self, user_id: UUID) -> List[LinkedInPost]:
@@ -195,9 +218,11 @@ class DatabaseClient:
if "user_id" in updates and updates["user_id"]: if "user_id" in updates and updates["user_id"]:
updates["user_id"] = str(updates["user_id"]) updates["user_id"] = str(updates["user_id"])
await asyncio.to_thread( result = await asyncio.to_thread(
lambda: self.client.table("linkedin_posts").update(updates).eq("id", str(post_id)).execute() lambda: self.client.table("linkedin_posts").update(updates).eq("id", str(post_id)).execute()
) )
if result.data:
await cache.invalidate_linkedin_posts(result.data[0]["user_id"])
logger.debug(f"Updated LinkedIn post {post_id}") logger.debug(f"Updated LinkedIn post {post_id}")
async def update_posts_classification_bulk( async def update_posts_classification_bulk(
@@ -233,7 +258,9 @@ class DatabaseClient:
lambda: self.client.table("post_types").insert(data).execute() lambda: self.client.table("post_types").insert(data).execute()
) )
logger.info(f"Created post type: {result.data[0]['name']}") logger.info(f"Created post type: {result.data[0]['name']}")
return PostType(**result.data[0]) created = PostType(**result.data[0])
await cache.invalidate_post_types(str(created.user_id))
return created
async def create_post_types_bulk(self, post_types: List[PostType]) -> List[PostType]: async def create_post_types_bulk(self, post_types: List[PostType]) -> List[PostType]:
"""Create multiple post types at once.""" """Create multiple post types at once."""
@@ -251,10 +278,18 @@ class DatabaseClient:
lambda: self.client.table("post_types").insert(data).execute() lambda: self.client.table("post_types").insert(data).execute()
) )
logger.info(f"Created {len(result.data)} post types") logger.info(f"Created {len(result.data)} post types")
return [PostType(**item) for item in result.data] created = [PostType(**item) for item in result.data]
affected_user_ids = {str(pt.user_id) for pt in created}
for uid in affected_user_ids:
await cache.invalidate_post_types(uid)
return created
async def get_post_types(self, user_id: UUID, active_only: bool = True) -> List[PostType]: async def get_post_types(self, user_id: UUID, active_only: bool = True) -> List[PostType]:
"""Get all post types for a user.""" """Get all post types for a user."""
key = cache.post_types_key(str(user_id), active_only)
if (hit := await cache.get(key)) is not None:
return [PostType(**item) for item in hit]
def _query(): def _query():
query = self.client.table("post_types").select("*").eq("user_id", str(user_id)) query = self.client.table("post_types").select("*").eq("user_id", str(user_id))
if active_only: if active_only:
@@ -262,7 +297,9 @@ class DatabaseClient:
return query.order("name").execute() return query.order("name").execute()
result = await asyncio.to_thread(_query) result = await asyncio.to_thread(_query)
return [PostType(**item) for item in result.data] post_types = [PostType(**item) for item in result.data]
await cache.set(key, [pt.model_dump(mode="json") for pt in post_types], POST_TYPES_TTL)
return post_types
# Alias for get_post_types # Alias for get_post_types
async def get_customer_post_types(self, user_id: UUID, active_only: bool = True) -> List[PostType]: async def get_customer_post_types(self, user_id: UUID, active_only: bool = True) -> List[PostType]:
@@ -271,13 +308,18 @@ class DatabaseClient:
async def get_post_type(self, post_type_id: UUID) -> Optional[PostType]: async def get_post_type(self, post_type_id: UUID) -> Optional[PostType]:
"""Get a single post type by ID.""" """Get a single post type by ID."""
key = cache.post_type_key(str(post_type_id))
if (hit := await cache.get(key)) is not None:
return PostType(**hit)
result = await asyncio.to_thread( result = await asyncio.to_thread(
lambda: self.client.table("post_types").select("*").eq( lambda: self.client.table("post_types").select("*").eq(
"id", str(post_type_id) "id", str(post_type_id)
).execute() ).execute()
) )
if result.data: if result.data:
return PostType(**result.data[0]) pt = PostType(**result.data[0])
await cache.set(key, pt.model_dump(mode="json"), POST_TYPE_TTL)
return pt
return None return None
async def update_post_type(self, post_type_id: UUID, updates: Dict[str, Any]) -> PostType: async def update_post_type(self, post_type_id: UUID, updates: Dict[str, Any]) -> PostType:
@@ -288,7 +330,9 @@ class DatabaseClient:
).execute() ).execute()
) )
logger.info(f"Updated post type: {post_type_id}") logger.info(f"Updated post type: {post_type_id}")
return PostType(**result.data[0]) pt = PostType(**result.data[0])
await cache.invalidate_post_type(str(post_type_id), str(pt.user_id))
return pt
async def update_post_type_analysis( async def update_post_type_analysis(
self, self,
@@ -306,23 +350,36 @@ class DatabaseClient:
}).eq("id", str(post_type_id)).execute() }).eq("id", str(post_type_id)).execute()
) )
logger.info(f"Updated analysis for post type: {post_type_id}") logger.info(f"Updated analysis for post type: {post_type_id}")
return PostType(**result.data[0]) pt = PostType(**result.data[0])
await cache.invalidate_post_type(str(post_type_id), str(pt.user_id))
return pt
async def delete_post_type(self, post_type_id: UUID, soft: bool = True) -> None: async def delete_post_type(self, post_type_id: UUID, soft: bool = True) -> None:
"""Delete a post type (soft delete by default).""" """Delete a post type (soft delete by default)."""
if soft: if soft:
await asyncio.to_thread( result = await asyncio.to_thread(
lambda: self.client.table("post_types").update({ lambda: self.client.table("post_types").update({
"is_active": False "is_active": False
}).eq("id", str(post_type_id)).execute() }).eq("id", str(post_type_id)).execute()
) )
logger.info(f"Soft deleted post type: {post_type_id}") logger.info(f"Soft deleted post type: {post_type_id}")
if result.data:
await cache.invalidate_post_type(str(post_type_id), result.data[0]["user_id"])
else: else:
# Fetch user_id before hard delete for cache invalidation
lookup = await asyncio.to_thread(
lambda: self.client.table("post_types").select("user_id").eq(
"id", str(post_type_id)
).execute()
)
user_id = lookup.data[0]["user_id"] if lookup.data else None
await asyncio.to_thread( await asyncio.to_thread(
lambda: self.client.table("post_types").delete().eq( lambda: self.client.table("post_types").delete().eq(
"id", str(post_type_id) "id", str(post_type_id)
).execute() ).execute()
) )
if user_id:
await cache.invalidate_post_type(str(post_type_id), user_id)
logger.info(f"Hard deleted post type: {post_type_id}") logger.info(f"Hard deleted post type: {post_type_id}")
# ==================== TOPICS ==================== # ==================== TOPICS ====================
@@ -418,17 +475,24 @@ class DatabaseClient:
) )
logger.info(f"Saved profile analysis for user: {analysis.user_id}") logger.info(f"Saved profile analysis for user: {analysis.user_id}")
return ProfileAnalysis(**result.data[0]) saved = ProfileAnalysis(**result.data[0])
await cache.delete(cache.profile_analysis_key(str(analysis.user_id)))
return saved
async def get_profile_analysis(self, user_id: UUID) -> Optional[ProfileAnalysis]: async def get_profile_analysis(self, user_id: UUID) -> Optional[ProfileAnalysis]:
"""Get profile analysis for user.""" """Get profile analysis for user."""
key = cache.profile_analysis_key(str(user_id))
if (hit := await cache.get(key)) is not None:
return ProfileAnalysis(**hit)
result = await asyncio.to_thread( result = await asyncio.to_thread(
lambda: self.client.table("profile_analyses").select("*").eq( lambda: self.client.table("profile_analyses").select("*").eq(
"user_id", str(user_id) "user_id", str(user_id)
).execute() ).execute()
) )
if result.data: if result.data:
return ProfileAnalysis(**result.data[0]) pa = ProfileAnalysis(**result.data[0])
await cache.set(key, pa.model_dump(mode="json"), PROFILE_ANALYSIS_TTL)
return pa
return None return None
# ==================== RESEARCH RESULTS ==================== # ==================== RESEARCH RESULTS ====================
@@ -491,7 +555,9 @@ class DatabaseClient:
lambda: self.client.table("generated_posts").insert(data).execute() lambda: self.client.table("generated_posts").insert(data).execute()
) )
logger.info(f"Saved generated post: {result.data[0]['id']}") logger.info(f"Saved generated post: {result.data[0]['id']}")
return GeneratedPost(**result.data[0]) saved = GeneratedPost(**result.data[0])
await cache.invalidate_gen_posts(str(saved.user_id))
return saved
async def update_generated_post(self, post_id: UUID, updates: Dict[str, Any]) -> GeneratedPost: async def update_generated_post(self, post_id: UUID, updates: Dict[str, Any]) -> GeneratedPost:
"""Update generated post.""" """Update generated post."""
@@ -504,11 +570,8 @@ class DatabaseClient:
# Handle metadata dict - ensure all nested datetime values are serialized # Handle metadata dict - ensure all nested datetime values are serialized
if 'metadata' in updates and isinstance(updates['metadata'], dict): if 'metadata' in updates and isinstance(updates['metadata'], dict):
serialized_metadata = {} serialized_metadata = {}
for key, value in updates['metadata'].items(): for k, value in updates['metadata'].items():
if isinstance(value, datetime): serialized_metadata[k] = value.isoformat() if isinstance(value, datetime) else value
serialized_metadata[key] = value.isoformat()
else:
serialized_metadata[key] = value
updates['metadata'] = serialized_metadata updates['metadata'] = serialized_metadata
result = await asyncio.to_thread( result = await asyncio.to_thread(
@@ -517,26 +580,38 @@ class DatabaseClient:
).execute() ).execute()
) )
logger.info(f"Updated generated post: {post_id}") logger.info(f"Updated generated post: {post_id}")
return GeneratedPost(**result.data[0]) updated = GeneratedPost(**result.data[0])
await cache.invalidate_gen_post(str(post_id), str(updated.user_id))
return updated
async def get_generated_posts(self, user_id: UUID) -> List[GeneratedPost]: async def get_generated_posts(self, user_id: UUID) -> List[GeneratedPost]:
"""Get all generated posts for user.""" """Get all generated posts for user."""
key = cache.gen_posts_key(str(user_id))
if (hit := await cache.get(key)) is not None:
return [GeneratedPost(**item) for item in hit]
result = await asyncio.to_thread( result = await asyncio.to_thread(
lambda: self.client.table("generated_posts").select("*").eq( lambda: self.client.table("generated_posts").select("*").eq(
"user_id", str(user_id) "user_id", str(user_id)
).order("created_at", desc=True).execute() ).order("created_at", desc=True).execute()
) )
return [GeneratedPost(**item) for item in result.data] posts = [GeneratedPost(**item) for item in result.data]
await cache.set(key, [p.model_dump(mode="json") for p in posts], GEN_POSTS_TTL)
return posts
async def get_generated_post(self, post_id: UUID) -> Optional[GeneratedPost]: async def get_generated_post(self, post_id: UUID) -> Optional[GeneratedPost]:
"""Get a single generated post by ID.""" """Get a single generated post by ID."""
key = cache.gen_post_key(str(post_id))
if (hit := await cache.get(key)) is not None:
return GeneratedPost(**hit)
result = await asyncio.to_thread( result = await asyncio.to_thread(
lambda: self.client.table("generated_posts").select("*").eq( lambda: self.client.table("generated_posts").select("*").eq(
"id", str(post_id) "id", str(post_id)
).execute() ).execute()
) )
if result.data: if result.data:
return GeneratedPost(**result.data[0]) post = GeneratedPost(**result.data[0])
await cache.set(key, post.model_dump(mode="json"), GEN_POST_TTL)
return post
return None return None
async def get_scheduled_posts_due(self) -> List[GeneratedPost]: async def get_scheduled_posts_due(self) -> List[GeneratedPost]:
@@ -627,11 +702,16 @@ class DatabaseClient:
async def get_profile(self, user_id: UUID) -> Optional[Profile]: async def get_profile(self, user_id: UUID) -> Optional[Profile]:
"""Get profile by user ID.""" """Get profile by user ID."""
key = cache.profile_key(str(user_id))
if (hit := await cache.get(key)) is not None:
return Profile(**hit)
result = await asyncio.to_thread( result = await asyncio.to_thread(
lambda: self.client.table("profiles").select("*").eq("id", str(user_id)).execute() lambda: self.client.table("profiles").select("*").eq("id", str(user_id)).execute()
) )
if result.data: if result.data:
return Profile(**result.data[0]) profile = Profile(**result.data[0])
await cache.set(key, profile.model_dump(mode="json"), PROFILE_TTL)
return profile
return None return None
async def get_profiles_by_linkedin_url(self, linkedin_url: str) -> List[Profile]: async def get_profiles_by_linkedin_url(self, linkedin_url: str) -> List[Profile]:
@@ -645,9 +725,9 @@ class DatabaseClient:
"""Update profile fields.""" """Update profile fields."""
if "company_id" in updates and updates["company_id"]: if "company_id" in updates and updates["company_id"]:
updates["company_id"] = str(updates["company_id"]) updates["company_id"] = str(updates["company_id"])
for key in ["account_type", "onboarding_status"]: for k in ["account_type", "onboarding_status"]:
if key in updates and hasattr(updates[key], "value"): if k in updates and hasattr(updates[k], "value"):
updates[key] = updates[key].value updates[k] = updates[k].value
result = await asyncio.to_thread( result = await asyncio.to_thread(
lambda: self.client.table("profiles").update(updates).eq( lambda: self.client.table("profiles").update(updates).eq(
@@ -655,19 +735,26 @@ class DatabaseClient:
).execute() ).execute()
) )
logger.info(f"Updated profile: {user_id}") logger.info(f"Updated profile: {user_id}")
return Profile(**result.data[0]) profile = Profile(**result.data[0])
await cache.invalidate_profile(str(user_id))
return profile
# ==================== LINKEDIN ACCOUNTS ==================== # ==================== LINKEDIN ACCOUNTS ====================
async def get_linkedin_account(self, user_id: UUID) -> Optional['LinkedInAccount']: async def get_linkedin_account(self, user_id: UUID) -> Optional['LinkedInAccount']:
"""Get LinkedIn account for user.""" """Get LinkedIn account for user."""
from src.database.models import LinkedInAccount from src.database.models import LinkedInAccount
key = cache.linkedin_account_key(str(user_id))
if (hit := await cache.get(key)) is not None:
return LinkedInAccount(**hit)
result = await asyncio.to_thread( result = await asyncio.to_thread(
lambda: self.client.table("linkedin_accounts").select("*") lambda: self.client.table("linkedin_accounts").select("*")
.eq("user_id", str(user_id)).eq("is_active", True).execute() .eq("user_id", str(user_id)).eq("is_active", True).execute()
) )
if result.data: if result.data:
return LinkedInAccount(**result.data[0]) account = LinkedInAccount(**result.data[0])
await cache.set(key, account.model_dump(mode="json"), LINKEDIN_ACCOUNT_TTL)
return account
return None return None
async def get_linkedin_account_by_id(self, account_id: UUID) -> Optional['LinkedInAccount']: async def get_linkedin_account_by_id(self, account_id: UUID) -> Optional['LinkedInAccount']:
@@ -692,7 +779,9 @@ class DatabaseClient:
lambda: self.client.table("linkedin_accounts").insert(data).execute() lambda: self.client.table("linkedin_accounts").insert(data).execute()
) )
logger.info(f"Created LinkedIn account for user: {account.user_id}") logger.info(f"Created LinkedIn account for user: {account.user_id}")
return LinkedInAccount(**result.data[0]) created = LinkedInAccount(**result.data[0])
await cache.invalidate_linkedin_account(str(created.user_id))
return created
async def update_linkedin_account(self, account_id: UUID, updates: Dict) -> 'LinkedInAccount': async def update_linkedin_account(self, account_id: UUID, updates: Dict) -> 'LinkedInAccount':
"""Update LinkedIn account.""" """Update LinkedIn account."""
@@ -708,25 +797,40 @@ class DatabaseClient:
.eq("id", str(account_id)).execute() .eq("id", str(account_id)).execute()
) )
logger.info(f"Updated LinkedIn account: {account_id}") logger.info(f"Updated LinkedIn account: {account_id}")
return LinkedInAccount(**result.data[0]) updated = LinkedInAccount(**result.data[0])
await cache.invalidate_linkedin_account(str(updated.user_id))
return updated
async def delete_linkedin_account(self, account_id: UUID) -> None: async def delete_linkedin_account(self, account_id: UUID) -> None:
"""Delete LinkedIn account connection.""" """Delete LinkedIn account connection."""
# Fetch user_id before delete for cache invalidation
lookup = await asyncio.to_thread(
lambda: self.client.table("linkedin_accounts").select("user_id")
.eq("id", str(account_id)).execute()
)
user_id = lookup.data[0]["user_id"] if lookup.data else None
await asyncio.to_thread( await asyncio.to_thread(
lambda: self.client.table("linkedin_accounts").delete() lambda: self.client.table("linkedin_accounts").delete()
.eq("id", str(account_id)).execute() .eq("id", str(account_id)).execute()
) )
if user_id:
await cache.invalidate_linkedin_account(user_id)
logger.info(f"Deleted LinkedIn account: {account_id}") logger.info(f"Deleted LinkedIn account: {account_id}")
# ==================== USERS ==================== # ==================== USERS ====================
async def get_user(self, user_id: UUID) -> Optional[User]: async def get_user(self, user_id: UUID) -> Optional[User]:
"""Get user by ID (from users view).""" """Get user by ID (from users view)."""
key = cache.user_key(str(user_id))
if (hit := await cache.get(key)) is not None:
return User(**hit)
result = await asyncio.to_thread( result = await asyncio.to_thread(
lambda: self.client.table("users").select("*").eq("id", str(user_id)).execute() lambda: self.client.table("users").select("*").eq("id", str(user_id)).execute()
) )
if result.data: if result.data:
return User(**result.data[0]) user = User(**result.data[0])
await cache.set(key, user.model_dump(mode="json"), USER_TTL)
return user
return None return None
async def get_user_by_email(self, email: str) -> Optional[User]: async def get_user_by_email(self, email: str) -> Optional[User]:
@@ -757,7 +861,10 @@ class DatabaseClient:
if profile_updates: if profile_updates:
await self.update_profile(user_id, profile_updates) await self.update_profile(user_id, profile_updates)
# update_profile already calls cache.invalidate_profile which also kills user_key
# Invalidate user view separately (in case it wasn't covered above)
await cache.delete(cache.user_key(str(user_id)))
return await self.get_user(user_id) return await self.get_user(user_id)
async def list_users(self, account_type: Optional[str] = None, company_id: Optional[UUID] = None) -> List[User]: async def list_users(self, account_type: Optional[str] = None, company_id: Optional[UUID] = None) -> List[User]:
@@ -838,11 +945,16 @@ class DatabaseClient:
async def get_company(self, company_id: UUID) -> Optional[Company]: async def get_company(self, company_id: UUID) -> Optional[Company]:
"""Get company by ID.""" """Get company by ID."""
key = cache.company_key(str(company_id))
if (hit := await cache.get(key)) is not None:
return Company(**hit)
result = await asyncio.to_thread( result = await asyncio.to_thread(
lambda: self.client.table("companies").select("*").eq("id", str(company_id)).execute() lambda: self.client.table("companies").select("*").eq("id", str(company_id)).execute()
) )
if result.data: if result.data:
return Company(**result.data[0]) company = Company(**result.data[0])
await cache.set(key, company.model_dump(mode="json"), COMPANY_TTL)
return company
return None return None
async def get_company_by_owner(self, owner_user_id: UUID) -> Optional[Company]: async def get_company_by_owner(self, owner_user_id: UUID) -> Optional[Company]:
@@ -871,7 +983,9 @@ class DatabaseClient:
).execute() ).execute()
) )
logger.info(f"Updated company: {company_id}") logger.info(f"Updated company: {company_id}")
return Company(**result.data[0]) company = Company(**result.data[0])
await cache.invalidate_company(str(company_id))
return company
async def list_companies(self) -> List[Company]: async def list_companies(self) -> List[Company]:
"""List all companies.""" """List all companies."""

View File

@@ -156,8 +156,10 @@ class BackgroundJobManager:
logger.info(f"Cleaned up {len(to_remove)} old background jobs") logger.info(f"Cleaned up {len(to_remove)} old background jobs")
# Global instance # Global instance — backed by Supabase DB + Redis pub/sub for multi-worker safety.
job_manager = BackgroundJobManager() # db_job_manager imports BackgroundJob/JobType/JobStatus from this module, so
# this import must stay at the bottom to avoid a circular-import issue.
from src.services.db_job_manager import job_manager # noqa: F401
async def run_post_scraping(user_id: UUID, linkedin_url: str, job_id: str): async def run_post_scraping(user_id: UUID, linkedin_url: str, job_id: str):
@@ -385,6 +387,11 @@ async def run_post_categorization(user_id: UUID, job_id: str):
message=f"{len(classifications)} Posts kategorisiert!" message=f"{len(classifications)} Posts kategorisiert!"
) )
# Invalidate cached LinkedIn posts — classifications changed but bulk update
# doesn't have user_id per-row, so we invalidate explicitly here.
from src.services.cache_service import cache as _cache
await _cache.invalidate_linkedin_posts(str(user_id))
logger.info(f"Post categorization completed for user {user_id}: {len(classifications)} posts") logger.info(f"Post categorization completed for user {user_id}: {len(classifications)} posts")
except Exception as e: except Exception as e:
@@ -480,6 +487,9 @@ async def run_post_recategorization(user_id: UUID, job_id: str):
message=f"{len(classifications)} Posts re-kategorisiert!" message=f"{len(classifications)} Posts re-kategorisiert!"
) )
from src.services.cache_service import cache as _cache
await _cache.invalidate_linkedin_posts(str(user_id))
logger.info(f"Post re-categorization completed for user {user_id}: {len(classifications)} posts") logger.info(f"Post re-categorization completed for user {user_id}: {len(classifications)} posts")
except Exception as e: except Exception as e:
@@ -556,21 +566,15 @@ async def run_full_analysis_pipeline(user_id: UUID):
logger.info(f"Starting full analysis pipeline for user {user_id}") logger.info(f"Starting full analysis pipeline for user {user_id}")
# 1. Profile Analysis # 1. Profile Analysis
job1 = job_manager.create_job(JobType.PROFILE_ANALYSIS, str(user_id)) job1 = await job_manager.create_job(JobType.PROFILE_ANALYSIS, str(user_id))
await run_profile_analysis(user_id, job1.id) await run_profile_analysis(user_id, job1.id)
if job1.status == JobStatus.FAILED: # 2. Post Categorization (always continue regardless of previous step outcome)
logger.warning(f"Profile analysis failed, continuing with categorization") job2 = await job_manager.create_job(JobType.POST_CATEGORIZATION, str(user_id))
# 2. Post Categorization
job2 = job_manager.create_job(JobType.POST_CATEGORIZATION, str(user_id))
await run_post_categorization(user_id, job2.id) await run_post_categorization(user_id, job2.id)
if job2.status == JobStatus.FAILED:
logger.warning(f"Post categorization failed, continuing with post type analysis")
# 3. Post Type Analysis # 3. Post Type Analysis
job3 = job_manager.create_job(JobType.POST_TYPE_ANALYSIS, str(user_id)) job3 = await job_manager.create_job(JobType.POST_TYPE_ANALYSIS, str(user_id))
await run_post_type_analysis(user_id, job3.id) await run_post_type_analysis(user_id, job3.id)
logger.info(f"Full analysis pipeline completed for user {user_id}") logger.info(f"Full analysis pipeline completed for user {user_id}")

View File

@@ -0,0 +1,157 @@
"""Typed cache helpers backed by Redis.
All failures are silent (logged as warnings) so Redis being down never causes an outage.
Key design:
profile:{user_id} — Profile row
user:{user_id} — User view row
linkedin_account:{user_id} — Active LinkedInAccount row
profile_analysis:{user_id} — ProfileAnalysis row
company:{company_id} — Company row
linkedin_posts:{user_id} — List[LinkedInPost] (scraped reference posts)
post_types:{user_id}:1 or :0 — List[PostType] (active_only=True/False)
post_type:{post_type_id} — Single PostType row
gen_post:{post_id} — Single GeneratedPost row
gen_posts:{user_id} — List[GeneratedPost] for user
"""
import json
from typing import Any, Optional
from loguru import logger
# TTL constants (seconds)
PROFILE_TTL = 300 # 5 min — updated on settings / onboarding changes
USER_TTL = 300 # 5 min
LINKEDIN_ACCOUNT_TTL = 300 # 5 min — updated only on OAuth connect/disconnect
PROFILE_ANALYSIS_TTL = 600 # 10 min — computed infrequently by background job
COMPANY_TTL = 300 # 5 min — company settings
LINKEDIN_POSTS_TTL = 600 # 10 min — scraped reference data, rarely changes
POST_TYPES_TTL = 600 # 10 min — strategy config, rarely changes
POST_TYPE_TTL = 600 # 10 min
GEN_POST_TTL = 120 # 2 min — status/content changes frequently
GEN_POSTS_TTL = 120 # 2 min
class CacheService:
"""Redis-backed cache with typed key helpers and silent failure semantics."""
# ------------------------------------------------------------------
# Key helpers
# ------------------------------------------------------------------
def profile_key(self, user_id: str) -> str:
return f"profile:{user_id}"
def user_key(self, user_id: str) -> str:
return f"user:{user_id}"
def linkedin_account_key(self, user_id: str) -> str:
return f"linkedin_account:{user_id}"
def profile_analysis_key(self, user_id: str) -> str:
return f"profile_analysis:{user_id}"
def company_key(self, company_id: str) -> str:
return f"company:{company_id}"
def linkedin_posts_key(self, user_id: str) -> str:
return f"linkedin_posts:{user_id}"
def post_types_key(self, user_id: str, active_only: bool = True) -> str:
return f"post_types:{user_id}:{'1' if active_only else '0'}"
def post_type_key(self, post_type_id: str) -> str:
return f"post_type:{post_type_id}"
def gen_post_key(self, post_id: str) -> str:
return f"gen_post:{post_id}"
def gen_posts_key(self, user_id: str) -> str:
return f"gen_posts:{user_id}"
# ------------------------------------------------------------------
# Core operations
# ------------------------------------------------------------------
async def get(self, key: str) -> Optional[Any]:
try:
from src.services.redis_client import get_redis
r = await get_redis()
value = await r.get(key)
if value is not None:
return json.loads(value)
return None
except Exception as e:
logger.warning(f"Cache get failed for {key}: {e}")
return None
async def set(self, key: str, value: Any, ttl: int):
try:
from src.services.redis_client import get_redis
r = await get_redis()
await r.setex(key, ttl, json.dumps(value, default=str))
except Exception as e:
logger.warning(f"Cache set failed for {key}: {e}")
async def delete(self, *keys: str):
if not keys:
return
try:
from src.services.redis_client import get_redis
r = await get_redis()
await r.delete(*keys)
except Exception as e:
logger.warning(f"Cache delete failed for {keys}: {e}")
# ------------------------------------------------------------------
# Compound invalidation helpers
# ------------------------------------------------------------------
async def invalidate_profile(self, user_id: str):
"""Invalidate profile, user view, profile_analysis, and linkedin_account."""
await self.delete(
self.profile_key(user_id),
self.user_key(user_id),
self.profile_analysis_key(user_id),
self.linkedin_account_key(user_id),
)
async def invalidate_company(self, company_id: str):
await self.delete(self.company_key(company_id))
async def invalidate_linkedin_posts(self, user_id: str):
"""Invalidate the scraped LinkedIn posts list for a user."""
await self.delete(self.linkedin_posts_key(user_id))
async def invalidate_post_types(self, user_id: str):
"""Invalidate both active_only variants of the post types list."""
await self.delete(
self.post_types_key(user_id, True),
self.post_types_key(user_id, False),
)
async def invalidate_post_type(self, post_type_id: str, user_id: str):
"""Invalidate a single PostType row + both list variants."""
await self.delete(
self.post_type_key(post_type_id),
self.post_types_key(user_id, True),
self.post_types_key(user_id, False),
)
async def invalidate_gen_post(self, post_id: str, user_id: str):
"""Invalidate a single GeneratedPost + the user's post list."""
await self.delete(
self.gen_post_key(post_id),
self.gen_posts_key(user_id),
)
async def invalidate_gen_posts(self, user_id: str):
"""Invalidate only the user's GeneratedPost list (not a single entry)."""
await self.delete(self.gen_posts_key(user_id))
async def invalidate_linkedin_account(self, user_id: str):
"""Invalidate the LinkedIn account cache for a user."""
await self.delete(self.linkedin_account_key(user_id))
# Global singleton
cache = CacheService()

View File

@@ -0,0 +1,208 @@
"""Database-backed job manager using Supabase + Redis pub/sub for cross-worker job updates."""
import asyncio
import json
from datetime import datetime, timezone, timedelta
from typing import Dict, Any, Optional
from loguru import logger
# BackgroundJob, JobType, JobStatus are defined in background_jobs.py.
# We import them here; background_jobs.py imports job_manager from this module at
# its bottom — Python handles this circular reference safely because background_jobs.py
# defines these symbols *before* it reaches the import-from-db_job_manager line.
from src.services.background_jobs import BackgroundJob, JobType, JobStatus
def _parse_ts(value: Optional[str]) -> Optional[datetime]:
"""Parse an ISO-8601 / Supabase timestamp string to a timezone-aware datetime."""
if not value:
return None
try:
# Supabase returns strings like "2024-01-01T12:00:00+00:00" or ending in "Z"
return datetime.fromisoformat(value.replace("Z", "+00:00"))
except ValueError:
return None
class DBJobManager:
"""Manages background jobs backed by the Supabase `background_jobs` table.
Publishes job-update payloads to Redis channel ``job_updates:{user_id}`` so
every worker process can push SSE events to its own connected clients.
"""
def _db(self):
"""Lazy DatabaseClient — avoids import at module load time."""
from src.database.client import DatabaseClient
if not hasattr(self, "_db_client"):
self._db_client = DatabaseClient()
return self._db_client
def _row_to_job(self, row: dict) -> BackgroundJob:
return BackgroundJob(
id=row["id"],
job_type=JobType(row["job_type"]),
user_id=row["user_id"],
status=JobStatus(row["status"]),
progress=row.get("progress") or 0,
message=row.get("message") or "",
error=row.get("error"),
result=row.get("result"),
created_at=_parse_ts(row.get("created_at")) or datetime.now(timezone.utc),
started_at=_parse_ts(row.get("started_at")),
completed_at=_parse_ts(row.get("completed_at")),
)
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
async def create_job(self, job_type: JobType, user_id: str) -> BackgroundJob:
"""Create a new background job row in the database."""
db = self._db()
try:
resp = await asyncio.to_thread(
lambda: db.client.table("background_jobs").insert({
"job_type": job_type.value,
"user_id": user_id,
"status": JobStatus.PENDING.value,
"progress": 0,
"message": "",
}).execute()
)
job = self._row_to_job(resp.data[0])
logger.info(f"Created background job {job.id} of type {job_type} for user {user_id}")
return job
except Exception as e:
logger.error(f"Failed to create job in DB: {e}")
raise
async def get_job(self, job_id: str) -> Optional[BackgroundJob]:
"""Fetch a single job by ID from the database."""
db = self._db()
try:
resp = await asyncio.to_thread(
lambda: db.client.table("background_jobs").select("*").eq("id", job_id).execute()
)
if resp.data:
return self._row_to_job(resp.data[0])
return None
except Exception as e:
logger.warning(f"Failed to get job {job_id}: {e}")
return None
async def get_user_jobs(self, user_id: str) -> list[BackgroundJob]:
"""Fetch the 50 most-recent jobs for a user from the database."""
db = self._db()
try:
resp = await asyncio.to_thread(
lambda: db.client.table("background_jobs").select("*")
.eq("user_id", user_id)
.order("created_at", desc=True)
.limit(50)
.execute()
)
return [self._row_to_job(r) for r in resp.data]
except Exception as e:
logger.warning(f"Failed to get jobs for user {user_id}: {e}")
return []
async def get_active_jobs(self, user_id: str) -> list[BackgroundJob]:
"""Fetch pending/running jobs for a user from the database."""
db = self._db()
try:
resp = await asyncio.to_thread(
lambda: db.client.table("background_jobs").select("*")
.eq("user_id", user_id)
.in_("status", [JobStatus.PENDING.value, JobStatus.RUNNING.value])
.order("created_at", desc=True)
.execute()
)
return [self._row_to_job(r) for r in resp.data]
except Exception as e:
logger.warning(f"Failed to get active jobs for user {user_id}: {e}")
return []
async def update_job(
self,
job_id: str,
status: Optional[JobStatus] = None,
progress: Optional[int] = None,
message: Optional[str] = None,
error: Optional[str] = None,
result: Optional[Dict[str, Any]] = None,
):
"""Update a job row in the database, then publish the new state to Redis."""
db = self._db()
update_data: dict = {}
if status is not None:
update_data["status"] = status.value
if status == JobStatus.RUNNING:
update_data["started_at"] = datetime.now(timezone.utc).isoformat()
elif status in (JobStatus.COMPLETED, JobStatus.FAILED):
update_data["completed_at"] = datetime.now(timezone.utc).isoformat()
if progress is not None:
update_data["progress"] = progress
if message is not None:
update_data["message"] = message
if error is not None:
update_data["error"] = error
if result is not None:
update_data["result"] = result
if not update_data:
return
try:
resp = await asyncio.to_thread(
lambda: db.client.table("background_jobs")
.update(update_data)
.eq("id", job_id)
.execute()
)
if resp.data:
job = self._row_to_job(resp.data[0])
await self._publish_update(job)
except Exception as e:
logger.error(f"Failed to update job {job_id} in DB: {e}")
async def cleanup_old_jobs(self, max_age_hours: int = 24):
"""Delete completed/failed jobs older than *max_age_hours* from the database."""
db = self._db()
try:
cutoff = (datetime.now(timezone.utc) - timedelta(hours=max_age_hours)).isoformat()
await asyncio.to_thread(
lambda: db.client.table("background_jobs")
.delete()
.in_("status", [JobStatus.COMPLETED.value, JobStatus.FAILED.value])
.lt("completed_at", cutoff)
.execute()
)
logger.info(f"Cleaned up background jobs older than {max_age_hours}h")
except Exception as e:
logger.warning(f"Failed to cleanup old jobs: {e}")
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
async def _publish_update(self, job: BackgroundJob):
"""Publish a job-update payload to Redis pub/sub so all workers can forward it."""
try:
from src.services.redis_client import get_redis
r = await get_redis()
payload = json.dumps({
"id": job.id,
"job_type": job.job_type.value,
"status": job.status.value,
"progress": job.progress,
"message": job.message,
"error": job.error,
})
await r.publish(f"job_updates:{job.user_id}", payload)
except Exception as e:
logger.warning(f"Failed to publish job update to Redis: {e}")
# Global singleton — replaces the old BackgroundJobManager instance
job_manager = DBJobManager()

View File

@@ -0,0 +1,27 @@
"""Single async Redis connection pool for the whole app."""
import redis.asyncio as aioredis
from loguru import logger
from src.config import settings
_pool: aioredis.Redis | None = None
async def get_redis() -> aioredis.Redis:
global _pool
if _pool is None:
_pool = aioredis.from_url(
settings.redis_url,
encoding="utf-8",
decode_responses=True,
max_connections=20,
)
return _pool
async def close_redis():
global _pool
if _pool:
await _pool.aclose()
_pool = None
logger.info("Redis connection pool closed")

View File

@@ -0,0 +1,42 @@
"""Standalone scheduler process entry point.
Run with:
python -m src.services.scheduler_runner
This module intentionally does NOT import the FastAPI app — it only starts the
SchedulerService so it can run in its own container without duplicating work in
the main web-worker containers (which set SCHEDULER_ENABLED=false).
"""
import asyncio
import signal
from loguru import logger
from src.database.client import DatabaseClient
from src.services.scheduler_service import init_scheduler
async def main():
db = DatabaseClient()
scheduler = init_scheduler(db, check_interval=60)
stop_event = asyncio.Event()
def handle_signal():
logger.info("Scheduler received shutdown signal — stopping…")
stop_event.set()
loop = asyncio.get_running_loop()
for sig in (signal.SIGINT, signal.SIGTERM):
loop.add_signal_handler(sig, handle_signal)
await scheduler.start()
logger.info("Scheduler started (dedicated process)")
await stop_event.wait()
await scheduler.stop()
logger.info("Scheduler stopped")
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -5,6 +5,7 @@ from pathlib import Path
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from fastapi.responses import RedirectResponse from fastapi.responses import RedirectResponse
from starlette.middleware.base import BaseHTTPMiddleware
from loguru import logger from loguru import logger
from src.config import settings from src.config import settings
@@ -14,35 +15,57 @@ from src.web.admin import admin_router
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
"""Manage application lifecycle - startup and shutdown.""" """Manage application lifecycle - startup and shutdown."""
# Startup
logger.info("Starting LinkedIn Post Creation System...") logger.info("Starting LinkedIn Post Creation System...")
# Initialize and start scheduler if enabled # Warm up Redis connection pool
from src.services.redis_client import get_redis, close_redis
await get_redis()
# Start scheduler only when this process is the dedicated scheduler container
scheduler = None scheduler = None
if settings.user_frontend_enabled: if settings.scheduler_enabled:
try: try:
from src.database.client import DatabaseClient from src.database.client import DatabaseClient
from src.services.scheduler_service import init_scheduler from src.services.scheduler_service import init_scheduler
db = DatabaseClient() db = DatabaseClient()
scheduler = init_scheduler(db, check_interval=60) # Check every 60 seconds scheduler = init_scheduler(db, check_interval=60)
await scheduler.start() await scheduler.start()
logger.info("Scheduler service started") logger.info("Scheduler started (dedicated process)")
except Exception as e: except Exception as e:
logger.error(f"Failed to start scheduler: {e}") logger.error(f"Failed to start scheduler: {e}")
yield # Application runs here yield # Application runs here
# Shutdown
logger.info("Shutting down LinkedIn Post Creation System...") logger.info("Shutting down LinkedIn Post Creation System...")
if scheduler: if scheduler:
await scheduler.stop() await scheduler.stop()
logger.info("Scheduler service stopped") logger.info("Scheduler service stopped")
await close_redis()
# Setup # Setup
app = FastAPI(title="LinkedIn Post Creation System", lifespan=lifespan) app = FastAPI(title="LinkedIn Post Creation System", lifespan=lifespan)
class StaticCacheMiddleware(BaseHTTPMiddleware):
"""Set long-lived Cache-Control headers on static assets."""
async def dispatch(self, request, call_next):
response = await call_next(request)
if request.url.path.startswith("/static/"):
if request.url.path.endswith((".css", ".js")):
response.headers["Cache-Control"] = "public, max-age=86400, stale-while-revalidate=3600"
elif request.url.path.endswith((".png", ".jpg", ".jpeg", ".svg", ".ico", ".webp")):
response.headers["Cache-Control"] = "public, max-age=604800, immutable"
else:
response.headers["Cache-Control"] = "public, max-age=3600"
return response
app.add_middleware(StaticCacheMiddleware)
# Static files # Static files
app.mount("/static", StaticFiles(directory=Path(__file__).parent / "static"), name="static") app.mount("/static", StaticFiles(directory=Path(__file__).parent / "static"), name="static")

View File

@@ -32,10 +32,11 @@ from src.services.email_service import (
mark_token_used, mark_token_used,
) )
from src.services.background_jobs import ( from src.services.background_jobs import (
job_manager, JobType, JobStatus, JobType, JobStatus,
run_post_scraping, run_profile_analysis, run_post_categorization, run_post_type_analysis, run_post_scraping, run_profile_analysis, run_post_categorization, run_post_type_analysis,
run_full_analysis_pipeline, run_post_recategorization run_full_analysis_pipeline, run_post_recategorization
) )
from src.services.db_job_manager import job_manager
from src.services.storage_service import storage from src.services.storage_service import storage
# Router for user frontend # Router for user frontend
@@ -93,6 +94,7 @@ async def get_user_avatar(session: UserSession, user_id: UUID) -> Optional[str]:
return None return None
def require_user_session(request: Request) -> Optional[UserSession]: def require_user_session(request: Request) -> Optional[UserSession]:
"""Check if user is authenticated, redirect to login if not.""" """Check if user is authenticated, redirect to login if not."""
session = get_user_session(request) session = get_user_session(request)
@@ -676,7 +678,7 @@ async def onboarding_profile_submit(
logger.info(f"Skipping scraping - {len(existing_posts)} posts already exist for user {user_id}") logger.info(f"Skipping scraping - {len(existing_posts)} posts already exist for user {user_id}")
if should_scrape: if should_scrape:
job = job_manager.create_job(JobType.POST_SCRAPING, str(user_id)) job = await job_manager.create_job(JobType.POST_SCRAPING, str(user_id))
background_tasks.add_task(run_post_scraping, user_id, linkedin_url, job.id) background_tasks.add_task(run_post_scraping, user_id, linkedin_url, job.id)
logger.info(f"Started background scraping for user {user_id}") logger.info(f"Started background scraping for user {user_id}")
@@ -829,7 +831,7 @@ async def api_rescrape(request: Request, background_tasks: BackgroundTasks):
return JSONResponse({"error": "No LinkedIn URL found"}, status_code=400) return JSONResponse({"error": "No LinkedIn URL found"}, status_code=400)
# Create job and start scraping # Create job and start scraping
job = job_manager.create_job(JobType.POST_SCRAPING, session.user_id) job = await job_manager.create_job(JobType.POST_SCRAPING, session.user_id)
background_tasks.add_task(run_post_scraping, user_id, profile.linkedin_url, job.id) background_tasks.add_task(run_post_scraping, user_id, profile.linkedin_url, job.id)
return JSONResponse({"success": True, "job_id": job.id}) return JSONResponse({"success": True, "job_id": job.id})
@@ -1451,53 +1453,45 @@ async def api_categorize_post(request: Request):
@user_router.get("/api/job-updates") @user_router.get("/api/job-updates")
async def job_updates_sse(request: Request): async def job_updates_sse(request: Request):
"""Server-Sent Events endpoint for job updates.""" """Server-Sent Events endpoint for job updates (Redis pub/sub — works across workers)."""
session = require_user_session(request) session = require_user_session(request)
tracking_id = getattr(session, 'user_id', None) or getattr(session, 'company_id', None) tracking_id = getattr(session, 'user_id', None) or getattr(session, 'company_id', None)
if not session or not tracking_id: if not session or not tracking_id:
return JSONResponse({"error": "Not authenticated"}, status_code=401) return JSONResponse({"error": "Not authenticated"}, status_code=401)
async def event_generator(): async def event_generator():
queue = asyncio.Queue() from src.services.redis_client import get_redis
r = await get_redis()
async def on_job_update(job): pubsub = r.pubsub()
await queue.put(job) await pubsub.subscribe(f"job_updates:{tracking_id}")
# Register listener
job_manager.add_listener(tracking_id, on_job_update)
try: try:
# Send initial active jobs # Send any currently active jobs as the initial state
active_jobs = job_manager.get_active_jobs(tracking_id) for job in await job_manager.get_active_jobs(tracking_id):
for job in active_jobs:
data = { data = {
"id": job.id, "id": job.id,
"job_type": job.job_type.value, "job_type": job.job_type.value,
"status": job.status.value, "status": job.status.value,
"progress": job.progress, "progress": job.progress,
"message": job.message, "message": job.message,
"error": job.error "error": job.error,
} }
yield f"data: {json.dumps(data)}\n\n" yield f"data: {json.dumps(data)}\n\n"
# Stream updates # Stream pub/sub messages, keepalive on timeout
while True: while True:
try: try:
job = await asyncio.wait_for(queue.get(), timeout=30) msg = await asyncio.wait_for(
data = { pubsub.get_message(ignore_subscribe_messages=True), timeout=30
"id": job.id, )
"job_type": job.job_type.value, if msg and msg.get("type") == "message":
"status": job.status.value, yield f"data: {msg['data']}\n\n"
"progress": job.progress, else:
"message": job.message, yield ": keepalive\n\n"
"error": job.error
}
yield f"data: {json.dumps(data)}\n\n"
except asyncio.TimeoutError: except asyncio.TimeoutError:
# Send keepalive
yield ": keepalive\n\n" yield ": keepalive\n\n"
finally: finally:
job_manager.remove_listener(tracking_id, on_job_update) await pubsub.unsubscribe(f"job_updates:{tracking_id}")
await pubsub.aclose()
return StreamingResponse( return StreamingResponse(
event_generator(), event_generator(),
@@ -1505,8 +1499,8 @@ async def job_updates_sse(request: Request):
headers={ headers={
"Cache-Control": "no-cache", "Cache-Control": "no-cache",
"Connection": "keep-alive", "Connection": "keep-alive",
"X-Accel-Buffering": "no" "X-Accel-Buffering": "no",
} },
) )
@@ -1521,7 +1515,7 @@ async def api_run_post_type_analysis(request: Request, background_tasks: Backgro
user_id = UUID(session.user_id) user_id = UUID(session.user_id)
# Create job # Create job
job = job_manager.create_job(JobType.POST_TYPE_ANALYSIS, session.user_id) job = await job_manager.create_job(JobType.POST_TYPE_ANALYSIS, session.user_id)
# Run in background # Run in background
background_tasks.add_task(run_post_type_analysis, user_id, job.id) background_tasks.add_task(run_post_type_analysis, user_id, job.id)
@@ -3278,13 +3272,13 @@ async def save_all_and_reanalyze(request: Request, background_tasks: BackgroundT
# Only trigger re-categorization and analysis if there were structural changes # Only trigger re-categorization and analysis if there were structural changes
if has_structural_changes: if has_structural_changes:
# Create background job for post re-categorization (ALL posts) # Create background job for post re-categorization (ALL posts)
categorization_job = job_manager.create_job( categorization_job = await job_manager.create_job(
job_type=JobType.POST_CATEGORIZATION, job_type=JobType.POST_CATEGORIZATION,
user_id=user_id_str user_id=user_id_str
) )
# Create background job for post type analysis # Create background job for post type analysis
analysis_job = job_manager.create_job( analysis_job = await job_manager.create_job(
job_type=JobType.POST_TYPE_ANALYSIS, job_type=JobType.POST_TYPE_ANALYSIS,
user_id=user_id_str user_id=user_id_str
) )