Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Attribute filtering during export #5320

Merged
merged 3 commits into from
Aug 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions components/collator/src/provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ fn data_ce_to_primary(data_ce: u64, c: char) -> u32 {
"collator/data@1",
// TODO(#3867): Use script fallback
fallback_by = "language",
attributes_domain = "collator",
))]
#[derive(Debug, PartialEq, Clone)]
#[cfg_attr(feature = "datagen", derive(serde::Serialize, databake::Bake), databake(path = icu_collator::provider))]
Expand Down Expand Up @@ -233,6 +234,7 @@ impl<'data> CollationDataV1<'data> {
CollationDiacriticsV1Marker,
"collator/dia@1",
fallback_by = "language",
attributes_domain = "collator",
))]
#[derive(Debug, PartialEq, Clone)]
#[cfg_attr(feature = "datagen", derive(serde::Serialize, databake::Bake), databake(path = icu_collator::provider))]
Expand Down Expand Up @@ -275,6 +277,7 @@ pub struct CollationJamoV1<'data> {
CollationReorderingV1Marker,
"collator/reord@1",
fallback_by = "language",
attributes_domain = "collator",
))]
#[derive(Debug, PartialEq, Clone)]
#[cfg_attr(feature = "datagen", derive(serde::Serialize, databake::Bake), databake(path = icu_collator::provider))]
Expand Down Expand Up @@ -363,6 +366,7 @@ impl<'data> CollationReorderingV1<'data> {
CollationMetadataV1Marker,
"collator/meta@1",
fallback_by = "language",
attributes_domain = "collator",
))]
#[derive(Debug, PartialEq, Clone, Copy)]
#[cfg_attr(feature = "datagen", derive(serde::Serialize, databake::Bake), databake(path = icu_collator::provider))]
Expand Down
6 changes: 5 additions & 1 deletion components/experimental/src/dimension/provider/units.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@ use icu_plurals::PluralCategory;
use icu_provider::prelude::*;
use zerovec::ZeroMap;

#[icu_provider::data_struct(UnitsDisplayNameV1Marker = "units/displaynames@1")]
#[icu_provider::data_struct(marker(
UnitsDisplayNameV1Marker,
"units/displaynames@1",
attributes_domain = "units"
))]
#[derive(Clone, PartialEq, Debug)]
#[cfg_attr(
feature = "datagen",
Expand Down
6 changes: 5 additions & 1 deletion components/segmenter/src/provider/lstm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,11 @@ impl databake::BakeSize for LstmDataFloat32<'_> {
/// including in SemVer minor releases. While the serde representation of data structs is guaranteed
/// to be stable, their Rust representation might not be. Use with caution.
/// </div>
#[icu_provider::data_struct(LstmForWordLineAutoV1Marker = "segmenter/lstm/wl_auto@1")]
#[icu_provider::data_struct(marker(
LstmForWordLineAutoV1Marker,
"segmenter/lstm/wl_auto@1",
attributes_domain = "segmenter"
))]
#[derive(Debug, PartialEq, Clone)]
#[cfg_attr(
feature = "datagen",
Expand Down
12 changes: 10 additions & 2 deletions components/segmenter/src/provider/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,16 @@ pub struct RuleBreakDataV2<'data> {
/// to be stable, their Rust representation might not be. Use with caution.
/// </div>
#[icu_provider::data_struct(
DictionaryForWordOnlyAutoV1Marker = "segmenter/dictionary/w_auto@1",
DictionaryForWordLineExtendedV1Marker = "segmenter/dictionary/wl_ext@1"
marker(
DictionaryForWordOnlyAutoV1Marker,
"segmenter/dictionary/w_auto@1",
attributes_domain = "segmenter"
),
marker(
DictionaryForWordLineExtendedV1Marker,
"segmenter/dictionary/wl_ext@1",
attributes_domain = "segmenter"
)
)]
#[derive(Debug, PartialEq, Clone)]
#[cfg_attr(
Expand Down
18 changes: 18 additions & 0 deletions provider/core/macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ struct DataStructArg {
marker_name: Path,
path_lit: Option<LitStr>,
fallback_by: Option<LitStr>,
attributes_domain: Option<LitStr>,
singleton: bool,
}

Expand All @@ -124,6 +125,7 @@ impl DataStructArg {
marker_name,
path_lit: None,
fallback_by: None,
attributes_domain: None,
singleton: false,
}
}
Expand Down Expand Up @@ -155,6 +157,7 @@ impl Parse for DataStructArg {
let mut marker_name: Option<Path> = None;
let mut path_lit: Option<LitStr> = None;
let mut fallback_by: Option<LitStr> = None;
let mut attributes_domain: Option<LitStr> = None;
let mut singleton = false;
let punct = content.parse_terminated(DataStructMarkerArg::parse, Token![,])?;

Expand All @@ -171,6 +174,13 @@ impl Parse for DataStructArg {
"fallback_by",
paren.span.join(),
)?;
} else if name == "attributes_domain" {
at_most_one_option(
&mut attributes_domain,
value,
"attributes_domain",
paren.span.join(),
)?;
} else {
return Err(parse::Error::new(
name.span(),
Expand Down Expand Up @@ -199,6 +209,7 @@ impl Parse for DataStructArg {
marker_name,
path_lit,
fallback_by,
attributes_domain,
singleton,
})
} else {
Expand Down Expand Up @@ -282,6 +293,7 @@ fn data_struct_impl(attr: DataStructArgs, input: DeriveInput) -> TokenStream2 {
marker_name,
path_lit,
fallback_by,
attributes_domain,
singleton,
} = single_attr;

Expand Down Expand Up @@ -324,12 +336,18 @@ fn data_struct_impl(attr: DataStructArgs, input: DeriveInput) -> TokenStream2 {
} else {
quote! {icu_provider::_internal::LocaleFallbackPriority::const_default()}
};
let attributes_domain_setter = if let Some(attributes_domain_lit) = attributes_domain {
quote! { info.attributes_domain = #attributes_domain_lit; }
} else {
quote!()
};
result.extend(quote!(
impl icu_provider::DataMarker for #marker_name {
const INFO: icu_provider::DataMarkerInfo = {
let mut info = icu_provider::DataMarkerInfo::from_path(icu_provider::marker::data_marker_path!(#path_str));
info.is_singleton = #singleton;
info.fallback_config.priority = #fallback_by_expr;
#attributes_domain_setter
info
};
}
Expand Down
3 changes: 3 additions & 0 deletions provider/core/src/marker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,8 @@ pub struct DataMarkerInfo {
/// Useful for reading and writing data to a file system.
pub path: DataMarkerPath,
/// TODO
pub attributes_domain: &'static str,
/// TODO
pub is_singleton: bool,
/// TODO
pub fallback_config: LocaleFallbackConfig,
Expand Down Expand Up @@ -560,6 +562,7 @@ impl DataMarkerInfo {
Self {
path,
is_singleton: false,
attributes_domain: "",
fallback_config: LocaleFallbackConfig::const_default(),
}
}
Expand Down
70 changes: 24 additions & 46 deletions provider/export/src/export_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use icu_provider::prelude::*;
use std::collections::HashMap;
use std::collections::HashSet;
use std::fmt;
use std::sync::Arc;
use std::time::Duration;
use std::time::Instant;
use writeable::Writeable;
Expand Down Expand Up @@ -53,8 +54,7 @@ impl ExportDriver {
include_full,
fallbacker,
deduplication_strategy,
additional_collations,
segmenter_models,
attributes_filters,
} = self;

let markers = markers.unwrap_or_else(|| provider.supported_markers());
Expand Down Expand Up @@ -177,9 +177,8 @@ impl ExportDriver {
provider,
marker,
&requested_families,
&attributes_filters,
include_full,
&additional_collations,
&segmenter_models,
&fallbacker,
)?;

Expand Down Expand Up @@ -262,13 +261,16 @@ impl ExportDriver {

/// Selects the maximal set of locales to export based on a [`DataMarkerInfo`] and this datagen
/// provider's options bag. The locales may be later optionally deduplicated for fallback.
#[allow(clippy::type_complexity)] // sigh
fn select_locales_for_marker<'a>(
provider: &'a dyn ExportableProvider,
marker: DataMarkerInfo,
requested_families: &HashMap<DataLocale, DataLocaleFamilyAnnotations>,
attributes_filters: &HashMap<
String,
Arc<Box<dyn Fn(&DataMarkerAttributes) -> bool + Send + Sync + 'static>>,
>,
include_full: bool,
additional_collations: &HashSet<String>,
segmenter_models: &[String],
fallbacker: &LocaleFallbacker,
) -> Result<HashSet<DataIdentifierCow<'a>>, DataError> {
// Map from all supported DataLocales to their corresponding supported DataIdentifiers.
Expand All @@ -283,41 +285,13 @@ fn select_locales_for_marker<'a>(
.insert(id);
}

if marker.path.as_str().starts_with("segmenter/dictionary/") {
supported_map.retain(|_, ids| {
ids.retain(|id| {
segmenter_models
.iter()
.any(|m| **m == **id.marker_attributes)
});
!ids.is_empty()
});
// Don't perform additional locale filtering
return Ok(supported_map.into_values().flatten().collect());
} else if marker.path.as_str().starts_with("segmenter/lstm/") {
supported_map.retain(|_, locales| {
locales.retain(|id| {
segmenter_models
.iter()
.any(|m| **m == **id.marker_attributes)
});
!locales.is_empty()
});
// Don't perform additional locale filtering
return Ok(supported_map.into_values().flatten().collect());
} else if marker.path.as_str().starts_with("collator/") {
supported_map.retain(|_, ids| {
ids.retain(|id| {
id.marker_attributes.is_empty()
|| additional_collations.contains(id.marker_attributes.as_str())
|| if id.marker_attributes.as_str().starts_with("search") {
additional_collations.contains("search*")
} else {
!["big5han", "gb2312"].contains(&id.marker_attributes.as_str())
}
if !marker.attributes_domain.is_empty() {
if let Some(filter) = attributes_filters.get(marker.attributes_domain) {
supported_map.retain(|_, ids| {
ids.retain(|id| filter(&id.marker_attributes));
!ids.is_empty()
});
!ids.is_empty()
});
}
}

if include_full && requested_families.is_empty() {
Expand Down Expand Up @@ -510,6 +484,7 @@ impl fmt::Display for DisplayDuration {

#[test]
fn test_collation_filtering() {
use crate::DataLocaleFamily;
use icu::locale::locale;
use std::collections::BTreeSet;

Expand Down Expand Up @@ -619,16 +594,19 @@ fn test_collation_filtering() {
},
];
for cas in cases {
let driver = ExportDriver::new(
[DataLocaleFamily::single(cas.language.clone())],
DeduplicationStrategy::None.into(),
LocaleFallbacker::new_without_data(),
)
.with_additional_collations(cas.include_collations.iter().copied().map(String::from));
let resolved_locales = select_locales_for_marker(
&Provider,
icu::collator::provider::CollationDataV1Marker::INFO,
&[(cas.language.clone(), DataLocaleFamilyAnnotations::single())]
.into_iter()
.collect(),
&driver.requested_families,
&driver.attributes_filters,
false,
&HashSet::from_iter(cas.include_collations.iter().copied().map(String::from)),
&[],
&LocaleFallbacker::new_without_data(),
&driver.fallbacker,
)
.unwrap()
.into_iter()
Expand Down
56 changes: 43 additions & 13 deletions provider/export/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ use icu_provider::prelude::*;
use std::collections::HashMap;
use std::collections::HashSet;
use std::hash::Hash;
use std::sync::Arc;

/// Configuration for a data export operation.
///
Expand All @@ -115,15 +116,29 @@ use std::hash::Hash;
/// )
/// .unwrap();
/// ```
#[derive(Debug, Clone)]
#[derive(Clone)]
pub struct ExportDriver {
markers: Option<HashSet<DataMarkerInfo>>,
requested_families: HashMap<DataLocale, DataLocaleFamilyAnnotations>,
#[allow(clippy::type_complexity)] // sigh
attributes_filters:
HashMap<String, Arc<Box<dyn Fn(&DataMarkerAttributes) -> bool + Send + Sync + 'static>>>,
fallbacker: LocaleFallbacker,
include_full: bool,
deduplication_strategy: DeduplicationStrategy,
additional_collations: HashSet<String>,
segmenter_models: Vec<String>,
}

impl core::fmt::Debug for ExportDriver {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ExportDriver")
.field("markers", &self.markers)
.field("requested_families", &self.requested_families)
.field("attributes_filters", &self.attributes_filters.keys())
.field("fallbacker", &self.fallbacker)
.field("include_full", &self.include_full)
.field("deduplication_strategy", &self.deduplication_strategy)
.finish()
}
}

impl ExportDriver {
Expand Down Expand Up @@ -158,13 +173,24 @@ impl ExportDriver {
))
})
.collect(),
attributes_filters: Default::default(),
include_full,
fallbacker,
deduplication_strategy: options.deduplication_strategy,
additional_collations: Default::default(),
segmenter_models: Default::default(),
}
.with_recommended_segmenter_models()
.with_additional_collations([])
}

/// TODO
pub fn with_marker_attributes_filter(
mut self,
domain: &str,
filter: impl Fn(&DataMarkerAttributes) -> bool + Send + Sync + 'static,
) -> Self {
self.attributes_filters
.insert(String::from(domain), Arc::new(Box::new(filter)));
self
}

/// Sets this driver to generate the given data markers.
Expand All @@ -187,10 +213,16 @@ impl ExportDriver {
self,
additional_collations: impl IntoIterator<Item = String>,
) -> Self {
Self {
additional_collations: additional_collations.into_iter().collect(),
..self
}
let set = additional_collations.into_iter().collect::<HashSet<_>>();
self.with_marker_attributes_filter("collator", move |attrs| {
attrs.is_empty()
|| set.contains(attrs.as_str())
|| if attrs.as_str().starts_with("search") {
set.contains("search*")
} else {
!["big5han", "gb2312"].contains(&attrs.as_str())
}
})
}

/// This option is only relevant if using `icu::segmenter`.
Expand Down Expand Up @@ -236,10 +268,8 @@ impl ExportDriver {
/// If multiple models for the same language and segmentation type (dictionary/LSTM) are
/// listed, the first one will be used.
pub fn with_segmenter_models(self, models: impl IntoIterator<Item = String>) -> Self {
Self {
segmenter_models: models.into_iter().collect(),
..self
}
let set = models.into_iter().collect::<HashSet<_>>();
self.with_marker_attributes_filter("segmenter", move |attrs| set.contains(attrs.as_str()))
}
}

Expand Down
Loading