Skip to content
This repository has been archived by the owner on May 9, 2022. It is now read-only.

Ty 1686 analytics initial rank #55

Merged
merged 2 commits into from
Apr 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 14 additions & 11 deletions xayn-ai/src/analytics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,18 +53,22 @@ impl systems::AnalyticsSystem for AnalyticsSystem {

let mut paired_ltr_scores = Vec::new();
let mut paired_context_scores = Vec::new();
let mut paired_final_ranking_score = Vec::new();
let mut paired_final_ranking_scores = Vec::new();
let mut paired_initial_ranking_scores = Vec::new();

for document in documents {
if let Some(relevance) = relevance_lookups.get(&document.document_id.id).copied() {
if let Some(relevance) = relevance_lookups.get(&document.document_base.id).copied() {
paired_ltr_scores.push((relevance, document.ltr.ltr_score));
paired_context_scores.push((relevance, document.context.context_value));

// nDCG expects higher scores to be better but for the ranking
// it's the oposite, the solution carried over from the dart impl
// is to multiply by -1.
let final_ranking_desc = -(document.mab.rank as f32);
paired_final_ranking_score.push((relevance, final_ranking_desc));
paired_final_ranking_scores.push((relevance, final_ranking_desc));

let intial_ranking_desc = -(document.document_base.initial_ranking as f32);
paired_initial_ranking_scores.push((relevance, intial_ranking_desc));
}
}

Expand All @@ -78,13 +82,13 @@ impl systems::AnalyticsSystem for AnalyticsSystem {
calcuate_reordered_ndcg_at_k_score(&mut paired_context_scores, DEFAULT_NDCG_K);

let ndcg_final_ranking =
calcuate_reordered_ndcg_at_k_score(&mut paired_final_ranking_score, DEFAULT_NDCG_K);
calcuate_reordered_ndcg_at_k_score(&mut paired_final_ranking_scores, DEFAULT_NDCG_K);

let ndcg_initial_ranking =
calcuate_reordered_ndcg_at_k_score(&mut paired_initial_ranking_scores, DEFAULT_NDCG_K);

Ok(Analytics {
//FIXME: We currently have no access to the initial score as thiss will require
// some changes to the main applications type state/component system this
// will be done in a followup PR.
ndcg_initial_ranking: f32::NAN,
ndcg_initial_ranking,
ndcg_ltr,
ndcg_context,
ndcg_final_ranking,
Expand Down Expand Up @@ -215,16 +219,15 @@ mod tests {
let Analytics {
ndcg_ltr,
ndcg_context,
ndcg_initial_ranking: _,
ndcg_initial_ranking,
ndcg_final_ranking,
} = AnalyticsSystem
.compute_analytics(&history, &documents)
.unwrap();

assert_f32_eq!(ndcg_ltr, 0.173_765_35);
assert_f32_eq!(ndcg_context, 0.826_234_64);
//FIXME: Currently not possible as `ndcg_initial_ranking` is not yet computed
// assert!(approx_eq!(f32, ndcg_initial_ranking, 0.7967075809905066, ulps = 2));
assert_f32_eq!(ndcg_initial_ranking, 0.796_707_6);
assert_f32_eq!(ndcg_final_ranking, 1.0);
}

Expand Down
5 changes: 3 additions & 2 deletions xayn-ai/src/coi/system.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,8 @@ mod tests {
document::{DocumentId, Relevance, UserFeedback},
document_data::{
ContextComponent,
DocumentBaseComponent,
DocumentDataWithMab,
DocumentIdComponent,
EmbeddingComponent,
LtrComponent,
MabComponent,
Expand All @@ -231,8 +231,9 @@ mod tests {
.iter()
.enumerate()
.map(|(id, embedding)| DocumentDataWithMab {
document_id: DocumentIdComponent {
document_base: DocumentBaseComponent {
id: DocumentId(id.to_string()),
initial_ranking: id,
},
embedding: EmbeddingComponent {
embedding: arr1(embedding.as_init_slice()).into(),
Expand Down
10 changes: 6 additions & 4 deletions xayn-ai/src/coi/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,8 @@ pub(super) mod tests {
data::{
document_data::{
CoiComponent,
DocumentBaseComponent,
DocumentDataWithEmbedding,
DocumentIdComponent,
EmbeddingComponent,
},
CoiId,
Expand Down Expand Up @@ -212,17 +212,19 @@ pub(super) mod tests {
embeddings
.iter()
.enumerate()
.map(|(id, embedding)| create_data_with_embedding(id, embedding.as_init_slice()))
.map(|(id, embedding)| create_data_with_embedding(id, id, embedding.as_init_slice()))
.collect()
}

pub(crate) fn create_data_with_embedding(
id: usize,
initial_ranking: usize,
embedding: &[f32],
) -> DocumentDataWithEmbedding {
DocumentDataWithEmbedding {
document_id: DocumentIdComponent {
document_base: DocumentBaseComponent {
id: DocumentId(id.to_string()),
initial_ranking,
},
embedding: EmbeddingComponent {
embedding: arr1(embedding).into(),
Expand Down Expand Up @@ -416,7 +418,7 @@ pub(super) mod tests {
]);

let mut documents = create_data_with_embeddings(&[[1., 2., 3.], [3., 2., 1.]]);
documents.push(create_data_with_embedding(5, &[4., 5., 6.]));
documents.push(create_data_with_embedding(5, 0, &[4., 5., 6.]));
let documents = to_vec_of_ref_of!(documents, &dyn CoiSystemData);

let matching_documents = collect_matching_documents(&history, &documents);
Expand Down
7 changes: 5 additions & 2 deletions xayn-ai/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ mod tests {
use super::*;
use crate::data::{
document::DocumentId,
document_data::{CoiComponent, DocumentIdComponent, EmbeddingComponent, LtrComponent},
document_data::{CoiComponent, DocumentBaseComponent, EmbeddingComponent, LtrComponent},
CoiId,
};

Expand All @@ -80,7 +80,10 @@ mod tests {
let embedding = arr1::<f32>(&[]).into();

self.docs.push(DocumentDataWithLtr {
document_id: DocumentIdComponent { id },
document_base: DocumentBaseComponent {
id,
initial_ranking: 13,
},
embedding: EmbeddingComponent { embedding },
coi: CoiComponent {
id: CoiId(0),
Expand Down
46 changes: 24 additions & 22 deletions xayn-ai/src/data/document_data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@ use crate::{

#[cfg_attr(test, derive(Debug, PartialEq, Clone))]
#[derive(Serialize, Deserialize)]
pub(crate) struct DocumentIdComponent {
pub(crate) struct DocumentBaseComponent {
pub(crate) id: DocumentId,
pub(crate) initial_ranking: usize,
}

#[cfg_attr(test, derive(Debug, PartialEq, Clone))]
Expand Down Expand Up @@ -59,14 +60,14 @@ pub(crate) struct MabComponent {
// DocumentDataWithLtr -> DocumentDataWithContext -> DocumentDataWithMab

pub(crate) struct DocumentDataWithDocument {
pub(crate) document_id: DocumentIdComponent,
pub(crate) document_base: DocumentBaseComponent,
pub(crate) document_content: DocumentContentComponent,
}

#[cfg_attr(test, derive(Debug, PartialEq, Clone))]
#[derive(Serialize, Deserialize)]
pub(crate) struct DocumentDataWithEmbedding {
pub(crate) document_id: DocumentIdComponent,
pub(crate) document_base: DocumentBaseComponent,
pub(crate) embedding: EmbeddingComponent,
}

Expand All @@ -76,15 +77,15 @@ impl DocumentDataWithEmbedding {
embedding: EmbeddingComponent,
) -> Self {
Self {
document_id: document.document_id,
document_base: document.document_base,
embedding,
}
}
}

impl CoiSystemData for DocumentDataWithEmbedding {
fn id(&self) -> &DocumentId {
&self.document_id.id
&self.document_base.id
}

fn embedding(&self) -> &EmbeddingComponent {
Expand All @@ -97,15 +98,15 @@ impl CoiSystemData for DocumentDataWithEmbedding {
}

pub(crate) struct DocumentDataWithCoi {
pub(crate) document_id: DocumentIdComponent,
pub(crate) document_base: DocumentBaseComponent,
pub(crate) embedding: EmbeddingComponent,
pub(crate) coi: CoiComponent,
}

impl DocumentDataWithCoi {
pub(crate) fn from_document(document: DocumentDataWithEmbedding, coi: CoiComponent) -> Self {
Self {
document_id: document.document_id,
document_base: document.document_base,
embedding: document.embedding,
coi,
}
Expand All @@ -114,7 +115,7 @@ impl DocumentDataWithCoi {

#[cfg_attr(test, derive(Debug))]
pub(crate) struct DocumentDataWithLtr {
pub(crate) document_id: DocumentIdComponent,
pub(crate) document_base: DocumentBaseComponent,
pub(crate) embedding: EmbeddingComponent,
pub(crate) coi: CoiComponent,
pub(crate) ltr: LtrComponent,
Expand All @@ -123,7 +124,7 @@ pub(crate) struct DocumentDataWithLtr {
impl DocumentDataWithLtr {
pub(crate) fn from_document(document: DocumentDataWithCoi, ltr: LtrComponent) -> Self {
Self {
document_id: document.document_id,
document_base: document.document_base,
embedding: document.embedding,
coi: document.coi,
ltr,
Expand All @@ -133,7 +134,7 @@ impl DocumentDataWithLtr {

#[cfg_attr(test, derive(Debug, Clone))]
pub(crate) struct DocumentDataWithContext {
pub(crate) document_id: DocumentIdComponent,
pub(crate) document_base: DocumentBaseComponent,
pub(crate) embedding: EmbeddingComponent,
pub(crate) coi: CoiComponent,
pub(crate) ltr: LtrComponent,
Expand All @@ -143,7 +144,7 @@ pub(crate) struct DocumentDataWithContext {
impl DocumentDataWithContext {
pub(crate) fn from_document(document: DocumentDataWithLtr, context: ContextComponent) -> Self {
Self {
document_id: document.document_id,
document_base: document.document_base,
embedding: document.embedding,
coi: document.coi,
ltr: document.ltr,
Expand All @@ -155,7 +156,7 @@ impl DocumentDataWithContext {
#[cfg_attr(test, derive(Clone, Debug, PartialEq))]
#[derive(Serialize, Deserialize)]
pub(crate) struct DocumentDataWithMab {
pub(crate) document_id: DocumentIdComponent,
pub(crate) document_base: DocumentBaseComponent,
pub(crate) embedding: EmbeddingComponent,
pub(crate) coi: CoiComponent,
pub(crate) ltr: LtrComponent,
Expand All @@ -166,7 +167,7 @@ pub(crate) struct DocumentDataWithMab {
impl DocumentDataWithMab {
pub(crate) fn from_document(document: DocumentDataWithContext, mab: MabComponent) -> Self {
Self {
document_id: document.document_id,
document_base: document.document_base,
embedding: document.embedding,
coi: document.coi,
ltr: document.ltr,
Expand All @@ -178,7 +179,7 @@ impl DocumentDataWithMab {

impl CoiSystemData for DocumentDataWithMab {
fn id(&self) -> &DocumentId {
&self.document_id.id
&self.document_base.id
}

fn embedding(&self) -> &EmbeddingComponent {
Expand All @@ -197,25 +198,26 @@ mod tests {

#[test]
fn transition_and_get() {
let document_id = DocumentIdComponent {
let document_id = DocumentBaseComponent {
id: DocumentId("id".to_string()),
initial_ranking: 23,
};
let document_content = DocumentContentComponent {
snippet: "snippet".to_string(),
};
let document_data = DocumentDataWithDocument {
document_id: document_id.clone(),
document_base: document_id.clone(),
document_content: document_content.clone(),
};
assert_eq!(document_data.document_id, document_id);
assert_eq!(document_data.document_base, document_id);
assert_eq!(document_data.document_content, document_content);

let embedding = EmbeddingComponent {
embedding: arr1(&[1., 2., 3., 4.]).into(),
};
let document_data =
DocumentDataWithEmbedding::from_document(document_data, embedding.clone());
assert_eq!(document_data.document_id, document_id);
assert_eq!(document_data.document_base, document_id);
assert_eq!(document_data.embedding, embedding);

let coi = CoiComponent {
Expand All @@ -224,13 +226,13 @@ mod tests {
neg_distance: 0.2,
};
let document_data = DocumentDataWithCoi::from_document(document_data, coi.clone());
assert_eq!(document_data.document_id, document_id);
assert_eq!(document_data.document_base, document_id);
assert_eq!(document_data.embedding, embedding);
assert_eq!(document_data.coi, coi);

let ltr = LtrComponent { ltr_score: 0.3 };
let document_data = DocumentDataWithLtr::from_document(document_data, ltr.clone());
assert_eq!(document_data.document_id, document_id);
assert_eq!(document_data.document_base, document_id);
assert_eq!(document_data.embedding, embedding);
assert_eq!(document_data.coi, coi);
assert_eq!(document_data.ltr, ltr);
Expand All @@ -239,15 +241,15 @@ mod tests {
context_value: 1.23,
};
let document_data = DocumentDataWithContext::from_document(document_data, context.clone());
assert_eq!(document_data.document_id, document_id);
assert_eq!(document_data.document_base, document_id);
assert_eq!(document_data.embedding, embedding);
assert_eq!(document_data.coi, coi);
assert_eq!(document_data.ltr, ltr);
assert_eq!(document_data.context, context);

let mab = MabComponent { rank: 3 };
let document_data = DocumentDataWithMab::from_document(document_data, mab.clone());
assert_eq!(document_data.document_id, document_id);
assert_eq!(document_data.document_base, document_id);
assert_eq!(document_data.embedding, embedding);
assert_eq!(document_data.coi, coi);
assert_eq!(document_data.ltr, ltr);
Expand Down
12 changes: 9 additions & 3 deletions xayn-ai/src/ltr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ mod tests {
use super::*;
use crate::data::{
document::DocumentId,
document_data::{CoiComponent, DocumentIdComponent, EmbeddingComponent},
document_data::{CoiComponent, DocumentBaseComponent, EmbeddingComponent},
CoiId,
};

Expand All @@ -59,7 +59,10 @@ mod tests {
neg_distance: 0.2,
};
let doc1 = DocumentDataWithCoi {
document_id: DocumentIdComponent { id },
document_base: DocumentBaseComponent {
id,
initial_ranking: 24,
},
embedding: EmbeddingComponent { embedding },
coi,
};
Expand All @@ -72,7 +75,10 @@ mod tests {
neg_distance: 0.9,
};
let doc2 = DocumentDataWithCoi {
document_id: DocumentIdComponent { id },
document_base: DocumentBaseComponent {
id,
initial_ranking: 42,
},
embedding: EmbeddingComponent { embedding },
coi,
};
Expand Down
Loading