diff --git a/xayn-ai/src/analytics.rs b/xayn-ai/src/analytics.rs index d9fafb092..2beaa791a 100644 --- a/xayn-ai/src/analytics.rs +++ b/xayn-ai/src/analytics.rs @@ -53,10 +53,11 @@ impl systems::AnalyticsSystem for AnalyticsSystem { let mut paired_ltr_scores = Vec::new(); let mut paired_context_scores = Vec::new(); - let mut paired_final_ranking_score = Vec::new(); + let mut paired_final_ranking_scores = Vec::new(); + let mut paired_initial_ranking_scores = Vec::new(); for document in documents { - if let Some(relevance) = relevance_lookups.get(&document.document_id.id).copied() { + if let Some(relevance) = relevance_lookups.get(&document.document_base.id).copied() { paired_ltr_scores.push((relevance, document.ltr.ltr_score)); paired_context_scores.push((relevance, document.context.context_value)); @@ -64,7 +65,10 @@ impl systems::AnalyticsSystem for AnalyticsSystem { // it's the oposite, the solution carried over from the dart impl // is to multiply by -1. let final_ranking_desc = -(document.mab.rank as f32); - paired_final_ranking_score.push((relevance, final_ranking_desc)); + paired_final_ranking_scores.push((relevance, final_ranking_desc)); + + let intial_ranking_desc = -(document.document_base.initial_ranking as f32); + paired_initial_ranking_scores.push((relevance, intial_ranking_desc)); } } @@ -78,13 +82,13 @@ impl systems::AnalyticsSystem for AnalyticsSystem { calcuate_reordered_ndcg_at_k_score(&mut paired_context_scores, DEFAULT_NDCG_K); let ndcg_final_ranking = - calcuate_reordered_ndcg_at_k_score(&mut paired_final_ranking_score, DEFAULT_NDCG_K); + calcuate_reordered_ndcg_at_k_score(&mut paired_final_ranking_scores, DEFAULT_NDCG_K); + + let ndcg_initial_ranking = + calcuate_reordered_ndcg_at_k_score(&mut paired_initial_ranking_scores, DEFAULT_NDCG_K); Ok(Analytics { - //FIXME: We currently have no access to the initial score as thiss will require - // some changes to the main applications type state/component system this - // will be done in a followup PR. - ndcg_initial_ranking: f32::NAN, + ndcg_initial_ranking, ndcg_ltr, ndcg_context, ndcg_final_ranking, @@ -215,7 +219,7 @@ mod tests { let Analytics { ndcg_ltr, ndcg_context, - ndcg_initial_ranking: _, + ndcg_initial_ranking, ndcg_final_ranking, } = AnalyticsSystem .compute_analytics(&history, &documents) @@ -223,8 +227,7 @@ mod tests { assert_f32_eq!(ndcg_ltr, 0.173_765_35); assert_f32_eq!(ndcg_context, 0.826_234_64); - //FIXME: Currently not possible as `ndcg_initial_ranking` is not yet computed - // assert!(approx_eq!(f32, ndcg_initial_ranking, 0.7967075809905066, ulps = 2)); + assert_f32_eq!(ndcg_initial_ranking, 0.796_707_6); assert_f32_eq!(ndcg_final_ranking, 1.0); } diff --git a/xayn-ai/src/coi/system.rs b/xayn-ai/src/coi/system.rs index e5fbb862f..8e203f9fa 100644 --- a/xayn-ai/src/coi/system.rs +++ b/xayn-ai/src/coi/system.rs @@ -211,8 +211,8 @@ mod tests { document::{DocumentId, Relevance, UserFeedback}, document_data::{ ContextComponent, + DocumentBaseComponent, DocumentDataWithMab, - DocumentIdComponent, EmbeddingComponent, LtrComponent, MabComponent, @@ -231,8 +231,9 @@ mod tests { .iter() .enumerate() .map(|(id, embedding)| DocumentDataWithMab { - document_id: DocumentIdComponent { + document_base: DocumentBaseComponent { id: DocumentId(id.to_string()), + initial_ranking: id, }, embedding: EmbeddingComponent { embedding: arr1(embedding.as_init_slice()).into(), diff --git a/xayn-ai/src/coi/utils.rs b/xayn-ai/src/coi/utils.rs index 5e85897d6..f5e262a99 100644 --- a/xayn-ai/src/coi/utils.rs +++ b/xayn-ai/src/coi/utils.rs @@ -139,8 +139,8 @@ pub(super) mod tests { data::{ document_data::{ CoiComponent, + DocumentBaseComponent, DocumentDataWithEmbedding, - DocumentIdComponent, EmbeddingComponent, }, CoiId, @@ -212,17 +212,19 @@ pub(super) mod tests { embeddings .iter() .enumerate() - .map(|(id, embedding)| create_data_with_embedding(id, embedding.as_init_slice())) + .map(|(id, embedding)| create_data_with_embedding(id, id, embedding.as_init_slice())) .collect() } pub(crate) fn create_data_with_embedding( id: usize, + initial_ranking: usize, embedding: &[f32], ) -> DocumentDataWithEmbedding { DocumentDataWithEmbedding { - document_id: DocumentIdComponent { + document_base: DocumentBaseComponent { id: DocumentId(id.to_string()), + initial_ranking, }, embedding: EmbeddingComponent { embedding: arr1(embedding).into(), @@ -416,7 +418,7 @@ pub(super) mod tests { ]); let mut documents = create_data_with_embeddings(&[[1., 2., 3.], [3., 2., 1.]]); - documents.push(create_data_with_embedding(5, &[4., 5., 6.])); + documents.push(create_data_with_embedding(5, 0, &[4., 5., 6.])); let documents = to_vec_of_ref_of!(documents, &dyn CoiSystemData); let matching_documents = collect_matching_documents(&history, &documents); diff --git a/xayn-ai/src/context.rs b/xayn-ai/src/context.rs index e7341871f..61513143d 100644 --- a/xayn-ai/src/context.rs +++ b/xayn-ai/src/context.rs @@ -62,7 +62,7 @@ mod tests { use super::*; use crate::data::{ document::DocumentId, - document_data::{CoiComponent, DocumentIdComponent, EmbeddingComponent, LtrComponent}, + document_data::{CoiComponent, DocumentBaseComponent, EmbeddingComponent, LtrComponent}, CoiId, }; @@ -80,7 +80,10 @@ mod tests { let embedding = arr1::(&[]).into(); self.docs.push(DocumentDataWithLtr { - document_id: DocumentIdComponent { id }, + document_base: DocumentBaseComponent { + id, + initial_ranking: 13, + }, embedding: EmbeddingComponent { embedding }, coi: CoiComponent { id: CoiId(0), diff --git a/xayn-ai/src/data/document_data.rs b/xayn-ai/src/data/document_data.rs index e96a51b39..c74f9f8bf 100644 --- a/xayn-ai/src/data/document_data.rs +++ b/xayn-ai/src/data/document_data.rs @@ -8,8 +8,9 @@ use crate::{ #[cfg_attr(test, derive(Debug, PartialEq, Clone))] #[derive(Serialize, Deserialize)] -pub(crate) struct DocumentIdComponent { +pub(crate) struct DocumentBaseComponent { pub(crate) id: DocumentId, + pub(crate) initial_ranking: usize, } #[cfg_attr(test, derive(Debug, PartialEq, Clone))] @@ -59,14 +60,14 @@ pub(crate) struct MabComponent { // DocumentDataWithLtr -> DocumentDataWithContext -> DocumentDataWithMab pub(crate) struct DocumentDataWithDocument { - pub(crate) document_id: DocumentIdComponent, + pub(crate) document_base: DocumentBaseComponent, pub(crate) document_content: DocumentContentComponent, } #[cfg_attr(test, derive(Debug, PartialEq, Clone))] #[derive(Serialize, Deserialize)] pub(crate) struct DocumentDataWithEmbedding { - pub(crate) document_id: DocumentIdComponent, + pub(crate) document_base: DocumentBaseComponent, pub(crate) embedding: EmbeddingComponent, } @@ -76,7 +77,7 @@ impl DocumentDataWithEmbedding { embedding: EmbeddingComponent, ) -> Self { Self { - document_id: document.document_id, + document_base: document.document_base, embedding, } } @@ -84,7 +85,7 @@ impl DocumentDataWithEmbedding { impl CoiSystemData for DocumentDataWithEmbedding { fn id(&self) -> &DocumentId { - &self.document_id.id + &self.document_base.id } fn embedding(&self) -> &EmbeddingComponent { @@ -97,7 +98,7 @@ impl CoiSystemData for DocumentDataWithEmbedding { } pub(crate) struct DocumentDataWithCoi { - pub(crate) document_id: DocumentIdComponent, + pub(crate) document_base: DocumentBaseComponent, pub(crate) embedding: EmbeddingComponent, pub(crate) coi: CoiComponent, } @@ -105,7 +106,7 @@ pub(crate) struct DocumentDataWithCoi { impl DocumentDataWithCoi { pub(crate) fn from_document(document: DocumentDataWithEmbedding, coi: CoiComponent) -> Self { Self { - document_id: document.document_id, + document_base: document.document_base, embedding: document.embedding, coi, } @@ -114,7 +115,7 @@ impl DocumentDataWithCoi { #[cfg_attr(test, derive(Debug))] pub(crate) struct DocumentDataWithLtr { - pub(crate) document_id: DocumentIdComponent, + pub(crate) document_base: DocumentBaseComponent, pub(crate) embedding: EmbeddingComponent, pub(crate) coi: CoiComponent, pub(crate) ltr: LtrComponent, @@ -123,7 +124,7 @@ pub(crate) struct DocumentDataWithLtr { impl DocumentDataWithLtr { pub(crate) fn from_document(document: DocumentDataWithCoi, ltr: LtrComponent) -> Self { Self { - document_id: document.document_id, + document_base: document.document_base, embedding: document.embedding, coi: document.coi, ltr, @@ -133,7 +134,7 @@ impl DocumentDataWithLtr { #[cfg_attr(test, derive(Debug, Clone))] pub(crate) struct DocumentDataWithContext { - pub(crate) document_id: DocumentIdComponent, + pub(crate) document_base: DocumentBaseComponent, pub(crate) embedding: EmbeddingComponent, pub(crate) coi: CoiComponent, pub(crate) ltr: LtrComponent, @@ -143,7 +144,7 @@ pub(crate) struct DocumentDataWithContext { impl DocumentDataWithContext { pub(crate) fn from_document(document: DocumentDataWithLtr, context: ContextComponent) -> Self { Self { - document_id: document.document_id, + document_base: document.document_base, embedding: document.embedding, coi: document.coi, ltr: document.ltr, @@ -155,7 +156,7 @@ impl DocumentDataWithContext { #[cfg_attr(test, derive(Clone, Debug, PartialEq))] #[derive(Serialize, Deserialize)] pub(crate) struct DocumentDataWithMab { - pub(crate) document_id: DocumentIdComponent, + pub(crate) document_base: DocumentBaseComponent, pub(crate) embedding: EmbeddingComponent, pub(crate) coi: CoiComponent, pub(crate) ltr: LtrComponent, @@ -166,7 +167,7 @@ pub(crate) struct DocumentDataWithMab { impl DocumentDataWithMab { pub(crate) fn from_document(document: DocumentDataWithContext, mab: MabComponent) -> Self { Self { - document_id: document.document_id, + document_base: document.document_base, embedding: document.embedding, coi: document.coi, ltr: document.ltr, @@ -178,7 +179,7 @@ impl DocumentDataWithMab { impl CoiSystemData for DocumentDataWithMab { fn id(&self) -> &DocumentId { - &self.document_id.id + &self.document_base.id } fn embedding(&self) -> &EmbeddingComponent { @@ -197,17 +198,18 @@ mod tests { #[test] fn transition_and_get() { - let document_id = DocumentIdComponent { + let document_id = DocumentBaseComponent { id: DocumentId("id".to_string()), + initial_ranking: 23, }; let document_content = DocumentContentComponent { snippet: "snippet".to_string(), }; let document_data = DocumentDataWithDocument { - document_id: document_id.clone(), + document_base: document_id.clone(), document_content: document_content.clone(), }; - assert_eq!(document_data.document_id, document_id); + assert_eq!(document_data.document_base, document_id); assert_eq!(document_data.document_content, document_content); let embedding = EmbeddingComponent { @@ -215,7 +217,7 @@ mod tests { }; let document_data = DocumentDataWithEmbedding::from_document(document_data, embedding.clone()); - assert_eq!(document_data.document_id, document_id); + assert_eq!(document_data.document_base, document_id); assert_eq!(document_data.embedding, embedding); let coi = CoiComponent { @@ -224,13 +226,13 @@ mod tests { neg_distance: 0.2, }; let document_data = DocumentDataWithCoi::from_document(document_data, coi.clone()); - assert_eq!(document_data.document_id, document_id); + assert_eq!(document_data.document_base, document_id); assert_eq!(document_data.embedding, embedding); assert_eq!(document_data.coi, coi); let ltr = LtrComponent { ltr_score: 0.3 }; let document_data = DocumentDataWithLtr::from_document(document_data, ltr.clone()); - assert_eq!(document_data.document_id, document_id); + assert_eq!(document_data.document_base, document_id); assert_eq!(document_data.embedding, embedding); assert_eq!(document_data.coi, coi); assert_eq!(document_data.ltr, ltr); @@ -239,7 +241,7 @@ mod tests { context_value: 1.23, }; let document_data = DocumentDataWithContext::from_document(document_data, context.clone()); - assert_eq!(document_data.document_id, document_id); + assert_eq!(document_data.document_base, document_id); assert_eq!(document_data.embedding, embedding); assert_eq!(document_data.coi, coi); assert_eq!(document_data.ltr, ltr); @@ -247,7 +249,7 @@ mod tests { let mab = MabComponent { rank: 3 }; let document_data = DocumentDataWithMab::from_document(document_data, mab.clone()); - assert_eq!(document_data.document_id, document_id); + assert_eq!(document_data.document_base, document_id); assert_eq!(document_data.embedding, embedding); assert_eq!(document_data.coi, coi); assert_eq!(document_data.ltr, ltr); diff --git a/xayn-ai/src/ltr.rs b/xayn-ai/src/ltr.rs index 94ee5db82..42ba230d5 100644 --- a/xayn-ai/src/ltr.rs +++ b/xayn-ai/src/ltr.rs @@ -45,7 +45,7 @@ mod tests { use super::*; use crate::data::{ document::DocumentId, - document_data::{CoiComponent, DocumentIdComponent, EmbeddingComponent}, + document_data::{CoiComponent, DocumentBaseComponent, EmbeddingComponent}, CoiId, }; @@ -59,7 +59,10 @@ mod tests { neg_distance: 0.2, }; let doc1 = DocumentDataWithCoi { - document_id: DocumentIdComponent { id }, + document_base: DocumentBaseComponent { + id, + initial_ranking: 24, + }, embedding: EmbeddingComponent { embedding }, coi, }; @@ -72,7 +75,10 @@ mod tests { neg_distance: 0.9, }; let doc2 = DocumentDataWithCoi { - document_id: DocumentIdComponent { id }, + document_base: DocumentBaseComponent { + id, + initial_ranking: 42, + }, embedding: EmbeddingComponent { embedding }, coi, }; diff --git a/xayn-ai/src/mab.rs b/xayn-ai/src/mab.rs index da993e7bc..ee01bf34e 100644 --- a/xayn-ai/src/mab.rs +++ b/xayn-ai/src/mab.rs @@ -278,7 +278,7 @@ mod tests { document_data::{ CoiComponent, ContextComponent, - DocumentIdComponent, + DocumentBaseComponent, EmbeddingComponent, LtrComponent, }, @@ -291,7 +291,10 @@ mod tests { fn with_ctx(id: DocumentId, coi_id: CoiId, context_value: f32) -> DocumentDataWithContext { DocumentDataWithContext { - document_id: DocumentIdComponent { id }, + document_base: DocumentBaseComponent { + id, + initial_ranking: 0, + }, embedding: EmbeddingComponent { embedding: arr1(&[]).into(), }, @@ -350,7 +353,7 @@ mod tests { .unwrap_or_else(|| panic!("document from coi id {:?}", coi_id)); let docs_id: HashSet = docs .iter() - .map(|doc| doc.0.document_id.id.clone()) + .map(|doc| doc.0.document_base.id.clone()) .collect(); assert_eq!(docs_id.len(), docs_id_ok.len()); @@ -386,7 +389,7 @@ mod tests { .into_iter() // into_sorted_vec returns elements in the revers order of what using pop will do .rev() - .map(|doc| doc.0.document_id.id) + .map(|doc| doc.0.document_base.id) .collect(); assert_eq!( @@ -413,7 +416,7 @@ mod tests { .into_iter() // into_sorted_vec returns elements in the revers order of what using pop will do .rev() - .map(|doc| doc.0.document_id.id) + .map(|doc| doc.0.document_base.id) .collect(); assert_eq!(docs_id, vec![doc_id_2, doc_id_0, doc_id_1]); @@ -648,7 +651,7 @@ mod tests { let (documents_by_coi, document) = pull_arms(&beta_sampler, &cois, documents_by_coi).expect("document"); - assert_eq!(doc_id, document.document_id.id); + assert_eq!(doc_id, document.document_base.id); documents_by_coi }); @@ -735,7 +738,7 @@ mod tests { let (documents_by_coi, document) = pull_arms(&beta_sampler, &cois, documents_by_coi).expect("document"); - assert_eq!(doc_id, document.document_id.id); + assert_eq!(doc_id, document.document_base.id); documents_by_coi }); @@ -787,7 +790,7 @@ mod tests { let (documents_by_coi, document) = pull_arms(&beta_sampler, &cois, documents_by_coi).expect("document"); - let ok = ok && doc_id == document.document_id.id; + let ok = ok && doc_id == document.document_base.id; (ok, documents_by_coi) }, @@ -838,7 +841,7 @@ mod tests { .expect("documents"); let documents_id: Vec<_> = documents .into_iter() - .map(|document| document.document_id.id) + .map(|document| document.document_base.id) .collect(); let documents_id_ok = vec![doc_id_5, doc_id_4, doc_id_3, doc_id_2, doc_id_1, doc_id_0]; @@ -930,7 +933,7 @@ mod tests { for document in documents { let rank = documents_id_to_rank - .get(&document.document_id.id) + .get(&document.document_base.id) .expect("rank"); assert_eq!(document.mab.rank, *rank); } diff --git a/xayn-ai/src/reranker/mod.rs b/xayn-ai/src/reranker/mod.rs index d70b65e90..93993d7da 100644 --- a/xayn-ai/src/reranker/mod.rs +++ b/xayn-ai/src/reranker/mod.rs @@ -11,11 +11,11 @@ use crate::{ data::{ document::{Document, DocumentHistory, Ranks}, document_data::{ + DocumentBaseComponent, DocumentContentComponent, DocumentDataWithDocument, DocumentDataWithEmbedding, DocumentDataWithMab, - DocumentIdComponent, }, UserInterests, }, @@ -66,8 +66,9 @@ where let documents: Vec<_> = documents .iter() .map(|document| DocumentDataWithDocument { - document_id: DocumentIdComponent { + document_base: DocumentBaseComponent { id: document.id.clone(), + initial_ranking: document.rank, }, document_content: DocumentContentComponent { snippet: document.snippet.clone(), @@ -101,7 +102,7 @@ where let ranks = documents .iter() - .map(|document| (document.document_id.id.clone(), document.mab.rank)) + .map(|document| (document.document_base.id.clone(), document.mab.rank)) .collect::>(); let ranks = original_documents .iter() diff --git a/xayn-ai/src/tests/systems.rs b/xayn-ai/src/tests/systems.rs index 339f1b932..390e2c593 100644 --- a/xayn-ai/src/tests/systems.rs +++ b/xayn-ai/src/tests/systems.rs @@ -67,7 +67,7 @@ pub(crate) fn mocked_bert_system() -> MockBertSystem { embedding.resize(128, 0.); DocumentDataWithEmbedding { - document_id: doc.document_id, + document_base: doc.document_base, embedding: EmbeddingComponent { embedding: arr1(&embedding).into(), }, diff --git a/xayn-ai/src/tests/utils.rs b/xayn-ai/src/tests/utils.rs index ce787de93..139f2b6ab 100644 --- a/xayn-ai/src/tests/utils.rs +++ b/xayn-ai/src/tests/utils.rs @@ -9,11 +9,11 @@ use crate::{ document_data::{ CoiComponent, ContextComponent, + DocumentBaseComponent, DocumentContentComponent, DocumentDataWithDocument, DocumentDataWithEmbedding, DocumentDataWithMab, - DocumentIdComponent, EmbeddingComponent, LtrComponent, MabComponent, @@ -55,8 +55,9 @@ fn cois_from_words(snippets: &[&str], bert: impl BertSystem) -> Ve .iter() .enumerate() .map(|(id, snippet)| DocumentDataWithDocument { - document_id: DocumentIdComponent { + document_base: DocumentBaseComponent { id: DocumentId(id.to_string()), + initial_ranking: id, }, document_content: DocumentContentComponent { snippet: snippet.to_string(), @@ -96,11 +97,14 @@ pub(crate) fn history_for_prev_docs( } pub(crate) fn data_with_mab( - ids_and_embeddings: impl Iterator, + ids_and_embeddings: impl Iterator, ) -> Vec { ids_and_embeddings - .map(|(id, embedding)| DocumentDataWithMab { - document_id: DocumentIdComponent { id }, + .map(|(id, initial_ranking, embedding)| DocumentDataWithMab { + document_base: DocumentBaseComponent { + id, + initial_ranking, + }, embedding: EmbeddingComponent { embedding }, coi: CoiComponent { id: CoiId(1), @@ -118,10 +122,15 @@ pub(crate) fn documents_with_embeddings_from_ids( ids: Range, ) -> Vec { from_ids(ids) - .map(|(id, embedding)| DocumentDataWithEmbedding { - document_id: DocumentIdComponent { id }, - embedding: EmbeddingComponent { embedding }, - }) + .map( + |(id, initial_ranking, embedding)| DocumentDataWithEmbedding { + document_base: DocumentBaseComponent { + id, + initial_ranking, + }, + embedding: EmbeddingComponent { embedding }, + }, + ) .collect() } @@ -160,10 +169,15 @@ pub(crate) fn document_history(docs: Vec<(u32, Relevance, UserFeedback)>) -> Vec .collect() } -pub(crate) fn from_ids(ids: Range) -> impl Iterator { +/// Return a sequence of `(document_id, initial_ranking, embedding)` tuples. +/// +/// The passed in integer ids are converted to a string and used as document_id's as +/// well as used as the initial_ranking. +pub(crate) fn from_ids(ids: Range) -> impl Iterator { ids.map(|id| { ( DocumentId(id.to_string()), + id as usize, arr1(&vec![id as f32; 128]).into(), ) })