Skip to content

Commit

Permalink
feat: Allow predicates to own object and evaluate against borrowed types
Browse files Browse the repository at this point in the history
It is very useful to be able to dynamically construct an object and
have that object owned by the predicate, yet evaluate against an
unowned type related to the owned one. An obvious example is a String
being owned by the predicate but being compared against &strs.

Therefore, implement Predicate for Eq/OrdPredicate that store an
object that implements Borrow for the predicate type, replacing
existing impls of Predicate<T> for Eq/OrdPredicate<T> and
Eq/OrdPredicate<&T>. This is backwards compatible as there are blanket
implementations of Borrow<T> for T and Borrow<T> for &T.

Note that Borrow imposes more requirements than are actually required
and AsRef would be sufficient. However, AsRef doesn't have a blanket
implementation for T and thus the existing impl of Predicate<T> for
EqPredicate<T> is still required, but results in a conflict since T
may also implement AsRef<T>. Requiring Borrow instead of AsRef is
sufficient for common use cases though.

This addresses assert-rs#20 more completely.
  • Loading branch information
rshearman committed Dec 28, 2022
1 parent ee57a38 commit c589b0c
Showing 1 changed file with 31 additions and 71 deletions.
102 changes: 31 additions & 71 deletions src/ord.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2018 The predicates-rs Project Developers.
// Copyright (c) 2018, 2022 The predicates-rs Project Developers.
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// http://www.apache.org/license/LICENSE-2.0> or the MIT license
Expand Down Expand Up @@ -35,26 +35,24 @@ impl fmt::Display for EqOps {
///
/// This is created by the `predicate::{eq, ne}` functions.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct EqPredicate<T>
where
T: fmt::Debug + PartialEq,
{
pub struct EqPredicate<T> {
constant: T,
op: EqOps,
}

impl<T> Predicate<T> for EqPredicate<T>
impl<P, T> Predicate<P> for EqPredicate<T>
where
T: fmt::Debug + PartialEq,
T: std::borrow::Borrow<P> + fmt::Debug,
P: fmt::Debug + PartialEq + ?Sized,
{
fn eval(&self, variable: &T) -> bool {
fn eval(&self, variable: &P) -> bool {
match self.op {
EqOps::Equal => variable.eq(&self.constant),
EqOps::NotEqual => variable.ne(&self.constant),
EqOps::Equal => variable.eq(self.constant.borrow()),
EqOps::NotEqual => variable.ne(self.constant.borrow()),
}
}

fn find_case<'a>(&'a self, expected: bool, variable: &T) -> Option<reflection::Case<'a>> {
fn find_case<'a>(&'a self, expected: bool, variable: &P) -> Option<reflection::Case<'a>> {
utils::default_find_case(self, expected, variable).map(|case| {
case.add_product(reflection::Product::new(
"var",
Expand All @@ -64,32 +62,11 @@ where
}
}

impl<'a, T> Predicate<T> for EqPredicate<&'a T>
where
T: fmt::Debug + PartialEq + ?Sized,
{
fn eval(&self, variable: &T) -> bool {
match self.op {
EqOps::Equal => variable.eq(self.constant),
EqOps::NotEqual => variable.ne(self.constant),
}
}

fn find_case<'b>(&'b self, expected: bool, variable: &T) -> Option<reflection::Case<'b>> {
utils::default_find_case(self, expected, variable).map(|case| {
case.add_product(reflection::Product::new(
"var",
utils::DebugAdapter::new(variable).to_string(),
))
})
}
}

impl<T> reflection::PredicateReflection for EqPredicate<T> where T: fmt::Debug + PartialEq {}
impl<T> reflection::PredicateReflection for EqPredicate<T> where T: fmt::Debug {}

impl<T> fmt::Display for EqPredicate<T>
where
T: fmt::Debug + PartialEq,
T: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let palette = crate::Palette::current();
Expand Down Expand Up @@ -120,6 +97,10 @@ where
/// let predicate_fn = predicate::eq("Hello");
/// assert_eq!(true, predicate_fn.eval("Hello"));
/// assert_eq!(false, predicate_fn.eval("Goodbye"));
///
/// let predicate_fn = predicate::eq(String::from("Hello"));
/// assert_eq!(true, predicate_fn.eval("Hello"));
/// assert_eq!(false, predicate_fn.eval("Goodbye"));
/// ```
pub fn eq<T>(constant: T) -> EqPredicate<T>
where
Expand Down Expand Up @@ -178,28 +159,26 @@ impl fmt::Display for OrdOps {
///
/// This is created by the `predicate::{gt, ge, lt, le}` functions.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct OrdPredicate<T>
where
T: fmt::Debug + PartialOrd,
{
pub struct OrdPredicate<T> {
constant: T,
op: OrdOps,
}

impl<T> Predicate<T> for OrdPredicate<T>
impl<P, T> Predicate<P> for OrdPredicate<T>
where
T: fmt::Debug + PartialOrd,
T: std::borrow::Borrow<P> + fmt::Debug,
P: fmt::Debug + PartialOrd + ?Sized,
{
fn eval(&self, variable: &T) -> bool {
fn eval(&self, variable: &P) -> bool {
match self.op {
OrdOps::LessThan => variable.lt(&self.constant),
OrdOps::LessThanOrEqual => variable.le(&self.constant),
OrdOps::GreaterThanOrEqual => variable.ge(&self.constant),
OrdOps::GreaterThan => variable.gt(&self.constant),
OrdOps::LessThan => variable.lt(self.constant.borrow()),
OrdOps::LessThanOrEqual => variable.le(self.constant.borrow()),
OrdOps::GreaterThanOrEqual => variable.ge(self.constant.borrow()),
OrdOps::GreaterThan => variable.gt(self.constant.borrow()),
}
}

fn find_case<'a>(&'a self, expected: bool, variable: &T) -> Option<reflection::Case<'a>> {
fn find_case<'a>(&'a self, expected: bool, variable: &P) -> Option<reflection::Case<'a>> {
utils::default_find_case(self, expected, variable).map(|case| {
case.add_product(reflection::Product::new(
"var",
Expand All @@ -209,34 +188,11 @@ where
}
}

impl<'a, T> Predicate<T> for OrdPredicate<&'a T>
where
T: fmt::Debug + PartialOrd + ?Sized,
{
fn eval(&self, variable: &T) -> bool {
match self.op {
OrdOps::LessThan => variable.lt(self.constant),
OrdOps::LessThanOrEqual => variable.le(self.constant),
OrdOps::GreaterThanOrEqual => variable.ge(self.constant),
OrdOps::GreaterThan => variable.gt(self.constant),
}
}

fn find_case<'b>(&'b self, expected: bool, variable: &T) -> Option<reflection::Case<'b>> {
utils::default_find_case(self, expected, variable).map(|case| {
case.add_product(reflection::Product::new(
"var",
utils::DebugAdapter::new(variable).to_string(),
))
})
}
}

impl<T> reflection::PredicateReflection for OrdPredicate<T> where T: fmt::Debug + PartialOrd {}
impl<T> reflection::PredicateReflection for OrdPredicate<T> where T: fmt::Debug {}

impl<T> fmt::Display for OrdPredicate<T>
where
T: fmt::Debug + PartialOrd,
T: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let palette = crate::Palette::current();
Expand Down Expand Up @@ -267,6 +223,10 @@ where
/// let predicate_fn = predicate::lt("b");
/// assert_eq!(true, predicate_fn.eval("a"));
/// assert_eq!(false, predicate_fn.eval("c"));
///
/// let predicate_fn = predicate::lt(String::from("b"));
/// assert_eq!(true, predicate_fn.eval("a"));
/// assert_eq!(false, predicate_fn.eval("c"));
/// ```
pub fn lt<T>(constant: T) -> OrdPredicate<T>
where
Expand Down

0 comments on commit c589b0c

Please sign in to comment.