Skip to content

Commit

Permalink
Fixing missing direction in TruncationParams. (#868)
Browse files Browse the repository at this point in the history
  • Loading branch information
Narsil authored Jan 4, 2022
1 parent 7069988 commit 4122a33
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 14 deletions.
12 changes: 0 additions & 12 deletions bindings/python/test.py

This file was deleted.

25 changes: 23 additions & 2 deletions tokenizers/src/utils/truncation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@ pub enum TruncationDirection {
Left,
Right,
}
impl Default for TruncationDirection {
fn default() -> Self {
TruncationDirection::Right
}
}

impl std::convert::AsRef<str> for TruncationDirection {
fn as_ref(&self) -> &str {
Expand All @@ -20,6 +25,7 @@ impl std::convert::AsRef<str> for TruncationDirection {

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TruncationParams {
#[serde(default)]
pub direction: TruncationDirection,
pub max_length: usize,
pub strategy: TruncationStrategy,
Expand All @@ -30,9 +36,9 @@ impl Default for TruncationParams {
fn default() -> Self {
Self {
max_length: 512,
strategy: TruncationStrategy::LongestFirst,
strategy: TruncationStrategy::default(),
stride: 0,
direction: TruncationDirection::Right,
direction: TruncationDirection::default(),
}
}
}
Expand Down Expand Up @@ -68,6 +74,12 @@ pub enum TruncationStrategy {
OnlySecond,
}

impl Default for TruncationStrategy {
fn default() -> Self {
TruncationStrategy::LongestFirst
}
}

impl std::convert::AsRef<str> for TruncationStrategy {
fn as_ref(&self) -> &str {
match self {
Expand Down Expand Up @@ -325,4 +337,13 @@ mod tests {
truncate_and_assert(get_medium(), get_medium(), &params, 0, 0);
truncate_and_assert(get_long(), get_long(), &params, 0, 0);
}

#[test]
fn test_deserialize_defaults() {
let old_truncation_params = r#"{"max_length":256,"strategy":"LongestFirst","stride":0}"#;

let params: TruncationParams = serde_json::from_str(old_truncation_params).unwrap();

assert_eq!(params.direction, TruncationDirection::Right);
}
}

0 comments on commit 4122a33

Please sign in to comment.