diff --git a/src/input/return_enums.rs b/src/input/return_enums.rs index 9722669cc..22faaba71 100644 --- a/src/input/return_enums.rs +++ b/src/input/return_enums.rs @@ -549,6 +549,7 @@ impl<'a> EitherInt<'a> { Ok(Self::BigInt(big_int)) } } + pub fn into_i64(self, py: Python<'a>) -> ValResult { match self { EitherInt::I64(i) => Ok(i), diff --git a/src/validators/literal.rs b/src/validators/literal.rs index 9c4eb7292..16eef090d 100644 --- a/src/validators/literal.rs +++ b/src/validators/literal.rs @@ -4,7 +4,7 @@ use core::fmt::Debug; use std::cmp::Ordering; use pyo3::prelude::*; -use pyo3::types::{PyDict, PyList}; +use pyo3::types::{PyDict, PyInt, PyList}; use pyo3::{intern, PyTraverseError, PyVisit}; use ahash::AHashMap; @@ -58,11 +58,13 @@ impl LiteralLookup { expected_bool.false_id = Some(id); } } - if let Ok(either_int) = k.exact_int() { - let int = either_int - .into_i64(py) - .map_err(|_| py_schema_error_type!("error extracting int {:?}", k))?; - expected_int.insert(int, id); + if k.is_exact_instance_of::() { + if let Ok(int_64) = k.extract::() { + expected_int.insert(int_64, id); + } else { + // cover the case of an int that's > i64::MAX etc. + expected_py_dict.set_item(k, id)?; + } } else if let Ok(either_str) = k.exact_str() { let str = either_str .as_cow() diff --git a/tests/validators/test_enums.py b/tests/validators/test_enums.py index 91701ce74..a4fb3c5c2 100644 --- a/tests/validators/test_enums.py +++ b/tests/validators/test_enums.py @@ -1,6 +1,6 @@ import re import sys -from enum import Enum, IntFlag +from enum import Enum, IntEnum, IntFlag import pytest @@ -331,3 +331,16 @@ class MyFlags(IntFlag): with pytest.raises(ValidationError): v.validate_python(None) + + +def test_big_int(): + class ColorEnum(IntEnum): + GREEN = 1 << 63 + BLUE = 1 << 64 + + v = SchemaValidator( + core_schema.with_default_schema(schema=core_schema.enum_schema(ColorEnum, list(ColorEnum.__members__.values()))) + ) + + assert v.validate_python(ColorEnum.GREEN) is ColorEnum.GREEN + assert v.validate_python(1 << 63) is ColorEnum.GREEN diff --git a/tests/validators/test_literal.py b/tests/validators/test_literal.py index d294f866c..5f9f942e2 100644 --- a/tests/validators/test_literal.py +++ b/tests/validators/test_literal.py @@ -378,3 +378,14 @@ class Foo(str, Enum): with pytest.raises(ValidationError) as exc_info: v.validate_python('bar_val') assert exc_info.value.errors(include_url=False) == err + + +def test_big_int(): + big_int = 2**64 + 1 + massive_int = 2**128 + 1 + v = SchemaValidator(core_schema.literal_schema([big_int, massive_int])) + assert v.validate_python(big_int) == big_int + assert v.validate_python(massive_int) == massive_int + m = r'Input should be 18446744073709551617 or 340282366920938463463374607431768211457 \[type=literal_error' + with pytest.raises(ValidationError, match=m): + v.validate_python(37)