Skip to content

Commit

Permalink
Int extraction (#1155)
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelcolvin authored Jan 15, 2024
1 parent 5d3aa43 commit d7cf72d
Show file tree
Hide file tree
Showing 11 changed files with 64 additions and 25 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,6 @@ node_modules/
/foobar.py
/python/pydantic_core/*.so
/src/self_schema.py

# samply
/profile.json
2 changes: 1 addition & 1 deletion src/errors/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -786,7 +786,7 @@ impl From<Int> for Number {

impl FromPyObject<'_> for Number {
fn extract(obj: &PyAny) -> PyResult<Self> {
if let Ok(int) = extract_i64(obj) {
if let Some(int) = extract_i64(obj) {
Ok(Number::Int(int))
} else if let Ok(float) = obj.extract::<f64>() {
Ok(Number::Float(float))
Expand Down
2 changes: 1 addition & 1 deletion src/errors/value_exception.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ impl PydanticCustomError {
let key: &PyString = key.downcast()?;
if let Ok(py_str) = value.downcast::<PyString>() {
message = message.replace(&format!("{{{}}}", key.to_str()?), py_str.to_str()?);
} else if let Ok(value_int) = extract_i64(value) {
} else if let Some(value_int) = extract_i64(value) {
message = message.replace(&format!("{{{}}}", key.to_str()?), &value_int.to_string());
} else {
// fallback for anything else just in case
Expand Down
10 changes: 5 additions & 5 deletions src/input/input_python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ impl AsLocItem for PyAny {
fn as_loc_item(&self) -> LocItem {
if let Ok(py_str) = self.downcast::<PyString>() {
py_str.to_string_lossy().as_ref().into()
} else if let Ok(key_int) = extract_i64(self) {
} else if let Some(key_int) = extract_i64(self) {
key_int.into()
} else {
safe_repr(self).to_string().into()
Expand Down Expand Up @@ -292,7 +292,7 @@ impl<'a> Input<'a> for PyAny {
if !strict {
if let Some(cow_str) = maybe_as_string(self, ErrorTypeDefaults::BoolParsing)? {
return str_as_bool(self, &cow_str).map(ValidationMatch::lax);
} else if let Ok(int) = extract_i64(self) {
} else if let Some(int) = extract_i64(self) {
return int_as_bool(self, int).map(ValidationMatch::lax);
} else if let Ok(float) = self.extract::<f64>() {
if let Ok(int) = float_as_int(self, float) {
Expand Down Expand Up @@ -635,7 +635,7 @@ impl<'a> Input<'a> for PyAny {
bytes_as_time(self, py_bytes.as_bytes(), microseconds_overflow_behavior)
} else if PyBool::is_exact_type_of(self) {
Err(ValError::new(ErrorTypeDefaults::TimeType, self))
} else if let Ok(int) = extract_i64(self) {
} else if let Some(int) = extract_i64(self) {
int_as_time(self, int, 0)
} else if let Ok(float) = self.extract::<f64>() {
float_as_time(self, float)
Expand Down Expand Up @@ -669,7 +669,7 @@ impl<'a> Input<'a> for PyAny {
bytes_as_datetime(self, py_bytes.as_bytes(), microseconds_overflow_behavior)
} else if PyBool::is_exact_type_of(self) {
Err(ValError::new(ErrorTypeDefaults::DatetimeType, self))
} else if let Ok(int) = extract_i64(self) {
} else if let Some(int) = extract_i64(self) {
int_as_datetime(self, int, 0)
} else if let Ok(float) = self.extract::<f64>() {
float_as_datetime(self, float)
Expand Down Expand Up @@ -706,7 +706,7 @@ impl<'a> Input<'a> for PyAny {
bytes_as_timedelta(self, str.as_bytes(), microseconds_overflow_behavior)
} else if let Ok(py_bytes) = self.downcast::<PyBytes>() {
bytes_as_timedelta(self, py_bytes.as_bytes(), microseconds_overflow_behavior)
} else if let Ok(int) = extract_i64(self) {
} else if let Some(int) = extract_i64(self) {
Ok(int_as_duration(self, int)?.into())
} else if let Ok(float) = self.extract::<f64>() {
Ok(float_as_duration(self, float)?.into())
Expand Down
6 changes: 3 additions & 3 deletions src/input/return_enums.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use pyo3::PyTypeInfo;
use serde::{ser::Error, Serialize, Serializer};

use crate::errors::{py_err_string, ErrorType, ErrorTypeDefaults, InputValue, ValError, ValLineError, ValResult};
use crate::tools::py_err;
use crate::tools::{extract_i64, py_err};
use crate::validators::{CombinedValidator, Exactness, ValidationState, Validator};

use super::input_string::StringMapping;
Expand Down Expand Up @@ -863,7 +863,7 @@ pub enum EitherInt<'a> {
impl<'a> EitherInt<'a> {
pub fn upcast(py_any: &'a PyAny) -> ValResult<Self> {
// Safety: we know that py_any is a python int
if let Ok(int_64) = py_any.extract::<i64>() {
if let Some(int_64) = extract_i64(py_any) {
Ok(Self::I64(int_64))
} else {
let big_int: BigInt = py_any.extract()?;
Expand Down Expand Up @@ -1021,7 +1021,7 @@ impl<'a> Rem for &'a Int {

impl<'a> FromPyObject<'a> for Int {
fn extract(obj: &'a PyAny) -> PyResult<Self> {
if let Ok(i) = obj.extract::<i64>() {
if let Some(i) = extract_i64(obj) {
Ok(Int::I64(i))
} else if let Ok(b) = obj.extract::<BigInt>() {
Ok(Int::Big(b))
Expand Down
2 changes: 1 addition & 1 deletion src/lookup_key.rs
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ impl PathItem {
} else {
Ok(Self::Pos(usize_key))
}
} else if let Ok(int_key) = extract_i64(obj) {
} else if let Some(int_key) = extract_i64(obj) {
if index == 0 {
py_err!(PyTypeError; "The first item in an alias path should be a string")
} else {
Expand Down
5 changes: 4 additions & 1 deletion src/serializers/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,10 @@ pub(crate) fn infer_to_python_known(
// `bool` and `None` can't be subclasses, `ObType::Int`, `ObType::Float`, `ObType::Str` refer to exact types
ObType::None | ObType::Bool | ObType::Int | ObType::Str => value.into_py(py),
// have to do this to make sure subclasses of for example str are upcast to `str`
ObType::IntSubclass => extract_i64(value)?.into_py(py),
ObType::IntSubclass => match extract_i64(value) {
Some(v) => v.into_py(py),
None => return py_err!(PyTypeError; "expected int, got {}", safe_repr(value)),
},
ObType::Float | ObType::FloatSubclass => {
let v = value.extract::<f64>()?;
if (v.is_nan() || v.is_infinite()) && extra.config.inf_nan_mode == InfNanMode::Null {
Expand Down
4 changes: 2 additions & 2 deletions src/serializers/type_serializers/literal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ impl BuildSerializer for LiteralSerializer {
repr_args.push(item.repr()?.extract()?);
if let Ok(bool) = item.downcast::<PyBool>() {
expected_py.append(bool)?;
} else if let Ok(int) = extract_i64(item) {
} else if let Some(int) = extract_i64(item) {
expected_int.insert(int);
} else if let Ok(py_str) = item.downcast::<PyString>() {
expected_str.insert(py_str.to_str()?.to_string());
Expand Down Expand Up @@ -79,7 +79,7 @@ impl LiteralSerializer {
fn check<'a>(&self, value: &'a PyAny, extra: &Extra) -> PyResult<OutputValue<'a>> {
if extra.check.enabled() {
if !self.expected_int.is_empty() && !PyBool::is_type_of(value) {
if let Ok(int) = extract_i64(value) {
if let Some(int) = extract_i64(value) {
if self.expected_int.contains(&int) {
return Ok(OutputValue::OkInt(int));
}
Expand Down
28 changes: 21 additions & 7 deletions src/tools.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use std::borrow::Cow;

use pyo3::exceptions::{PyKeyError, PyTypeError};
use pyo3::exceptions::PyKeyError;
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyInt, PyString};
use pyo3::{intern, FromPyObject, PyTypeInfo};
use pyo3::types::{PyDict, PyString};
use pyo3::{ffi, intern, FromPyObject};

pub trait SchemaDict<'py> {
fn get_as<T>(&'py self, key: &PyString) -> PyResult<Option<T>>
Expand Down Expand Up @@ -99,10 +99,24 @@ pub fn safe_repr(v: &PyAny) -> Cow<str> {
}
}

pub fn extract_i64(v: &PyAny) -> PyResult<i64> {
if PyInt::is_type_of(v) {
v.extract()
/// Extract an i64 from a python object more quickly, see
/// https://github.com/PyO3/pyo3/pull/3742#discussion_r1451763928
#[cfg(not(any(target_pointer_width = "32", windows, PyPy)))]
pub fn extract_i64(obj: &PyAny) -> Option<i64> {
let val = unsafe { ffi::PyLong_AsLong(obj.as_ptr()) };
if val == -1 && PyErr::occurred(obj.py()) {
unsafe { ffi::PyErr_Clear() };
None
} else {
py_err!(PyTypeError; "expected int, got {}", safe_repr(v))
Some(val)
}
}

#[cfg(any(target_pointer_width = "32", windows, PyPy))]
pub fn extract_i64(v: &PyAny) -> Option<i64> {
if v.is_instance_of::<pyo3::types::PyInt>() {
v.extract().ok()
} else {
None
}
}
12 changes: 12 additions & 0 deletions tests/benchmarks/test_micro_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1232,6 +1232,18 @@ def test_strict_int(benchmark):
benchmark(v.validate_python, 42)


@pytest.mark.benchmark(group='strict_int')
def test_strict_int_fails(benchmark):
v = SchemaValidator(core_schema.int_schema(strict=True))

@benchmark
def t():
try:
v.validate_python(())
except ValidationError:
pass


@pytest.mark.benchmark(group='int_range')
def test_int_range(benchmark):
v = SchemaValidator(core_schema.int_schema(gt=0, lt=100))
Expand Down
15 changes: 11 additions & 4 deletions tests/validators/test_int.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
('123456789123456.00001', Err('Input should be a valid integer, unable to parse string as an integer')),
(int(1e10), int(1e10)),
(i64_max, i64_max),
(i64_max + 1, i64_max + 1),
(i64_max * 2, i64_max * 2),
pytest.param(
12.5,
Err('Input should be a valid integer, got a number with a fractional part [type=int_from_float'),
Expand Down Expand Up @@ -106,10 +108,15 @@ def test_int(input_value, expected):
@pytest.mark.parametrize(
'input_value,expected',
[
(Decimal('1'), 1),
(Decimal('1.0'), 1),
(i64_max, i64_max),
(i64_max + 1, i64_max + 1),
pytest.param(Decimal('1'), 1),
pytest.param(Decimal('1.0'), 1),
pytest.param(i64_max, i64_max, id='i64_max'),
pytest.param(i64_max + 1, i64_max + 1, id='i64_max+1'),
pytest.param(
-1,
Err('Input should be greater than 0 [type=greater_than, input_value=-1, input_type=int]'),
id='-1',
),
(
-i64_max + 1,
Err('Input should be greater than 0 [type=greater_than, input_value=-9223372036854775806, input_type=int]'),
Expand Down

0 comments on commit d7cf72d

Please sign in to comment.