diff --git a/Cargo.lock b/Cargo.lock index e120e74c4..441840c77 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -138,15 +138,16 @@ checksum = "62b02a5381cc465bd3041d84623d0fa3b66738b52b8e2fc3bab8ad63ab032f4a" [[package]] name = "jiter" -version = "0.1.1" +version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2c0b7c896d2b1da897be13affb0bbf7bff95437e9c50823ede962addadae58d8" +checksum = "8e1177860adcf80c1ae7d7c1d41561f008c7530664caebbfa5ddd8a7f7316b98" dependencies = [ "ahash", "lexical-parse-float", "num-bigint", "num-traits", "pyo3", + "pyo3-build-config", "smallvec", ] diff --git a/Cargo.toml b/Cargo.toml index 420509290..69acef742 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -44,7 +44,7 @@ base64 = "0.21.7" num-bigint = "0.4.4" python3-dll-a = "0.2.7" uuid = "1.7.0" -jiter = { version = "0.1.1", features = ["python"] } +jiter = { version = "0.2.1", features = ["python"] } [lib] name = "_pydantic_core" diff --git a/src/input/return_enums.rs b/src/input/return_enums.rs index 5d1019f05..d28bac2ec 100644 --- a/src/input/return_enums.rs +++ b/src/input/return_enums.rs @@ -19,7 +19,7 @@ use serde::{ser::Error, Serialize, Serializer}; use crate::errors::{ py_err_string, ErrorType, ErrorTypeDefaults, InputValue, ToErrorValue, ValError, ValLineError, ValResult, }; -use crate::tools::{extract_i64, py_err}; +use crate::tools::{extract_i64, new_py_string, py_err}; use crate::validators::{CombinedValidator, Exactness, ValidationState, Validator}; use super::{py_error_on_minusone, BorrowInput, Input}; @@ -437,13 +437,7 @@ impl<'a> EitherString<'a> { pub fn as_py_string(&'a self, py: Python<'a>, cache_str: StringCacheMode) -> Bound<'a, PyString> { match self { - Self::Cow(cow) => { - if matches!(cache_str, StringCacheMode::All) { - jiter::cached_py_string(py, cow.as_ref()) - } else { - PyString::new_bound(py, cow.as_ref()) - } - } + Self::Cow(cow) => new_py_string(py, cow.as_ref(), cache_str), Self::Py(py_string) => py_string.clone(), } } diff --git a/src/input/shared.rs b/src/input/shared.rs index 5f0040e3e..e99bfabcf 100644 --- a/src/input/shared.rs +++ b/src/input/shared.rs @@ -2,7 +2,7 @@ use pyo3::prelude::*; use pyo3::sync::GILOnceCell; use pyo3::{intern, Py, PyAny, Python}; -use num_bigint::BigInt; +use jiter::{JsonErrorType, NumberInt}; use crate::errors::{ErrorTypeDefaults, ValError, ValResult}; @@ -68,29 +68,24 @@ fn strip_underscores(s: &str) -> Option { } /// parse a string as an int -/// -/// max length of the input is 4300, see -/// https://docs.python.org/3/whatsnew/3.11.html#other-cpython-implementation-changes and -/// https://github.com/python/cpython/issues/95778 for more info in that length bound pub fn str_as_int<'py>(input: &(impl Input<'py> + ?Sized), str: &str) -> ValResult> { let str = str.trim(); - let len = str.len(); - if len > 4300 { - Err(ValError::new(ErrorTypeDefaults::IntParsingSize, input)) - } else if let Some(int) = _parse_str(input, str, len) { - Ok(int) - } else if let Some(str_stripped) = strip_decimal_zeros(str) { - if let Some(int) = _parse_str(input, str_stripped, len) { - Ok(int) - } else { - Err(ValError::new(ErrorTypeDefaults::IntParsing, input)) + + // we have to call `NumberInt::try_from` directly first so we fail fast if the string is too long + match NumberInt::try_from(str.as_bytes()) { + Ok(NumberInt::Int(i)) => return Ok(EitherInt::I64(i)), + Ok(NumberInt::BigInt(i)) => return Ok(EitherInt::BigInt(i)), + Err(e) => { + if e.error_type == JsonErrorType::NumberOutOfRange { + return Err(ValError::new(ErrorTypeDefaults::IntParsingSize, input)); + } } + } + + if let Some(str_stripped) = strip_decimal_zeros(str) { + _parse_str(input, str_stripped) } else if let Some(str_stripped) = strip_underscores(str) { - if let Some(int) = _parse_str(input, &str_stripped, len) { - Ok(int) - } else { - Err(ValError::new(ErrorTypeDefaults::IntParsing, input)) - } + _parse_str(input, &str_stripped) } else { Err(ValError::new(ErrorTypeDefaults::IntParsing, input)) } @@ -108,16 +103,18 @@ pub fn str_as_float<'py>(input: &(impl Input<'py> + ?Sized), str: &str) -> ValRe } /// parse a string as an int, `input` is required here to get lifetimes to match up -/// -fn _parse_str<'py>(_input: &(impl Input<'py> + ?Sized), str: &str, len: usize) -> Option> { - if len < 19 { - if let Ok(i) = str.parse::() { - return Some(EitherInt::I64(i)); - } - } else if let Ok(i) = str.parse::() { - return Some(EitherInt::BigInt(i)); +/// max length of the input is 4300 which is checked by jiter, see +/// https://docs.python.org/3/whatsnew/3.11.html#other-cpython-implementation-changes and +/// https://github.com/python/cpython/issues/95778 for more info in that length bound +fn _parse_str<'py>(input: &(impl Input<'py> + ?Sized), str: &str) -> ValResult> { + match NumberInt::try_from(str.as_bytes()) { + Ok(jiter::NumberInt::Int(i)) => Ok(EitherInt::I64(i)), + Ok(jiter::NumberInt::BigInt(i)) => Ok(EitherInt::BigInt(i)), + Err(e) => match e.error_type { + JsonErrorType::NumberOutOfRange => Err(ValError::new(ErrorTypeDefaults::IntParsingSize, input)), + _ => Err(ValError::new(ErrorTypeDefaults::IntParsing, input)), + }, } - None } /// we don't want to parse as f64 then call `float_as_int` as it can loose precision for large ints, therefore diff --git a/src/tools.rs b/src/tools.rs index a823311a0..28c661d8f 100644 --- a/src/tools.rs +++ b/src/tools.rs @@ -5,6 +5,8 @@ use pyo3::prelude::*; use pyo3::types::{PyDict, PyString}; use pyo3::{ffi, intern, FromPyObject}; +use jiter::{cached_py_string, pystring_fast_new, StringCacheMode}; + pub trait SchemaDict<'py> { fn get_as(&self, key: &Bound<'_, PyString>) -> PyResult> where @@ -143,3 +145,13 @@ pub fn extract_i64(v: &Bound<'_, PyAny>) -> Option { None } } + +pub(crate) fn new_py_string<'py>(py: Python<'py>, s: &str, cache_str: StringCacheMode) -> Bound<'py, PyString> { + // we could use `bytecount::num_chars(s.as_bytes()) == s.len()` as orjson does, but it doesn't appear to be faster + let ascii_only = false; + if matches!(cache_str, StringCacheMode::All) { + cached_py_string(py, s, ascii_only) + } else { + pystring_fast_new(py, s, ascii_only) + } +} diff --git a/src/validators/validation_state.rs b/src/validators/validation_state.rs index 8e68cd8d9..ef6954618 100644 --- a/src/validators/validation_state.rs +++ b/src/validators/validation_state.rs @@ -4,6 +4,7 @@ use pyo3::types::PyString; use jiter::StringCacheMode; use crate::recursion_guard::{ContainsRecursionState, RecursionState}; +use crate::tools::new_py_string; use super::Extra; @@ -72,11 +73,7 @@ impl<'a, 'py> ValidationState<'a, 'py> { } pub fn maybe_cached_str(&self, py: Python<'py>, s: &str) -> Bound<'py, PyString> { - if matches!(self.extra.cache_str, StringCacheMode::All) { - jiter::cached_py_string(py, s) - } else { - PyString::new_bound(py, s) - } + new_py_string(py, s, self.extra.cache_str) } }