Skip to content

Commit

Permalink
migrate to evaldata
Browse files Browse the repository at this point in the history
  • Loading branch information
Zentavious committed Dec 3, 2024
1 parent 6c619f1 commit baf92b7
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/poprox_recommender/evaluation/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from poprox_concepts.api.recommendations import RecommendationRequest
from poprox_concepts.domain import ArticleSet
from poprox_recommender.config import default_device
from poprox_recommender.data.data import Data
from poprox_recommender.data.eval import EvalData
from poprox_recommender.data.mind import TEST_REC_COUNT, MindData
from poprox_recommender.data.poprox import PoproxData
from poprox_recommender.lkpipeline import PipelineState
Expand Down Expand Up @@ -108,7 +108,7 @@ def extract_recs(
return output_df


def generate_user_recs(data: Data, pipe_names: list[str] | None = None, n_users: int | None = None):
def generate_user_recs(data: EvalData, pipe_names: list[str] | None = None, n_users: int | None = None):
pipelines = recommendation_pipelines(device=default_device())
if pipe_names is not None:
pipelines = {name: pipelines[name] for name in pipe_names} # type: ignore
Expand All @@ -118,9 +118,9 @@ def generate_user_recs(data: Data, pipe_names: list[str] | None = None, n_users:
logger.info("generating recommendations")
user_recs = []

user_iter = data.iter_users()
user_iter = data.iter_profiles()
if n_users is None:
n_users = data.n_users
n_users = data.n_profiles
logger.info("recommending for all %d users", n_users)
else:
logger.info("running on subset of %d users", n_users)
Expand Down

0 comments on commit baf92b7

Please sign in to comment.