Skip to content

Commit

Permalink
fix(core): fix max decoding length not being respected (#1626)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsxiaoys committed Mar 6, 2024
1 parent cc7803d commit bd77572
Showing 1 changed file with 18 additions and 5 deletions.
23 changes: 18 additions & 5 deletions crates/tabby-inference/src/decoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ impl<'a> StopCondition<'a> {
}

pub fn should_stop(&mut self, new_text: &str) -> (bool, usize) {
self.num_decoded += 1;
if !new_text.is_empty() {
self.reversed_text = reverse(new_text) + &self.reversed_text;

Expand All @@ -89,19 +90,17 @@ impl<'a> StopCondition<'a> {
let matched_length = matches.into_iter().map(|x| x.len()).max();
if let Some(matched_length) = matched_length {
return (true, matched_length);
} else {
return (false, 0);
};
}
}
}

self.num_decoded += 1;
(self.num_decoded >= self.max_decoding_length, 0)
}
}

#[cfg(test)]
mod tests {
use tabby_common::languages::UNKNOWN_LANGUAGE;

use super::*;

#[test]
Expand All @@ -118,4 +117,18 @@ mod tests {
]);
assert!(!trie.common_prefix_search(&text).is_empty());
}

#[test]
fn test_stop_condition_max_length() {
let factory = StopConditionFactory::default();
let mut cond = factory.create("", 4, Some(&UNKNOWN_LANGUAGE));
let (should_stop, _) = cond.should_stop("1");
assert!(!should_stop);
let (should_stop, _) = cond.should_stop("2");
assert!(!should_stop);
let (should_stop, _) = cond.should_stop("3");
assert!(!should_stop);
let (should_stop, _) = cond.should_stop("4");
assert!(should_stop)
}
}

0 comments on commit bd77572

Please sign in to comment.