forked from CCRI-POPROX/poprox-recommender
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
News Locality Calibration (CCRI-POPROX#103)
Apply some refactoring by creating a `Calibrator` object that both `TopicCalibration` and the new `LocalityCalibration` inherit from. Updated current `test_calibration` test suites. Some TODOs we need to figure out for next step: - [x] Add test for current logic - [ ] Tune the parameter of calibration (The current params are borrowed from topic calibration) - [ ] Run the code on an S3 instance and call necessary endpoints from POPROX - [ ] Figure out how to connect the logic to participant selection (who will be included in this experiment?) - [ ] Integrate the LLM context generation into calibrated articles
- Loading branch information
1 parent
114bc34
commit 2fd4b13
Showing
10 changed files
with
258 additions
and
86 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,7 @@ | ||
from poprox_recommender.components.diversifiers.calibration import Calibrator | ||
from poprox_recommender.components.diversifiers.locality_calibration import LocalityCalibrator | ||
from poprox_recommender.components.diversifiers.mmr import MMRDiversifier | ||
from poprox_recommender.components.diversifiers.pfar import PFARDiversifier | ||
from poprox_recommender.components.diversifiers.topic_calibration import TopicCalibrator | ||
|
||
__all__ = ["MMRDiversifier", "PFARDiversifier", "TopicCalibrator"] | ||
__all__ = ["MMRDiversifier", "PFARDiversifier", "Calibrator", "TopicCalibrator", "LocalityCalibrator"] |
77 changes: 77 additions & 0 deletions
77
src/poprox_recommender/components/diversifiers/calibration.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
from collections import defaultdict | ||
|
||
import numpy as np | ||
|
||
from poprox_concepts import Article | ||
from poprox_recommender.lkpipeline import Component | ||
from poprox_recommender.topics import normalized_category_count | ||
|
||
|
||
# General calibration uses MMR | ||
# to rerank recommendations according to | ||
# certain calibration context (e.g. news topic, locality) | ||
class Calibrator(Component): | ||
def __init__(self, theta: float = 0.1, num_slots=10): | ||
# Theta term controls the score and calibration tradeoff, the higher | ||
# the theta the higher the resulting recommendation will be calibrated. | ||
self.theta = theta | ||
self.num_slots = num_slots | ||
|
||
def __call__(): | ||
pass | ||
|
||
def add_article_to_categories(self, rec_categories_with_candidate, article): | ||
pass | ||
|
||
def normalized_categories_with_candidate(self, rec_categories, article): | ||
rec_categories_with_candidate = rec_categories.copy() | ||
self.add_article_to_categories(rec_categories_with_candidate, article) | ||
return normalized_category_count(rec_categories_with_candidate) | ||
|
||
def calibration(self, relevance_scores, articles, preferences, theta, topk) -> list[Article]: | ||
# MR_i = \theta * reward_i - (1 - \theta)*C(S + i) # C is calibration | ||
# R is all candidates (not selected yet) | ||
|
||
recommendations = [] # final recommendation (topk index) | ||
rec_categories = defaultdict(int) # frequency distribution of categories of S | ||
|
||
for _ in range(topk): | ||
candidate = None # next item | ||
best_candidate_score = float("-inf") | ||
|
||
for article_idx, article_score in enumerate(relevance_scores): # iterate R for next item | ||
if article_idx in recommendations: | ||
continue | ||
|
||
normalized_candidate_topics = self.normalized_categories_with_candidate( | ||
rec_categories, articles[article_idx] | ||
) | ||
calibration = compute_kl_divergence(preferences, normalized_candidate_topics) | ||
|
||
adjusted_candidate_score = (1 - theta) * article_score - (theta * calibration) | ||
if adjusted_candidate_score > best_candidate_score: | ||
best_candidate_score = adjusted_candidate_score | ||
candidate = article_idx | ||
|
||
if candidate is not None: | ||
recommendations.append(candidate) | ||
self.add_article_to_categories(rec_categories, articles[candidate]) | ||
|
||
return recommendations | ||
|
||
|
||
# from https://github.com/CCRI-POPROX/poprox-recommender/blob/feature/experiment0/tests/test_calibration.ipynb | ||
def compute_kl_divergence(interacted_distr, reco_distr, kl_div=0.0, alpha=0.01): | ||
""" | ||
KL (p || q), the lower the better. | ||
alpha is not really a tuning parameter, it's just there to make the | ||
computation more numerically stable. | ||
""" | ||
for category, score in interacted_distr.items(): | ||
reco_score = reco_distr.get(category, 0.0) | ||
reco_score = (1 - alpha) * reco_score + alpha * score | ||
if reco_score != 0.0: | ||
kl_div += score * np.log2(score / reco_score) | ||
|
||
return kl_div |
37 changes: 37 additions & 0 deletions
37
src/poprox_recommender/components/diversifiers/locality_calibration.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import torch as th | ||
|
||
from poprox_concepts import ArticleSet, InterestProfile | ||
from poprox_recommender.components.diversifiers.calibration import Calibrator | ||
from poprox_recommender.topics import extract_locality, normalized_category_count | ||
|
||
|
||
# Locality Calibration uses MMR | ||
# to rerank recommendations according to | ||
# locality calibration | ||
class LocalityCalibrator(Calibrator): | ||
def __init__(self, theta: float = 0.1, num_slots=10): | ||
super().__init__(theta, num_slots) | ||
|
||
def __call__(self, candidate_articles: ArticleSet, interest_profile: InterestProfile) -> ArticleSet: | ||
normalized_locality_prefs = normalized_category_count(interest_profile.click_locality_counts) | ||
|
||
if candidate_articles.scores is not None: | ||
article_scores = th.sigmoid(th.tensor(candidate_articles.scores)) | ||
else: | ||
article_scores = th.zeros(len(candidate_articles.articles)) | ||
|
||
article_scores = article_scores.cpu().detach().numpy() | ||
|
||
article_indices = self.calibration( | ||
article_scores, | ||
candidate_articles.articles, | ||
normalized_locality_prefs, | ||
self.theta, | ||
topk=self.num_slots, | ||
) | ||
return ArticleSet(articles=[candidate_articles.articles[int(idx)] for idx in article_indices]) | ||
|
||
def add_article_to_categories(self, rec_categories, article): | ||
locality_list = extract_locality(article) | ||
for locality in locality_list: | ||
rec_categories[locality] = rec_categories.get(locality, 0) + 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.