diff --git a/app/api/endpoints/catalogs.py b/app/api/endpoints/catalogs.py index 34761d4..15f262a 100644 --- a/app/api/endpoints/catalogs.py +++ b/app/api/endpoints/catalogs.py @@ -11,7 +11,7 @@ from app.services.token_store import token_store MAX_RESULTS = 50 -SOURCE_ITEMS_LIMIT = 15 +SOURCE_ITEMS_LIMIT = 10 router = APIRouter() @@ -54,8 +54,14 @@ async def get_catalog(type: str, id: str, response: Response, token: str): # Create services with credentials stremio_service = StremioService(auth_key=credentials.get("authKey")) + # Fetch library once per request and reuse across recommendation paths + library_items = await stremio_service.get_library_items() recommendation_service = RecommendationService( - stremio_service=stremio_service, language=language, user_settings=user_settings + stremio_service=stremio_service, + language=language, + user_settings=user_settings, + token=token, + library_data=library_items, ) # Handle item-based recommendations @@ -83,7 +89,9 @@ async def get_catalog(type: str, id: str, response: Response, token: str): logger.info(f"Returning {len(recommendations)} items for {type}") # Cache catalog responses for 4 hours - response.headers["Cache-Control"] = "public, max-age=14400" if len(recommendations) > 0 else "no-cache" + response.headers["Cache-Control"] = ( + "public, max-age=14400" if len(recommendations) > 0 else "public, max-age=7200" + ) return {"metas": recommendations} except HTTPException: diff --git a/app/api/endpoints/manifest.py b/app/api/endpoints/manifest.py index 9c89489..bf1c7bf 100644 --- a/app/api/endpoints/manifest.py +++ b/app/api/endpoints/manifest.py @@ -71,7 +71,7 @@ async def fetch_catalogs(token: str): # Note: get_library_items is expensive, but we need it to determine *which* genre catalogs to show. library_items = await stremio_service.get_library_items() - dynamic_catalog_service = DynamicCatalogService(stremio_service=stremio_service) + dynamic_catalog_service = DynamicCatalogService(stremio_service=stremio_service, language=user_settings.language) # Base catalogs are already in manifest, these are *extra* dynamic ones # Pass user_settings to filter/rename @@ -96,7 +96,7 @@ def get_config_id(catalog) -> str | None: async def _manifest_handler(response: Response, token: str): - response.headers["Cache-Control"] = "no-cache" + response.headers["Cache-Control"] = "public, max-age=7200" if not token: raise HTTPException(status_code=401, detail="Missing token. Please reconfigure the addon.") diff --git a/app/api/endpoints/meta.py b/app/api/endpoints/meta.py index 3387638..ea5579e 100644 --- a/app/api/endpoints/meta.py +++ b/app/api/endpoints/meta.py @@ -1,19 +1,25 @@ +from async_lru import alru_cache from fastapi import APIRouter, HTTPException from loguru import logger -from app.services.tmdb_service import TMDBService +from app.services.tmdb_service import get_tmdb_service router = APIRouter() +@alru_cache(maxsize=1, ttl=24 * 60 * 60) +async def _cached_languages(): + tmdb = get_tmdb_service() + return await tmdb._make_request("/configuration/languages") + + @router.get("/api/languages") async def get_languages(): """ Proxy endpoint to fetch languages from TMDB. """ - tmdb_service = TMDBService() try: - languages = await tmdb_service._make_request("/configuration/languages") + languages = await _cached_languages() if not languages: return [] return languages @@ -21,4 +27,5 @@ async def get_languages(): logger.error(f"Failed to fetch languages: {e}") raise HTTPException(status_code=502, detail="Failed to fetch languages from TMDB") finally: - await tmdb_service.close() + # shared client: no explicit close + pass diff --git a/app/core/app.py b/app/core/app.py index bdff5f7..7885fd3 100644 --- a/app/core/app.py +++ b/app/core/app.py @@ -3,7 +3,7 @@ from contextlib import asynccontextmanager from pathlib import Path -from fastapi import FastAPI +from fastapi import FastAPI, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import HTMLResponse from fastapi.staticfiles import StaticFiles @@ -11,6 +11,7 @@ from app.api.main import api_router from app.services.catalog_updater import BackgroundCatalogUpdater +from app.services.token_store import token_store from app.startup.migration import migrate_tokens from .config import settings @@ -82,6 +83,23 @@ def _on_done(t: asyncio.Task): allow_headers=["*"], ) + +# Middleware to track per-request Redis calls and attach as response header for diagnostics +@app.middleware("http") +async def redis_calls_middleware(request: Request, call_next): + try: + token_store.reset_call_counter() + except Exception: + pass + response = await call_next(request) + try: + count = token_store.get_call_count() + response.headers["X-Redis-Calls"] = str(count) + except Exception: + pass + return response + + # Serve static files # Static directory is at project root (3 levels up from app/core/app.py) # app/core/app.py -> app/core -> app -> root diff --git a/app/core/settings.py b/app/core/settings.py index 6a8807e..b246215 100644 --- a/app/core/settings.py +++ b/app/core/settings.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel +from pydantic import BaseModel, Field class CatalogConfig(BaseModel): @@ -11,8 +11,8 @@ class UserSettings(BaseModel): catalogs: list[CatalogConfig] language: str = "en-US" rpdb_key: str | None = None - excluded_movie_genres: list[str] = [] - excluded_series_genres: list[str] = [] + excluded_movie_genres: list[str] = Field(default_factory=list) + excluded_series_genres: list[str] = Field(default_factory=list) def get_default_settings() -> UserSettings: diff --git a/app/core/version.py b/app/core/version.py index 5c4105c..6849410 100644 --- a/app/core/version.py +++ b/app/core/version.py @@ -1 +1 @@ -__version__ = "1.0.1" +__version__ = "1.1.0" diff --git a/app/models/profile.py b/app/models/profile.py index a96f6cb..11d5d33 100644 --- a/app/models/profile.py +++ b/app/models/profile.py @@ -2,21 +2,6 @@ class SparseVector(BaseModel): - """ - Represents a sparse vector where keys are feature IDs and values are weights. - For countries, keys can be string codes (hashed or mapped to int if strictly int keys needed, - but let's check if we can use str keys or if we stick to int. - Original SparseVector uses `dict[int, float]`. - TMDB country codes are strings (e.g. "US"). - We can either map them to ints or change the model to support str keys. - Let's update the model to support string keys for versatility, or keep int and hash strings. - However, for Pydantic and JSON, string keys are native. - Let's change keys to string/int union or just strings (since ints are valid dict keys too). - Actually, since `genres` IDs are ints, let's allow both or specific types. - For simplicity, let's stick to `dict[str, float]` since JSON keys are strings anyway. - But wait, existing code uses ints for IDs. - Let's make a separate StringSparseVector or just genericize it. - """ values: dict[int, float] = Field(default_factory=dict) @@ -67,6 +52,8 @@ class UserTasteProfile(BaseModel): crew: SparseVector = Field(default_factory=SparseVector) years: SparseVector = Field(default_factory=SparseVector) countries: StringSparseVector = Field(default_factory=StringSparseVector) + # Free-text/topic tokens from titles/overviews/keyword names + topics: StringSparseVector = Field(default_factory=StringSparseVector) def normalize_all(self): """Normalize all component vectors.""" @@ -76,6 +63,7 @@ def normalize_all(self): self.crew.normalize() self.years.normalize() self.countries.normalize() + self.topics.normalize() def get_top_genres(self, limit: int = 3) -> list[tuple[int, float]]: return self.genres.get_top_features(limit) diff --git a/app/services/catalog.py b/app/services/catalog.py index 17dbc8e..7525f5c 100644 --- a/app/services/catalog.py +++ b/app/services/catalog.py @@ -4,7 +4,7 @@ from app.services.row_generator import RowGeneratorService from app.services.scoring import ScoringService from app.services.stremio_service import StremioService -from app.services.tmdb_service import TMDBService +from app.services.tmdb_service import get_tmdb_service from app.services.user_profile import UserProfileService @@ -13,11 +13,11 @@ class DynamicCatalogService: Generates dynamic catalog rows based on user library and preferences. """ - def __init__(self, stremio_service: StremioService): + def __init__(self, stremio_service: StremioService, language: str = "en-US"): self.stremio_service = stremio_service - self.tmdb_service = TMDBService() + self.tmdb_service = get_tmdb_service(language=language) self.scoring_service = ScoringService() - self.user_profile_service = UserProfileService() + self.user_profile_service = UserProfileService(language=language) self.row_generator = RowGeneratorService(tmdb_service=self.tmdb_service) @staticmethod diff --git a/app/services/catalog_updater.py b/app/services/catalog_updater.py index 53a5db3..0a1af73 100644 --- a/app/services/catalog_updater.py +++ b/app/services/catalog_updater.py @@ -38,9 +38,6 @@ async def refresh_catalogs_for_credentials(token: str, credentials: dict[str, An logger.exception(f"[{redact_token(token)}] Failed to check if addon is installed: {e}") try: - library_items = await stremio_service.get_library_items() - dynamic_catalog_service = DynamicCatalogService(stremio_service=stremio_service) - # Ensure user_settings is available user_settings = get_default_settings() if credentials.get("settings"): @@ -49,7 +46,12 @@ async def refresh_catalogs_for_credentials(token: str, credentials: dict[str, An except Exception as e: user_settings = get_default_settings() logger.warning(f"[{redact_token(token)}] Failed to parse user settings from credentials: {e}") - + # force fresh library for background refresh + library_items = await stremio_service.get_library_items(use_cache=False) + dynamic_catalog_service = DynamicCatalogService( + stremio_service=stremio_service, + language=(user_settings.language if user_settings else "en-US"), + ) catalogs = await dynamic_catalog_service.get_dynamic_catalogs( library_items=library_items, user_settings=user_settings ) diff --git a/app/services/discovery.py b/app/services/discovery.py index b52d485..6f0d09c 100644 --- a/app/services/discovery.py +++ b/app/services/discovery.py @@ -1,7 +1,7 @@ import asyncio from app.models.profile import UserTasteProfile -from app.services.tmdb_service import TMDBService +from app.services.tmdb_service import get_tmdb_service class DiscoveryEngine: @@ -10,8 +10,8 @@ class DiscoveryEngine: Uses TMDB Discovery API with weighted query parameters derived from the user profile. """ - def __init__(self): - self.tmdb_service = TMDBService() + def __init__(self, language: str = "en-US"): + self.tmdb_service = get_tmdb_service(language=language) # Limit concurrent discovery calls to avoid rate limiting self._sem = asyncio.Semaphore(10) @@ -21,6 +21,13 @@ async def discover_recommendations( content_type: str, limit: int = 20, excluded_genres: list[int] | None = None, + *, + use_genres: bool = False, + use_keywords: bool = True, + use_cast: bool = True, + use_director: bool = True, + use_countries: bool = False, + use_year: bool = False, ) -> list[dict]: """ Find content that matches the user's taste profile. @@ -31,15 +38,15 @@ async def discover_recommendations( 4. Return the combined candidate set (B). """ # 1. Extract Top Features - top_genres = profile.get_top_genres(limit=3) # e.g. [(28, 1.0), (878, 0.8)] - top_keywords = profile.get_top_keywords(limit=3) # e.g. [(123, 0.9)] + top_genres = profile.get_top_genres(limit=3) if use_genres else [] # e.g. [(28, 1.0), (878, 0.8)] + top_keywords = profile.get_top_keywords(limit=3) if use_keywords else [] # e.g. [(123, 0.9)] # Need to add get_top_cast to UserTasteProfile model first, assuming it exists or using profile.cast directly # Based on previous step, profile.cast exists. - top_cast = profile.cast.get_top_features(limit=2) - top_crew = profile.get_top_crew(limit=1) # e.g. [(555, 1.0)] - Director + top_cast = profile.cast.get_top_features(limit=2) if use_cast else [] + top_crew = profile.get_top_crew(limit=1) if use_director else [] # e.g. [(555, 1.0)] - Director - top_countries = profile.get_top_countries(limit=2) - top_year = profile.get_top_year(limit=1) + top_countries = profile.get_top_countries(limit=2) if use_countries else [] + top_year = profile.get_top_year(limit=1) if use_year else [] if not top_genres and not top_keywords and not top_cast: # Fallback if profile is empty @@ -50,7 +57,7 @@ async def discover_recommendations( if excluded_genres: base_params["without_genres"] = "|".join([str(g) for g in excluded_genres]) - # Query 1: Top Genres Mix + # Phase 1: build first-page tasks only if top_genres: genre_ids = "|".join([str(g[0]) for g in top_genres]) params_popular = { @@ -60,17 +67,13 @@ async def discover_recommendations( **base_params, } tasks.append(self._fetch_discovery(content_type, params_popular)) - - # fetch atleast two pages of results - for i in range(2): - params_rating = { - "with_genres": genre_ids, - "sort_by": "vote_average.desc", - "vote_count.gte": 500, - "page": i + 1, - **base_params, - } - tasks.append(self._fetch_discovery(content_type, params_rating)) + params_rating = { + "with_genres": genre_ids, + "sort_by": "vote_average.desc", + "vote_count.gte": 500, + **base_params, + } + tasks.append(self._fetch_discovery(content_type, params_rating)) # Query 2: Top Keywords if top_keywords: @@ -83,16 +86,15 @@ async def discover_recommendations( } tasks.append(self._fetch_discovery(content_type, params_keywords)) - # fetch atleast two pages of results - for i in range(3): - params_rating = { + for page in range(1, 3): + params_rating_kw = { "with_keywords": keyword_ids, "sort_by": "vote_average.desc", "vote_count.gte": 500, - "page": i + 1, + "page": page, **base_params, } - tasks.append(self._fetch_discovery(content_type, params_rating)) + tasks.append(self._fetch_discovery(content_type, params_rating_kw)) # Query 3: Top Actors for actor in top_cast: @@ -104,14 +106,13 @@ async def discover_recommendations( **base_params, } tasks.append(self._fetch_discovery(content_type, params_actor)) - - params_rating = { - "with_cast": str(actor_id), - "sort_by": "vote_average.desc", - "vote_count.gte": 500, - **base_params, - } - tasks.append(self._fetch_discovery(content_type, params_rating)) + # params_rating = { + # "with_cast": str(actor_id), + # "sort_by": "vote_average.desc", + # "vote_count.gte": 500, + # **base_params, + # } + # tasks.append(self._fetch_discovery(content_type, params_rating)) # Query 4: Top Director if top_crew: @@ -124,14 +125,6 @@ async def discover_recommendations( } tasks.append(self._fetch_discovery(content_type, params_director)) - params_rating = { - "with_crew": str(director_id), - "sort_by": "vote_average.desc", - "vote_count.gte": 500, - **base_params, - } - tasks.append(self._fetch_discovery(content_type, params_rating)) - # Query 5: Top Countries if top_countries: country_ids = "|".join([str(c[0]) for c in top_countries]) @@ -142,14 +135,13 @@ async def discover_recommendations( **base_params, } tasks.append(self._fetch_discovery(content_type, params_country)) - - params_rating = { - "with_origin_country": country_ids, - "sort_by": "vote_average.desc", - "vote_count.gte": 300, - **base_params, - } - tasks.append(self._fetch_discovery(content_type, params_rating)) + # params_rating = { + # "with_origin_country": country_ids, + # "sort_by": "vote_average.desc", + # "vote_count.gte": 300, + # **base_params, + # } + # tasks.append(self._fetch_discovery(content_type, params_rating)) # query 6: Top year if top_year: @@ -166,7 +158,7 @@ async def discover_recommendations( } tasks.append(self._fetch_discovery(content_type, params_rating)) - # 3. Execute Parallel Queries + # 3. Execute Phase 1 results_batches = await asyncio.gather(*tasks, return_exceptions=True) # 4. Aggregate and Deduplicate @@ -178,6 +170,61 @@ async def discover_recommendations( if item["id"] not in all_candidates: all_candidates[item["id"]] = item + # Conditional Phase 2: fetch page 2 if pool is thin + if len(all_candidates) < 120: + tasks2 = [] + if top_genres: + genre_ids = "|".join([str(g[0]) for g in top_genres]) + tasks2.append( + self._fetch_discovery( + content_type, + { + "with_genres": genre_ids, + "sort_by": "vote_average.desc", + "vote_count.gte": 400, + "page": 2, + **base_params, + }, + ) + ) + if top_keywords: + keyword_ids = "|".join([str(k[0]) for k in top_keywords]) + tasks2.append( + self._fetch_discovery( + content_type, + { + "with_keywords": keyword_ids, + "sort_by": "vote_average.desc", + "vote_count.gte": 400, + "page": 2, + **base_params, + }, + ) + ) + for actor in top_cast[:1]: + actor_id = actor[0] + tasks2.append( + self._fetch_discovery( + content_type, + { + "with_cast": str(actor_id), + "sort_by": "vote_average.desc", + "vote_count.gte": 400, + "page": 2, + **base_params, + }, + ) + ) + + if tasks2: + results_batches2 = await asyncio.gather(*tasks2, return_exceptions=True) + for batch in results_batches2: + if isinstance(batch, Exception) or not batch: + continue + for item in batch: + if item["id"] not in all_candidates: + all_candidates[item["id"]] = item + return list(all_candidates.values()) async def _fetch_discovery(self, media_type: str, params: dict) -> list[dict]: diff --git a/app/services/recommendation_service.py b/app/services/recommendation_service.py index a11170b..e032de6 100644 --- a/app/services/recommendation_service.py +++ b/app/services/recommendation_service.py @@ -1,5 +1,6 @@ import asyncio -import random +import hashlib +import math from urllib.parse import unquote from loguru import logger @@ -9,17 +10,11 @@ from app.services.rpdb import RPDBService from app.services.scoring import ScoringService from app.services.stremio_service import StremioService -from app.services.tmdb_service import TMDBService -from app.services.user_profile import UserProfileService +from app.services.tmdb_service import get_tmdb_service +from app.services.user_profile import TOP_GENRE_WHITELIST_LIMIT, UserProfileService - -def normalize(value, min_v=0, max_v=10): - """ - Normalize popularity / rating when blending. - """ - if max_v == min_v: - return 0 - return (value - min_v) / (max_v - min_v) +# Diversification: cap per-genre share in final results (e.g., 0.4 => max 40% per genre) +PER_GENRE_MAX_SHARE = 0.4 def _parse_identifier(identifier: str) -> tuple[str | None, int | None]: @@ -59,32 +54,165 @@ def __init__( stremio_service: StremioService | None = None, language: str = "en-US", user_settings: UserSettings | None = None, + token: str | None = None, + library_data: dict | None = None, ): if stremio_service is None: raise ValueError("StremioService instance is required for personalized recommendations") - self.tmdb_service = TMDBService(language=language) + self.tmdb_service = get_tmdb_service(language=language) self.stremio_service = stremio_service self.scoring_service = ScoringService() - self.user_profile_service = UserProfileService() - self.discovery_engine = DiscoveryEngine() + self.user_profile_service = UserProfileService(language=language) + self.discovery_engine = DiscoveryEngine(language=language) self.per_item_limit = 20 self.user_settings = user_settings + # Stable seed for tie-breaking and per-token caching + self.stable_seed = token or "" + # Optional pre-fetched library payload (reuse within the request) + self._library_data: dict | None = library_data + # cache: content_type -> set of top genre IDs + self._whitelist_cache: dict[str, set[int]] = {} + + def _stable_epsilon(self, tmdb_id: int) -> float: + if not self.stable_seed: + return 0.0 + h = hashlib.md5(f"{self.stable_seed}:{tmdb_id}".encode()).hexdigest() + # Use last 6 hex digits for tiny epsilon + eps = int(h[-6:], 16) % 1000 + return eps / 1_000_000.0 + + @staticmethod + def _normalize(value: float, min_v: float = 0.0, max_v: float = 10.0) -> float: + if max_v == min_v: + return 0.0 + return max(0.0, min(1.0, (value - min_v) / (max_v - min_v))) + + @staticmethod + def _weighted_rating(vote_avg: float | None, vote_count: int | None, C: float = 6.8, m: int = 300) -> float: + """ + IMDb-style weighted rating. Returns value on 0-10 scale. + C = global mean; m = minimum votes for full weight. + """ + try: + R = float(vote_avg or 0.0) + v = int(vote_count or 0) + except Exception: + R, v = 0.0, 0 + return ((v / (v + m)) * R) + ((m / (v + m)) * C) + + # ---------------- Recency preference (AUTO, sigmoid intensity) ---------------- + def _get_recency_multiplier_fn(self, profile, candidate_decades: set[int] | None = None): + """ + Build a multiplier function m(year) using a sigmoid-scaled intensity of the user's + recent vs classic preference derived from profile.years. + - Compute score in [-1,1] from recent (>=2015) vs classic (<2000) weights + - intensity = 2*(sigmoid(k*score)-0.5) in [-1,1] + - Apply per-year-bin deltas scaled by intensity, clamped to [0.85, 1.15] + """ + try: + years_map = getattr(profile.years, "values", {}) or {} + # Build user decade weights (keys are decades like 1990, 2000, ...) + decade_weights = {int(k): float(v) for k, v in years_map.items() if isinstance(k, int)} + total_w = sum(decade_weights.values()) + except Exception: + decade_weights = {} + total_w = 0.0 + + # Recent vs classic signal for intensity + recent_w = sum(w for d, w in decade_weights.items() if d >= 2010) + classic_w = sum(w for d, w in decade_weights.items() if d < 2000) + total_rc = recent_w + classic_w + if total_rc <= 0: + # No signal → neutral function with zero intensity + return (lambda _y: 1.0), 0.0 + + score = (recent_w - classic_w) / (total_rc + 1e-6) + k = 2.0 + intensity_raw = 1.0 / (1.0 + math.exp(-k * score)) + intensity = 2.0 * (intensity_raw - 0.5) # [-1, 1] + alpha = abs(intensity) + + # Build p_user over the support set of decades (union of profile and candidate decades) + if candidate_decades: + support = {int(d) for d in candidate_decades if isinstance(d, int)} | set(decade_weights.keys()) + else: + support = set(decade_weights.keys()) + if not support: + return (lambda _y: 1.0), 0.0 + + # Normalize user distribution over support (zero for unseen decades) + # If total_w is zero, return neutral + if total_w > 0: + p_user = {d: (decade_weights.get(d, 0.0) / total_w) for d in support} + else: + p_user = {d: 0.0 for d in support} + D = max(1, len(support)) + uniform = 1.0 / D + + def m_raw(year: int | None) -> float: + if year is None: + return 1.0 + try: + y = int(year) + except Exception: + return 1.0 + decade = (y // 10) * 10 + pu = p_user.get(decade, 0.0) + return 1.0 + intensity * (pu - uniform) + + return m_raw, alpha + + @staticmethod + def _extract_year_from_item(item: dict) -> int | None: + """Extract year from a TMDB item dict (raw or enriched).""" + date_str = item.get("release_date") or item.get("first_air_date") + if not date_str: + ri = item.get("releaseInfo") + if isinstance(ri, str) and len(ri) >= 4 and ri[:4].isdigit(): + try: + return int(ri[:4]) + except Exception: + return None + return None + try: + return int(date_str[:4]) + except Exception: + return None + + @staticmethod + def _recency_multiplier(year: int | None) -> float: + """Prefer recent titles. Softly dampen very old titles.""" + if not year: + return 1.0 + try: + y = int(year) + except Exception: + return 1.0 + if y >= 2021: + return 1.12 + if y >= 2015: + return 1.06 + if y >= 2010: + return 1.00 + if y >= 2000: + return 0.92 + if y >= 1990: + return 0.82 + return 0.70 async def _get_exclusion_sets(self, content_type: str | None = None) -> tuple[set[str], set[int]]: """ Fetch library items and build strict exclusion sets for watched content. - Also exclude items the user has added to library to avoid recommending duplicates. + Excludes watched and loved items (and items user explicitly removed). + Note: We no longer exclude 'added' items to avoid over-thinning the pool. Returns (watched_imdb_ids, watched_tmdb_ids) """ - # Always fetch fresh library to ensure we don't recommend what was just watched - library_data = await self.stremio_service.get_library_items() + # Use cached/pre-fetched library data when available + if self._library_data is None: + self._library_data = await self.stremio_service.get_library_items() + library_data = self._library_data # Combine loved, watched, added, and removed (added/removed treated as exclude-only) - all_items = ( - library_data.get("loved", []) - + library_data.get("watched", []) - + library_data.get("added", []) - + library_data.get("removed", []) - ) + all_items = library_data.get("loved", []) + library_data.get("watched", []) + library_data.get("removed", []) imdb_ids = set() tmdb_ids = set() @@ -112,6 +240,107 @@ async def _get_exclusion_sets(self, content_type: str | None = None) -> tuple[se return imdb_ids, tmdb_ids + async def _get_top_genre_whitelist(self, content_type: str) -> set[int]: + """Compute and cache user's top-genre whitelist for the given content type.""" + if content_type in self._whitelist_cache: + return self._whitelist_cache[content_type] + + try: + if self._library_data is None: + self._library_data = await self.stremio_service.get_library_items() + all_items = ( + self._library_data.get("loved", []) + + self._library_data.get("watched", []) + + self._library_data.get("added", []) + ) + typed = [ + it + for it in all_items + if it.get("type") == content_type or (content_type in ("tv", "series") and it.get("type") == "series") + ] + unique_items = {it["_id"]: it for it in typed} + scored_objects = [] + sorted_history = sorted( + unique_items.values(), key=lambda x: x.get("state", {}).get("lastWatched"), reverse=True + ) + for it in sorted_history[:10]: + scored_objects.append(self.scoring_service.process_item(it)) + # UserProfileService expects 'movie' or 'series' + prof_content_type = "series" if content_type in ("tv", "series") else "movie" + user_profile = await self.user_profile_service.build_user_profile( + scored_objects, content_type=prof_content_type + ) + top_gen_pairs = user_profile.get_top_genres(limit=TOP_GENRE_WHITELIST_LIMIT) + whitelist = {int(gid) for gid, _ in top_gen_pairs} + except Exception: + whitelist = set() + + self._whitelist_cache[content_type] = whitelist + return whitelist + + async def _passes_top_genre(self, genre_ids: list[int] | None, content_type: str) -> bool: + whitelist = await self._get_top_genre_whitelist(content_type) + if not whitelist: + return True + gids = set(genre_ids or []) + if not gids: + return True + if 16 in gids and 16 not in whitelist: + return False + return bool(gids & whitelist) + + async def _inject_freshness( + self, + pool: list[dict], + media_type: str, + watched_tmdb: set[int], + excluded_ids: set[int], + cap_injection: int, + target_capacity: int, + ) -> list[dict]: + try: + mtype = "tv" if media_type in ("tv", "series") else "movie" + trending_resp = await self.tmdb_service.get_trending(mtype, time_window="week") + trending = trending_resp.get("results", []) if trending_resp else [] + top_rated_resp = await self.tmdb_service.get_top_rated(mtype) + top_rated = top_rated_resp.get("results", []) if top_rated_resp else [] + fresh_pool = [] + fresh_pool.extend(trending[:40]) + fresh_pool.extend(top_rated[:40]) + + from collections import defaultdict + + existing_ids = {it.get("id") for it in pool if it.get("id") is not None} + fresh_genre_counts = defaultdict(int) + fresh_added = 0 + for it in fresh_pool: + tid = it.get("id") + if not tid or tid in existing_ids or tid in watched_tmdb: + continue + gids = it.get("genre_ids") or [] + if excluded_ids and excluded_ids.intersection(set(gids)): + continue + if not await self._passes_top_genre(gids, media_type): + continue + if gids and any(fresh_genre_counts[g] >= cap_injection for g in gids): + continue + va = float(it.get("vote_average") or 0.0) + vc = int(it.get("vote_count") or 0) + if vc < 300 or va < 7.0: + continue + pool.append(it) + existing_ids.add(tid) + for g in gids: + fresh_genre_counts[g] += 1 + fresh_added += 1 + if len(pool) >= target_capacity: + break + if fresh_added: + logger.info(f"Freshness injection added {fresh_added} items") + except Exception as e: + logger.warning(f"Freshness injection failed: {e}") + return pool + async def _filter_candidates( self, candidates: list[dict], watched_imdb_ids: set[str], watched_tmdb_ids: set[int] ) -> list[dict]: @@ -134,7 +363,9 @@ async def _filter_candidates( filtered.append(item) return filtered - async def _fetch_metadata_for_items(self, items: list[dict], media_type: str) -> list[dict]: + async def _fetch_metadata_for_items( + self, items: list[dict], media_type: str, target_count: int | None = None, batch_size: int = 20 + ) -> list[dict]: """ Fetch detailed metadata for items directly from TMDB API and format for Stremio. """ @@ -155,72 +386,112 @@ async def _fetch_details(tmdb_id: int): logger.warning(f"Failed to fetch details for TMDB ID {tmdb_id}: {e}") return None - # Create tasks for all items to fetch details (needed for IMDB ID and full meta) - # Filter out items without ID + # Filter out items without ID and process in batches for early stop valid_items = [item for item in items if item.get("id")] - tasks = [_fetch_details(item["id"]) for item in valid_items] - - if not tasks: + if not valid_items: return [] - details_results = await asyncio.gather(*tasks) - - for details in details_results: - if not details: - continue - - # Extract IMDB ID from external_ids - external_ids = details.get("external_ids", {}) - imdb_id = external_ids.get("imdb_id") - # tmdb_id = details.get("id") - - # Prefer IMDB ID, fallback to TMDB ID - if imdb_id: - stremio_id = imdb_id - else: # skip content if imdb id is not available - continue + # Decide target_count if not provided + if target_count is None: + # Aim to collect up to 2x of typical need but not exceed total + target_count = min(len(valid_items), 40) - # Construct Stremio meta object - title = details.get("title") or details.get("name") - if not title: - continue + for i in range(0, len(valid_items), batch_size): + if len(final_results) >= target_count: + break + chunk = valid_items[i : i + batch_size] # noqa + tasks = [_fetch_details(item["id"]) for item in chunk] + details_results = await asyncio.gather(*tasks) + for details in details_results: + if not details: + continue - # Image paths - poster_path = details.get("poster_path") - backdrop_path = details.get("backdrop_path") + # Extract IMDB ID from external_ids + external_ids = details.get("external_ids", {}) + imdb_id = external_ids.get("imdb_id") + + # Prefer IMDB ID, fallback to TMDB ID (as stremio:tmdb:) to avoid losing candidates + if imdb_id: + stremio_id = imdb_id + else: + tmdb_fallback = details.get("id") + if tmdb_fallback: + stremio_id = f"tmdb:{tmdb_fallback}" + else: + continue - release_date = details.get("release_date") or details.get("first_air_date") or "" - year = release_date[:4] if release_date else None + # Construct Stremio meta object + title = details.get("title") or details.get("name") + if not title: + continue - if self.user_settings and self.user_settings.rpdb_key: - poster_url = RPDBService.get_poster_url(self.user_settings.rpdb_key, stremio_id) - else: - poster_url = f"https://image.tmdb.org/t/p/w500{poster_path}" if poster_path else None - - meta_data = { - "id": stremio_id, - "imdb_id": stremio_id, - "type": "series" if media_type in ["tv", "series"] else "movie", - "name": title, - "poster": poster_url, - "background": f"https://image.tmdb.org/t/p/original{backdrop_path}" if backdrop_path else None, - "description": details.get("overview"), - "releaseInfo": year, - "imdbRating": str(details.get("vote_average", "")), - "genres": [g.get("name") for g in details.get("genres", [])], - # pass internal external_ids for post-filtering if needed - "_external_ids": external_ids, - } - - # Add runtime if available (Movie) or episode run time (TV) - runtime = details.get("runtime") - if not runtime and details.get("episode_run_time"): - runtime = details.get("episode_run_time")[0] - - if runtime: - meta_data["runtime"] = f"{runtime} min" - - final_results.append(meta_data) + # Image paths + poster_path = details.get("poster_path") + backdrop_path = details.get("backdrop_path") + + release_date = details.get("release_date") or details.get("first_air_date") or "" + year = release_date[:4] if release_date else None + + if self.user_settings and self.user_settings.rpdb_key: + poster_url = RPDBService.get_poster_url(self.user_settings.rpdb_key, stremio_id) + else: + poster_url = f"https://image.tmdb.org/t/p/w500{poster_path}" if poster_path else None + + genres_full = details.get("genres", []) or [] + genre_ids = [g.get("id") for g in genres_full if isinstance(g, dict) and g.get("id") is not None] + + meta_data = { + "id": stremio_id, + "imdb_id": imdb_id, + "type": "series" if media_type in ["tv", "series"] else "movie", + "name": title, + "poster": poster_url, + "background": f"https://image.tmdb.org/t/p/original{backdrop_path}" if backdrop_path else None, + "description": details.get("overview"), + "releaseInfo": year, + "imdbRating": str(details.get("vote_average", "")), + # Display genres (names) but keep full ids separately + "genres": [g.get("name") for g in genres_full], + # Keep fields for ranking and post-processing + "vote_average": details.get("vote_average"), + "vote_count": details.get("vote_count"), + "popularity": details.get("popularity"), + "original_language": details.get("original_language"), + # pass internal external_ids for post-filtering if needed + "_external_ids": external_ids, + # internal fields for suppression/rerank + "_tmdb_id": details.get("id"), + "genre_ids": genre_ids, + } + + # Add runtime if available (Movie) or episode run time (TV) + runtime = details.get("runtime") + if not runtime and details.get("episode_run_time"): + runtime = details.get("episode_run_time")[0] + + if runtime: + meta_data["runtime"] = f"{runtime} min" + + # internal fields for collection and cast (movies only for collection) + if query_media_type == "movie": + coll = details.get("belongs_to_collection") or {} + if isinstance(coll, dict): + meta_data["_collection_id"] = coll.get("id") + + # top 3 cast ids + cast = details.get("credits", {}).get("cast", []) or [] + meta_data["_top_cast_ids"] = [c.get("id") for c in cast[:3] if c.get("id") is not None] + + # Attach minimal structures for similarity to use keywords/credits later + if details.get("keywords"): + meta_data["keywords"] = details.get("keywords") + if details.get("credits"): + meta_data["credits"] = details.get("credits") + + final_results.append(meta_data) + + if len(final_results) >= target_count: + break return final_results @@ -251,6 +522,19 @@ async def get_recommendations_for_item(self, item_id: str) -> list[dict]: if not media_type: media_type = "movie" + # Build top-genre whitelist for this type + _whitelist = await self._get_top_genre_whitelist(media_type) + + def _passes_top_genre(item_genre_ids: list[int] | None) -> bool: + if not _whitelist: + return True + gids = set(item_genre_ids or []) + if not gids: + return True + if 16 in gids and 16 not in _whitelist: + return False + return bool(gids & _whitelist) + # Fetch more candidates to account for filtering # We want 20 final, so fetch 40 buffer_limit = self.per_item_limit * 2 @@ -271,9 +555,24 @@ async def get_recommendations_for_item(self, item_id: str) -> list[dict]: recommendations = [ item for item in recommendations if not excluded_ids.intersection(item.get("genre_ids") or []) ] + # Top-genre whitelist filter + recommendations = [it for it in recommendations if _passes_top_genre(it.get("genre_ids"))] + + # 1.6 Freshness: inject trending/top-rated within whitelist to expand pool + if len(recommendations) < buffer_limit: + recommendations = await self._inject_freshness( + recommendations, + media_type, + watched_tmdb, + excluded_ids, + max(1, int(self.per_item_limit * PER_GENRE_MAX_SHARE)), + buffer_limit, + ) # 2. Fetch Metadata (gets IMDB IDs) - meta_items = await self._fetch_metadata_for_items(recommendations, media_type) + meta_items = await self._fetch_metadata_for_items( + recommendations, media_type, target_count=self.per_item_limit * 2 + ) # 3. Strict Filter by IMDB ID (using metadata) final_items = [] @@ -285,6 +584,9 @@ async def get_recommendations_for_item(self, item_id: str) -> list[dict]: ext_ids = item.get("_external_ids", {}) if ext_ids.get("imdb_id") in watched_imdb: continue + # Apply top-genre whitelist with enriched genre_ids + if not _passes_top_genre(item.get("genre_ids")): + continue # Clean up internal fields item.pop("_external_ids", None) @@ -342,21 +644,72 @@ async def get_recommendations_for_theme(self, theme_id: str, content_type: str, if "sort_by" not in params: params["sort_by"] = "popularity.desc" - # Apply Excluded Genres + # Apply Excluded Genres but don't conflict with explicit with_genres from theme excluded_ids = self._get_excluded_genre_ids(content_type) if excluded_ids: - params["without_genres"] = "|".join(str(g) for g in excluded_ids) - - # Fetch - recommendations = await self.tmdb_service.get_discover(content_type, **params) - candidates = recommendations.get("results", []) + try: + with_ids = { + int(g) + for g in ( + params.get("with_genres", "").replace("|", ",").split(",") if params.get("with_genres") else [] + ) + if g + } + except Exception: + with_ids = set() + final_without = [g for g in excluded_ids if g not in with_ids] + if final_without: + params["without_genres"] = "|".join(str(g) for g in final_without) + + # Build whitelist via helper + _whitelist = await self._get_top_genre_whitelist(content_type) + + def _passes_top_genre(item_genre_ids: list[int] | None) -> bool: + if not _whitelist: + return True + gids = set(item_genre_ids or []) + if not gids: + return True + if 16 in gids and 16 not in _whitelist: + return False + return bool(gids & _whitelist) + + # Fetch (with simple multi-page fallback to increase pool) + candidates: list[dict] = [] + try: + first = await self.tmdb_service.get_discover(content_type, **params) + candidates.extend(first.get("results", [])) + # If we have too few, try page 2 (and 3) to increase pool size + if len(candidates) < limit * 2: + second = await self.tmdb_service.get_discover(content_type, page=2, **params) + candidates.extend(second.get("results", [])) + if len(candidates) < limit * 2: + third = await self.tmdb_service.get_discover(content_type, page=3, **params) + candidates.extend(third.get("results", [])) + except Exception: + candidates = [] + + # Apply top-genre whitelist on raw candidates + if candidates: + candidates = [it for it in candidates if _passes_top_genre(it.get("genre_ids"))] # Strict Filtering watched_imdb, watched_tmdb = await self._get_exclusion_sets() filtered = await self._filter_candidates(candidates, watched_imdb, watched_tmdb) + # Freshness injection: add trending/popular/top-rated (within whitelist) if pool thin + if len(filtered) < limit * 2: + filtered = await self._inject_freshness( + filtered, + content_type, + watched_tmdb, + set(excluded_ids), + max(1, int(limit * PER_GENRE_MAX_SHARE)), + limit * 3, + ) + # Meta - meta_items = await self._fetch_metadata_for_items(filtered[: limit * 2], content_type) + meta_items = await self._fetch_metadata_for_items(filtered, content_type, target_count=limit * 3) final_items = [] for item in meta_items: @@ -364,9 +717,16 @@ async def get_recommendations_for_theme(self, theme_id: str, content_type: str, continue if item.get("_external_ids", {}).get("imdb_id") in watched_imdb: continue + # Apply whitelist again on enriched metadata + if not _passes_top_genre(item.get("genre_ids")): + continue item.pop("_external_ids", None) final_items.append(item) + # Enforce limit + if len(final_items) > limit: + final_items = final_items[:limit] + return final_items async def _fetch_recommendations_from_tmdb(self, item_id: str, media_type: str, limit: int) -> list[dict]: @@ -407,11 +767,42 @@ async def _fetch_recommendations_from_tmdb(self, item_id: str, media_type: str, # Normalize series alias mtype = "tv" if media_type in ("tv", "series") else "movie" - recommendation_response = await self.tmdb_service.get_recommendations(tmdb_id, mtype) - recommended_items = recommendation_response.get("results", []) - if not recommended_items: - return [] - return recommended_items + # Try multiple pages to increase pool + combined: dict[int, dict] = {} + try: + rec1 = await self.tmdb_service.get_recommendations(tmdb_id, mtype, page=1) + for it in rec1.get("results", []): + if it.get("id") is not None: + combined[it["id"]] = it + if len(combined) < limit: + rec2 = await self.tmdb_service.get_recommendations(tmdb_id, mtype, page=2) + for it in rec2.get("results", []): + if it.get("id") is not None: + combined[it["id"]] = it + if len(combined) < limit: + rec3 = await self.tmdb_service.get_recommendations(tmdb_id, mtype, page=3) + for it in rec3.get("results", []): + if it.get("id") is not None: + combined[it["id"]] = it + except Exception: + pass + + # If still thin, use similar as fallback + if len(combined) < max(20, limit // 2): + try: + sim1 = await self.tmdb_service.get_similar(tmdb_id, mtype, page=1) + for it in sim1.get("results", []): + if it.get("id") is not None: + combined[it["id"]] = it + if len(combined) < limit: + sim2 = await self.tmdb_service.get_similar(tmdb_id, mtype, page=2) + for it in sim2.get("results", []): + if it.get("id") is not None: + combined[it["id"]] = it + except Exception: + pass + + return list(combined.values()) async def get_recommendations( self, @@ -429,7 +820,9 @@ async def get_recommendations( logger.info(f"Starting Hybrid Recommendation Pipeline for {content_type}") # Step 1: Fetch & Score User Library - library_data = await self.stremio_service.get_library_items() + if self._library_data is None: + self._library_data = await self.stremio_service.get_library_items() + library_data = self._library_data all_items = library_data.get("loved", []) + library_data.get("watched", []) + library_data.get("added", []) logger.info(f"processing {len(all_items)} Items.") # Cold-start fallback remains (redundant safety) @@ -468,10 +861,27 @@ async def get_recommendations( excluded_ids = set(self._get_excluded_genre_ids(content_type)) similarity_recommendations = [item for item in similarity_recommendations if not isinstance(item, Exception)] + # Apply excluded-genre filter for similarity candidates (whitelist will be applied after profile build) for batch in similarity_recommendations: - similarity_candidates.extend( - item for item in batch if not excluded_ids.intersection(item.get("genre_ids") or []) - ) + for item in batch: + gids = item.get("genre_ids") or [] + if excluded_ids.intersection(gids): + continue + similarity_candidates.append(item) + + # Quality gate for similarity candidates: keep higher-quality when we have enough + def _qual(item: dict) -> bool: + try: + vc = int(item.get("vote_count") or 0) + va = float(item.get("vote_average") or 0.0) + wr = self._weighted_rating(va, vc) + return (vc >= 150 and wr >= 6.0) or (vc >= 500 and wr >= 5.6) + except Exception: + return False + + # filtered_sim = [it for it in similarity_candidates if _qual(it)] + # if len(filtered_sim) >= 40: + # similarity_candidates = filtered_sim # --- Candidate Set B: Profile-based Discovery --- # Extract excluded genres @@ -481,14 +891,50 @@ async def get_recommendations( user_profile = await self.user_profile_service.build_user_profile( scored_objects, content_type=content_type, excluded_genres=excluded_genres ) - discovery_candidates = await self.discovery_engine.discover_recommendations( - user_profile, content_type, limit=20, excluded_genres=excluded_genres - ) + # AUTO recency preference function based on profile years + # recency_fn = self._get_recency_multiplier_fn(user_profile) + # Build per-user top-genre whitelist + try: + top_gen_pairs = user_profile.get_top_genres(limit=TOP_GENRE_WHITELIST_LIMIT) + top_genre_whitelist: set[int] = {int(gid) for gid, _ in top_gen_pairs} + except Exception: + top_genre_whitelist = set() + + def _passes_top_genre(item_genre_ids: list[int] | None) -> bool: + if not top_genre_whitelist: + return True + gids = set(item_genre_ids or []) + if not gids: + return True + if 16 in gids and 16 not in top_genre_whitelist: + return False + return bool(gids & top_genre_whitelist) + + # Always include discovery, but bias to keywords/cast (avoid genre-heavy discovery) + try: + discovery_candidates = await self.discovery_engine.discover_recommendations( + user_profile, + content_type, + limit=max_results * 3, + excluded_genres=excluded_genres, + use_genres=False, + use_keywords=True, + use_cast=True, + use_director=True, + use_countries=False, + use_year=False, + ) + except Exception as e: + logger.warning(f"Discovery fetch failed: {e}") + discovery_candidates = [] # --- Combine & Deduplicate --- candidate_pool = {} # tmdb_id -> item_dict for item in discovery_candidates: + gids = item.get("genre_ids") or [] + if not _passes_top_genre(gids): + continue candidate_pool[item["id"]] = item for item in similarity_candidates: @@ -496,6 +942,71 @@ async def get_recommendations( item["_ranked_candidate"] = True candidate_pool[item["id"]] = item + logger.info(f"Similarity candidates collected: {len(similarity_candidates)}; pool size: {len(candidate_pool)}") + + # Build recency blend function (m_raw, alpha) based on profile and candidate decades + try: + candidate_decades = set() + for it in candidate_pool.values(): + y = self._extract_year_from_item(it) + if y: + candidate_decades.add((int(y) // 10) * 10) + recency_m_raw, recency_alpha = self._get_recency_multiplier_fn(user_profile, candidate_decades) + except Exception: + recency_m_raw, recency_alpha = (lambda _y: 1.0), 0.0 + + # Freshness injection: trending/highly rated items to broaden taste + try: + fresh_added = 0 + from collections import defaultdict + + fresh_genre_counts = defaultdict(int) + cap_injection = max(1, int(max_results * PER_GENRE_MAX_SHARE)) + mtype = "tv" if content_type in ("tv", "series") else "movie" + trending_resp = await self.tmdb_service.get_trending(mtype, time_window="week") + trending = trending_resp.get("results", []) if trending_resp else [] + # Mix in top-rated + top_rated_resp = await self.tmdb_service.get_top_rated(mtype) + top_rated = top_rated_resp.get("results", []) if top_rated_resp else [] + fresh_pool = [] + fresh_pool.extend(trending[:40]) + fresh_pool.extend(top_rated[:40]) + # Filter by excluded genres and quality threshold + for it in fresh_pool: + tid = it.get("id") + if not tid or tid in candidate_pool: + continue + # Exclude already watched by TMDB id + if tid in watched_tmdb_ids: + continue + # Excluded genres + gids = it.get("genre_ids") or [] + if excluded_ids and excluded_ids.intersection(set(gids)): + continue + # Respect top-genre whitelist + if not _passes_top_genre(gids): + continue + # Quality: prefer strong audience signal + va = float(it.get("vote_average") or 0.0) + vc = int(it.get("vote_count") or 0) + if vc < 300 or va < 7.0: + continue + # Genre diversity inside freshness injection + if gids and any(fresh_genre_counts[g] >= cap_injection for g in gids): + continue + # Mark as freshness candidate + it["_fresh_boost"] = True + candidate_pool[tid] = it + for g in gids: + fresh_genre_counts[g] += 1 + fresh_added += 1 + if fresh_added >= max_results * 2: + break + if fresh_added: + logger.info(f"Freshness injection added {fresh_added} trending/top-rated candidates") + except Exception as e: + logger.warning(f"Freshness injection failed: {e}") + # --- Re-Ranking & Filtering --- ranked_candidates = [] @@ -504,22 +1015,75 @@ async def get_recommendations( if tmdb_id in watched_tmdb_ids or f"tmdb:{tmdb_id}" in watched_imdb_ids: continue - sim_score = self.user_profile_service.calculate_similarity(user_profile, item) - vote_average = item.get("vote_average", 0) - popularity = item.get("popularity", 0) - - pop_score = normalize(popularity, 0, 1000) - vote_score = normalize(vote_average, 0, 10) - - final_score = (sim_score * 0.6) + (vote_score * 0.3) + (pop_score * 0.1) + # Use simple overlap similarity (Jaccard on tokens/genres/keywords) + try: + sim_score, sim_breakdown = self.user_profile_service.calculate_simple_overlap_with_breakdown( + user_profile, item + ) + except Exception: + sim_score = 0.0 + sim_breakdown = {} + # attach breakdown to item for later inspection + item["_sim_breakdown"] = sim_breakdown + + # If we only matched on genres (topics/keywords near zero), slightly penalize + try: + non_gen_relevance = float(sim_breakdown.get("topics_jaccard", 0.0)) + float( + sim_breakdown.get("keywords_jaccard", 0.0) + ) + if non_gen_relevance <= 0.0001: + sim_score *= 0.8 + item["_sim_penalty"] = True + item["_sim_penalty_reason"] = "genre_only_match" + except Exception: + pass + vote_avg = item.get("vote_average", 0.0) + vote_count = item.get("vote_count", 0) + popularity = float(item.get("popularity", 0.0)) + + # Weighted rating then normalize to 0-1 + wr = self._weighted_rating(vote_avg, vote_count) + vote_score = self._normalize(wr, 0.0, 10.0) + pop_score = self._normalize(popularity, 0.0, 1000.0) + + # Increase weight on quality to avoid low-rated picks + final_score = (sim_score * 0.55) + (vote_score * 0.35) + (pop_score * 0.10) + # AUTO recency (blend): final *= (1 - alpha) + alpha * m_raw + try: + y = self._extract_year_from_item(item) + m = recency_m_raw(y) + final_score *= (1.0 - recency_alpha) + (recency_alpha * m) + except Exception: + pass + # Stable tiny epsilon to break ties deterministically + final_score += self._stable_epsilon(tmdb_id) + + # Quality-aware multiplicative adjustments + q_mult = 1.0 + if vote_count < 50: + q_mult *= 0.6 + elif vote_count < 150: + q_mult *= 0.85 + if wr < 5.5: + q_mult *= 0.5 + elif wr < 6.0: + q_mult *= 0.7 + elif wr >= 7.0 and vote_count >= 500: + q_mult *= 1.10 + + # Boost candidate if from TMDB collaborative recommendations, but only if quality is decent + if item.get("_ranked_candidate"): + if wr >= 6.5 and vote_count >= 200: + q_mult *= 1.25 + elif wr >= 6.0 and vote_count >= 100: + q_mult *= 1.10 + # else no boost - # Add tiny jitter to promote freshness and avoid static ordering - jitter = random.uniform(-0.02, 0.02) # +/-2% - final_score = final_score * (1 + jitter) + # Mild boost for freshness-injected trending/top-rated picks to keep feed fresh + if item.get("_fresh_boost") and wr >= 7.0 and vote_count >= 300: + q_mult *= 1.10 - # Boost candidate if its from tmdb collaborative recommendations - if item.get("_ranked_candidate"): - final_score *= 1.25 + final_score *= q_mult ranked_candidates.append((final_score, item)) # Sort by Final Score and cache score on item for diversification @@ -527,71 +1091,201 @@ async def get_recommendations( for score, item in ranked_candidates: item["_final_score"] = score - # Diversify with MMR to avoid shallow, repetitive picks - def _jaccard(a: set, b: set) -> float: - if not a and not b: - return 0.0 - inter = len(a & b) - union = len(a | b) - return inter / union if union else 0.0 - - def _candidate_similarity(x: dict, y: dict) -> float: - gx = set(x.get("genre_ids") or []) - gy = set(y.get("genre_ids") or []) - s = _jaccard(gx, gy) - # Mild penalty if same language to encourage variety - lx = x.get("original_language") - ly = y.get("original_language") - if lx and ly and lx == ly: - s += 0.05 - return min(s, 1.0) - - def _mmr_select(cands: list[dict], k: int, lamb: float = 0.75) -> list[dict]: - selected: list[dict] = [] - remaining = cands[:] - while remaining and len(selected) < k: - if not selected: - best = remaining.pop(0) - selected.append(best) - continue - best_item = None - best_score = float("-inf") - for cand in remaining[:50]: # evaluate a window for speed - rel = cand.get("_final_score", 0.0) - div = 0.0 - for s in selected: - div = max(div, _candidate_similarity(cand, s)) - mmr = lamb * rel - (1 - lamb) * div - if mmr > best_score: - best_score = mmr - best_item = cand - if best_item is None: - break - selected.append(best_item) - try: - remaining.remove(best_item) - except ValueError: - pass - return selected - + # Lightweight logging: show top 5 ranked candidates with similarity breakdown + try: + top_n = ranked_candidates[:5] + if top_n: + logger.info("Top similarity-ranked candidates (pre-meta):") + for sc, it in top_n: + name = it.get("title") or it.get("name") or it.get("original_title") or it.get("id") + bd = it.get("_sim_breakdown") or {} + logger.info(f"- {name} (tmdb:{it.get('id')}): score={sc:.4f} breakdown={bd}") + except Exception: + pass + + # Simplified selection: take top-ranked items directly (no MMR diversification) top_ranked_items = [item for _, item in ranked_candidates] - diversified = _mmr_select(top_ranked_items, k=max_results * 2, lamb=0.75) - # Select with buffer for final IMDB filtering after diversification - buffer_selection = diversified + # Buffer selection size is 2x requested results to allow final filtering + buffer_selection = top_ranked_items[: max_results * 2] # Fetch Full Metadata - meta_items = await self._fetch_metadata_for_items(buffer_selection, content_type) + meta_items = await self._fetch_metadata_for_items(buffer_selection, content_type, target_count=max_results * 2) - # Final Strict Filter by IMDB ID + # Recompute similarity with enriched metadata (keywords, credits) final_items = [] + used_collections: set[int] = set() + used_cast: set[int] = set() for item in meta_items: if item["id"] in watched_imdb_ids: continue ext_ids = item.get("_external_ids", {}) if ext_ids.get("imdb_id") in watched_imdb_ids: continue + # Apply top-genre whitelist again using enriched genre_ids if present + if not _passes_top_genre(item.get("genre_ids")): + continue - item.pop("_external_ids", None) + try: + sim_score, sim_breakdown = self.user_profile_service.calculate_simple_overlap_with_breakdown( + user_profile, item + ) + except Exception: + sim_score = 0.0 + sim_breakdown = {} + item["_sim_breakdown"] = sim_breakdown + wr = self._weighted_rating(item.get("vote_average"), item.get("vote_count")) + vote_score = self._normalize(wr, 0.0, 10.0) + pop_score = self._normalize(float(item.get("popularity") or 0.0), 0.0, 1000.0) + + base = (sim_score * 0.55) + (vote_score * 0.35) + (pop_score * 0.10) + base += self._stable_epsilon(item.get("_tmdb_id") or 0) + + # Quality-aware adjustment + vc = int(item.get("vote_count") or 0) + q_mult = 1.0 + if vc < 50: + q_mult *= 0.6 + elif vc < 150: + q_mult *= 0.85 + if wr < 5.5: + q_mult *= 0.5 + elif wr < 6.0: + q_mult *= 0.7 + elif wr >= 7.0 and vc >= 500: + q_mult *= 1.10 + + # AUTO recency (blend) in post-metadata stage as well + try: + y = self._extract_year_from_item(item) + m = recency_m_raw(y) + q_mult *= (1.0 - recency_alpha) + (recency_alpha * m) + except Exception: + pass + + score = base * q_mult + + # Collection/cast suppression + penalty = 0.0 + coll_id = item.get("_collection_id") + if isinstance(coll_id, int) and coll_id in used_collections: + penalty += 0.05 + cast_ids = set(item.get("_top_cast_ids", []) or []) + overlap = len(cast_ids & used_cast) + if overlap: + penalty += min(0.03 * overlap, 0.09) + score *= 1.0 - penalty + item["_adjusted_score"] = score final_items.append(item) - return final_items + # Sort by adjusted score descending + final_items.sort(key=lambda x: x.get("_adjusted_score", 0.0), reverse=True) + + # Diversified selection: per-genre cap AND proportional decade apportionment + from collections import defaultdict + + genre_take_counts = defaultdict(int) + cap_per_genre = max(1, int(max_results * PER_GENRE_MAX_SHARE)) + + # Build decade targets from user profile distribution over decades present in final_items + decades_in_results = [] + for it in final_items: + y = self._extract_year_from_item(it) + if y: + decades_in_results.append((int(y) // 10) * 10) + else: + decades_in_results.append(None) + + # User decade prefs + try: + years_map = getattr(user_profile.years, "values", {}) or {} + decade_weights = {int(k): float(v) for k, v in years_map.items() if isinstance(k, int)} + total_w = sum(decade_weights.values()) + except Exception: + decade_weights = {} + total_w = 0.0 + + support = {d for d in decades_in_results if d is not None} + if total_w > 0 and support: + p_user = {d: (decade_weights.get(d, 0.0) / total_w) for d in support} + # Normalize to sum 1 over support + s = sum(p_user.values()) + if s > 0: + for d in list(p_user.keys()): + p_user[d] = p_user[d] / s + else: + # fallback to uniform over support + p_user = {d: 1.0 / len(support) for d in support} + else: + # Neutral: uniform over decades present + p_user = {d: 1.0 / len(support) for d in support} if support else {} + + # Largest remainder apportionment + targets = defaultdict(int) + remainders = [] + slots = max_results + for d, p in p_user.items(): + tgt = p * slots + base = int(tgt) + targets[d] = base + remainders.append((tgt - base, d)) + assigned = sum(targets.values()) + remaining = max(0, slots - assigned) + if remaining > 0 and remainders: + remainders.sort(key=lambda x: x[0], reverse=True) + for _, d in remainders[:remaining]: + targets[d] += 1 + + # First pass: honor decade targets and genre caps + decade_counts = defaultdict(int) + diversified = [] + for it in final_items: + if len(diversified) >= max_results * 2: + break + gids = list(it.get("genre_ids") or []) + if gids and any(genre_take_counts[g] >= cap_per_genre for g in gids): + continue + y = self._extract_year_from_item(it) + d = (int(y) // 10) * 10 if y else None + if d is not None and d in targets and decade_counts[d] >= targets[d]: + continue + diversified.append(it) + for g in gids: + genre_take_counts[g] += 1 + if d is not None: + decade_counts[d] += 1 + + # Second pass: fill remaining up to max_results ignoring decade targets but keeping genre caps + if len(diversified) < max_results: + for it in final_items: + if it in diversified: + continue + if len(diversified) >= max_results * 2: + break + gids = list(it.get("genre_ids") or []) + if gids and any(genre_take_counts[g] >= cap_per_genre for g in gids): + continue + diversified.append(it) + for g in gids: + genre_take_counts[g] += 1 + + # Update used sets for next requests (implicit) and cleanup internal fields + ordered = [] + for it in diversified: + coll = it.pop("_collection_id", None) + if isinstance(coll, int): + used_collections.add(coll) + for cid in it.pop("_top_cast_ids", []) or []: + try: + used_cast.add(int(cid)) + except Exception: + pass + it.pop("_external_ids", None) + it.pop("_tmdb_id", None) + it.pop("_adjusted_score", None) + ordered.append(it) + + # Enforce max_results limit + if len(ordered) > max_results: + ordered = ordered[:max_results] + + return ordered diff --git a/app/services/row_generator.py b/app/services/row_generator.py index 250177e..706c040 100644 --- a/app/services/row_generator.py +++ b/app/services/row_generator.py @@ -6,7 +6,7 @@ from app.services.gemini import gemini_service from app.services.tmdb.countries import COUNTRY_ADJECTIVES from app.services.tmdb.genre import movie_genres, series_genres -from app.services.tmdb_service import TMDBService +from app.services.tmdb_service import TMDBService, get_tmdb_service def normalize_keyword(kw): @@ -36,7 +36,7 @@ class RowGeneratorService: """ def __init__(self, tmdb_service: TMDBService | None = None): - self.tmdb_service = tmdb_service or TMDBService() + self.tmdb_service = tmdb_service or get_tmdb_service() async def generate_rows(self, profile: UserTasteProfile, content_type: str = "movie") -> list[RowDefinition]: """ diff --git a/app/services/stremio_service.py b/app/services/stremio_service.py index 1bee937..c21719e 100644 --- a/app/services/stremio_service.py +++ b/app/services/stremio_service.py @@ -45,6 +45,9 @@ def __init__( # Reuse HTTP client for connection pooling and better performance self._client: httpx.AsyncClient | None = None self._likes_client: httpx.AsyncClient | None = None + # lightweight per-instance cache for library fetch + self._library_cache: dict | None = None + self._library_cache_expiry: float = 0.0 async def _get_client(self) -> httpx.AsyncClient: """Get or create the main Stremio API client.""" @@ -212,11 +215,16 @@ async def get_user_email(self) -> str: user_info = await self.get_user_info() return user_info.get("email", "") - async def get_library_items(self) -> dict[str, list[dict]]: + async def get_library_items(self, use_cache: bool = True, cache_ttl_seconds: int = 30) -> dict[str, list[dict]]: """ Fetch library items from Stremio once and return both watched and loved items. Returns a dict with 'watched' and 'loved' keys. """ + import time + + if use_cache and self._library_cache and time.time() < self._library_cache_expiry: + return self._library_cache + if not self._auth_key: logger.warning("Stremio auth key not configured") return {"watched": [], "loved": []} @@ -318,13 +326,18 @@ def _sort_key(x: dict): added_items.append(item) logger.info(f"Found {len(added_items)} added (unwatched) and {len(removed_items)} removed library items") - # Return raw items; ScoringService will handle Pydantic conversion - return { + # Prepare result + result = { "watched": watched_items, "loved": loved_items, "added": added_items, "removed": removed_items, } + # cache + if use_cache and cache_ttl_seconds > 0: + self._library_cache = result + self._library_cache_expiry = time.time() + cache_ttl_seconds + return result except Exception as e: logger.error(f"Error fetching library items: {e}", exc_info=True) return {"watched": [], "loved": []} diff --git a/app/services/tmdb_service.py b/app/services/tmdb_service.py index 1ea9a5c..fed117d 100644 --- a/app/services/tmdb_service.py +++ b/app/services/tmdb_service.py @@ -1,4 +1,5 @@ import asyncio +import functools import random import httpx @@ -148,21 +149,21 @@ async def get_tv_details(self, tv_id: int) -> dict: params = {"append_to_response": "credits,external_ids,keywords"} return await self._make_request(f"/tv/{tv_id}", params=params) - @alru_cache(maxsize=1000) + @alru_cache(maxsize=1000, ttl=6 * 60 * 60) async def get_recommendations(self, tmdb_id: int, media_type: str, page: int = 1) -> dict: """Get recommendations based on TMDB ID and media type.""" params = {"page": page} endpoint = f"/{media_type}/{tmdb_id}/recommendations" return await self._make_request(endpoint, params=params) - @alru_cache(maxsize=1000) + @alru_cache(maxsize=1000, ttl=6 * 60 * 60) async def get_similar(self, tmdb_id: int, media_type: str, page: int = 1) -> dict: """Get similar content based on TMDB ID and media type.""" params = {"page": page} endpoint = f"/{media_type}/{tmdb_id}/similar" return await self._make_request(endpoint, params=params) - @alru_cache(maxsize=1000) + @alru_cache(maxsize=1000, ttl=30 * 60) async def get_discover( self, media_type: str, @@ -180,3 +181,25 @@ async def get_discover( params.update(kwargs) endpoint = f"/discover/{media_type}" return await self._make_request(endpoint, params=params) + + @alru_cache(maxsize=500, ttl=60 * 60) + async def get_trending(self, media_type: str, time_window: str = "week", page: int = 1) -> dict: + """Get trending content. media_type: 'movie' or 'tv'. time_window: 'day' or 'week'""" + mt = "movie" if media_type == "movie" else "tv" + params = {"page": page} + endpoint = f"/trending/{mt}/{time_window}" + return await self._make_request(endpoint, params=params) + + @alru_cache(maxsize=500, ttl=60 * 60) + async def get_top_rated(self, media_type: str, page: int = 1) -> dict: + """Get top-rated content list.""" + mt = "movie" if media_type == "movie" else "tv" + params = {"page": page} + endpoint = f"/{mt}/top_rated" + return await self._make_request(endpoint, params=params) + + +# Singleton factory to reuse clients and async caches per language +@functools.lru_cache(maxsize=16) +def get_tmdb_service(language: str = "en-US") -> TMDBService: + return TMDBService(language=language) diff --git a/app/services/token_store.py b/app/services/token_store.py index 304a6be..918af2c 100644 --- a/app/services/token_store.py +++ b/app/services/token_store.py @@ -1,4 +1,5 @@ import base64 +import contextvars import json from collections.abc import AsyncIterator from typing import Any @@ -24,6 +25,8 @@ def __init__(self) -> None: # Cache decrypted payloads for 1 day (86400s) to reduce Redis hits # Max size 5000 allows many active users without eviction self._payload_cache: TTLCache = TTLCache(maxsize=5000, ttl=86400) + # per-request redis call counter (context-local) + self._redis_calls_var: contextvars.ContextVar[int] = contextvars.ContextVar("watchly_redis_calls", default=0) if not settings.REDIS_URL: logger.warning("REDIS_URL is not set. Token storage will fail until a Redis instance is configured.") @@ -100,8 +103,10 @@ async def store_user_data(self, user_id: str, payload: dict[str, Any]) -> str: json_str = json.dumps(storage_data) if settings.TOKEN_TTL_SECONDS and settings.TOKEN_TTL_SECONDS > 0: + self._incr_calls() await client.setex(key, settings.TOKEN_TTL_SECONDS, json_str) else: + self._incr_calls() await client.set(key, json_str) # Update cache with the payload @@ -111,10 +116,13 @@ async def store_user_data(self, user_id: str, payload: dict[str, Any]) -> str: async def get_user_data(self, token: str) -> dict[str, Any] | None: if token in self._payload_cache: + logger.info(f"[REDIS] Using cached redis data {token}") return self._payload_cache[token] + logger.info(f"[REDIS]Caching Failed. Fetching data from redis for {token}") key = self._format_key(token) client = await self._get_client() + self._incr_calls() data_raw = await client.get(key) if not data_raw: @@ -136,13 +144,14 @@ async def delete_token(self, token: str = None, key: str = None) -> None: key = self._format_key(token) client = await self._get_client() + self._incr_calls() await client.delete(key) # Invalidate local cache if token and token in self._payload_cache: del self._payload_cache[token] - async def iter_payloads(self) -> AsyncIterator[tuple[str, dict[str, Any]]]: + async def iter_payloads(self, batch_size: int = 200) -> AsyncIterator[tuple[str, dict[str, Any]]]: try: client = await self._get_client() except (redis.RedisError, OSError) as exc: @@ -152,25 +161,82 @@ async def iter_payloads(self) -> AsyncIterator[tuple[str, dict[str, Any]]]: pattern = f"{self.KEY_PREFIX}*" try: - async for key in client.scan_iter(match=pattern): + buffer: list[str] = [] + async for key in client.scan_iter(match=pattern, count=batch_size): + buffer.append(key) + if len(buffer) >= batch_size: + try: + self._incr_calls() + values = await client.mget(buffer) + except (redis.RedisError, OSError) as exc: + logger.warning(f"Failed batch fetch for {len(buffer)} keys: {exc}") + values = [None] * len(buffer) + for k, data_raw in zip(buffer, values): + if not data_raw: + continue + try: + payload = json.loads(data_raw) + except json.JSONDecodeError: + logger.warning(f"Failed to decode payload for key {redact_token(k)}. Skipping.") + continue + # Decrypt authKey for downstream consumers + try: + if payload.get("authKey"): + payload["authKey"] = self.decrypt_token(payload["authKey"]) + except Exception: + pass + # Update L1 cache (token only) + tok = k[len(self.KEY_PREFIX) :] if k.startswith(self.KEY_PREFIX) else k # noqa + self._payload_cache[tok] = payload + yield k, payload + buffer.clear() + + # Flush remainder + if buffer: try: - data_raw = await client.get(key) + self._incr_calls() + values = await client.mget(buffer) except (redis.RedisError, OSError) as exc: - logger.warning(f"Failed to fetch payload for {redact_token(key)}: {exc}") - continue + logger.warning(f"Failed batch fetch for {len(buffer)} keys: {exc}") + values = [None] * len(buffer) + for k, data_raw in zip(buffer, values): + if not data_raw: + continue + try: + payload = json.loads(data_raw) + except json.JSONDecodeError: + logger.warning(f"Failed to decode payload for key {redact_token(k)}. Skipping.") + continue + try: + if payload.get("authKey"): + payload["authKey"] = self.decrypt_token(payload["authKey"]) + except Exception: + pass + tok = k[len(self.KEY_PREFIX) :] if k.startswith(self.KEY_PREFIX) else k # noqa + self._payload_cache[tok] = payload + yield k, payload + except (redis.RedisError, OSError) as exc: + logger.warning(f"Failed to scan credential tokens: {exc}") - if not data_raw: - continue + # ---- Diagnostics ---- + def _incr_calls(self) -> None: + try: + current = self._redis_calls_var.get() + self._redis_calls_var.set(current + 1) + except Exception: + pass - try: - payload = json.loads(data_raw) - except json.JSONDecodeError: - logger.warning(f"Failed to decode payload for key {redact_token(key)}. Skipping.") - continue + def reset_call_counter(self) -> None: + try: + self._redis_calls_var.set(0) + except Exception: + pass - yield key, payload - except (redis.RedisError, OSError) as exc: - logger.warning(f"Failed to scan credential tokens: {exc}") + def get_call_count(self) -> int: + try: + return int(self._redis_calls_var.get()) + except Exception: + return 0 token_store = TokenStore() diff --git a/app/services/user_profile.py b/app/services/user_profile.py index 2c0cbba..e5a0334 100644 --- a/app/services/user_profile.py +++ b/app/services/user_profile.py @@ -3,16 +3,20 @@ from app.models.profile import UserTasteProfile from app.models.scoring import ScoredItem -from app.services.tmdb_service import TMDBService +from app.services.tmdb_service import get_tmdb_service # TODO: Make these weights dynamic based on user's preferences. -GENRES_WEIGHT = 0.3 -KEYWORDS_WEIGHT = 0.40 -CAST_WEIGHT = 0.1 -CREW_WEIGHT = 0.1 +GENRES_WEIGHT = 0.20 +KEYWORDS_WEIGHT = 0.30 +CAST_WEIGHT = 0.12 +CREW_WEIGHT = 0.08 YEAR_WEIGHT = 0.05 COUNTRIES_WEIGHT = 0.05 -BASE_GENRE_WEIGHT = 0.15 +BASE_GENRE_WEIGHT = 0.05 +TOPICS_WEIGHT = 0.20 + +# Global constant to control size of user's top-genre whitelist used in filtering +TOP_GENRE_WHITELIST_LIMIT = 5 def emphasis(x: float) -> float: @@ -35,8 +39,8 @@ class UserProfileService: a single 'User Vector' representing their taste. """ - def __init__(self): - self.tmdb_service = TMDBService() + def __init__(self, language: str = "en-US"): + self.tmdb_service = get_tmdb_service(language=language) async def build_user_profile( self, @@ -56,6 +60,7 @@ async def build_user_profile( "crew": defaultdict(float), "years": defaultdict(float), "countries": defaultdict(float), + "topics": defaultdict(float), } async def _process(item): @@ -100,6 +105,7 @@ async def _process(item): crew={"values": dict(profile_data["crew"])}, years={"values": dict(profile_data["years"])}, countries={"values": dict(profile_data["countries"])}, + topics={"values": dict(profile_data["topics"])}, ) # Normalize all vectors to 0-1 range @@ -110,83 +116,142 @@ async def _process(item): def calculate_similarity(self, profile: UserTasteProfile, item_meta: dict) -> float: """ Final improved similarity scoring function. - Uses normalized sparse matching + rarity boosting + non-linear emphasis. + Simplified similarity: linear weighted sum across core dimensions. """ + item_vec = self._vectorize_item(item_meta) + # Linear weighted sum across selected dimensions + # For each dimension we average per-feature match to avoid bias from many features + def avg_pref(features, mapping): + if not features: + return 0.0 + s = 0.0 + for f in features: + s += mapping.get(f, 0.0) + return s / max(1, len(features)) + + g_score = avg_pref(item_vec.get("genres", []), profile.genres.values) * GENRES_WEIGHT + k_score = avg_pref(item_vec.get("keywords", []), profile.keywords.values) * KEYWORDS_WEIGHT + c_score = avg_pref(item_vec.get("cast", []), profile.cast.values) * CAST_WEIGHT + t_score = avg_pref(item_vec.get("topics", []), profile.topics.values) * TOPICS_WEIGHT + + # Optional extras with small weights + crew_score = avg_pref(item_vec.get("crew", []), profile.crew.values) * CREW_WEIGHT + country_score = avg_pref(item_vec.get("countries", []), profile.countries.values) * COUNTRIES_WEIGHT + year_val = item_vec.get("year") + year_score = 0.0 + if year_val is not None: + year_score = profile.years.values.get(year_val, 0.0) * YEAR_WEIGHT + + score = g_score + k_score + c_score + t_score + crew_score + country_score + year_score + + return float(score) + + def calculate_similarity_with_breakdown(self, profile: UserTasteProfile, item_meta: dict) -> tuple[float, dict]: + """ + Compute similarity and also return a per-dimension breakdown for logging/tuning. + Returns (score, breakdown_dict) + """ item_vec = self._vectorize_item(item_meta) - score = 0.0 - - # 1. GENRES - # Normalize so movies with many genres don't get excessive score. - for gid in item_vec["genres"]: - pref = profile.genres.values.get(gid, 0.0) - - if pref > 0: - s = emphasis(pref) - s = safe_div(s, len(item_vec["genres"])) - score += s * GENRES_WEIGHT - - # Soft prior bias (genre-only) - base_pref = profile.top_genres_normalized.get(gid, 0.0) - score += base_pref * BASE_GENRE_WEIGHT - - # 2. KEYWORDS - for kw in item_vec["keywords"]: - pref = profile.keywords.values.get(kw, 0.0) - - if pref > 0: - s = emphasis(pref) - s = safe_div(s, len(item_vec["keywords"])) - score += s * KEYWORDS_WEIGHT - - # 3. CAST - for cid in item_vec["cast"]: - pref = profile.cast.values.get(cid, 0.0) - - if pref > 0: - s = emphasis(pref) - s = safe_div(s, len(item_vec["cast"])) - score += s * CAST_WEIGHT - - # 4. CREW - for cr in item_vec["crew"]: - pref = profile.crew.values.get(cr, 0.0) - - if pref > 0: - s = emphasis(pref) - s = safe_div(s, len(item_vec["crew"])) - score += s * CREW_WEIGHT - - # 5. COUNTRIES - for c in item_vec["countries"]: - pref = profile.countries.values.get(c, 0.0) - - if pref > 0: - s = emphasis(pref) - s = safe_div(s, len(item_vec["countries"])) - score += s * COUNTRIES_WEIGHT - - # 6. YEAR/DECADE - # Reward matches on the user's preferred decades, with soft credit to adjacent decades. - item_year = item_vec.get("year") - if item_year is not None: - base_pref = profile.years.values.get(item_year, 0.0) - if base_pref > 0: - score += emphasis(base_pref) * YEAR_WEIGHT - else: - # Soft-match adjacent decades at half strength - prev_decade = item_year - 10 - next_decade = item_year + 10 - neighbor_pref = 0.0 - if prev_decade in profile.years.values: - neighbor_pref = max(neighbor_pref, profile.years.values.get(prev_decade, 0.0)) - if next_decade in profile.years.values: - neighbor_pref = max(neighbor_pref, profile.years.values.get(next_decade, 0.0)) - if neighbor_pref > 0: - score += emphasis(neighbor_pref) * (YEAR_WEIGHT * 0.5) - - return score + def avg_pref(features, mapping): + if not features: + return 0.0 + s = 0.0 + for f in features: + s += mapping.get(f, 0.0) + return s / max(1, len(features)) + + g_score = avg_pref(item_vec.get("genres", []), profile.genres.values) * GENRES_WEIGHT + k_score = avg_pref(item_vec.get("keywords", []), profile.keywords.values) * KEYWORDS_WEIGHT + c_score = avg_pref(item_vec.get("cast", []), profile.cast.values) * CAST_WEIGHT + t_score = avg_pref(item_vec.get("topics", []), profile.topics.values) * TOPICS_WEIGHT + crew_score = avg_pref(item_vec.get("crew", []), profile.crew.values) * CREW_WEIGHT + country_score = avg_pref(item_vec.get("countries", []), profile.countries.values) * COUNTRIES_WEIGHT + year_val = item_vec.get("year") + year_score = 0.0 + if year_val is not None: + year_score = profile.years.values.get(year_val, 0.0) * YEAR_WEIGHT + + score = g_score + k_score + c_score + t_score + crew_score + country_score + year_score + + breakdown = { + "genres": float(g_score), + "keywords": float(k_score), + "cast": float(c_score), + "topics": float(t_score), + "crew": float(crew_score), + "countries": float(country_score), + "year": float(year_score), + "total": float(score), + } + + return float(score), breakdown + + # ---------------- Super-simple overlap similarity ---------------- + @staticmethod + def _jaccard(a: set, b: set) -> float: + if not a and not b: + return 0.0 + if not a or not b: + return 0.0 + inter = len(a & b) + union = len(a | b) + if union == 0: + return 0.0 + return inter / union + + def calculate_simple_overlap_with_breakdown( + self, + profile: UserTasteProfile, + item_meta: dict, + *, + top_topic_tokens: int = 300, + top_genres: int = 20, + top_keyword_ids: int = 200, + ) -> tuple[float, dict]: + """ + Very simple, explainable similarity using plain set overlaps: + - Jaccard of token-level "topics" (title/overview/keyword-names tokens) + - Jaccard of genre ids + - Jaccard of TMDB keyword ids (optional, small weight) + + No embeddings; robust to partial-word matching via lightweight tokenization + and heuristic stemming in _tokenize(). + """ + # Preference sets from profile (take top-N by weight to reduce noise) + pref_topics_sorted = sorted(profile.topics.values.items(), key=lambda kv: kv[1], reverse=True) + pref_topic_tokens = {k for k, _ in pref_topics_sorted[:top_topic_tokens]} + + pref_genres_sorted = sorted(profile.genres.values.items(), key=lambda kv: kv[1], reverse=True) + pref_genres = {int(k) for k, _ in pref_genres_sorted[:top_genres]} + + pref_keywords_sorted = sorted(profile.keywords.values.items(), key=lambda kv: kv[1], reverse=True) + pref_keyword_ids = {int(k) for k, _ in pref_keywords_sorted[:top_keyword_ids]} + + # Item sets + vec = self._vectorize_item(item_meta) + item_topic_tokens = set(vec.get("topics") or []) + item_genres = {int(g) for g in (vec.get("genres") or [])} + item_keyword_ids = {int(k) for k in (vec.get("keywords") or [])} + + # Jaccard components + topics_j = self._jaccard(item_topic_tokens, pref_topic_tokens) + genres_j = self._jaccard(item_genres, pref_genres) + kw_j = self._jaccard(item_keyword_ids, pref_keyword_ids) + + # Simple weighted sum; emphasize token overlap + w_topics, w_genres, w_kw = 0.6, 0.25, 0.15 + score = (topics_j * w_topics) + (genres_j * w_genres) + (kw_j * w_kw) + + breakdown = { + "topics_jaccard": float(topics_j), + "genres_jaccard": float(genres_j), + "keywords_jaccard": float(kw_j), + "total": float(score), + } + + return float(score), breakdown def _vectorize_item(self, meta: dict) -> dict[str, list[int] | int | list[str] | None]: """ @@ -206,13 +271,32 @@ def _vectorize_item(self, meta: dict) -> dict[str, list[int] | int | list[str] | elif "origin_country" in meta: countries = meta.get("origin_country", []) + # genres: prefer explicit genre_ids; fallback to dict list if present + genre_ids = meta.get("genre_ids") or [] + if not genre_ids: + genres_src = meta.get("genres") or [] + if genres_src and isinstance(genres_src, list) and genres_src and isinstance(genres_src[0], dict): + genre_ids = [g.get("id") for g in genres_src if isinstance(g, dict) and g.get("id") is not None] + + # Build topics tokens from title/overview and keyword names + # Handle both our enriched meta format and raw TMDB payloads + title_text = meta.get("name") or meta.get("title") or meta.get("original_title") or "" + overview_text = meta.get("description") or meta.get("overview") or "" + kw_names = [k.get("name") for k in keywords if isinstance(k, dict) and k.get("name")] + topics_tokens: list[str] = [] + topics_tokens.extend(self._tokenize(title_text)) + topics_tokens.extend(self._tokenize(overview_text)) + for nm in kw_names: + topics_tokens.extend(self._tokenize(nm)) + vector = { - "genres": [g["id"] for g in meta.get("genres", [])], + "genres": genre_ids, "keywords": [k["id"] for k in keywords], "cast": [], "crew": [], "year": None, "countries": countries, + "topics": topics_tokens, } # Cast (Top 3 only to reduce noise) @@ -254,6 +338,7 @@ def _merge_vector( "crew": CREW_WEIGHT, "year": YEAR_WEIGHT, "countries": COUNTRIES_WEIGHT, + "topics": TOPICS_WEIGHT, } for dim, ids in item_vector.items(): @@ -269,6 +354,93 @@ def _merge_vector( continue profile[dim][feature_id] += final_weight + # ---------------- Tokenization helpers ---------------- + _STOPWORDS = { + "a", + "an", + "and", + "the", + "of", + "to", + "in", + "on", + "for", + "with", + "by", + "from", + "at", + "as", + "is", + "it", + "this", + "that", + "be", + "or", + "are", + "was", + "were", + "has", + "have", + "had", + "into", + "their", + "his", + "her", + "its", + "but", + "not", + "no", + "so", + "about", + "over", + "under", + "after", + "before", + "than", + "then", + "out", + "up", + "down", + "off", + "only", + "more", + "most", + "some", + "any", + } + + @staticmethod + def _normalize_token(tok: str) -> str: + t = tok.lower() + t = "".join(ch for ch in t if ch.isalnum()) + if len(t) <= 2: + return "" + for suf in ("ing", "ers", "ies", "ment", "tion", "s", "ed"): + if t.endswith(suf) and len(t) - len(suf) >= 3: + t = t[: -len(suf)] + break + return t + + def _tokenize(self, text: str) -> list[str]: + if not text: + return [] + raw = text.replace("-", " ").replace("_", " ") + tokens = [] + for part in raw.split(): + t = self._normalize_token(part) + if not t or t in self._STOPWORDS: + continue + tokens.append(t) + # de-duplicate while preserving order + seen = set() + dedup = [] + for t in tokens: + if t in seen: + continue + seen.add(t) + dedup.append(t) + return dedup + async def _fetch_full_metadata(self, tmdb_id: int, type_: str) -> dict | None: """Helper to fetch deep metadata.""" try: diff --git a/app/startup/migration.py b/app/startup/migration.py index 432df02..39f46d9 100644 --- a/app/startup/migration.py +++ b/app/startup/migration.py @@ -6,11 +6,10 @@ import httpx import redis.asyncio as redis from cryptography.fernet import Fernet -from cryptography.hazmat.primitives import hashes -from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC from loguru import logger from app.core.config import settings +from app.services.token_store import token_store def decrypt_data(enc_json: str): @@ -117,18 +116,9 @@ async def decode_old_payloads(encrypted_raw: str): return payload -def encrypt_auth_key(auth_key): - salt = b"x7FDf9kypzQ1LmR32b8hWv49sKq2Pd8T" - kdf = PBKDF2HMAC( - algorithm=hashes.SHA256(), - length=32, - salt=salt, - iterations=200_000, - ) - - key = base64.urlsafe_b64encode(kdf.derive(settings.TOKEN_SALT.encode("utf-8"))) - client = Fernet(key) - return client.encrypt(auth_key.encode("utf-8")).decode("utf-8") +def encrypt_auth_key(auth_key: str) -> str: + # Delegate to TokenStore to keep encryption consistent everywhere + return token_store.encrypt_token(auth_key) def prepare_default_payload(email, user_id): @@ -157,7 +147,7 @@ async def store_payload(client: redis.Redis, email: str, user_id: str, auth_key: # encrypt auth_key if auth_key: payload["authKey"] = encrypt_auth_key(auth_key) - key = user_id.strip() + key = f"{settings.REDIS_TOKEN_KEY}{user_id.strip()}" await client.set(key, json.dumps(payload)) except (redis.RedisError, OSError) as exc: logger.warning(f"Failed to store payload for {key}: {exc}") @@ -200,7 +190,7 @@ async def process_migration_key(redis_client: redis.Redis, key: str) -> bool: if auth_key: new_payload["authKey"] = encrypt_auth_key(auth_key) - new_key = user_id.strip() + new_key = f"{settings.REDIS_TOKEN_KEY}{user_id.strip()}" payload_json = json.dumps(new_payload) if settings.TOKEN_TTL_SECONDS and settings.TOKEN_TTL_SECONDS > 0: @@ -233,7 +223,7 @@ async def migrate_tokens(): failed_tokens = 0 success_tokens = 0 try: - redis_client = redis.from_url(settings.REDIS_URL, decode_responses=True, encoding="utf-8") + redis_client = await token_store._get_client() except (redis.RedisError, OSError) as exc: logger.warning(f"Failed to connect to Redis: {exc}") return