diff --git a/piccolo/columns/base.py b/piccolo/columns/base.py index 8c5b84064..1004fe6e3 100644 --- a/piccolo/columns/base.py +++ b/piccolo/columns/base.py @@ -36,6 +36,7 @@ if t.TYPE_CHECKING: # pragma: no cover from piccolo.columns.column_types import ForeignKey + from piccolo.query.methods.select import Select from piccolo.table import Table @@ -595,11 +596,21 @@ def _validate_choices( return True - def is_in(self, values: t.List[t.Any]) -> Where: - if len(values) == 0: - raise ValueError( - "The `values` list argument must contain at least one value." - ) + def is_in(self, values: t.Union[Select, t.List[t.Any]]) -> Where: + from piccolo.query.methods.select import Select + + if isinstance(values, list): + if len(values) == 0: + raise ValueError( + "The `values` list argument must contain at least one " + "value." + ) + elif isinstance(values, Select): + if len(values.columns_delegate.selected_columns) != 1: + raise ValueError( + "A sub select must only return a single column." + ) + return Where(column=self, values=values, operator=In) def not_in(self, values: t.List[t.Any]) -> Where: diff --git a/piccolo/columns/combination.py b/piccolo/columns/combination.py index e080cced2..523678ca4 100644 --- a/piccolo/columns/combination.py +++ b/piccolo/columns/combination.py @@ -13,6 +13,7 @@ if t.TYPE_CHECKING: from piccolo.columns.base import Column + from piccolo.query.methods.select import Select class CombinableMixin(object): @@ -146,18 +147,22 @@ def __init__( self, column: Column, value: t.Any = UNDEFINED, - values: t.Union[Iterable, Undefined] = UNDEFINED, + values: t.Union[Iterable, Undefined, Select] = UNDEFINED, operator: t.Type[ComparisonOperator] = ComparisonOperator, ) -> None: """ We use the UNDEFINED value to show the value was deliberately omitted, vs None, which is a valid value for a where clause. """ + from piccolo.query.methods.select import Select + self.column = column self.value = value if value == UNDEFINED else self.clean_value(value) if values == UNDEFINED: self.values = values + elif isinstance(values, Select): + self.values = values else: self.values = [self.clean_value(i) for i in values] # type: ignore @@ -190,11 +195,16 @@ def clean_value(self, value: t.Any) -> t.Any: @property def values_querystring(self) -> QueryString: + from piccolo.query.methods.select import Select + values = self.values if isinstance(values, Undefined): raise ValueError("values is undefined") + if isinstance(values, Select): + return values.querystrings[0] + template = ", ".join("{}" for _ in values) return QueryString(template, *values)