Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement Input for str #1229

Merged
merged 2 commits into from
Mar 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 20 additions & 13 deletions src/errors/line_error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,27 @@ use pyo3::DowncastIntoError;

use jiter::JsonValue;

use crate::input::BorrowInput;
use crate::input::Input;

use super::location::{LocItem, Location};
use super::types::ErrorType;

pub type ValResult<T> = Result<T, ValError>;

pub trait AsErrorValue {
fn as_error_value(&self) -> InputValue;
pub trait ToErrorValue {
fn to_error_value(&self) -> InputValue;
}

impl<'a, T: Input<'a>> AsErrorValue for T {
fn as_error_value(&self) -> InputValue {
Input::as_error_value(self)
impl<'a, T: BorrowInput<'a>> ToErrorValue for T {
fn to_error_value(&self) -> InputValue {
Input::as_error_value(self.borrow_input())
}
}

impl ToErrorValue for &'_ dyn ToErrorValue {
fn to_error_value(&self) -> InputValue {
(**self).to_error_value()
}
}

Expand Down Expand Up @@ -55,11 +62,11 @@ impl From<Vec<ValLineError>> for ValError {
}

impl ValError {
pub fn new(error_type: ErrorType, input: &impl AsErrorValue) -> ValError {
pub fn new(error_type: ErrorType, input: impl ToErrorValue) -> ValError {
Self::LineErrors(vec![ValLineError::new(error_type, input)])
}

pub fn new_with_loc(error_type: ErrorType, input: &impl AsErrorValue, loc: impl Into<LocItem>) -> ValError {
pub fn new_with_loc(error_type: ErrorType, input: impl ToErrorValue, loc: impl Into<LocItem>) -> ValError {
Self::LineErrors(vec![ValLineError::new_with_loc(error_type, input, loc)])
}

Expand Down Expand Up @@ -94,26 +101,26 @@ pub struct ValLineError {
}

impl ValLineError {
pub fn new(error_type: ErrorType, input: &impl AsErrorValue) -> ValLineError {
pub fn new(error_type: ErrorType, input: impl ToErrorValue) -> ValLineError {
Self {
error_type,
input_value: input.as_error_value(),
input_value: input.to_error_value(),
location: Location::default(),
}
}

pub fn new_with_loc(error_type: ErrorType, input: &impl AsErrorValue, loc: impl Into<LocItem>) -> ValLineError {
pub fn new_with_loc(error_type: ErrorType, input: impl ToErrorValue, loc: impl Into<LocItem>) -> ValLineError {
Self {
error_type,
input_value: input.as_error_value(),
input_value: input.to_error_value(),
location: Location::new_some(loc.into()),
}
}

pub fn new_with_full_loc(error_type: ErrorType, input: &impl AsErrorValue, location: Location) -> ValLineError {
pub fn new_with_full_loc(error_type: ErrorType, input: impl ToErrorValue, location: Location) -> ValLineError {
Self {
error_type,
input_value: input.as_error_value(),
input_value: input.to_error_value(),
location,
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/errors/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ mod types;
mod validation_exception;
mod value_exception;

pub use self::line_error::{AsErrorValue, InputValue, ValError, ValLineError, ValResult};
pub use self::line_error::{InputValue, ToErrorValue, ValError, ValLineError, ValResult};
pub use self::location::LocItem;
pub use self::types::{list_all_errors, ErrorType, ErrorTypeDefaults, Number};
pub use self::validation_exception::ValidationError;
Expand Down
6 changes: 3 additions & 3 deletions src/errors/value_exception.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use pyo3::types::{PyDict, PyString};
use crate::input::InputType;
use crate::tools::extract_i64;

use super::line_error::AsErrorValue;
use super::line_error::ToErrorValue;
use super::{ErrorType, ValError};

#[pyclass(extends=PyException, module="pydantic_core._pydantic_core")]
Expand Down Expand Up @@ -106,7 +106,7 @@ impl PydanticCustomError {
}

impl PydanticCustomError {
pub fn into_val_error(self, input: &impl AsErrorValue) -> ValError {
pub fn into_val_error(self, input: impl ToErrorValue) -> ValError {
let error_type = ErrorType::CustomError {
error_type: self.error_type,
message_template: self.message_template,
Expand Down Expand Up @@ -181,7 +181,7 @@ impl PydanticKnownError {
}

impl PydanticKnownError {
pub fn into_val_error(self, input: &impl AsErrorValue) -> ValError {
pub fn into_val_error(self, input: impl ToErrorValue) -> ValError {
ValError::new(self.error_type, input)
}
}
47 changes: 24 additions & 23 deletions src/input/datetime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use std::hash::Hasher;
use strum::EnumMessage;

use super::Input;
use crate::errors::ToErrorValue;
use crate::errors::{ErrorType, ValError, ValResult};
use crate::tools::py_err;

Expand Down Expand Up @@ -285,7 +286,7 @@ impl<'a> EitherDateTime<'a> {
}
}

pub fn bytes_as_date<'a>(input: &'a impl Input<'a>, bytes: &[u8]) -> ValResult<EitherDate<'a>> {
pub fn bytes_as_date<'py>(input: &(impl Input<'py> + ?Sized), bytes: &[u8]) -> ValResult<EitherDate<'py>> {
match Date::parse_bytes(bytes) {
Ok(date) => Ok(date.into()),
Err(err) => Err(ValError::new(
Expand All @@ -298,11 +299,11 @@ pub fn bytes_as_date<'a>(input: &'a impl Input<'a>, bytes: &[u8]) -> ValResult<E
}
}

pub fn bytes_as_time<'a>(
input: &'a impl Input<'a>,
pub fn bytes_as_time<'py>(
input: &(impl Input<'py> + ?Sized),
bytes: &[u8],
microseconds_overflow_behavior: MicrosecondsPrecisionOverflowBehavior,
) -> ValResult<EitherTime<'a>> {
) -> ValResult<EitherTime<'py>> {
match Time::parse_bytes_with_config(
bytes,
&TimeConfig {
Expand All @@ -321,11 +322,11 @@ pub fn bytes_as_time<'a>(
}
}

pub fn bytes_as_datetime<'a, 'b>(
input: &'a impl Input<'a>,
bytes: &'b [u8],
pub fn bytes_as_datetime<'py>(
input: &(impl Input<'py> + ?Sized),
bytes: &[u8],
microseconds_overflow_behavior: MicrosecondsPrecisionOverflowBehavior,
) -> ValResult<EitherDateTime<'a>> {
) -> ValResult<EitherDateTime<'py>> {
match DateTime::parse_bytes_with_config(
bytes,
&TimeConfig {
Expand All @@ -344,11 +345,11 @@ pub fn bytes_as_datetime<'a, 'b>(
}
}

pub fn int_as_datetime<'a>(
input: &'a impl Input<'a>,
pub fn int_as_datetime<'py>(
input: &(impl Input<'py> + ?Sized),
timestamp: i64,
timestamp_microseconds: u32,
) -> ValResult<EitherDateTime> {
) -> ValResult<EitherDateTime<'py>> {
match DateTime::from_timestamp_with_config(
timestamp,
timestamp_microseconds,
Expand Down Expand Up @@ -382,7 +383,7 @@ macro_rules! nan_check {
};
}

pub fn float_as_datetime<'a>(input: &'a impl Input<'a>, timestamp: f64) -> ValResult<EitherDateTime> {
pub fn float_as_datetime<'py>(input: &(impl Input<'py> + ?Sized), timestamp: f64) -> ValResult<EitherDateTime<'py>> {
nan_check!(input, timestamp, DatetimeParsing);
let microseconds = timestamp.fract().abs() * 1_000_000.0;
// checking for extra digits in microseconds is unreliable with large floats,
Expand All @@ -408,11 +409,11 @@ pub fn date_as_datetime<'py>(date: &Bound<'py, PyDate>) -> PyResult<EitherDateTi

const MAX_U32: i64 = u32::MAX as i64;

pub fn int_as_time<'a>(
input: &'a impl Input<'a>,
pub fn int_as_time<'py>(
input: &(impl Input<'py> + ?Sized),
timestamp: i64,
timestamp_microseconds: u32,
) -> ValResult<EitherTime> {
) -> ValResult<EitherTime<'py>> {
let time_timestamp: u32 = match timestamp {
t if t < 0_i64 => {
return Err(ValError::new(
Expand Down Expand Up @@ -447,14 +448,14 @@ pub fn int_as_time<'a>(
}
}

pub fn float_as_time<'a>(input: &'a impl Input<'a>, timestamp: f64) -> ValResult<EitherTime> {
pub fn float_as_time<'py>(input: &(impl Input<'py> + ?Sized), timestamp: f64) -> ValResult<EitherTime<'py>> {
nan_check!(input, timestamp, TimeParsing);
let microseconds = timestamp.fract().abs() * 1_000_000.0;
// round for same reason as above
int_as_time(input, timestamp.floor() as i64, microseconds.round() as u32)
}

fn map_timedelta_err<'a>(input: &'a impl Input<'a>, err: ParseError) -> ValError {
fn map_timedelta_err(input: impl ToErrorValue, err: ParseError) -> ValError {
ValError::new(
ErrorType::TimeDeltaParsing {
error: Cow::Borrowed(err.get_documentation().unwrap_or_default()),
Expand All @@ -464,11 +465,11 @@ fn map_timedelta_err<'a>(input: &'a impl Input<'a>, err: ParseError) -> ValError
)
}

pub fn bytes_as_timedelta<'a, 'b>(
input: &'a impl Input<'a>,
bytes: &'b [u8],
pub fn bytes_as_timedelta<'py>(
input: &(impl Input<'py> + ?Sized),
bytes: &[u8],
microseconds_overflow_behavior: MicrosecondsPrecisionOverflowBehavior,
) -> ValResult<EitherTimedelta<'a>> {
) -> ValResult<EitherTimedelta<'py>> {
match Duration::parse_bytes_with_config(
bytes,
&TimeConfig {
Expand All @@ -481,7 +482,7 @@ pub fn bytes_as_timedelta<'a, 'b>(
}
}

pub fn int_as_duration<'a>(input: &'a impl Input<'a>, total_seconds: i64) -> ValResult<Duration> {
pub fn int_as_duration(input: impl ToErrorValue, total_seconds: i64) -> ValResult<Duration> {
let positive = total_seconds >= 0;
let total_seconds = total_seconds.unsigned_abs();
// we can safely unwrap here since we've guaranteed seconds and microseconds can't cause overflow
Expand All @@ -490,7 +491,7 @@ pub fn int_as_duration<'a>(input: &'a impl Input<'a>, total_seconds: i64) -> Val
Duration::new(positive, days, seconds, 0).map_err(|err| map_timedelta_err(input, err))
}

pub fn float_as_duration<'a>(input: &'a impl Input<'a>, total_seconds: f64) -> ValResult<Duration> {
pub fn float_as_duration(input: impl ToErrorValue, total_seconds: f64) -> ValResult<Duration> {
nan_check!(input, total_seconds, TimeDeltaParsing);
let positive = total_seconds >= 0_f64;
let total_seconds = total_seconds.abs();
Expand Down
25 changes: 16 additions & 9 deletions src/input/input_abstract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use pyo3::exceptions::PyValueError;
use pyo3::types::{PyDict, PyType};
use pyo3::{intern, prelude::*};

use crate::errors::{ErrorTypeDefaults, InputValue, LocItem, ValError, ValResult};
use crate::errors::{ErrorTypeDefaults, InputValue, ValError, ValResult};
use crate::tools::py_err;
use crate::{PyMultiHostUrl, PyUrl};

Expand Down Expand Up @@ -46,7 +46,7 @@ impl TryFrom<&str> for InputType {
/// the convention is to either implement:
/// * `strict_*` & `lax_*` if they have different behavior
/// * or, `validate_*` and `strict_*` to just call `validate_*` if the behavior for strict and lax is the same
pub trait Input<'py>: fmt::Debug + ToPyObject + Into<LocItem> + Sized {
pub trait Input<'py>: fmt::Debug + ToPyObject {
fn as_error_value(&self) -> InputValue;

fn identity(&self) -> Option<usize> {
Expand Down Expand Up @@ -83,9 +83,9 @@ pub trait Input<'py>: fmt::Debug + ToPyObject + Into<LocItem> + Sized {
false
}

fn validate_args(&self) -> ValResult<GenericArguments<'_>>;
fn validate_args(&self) -> ValResult<GenericArguments<'_, 'py>>;

fn validate_dataclass_args<'a>(&'a self, dataclass_name: &str) -> ValResult<GenericArguments<'a>>;
fn validate_dataclass_args<'a>(&'a self, dataclass_name: &str) -> ValResult<GenericArguments<'a, 'py>>;

fn validate_str(&self, strict: bool, coerce_numbers_to_str: bool) -> ValResult<ValidationMatch<EitherString<'_>>>;

Expand Down Expand Up @@ -201,25 +201,25 @@ pub trait Input<'py>: fmt::Debug + ToPyObject + Into<LocItem> + Sized {

fn validate_iter(&self) -> ValResult<GenericIterator>;

fn validate_date(&self, strict: bool) -> ValResult<ValidationMatch<EitherDate>>;
fn validate_date(&self, strict: bool) -> ValResult<ValidationMatch<EitherDate<'py>>>;

fn validate_time(
&self,
strict: bool,
microseconds_overflow_behavior: speedate::MicrosecondsPrecisionOverflowBehavior,
) -> ValResult<ValidationMatch<EitherTime>>;
) -> ValResult<ValidationMatch<EitherTime<'py>>>;

fn validate_datetime(
&self,
strict: bool,
microseconds_overflow_behavior: speedate::MicrosecondsPrecisionOverflowBehavior,
) -> ValResult<ValidationMatch<EitherDateTime>>;
) -> ValResult<ValidationMatch<EitherDateTime<'py>>>;

fn validate_timedelta(
&self,
strict: bool,
microseconds_overflow_behavior: speedate::MicrosecondsPrecisionOverflowBehavior,
) -> ValResult<ValidationMatch<EitherTimedelta>>;
) -> ValResult<ValidationMatch<EitherTimedelta<'py>>>;
}

/// The problem to solve here is that iterating collections often returns owned
Expand All @@ -228,6 +228,13 @@ pub trait Input<'py>: fmt::Debug + ToPyObject + Into<LocItem> + Sized {
/// or borrowed; all we care about is that we can borrow it again with `borrow_input`
/// for some lifetime 'a.
pub trait BorrowInput<'py> {
type Input: Input<'py>;
type Input: Input<'py> + ?Sized;
fn borrow_input(&self) -> &Self::Input;
}

impl<'py, T: Input<'py> + ?Sized> BorrowInput<'py> for &'_ T {
type Input = T;
fn borrow_input(&self) -> &Self::Input {
self
}
}
Loading
Loading