Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add lshift and rshift operators #7741

Merged
merged 21 commits into from
Apr 28, 2023
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
22c2c02
Initial commit
abrammer Apr 5, 2023
29dd917
Add auto generated .pyi typed_ops file
abrammer Apr 6, 2023
a8c217d
Merge branch 'pydata:main' into feat/add_lshift_and_rshift_operators
abrammer Apr 8, 2023
e14da23
Add bitshift op test for dask
abrammer Apr 21, 2023
426d559
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 21, 2023
5a78c53
Add bitshift tests to dataarray and variable
abrammer Apr 21, 2023
4091825
Merge branch 'main' into feat/add_lshift_and_rshift_operators
abrammer Apr 21, 2023
9961aa5
Apply typing suggestions from code review
abrammer Apr 21, 2023
13193bc
Fix type checking on test_dask addition
abrammer Apr 22, 2023
e35c22f
Remove new type checking on test_variable edits
abrammer Apr 22, 2023
09033f5
Add typing to test_1d_math and ignore 1 existing line
abrammer Apr 24, 2023
6c26fd0
Add simple bitshift test on groupby ops
abrammer Apr 24, 2023
f2837fe
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 24, 2023
c46fea0
Merge branch 'main' into feat/add_lshift_and_rshift_operators
abrammer Apr 24, 2023
453c3c1
Edit groupby bitshift test to use groups with len>1
abrammer Apr 24, 2023
e87533e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 24, 2023
9163321
Add rshift and lshift to docs
abrammer Apr 24, 2023
46ccf4c
Create new array in docs so examples later don't break
abrammer Apr 24, 2023
37bffa2
Indent second line on whats-new entry
abrammer Apr 25, 2023
3bebcf8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 25, 2023
3efcbe5
Merge branch 'main' into feat/add_lshift_and_rshift_operators
abrammer Apr 25, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions xarray/core/_typed_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@ def __xor__(self, other):
def __or__(self, other):
return self._binary_op(other, operator.or_)

def __lshift__(self, other):
return self._binary_op(other, operator.lshift)

def __rshift__(self, other):
return self._binary_op(other, operator.rshift)

def __lt__(self, other):
return self._binary_op(other, operator.lt)

Expand Down Expand Up @@ -123,6 +129,12 @@ def __ixor__(self, other):
def __ior__(self, other):
return self._inplace_binary_op(other, operator.ior)

def __ilshift__(self, other):
return self._inplace_binary_op(other, operator.ilshift)

def __irshift__(self, other):
return self._inplace_binary_op(other, operator.irshift)

def _unary_op(self, f, *args, **kwargs):
raise NotImplementedError

Expand Down Expand Up @@ -160,6 +172,8 @@ def conjugate(self, *args, **kwargs):
__and__.__doc__ = operator.and_.__doc__
__xor__.__doc__ = operator.xor.__doc__
__or__.__doc__ = operator.or_.__doc__
__lshift__.__doc__ = operator.lshift.__doc__
__rshift__.__doc__ = operator.rshift.__doc__
__lt__.__doc__ = operator.lt.__doc__
__le__.__doc__ = operator.le.__doc__
__gt__.__doc__ = operator.gt.__doc__
Expand All @@ -186,6 +200,8 @@ def conjugate(self, *args, **kwargs):
__iand__.__doc__ = operator.iand.__doc__
__ixor__.__doc__ = operator.ixor.__doc__
__ior__.__doc__ = operator.ior.__doc__
__ilshift__.__doc__ = operator.ilshift.__doc__
__irshift__.__doc__ = operator.irshift.__doc__
__neg__.__doc__ = operator.neg.__doc__
__pos__.__doc__ = operator.pos.__doc__
__abs__.__doc__ = operator.abs.__doc__
Expand Down Expand Up @@ -232,6 +248,12 @@ def __xor__(self, other):
def __or__(self, other):
return self._binary_op(other, operator.or_)

def __lshift__(self, other):
return self._binary_op(other, operator.lshift)

def __rshift__(self, other):
return self._binary_op(other, operator.rshift)

def __lt__(self, other):
return self._binary_op(other, operator.lt)

Expand Down Expand Up @@ -313,6 +335,12 @@ def __ixor__(self, other):
def __ior__(self, other):
return self._inplace_binary_op(other, operator.ior)

def __ilshift__(self, other):
return self._inplace_binary_op(other, operator.ilshift)

def __irshift__(self, other):
return self._inplace_binary_op(other, operator.irshift)

def _unary_op(self, f, *args, **kwargs):
raise NotImplementedError

Expand Down Expand Up @@ -350,6 +378,8 @@ def conjugate(self, *args, **kwargs):
__and__.__doc__ = operator.and_.__doc__
__xor__.__doc__ = operator.xor.__doc__
__or__.__doc__ = operator.or_.__doc__
__lshift__.__doc__ = operator.lshift.__doc__
__rshift__.__doc__ = operator.rshift.__doc__
__lt__.__doc__ = operator.lt.__doc__
__le__.__doc__ = operator.le.__doc__
__gt__.__doc__ = operator.gt.__doc__
Expand All @@ -376,6 +406,8 @@ def conjugate(self, *args, **kwargs):
__iand__.__doc__ = operator.iand.__doc__
__ixor__.__doc__ = operator.ixor.__doc__
__ior__.__doc__ = operator.ior.__doc__
__ilshift__.__doc__ = operator.ilshift.__doc__
__irshift__.__doc__ = operator.irshift.__doc__
__neg__.__doc__ = operator.neg.__doc__
__pos__.__doc__ = operator.pos.__doc__
__abs__.__doc__ = operator.abs.__doc__
Expand Down Expand Up @@ -422,6 +454,12 @@ def __xor__(self, other):
def __or__(self, other):
return self._binary_op(other, operator.or_)

def __lshift__(self, other):
return self._binary_op(other, operator.lshift)

def __rshift__(self, other):
return self._binary_op(other, operator.rshift)

def __lt__(self, other):
return self._binary_op(other, operator.lt)

Expand Down Expand Up @@ -503,6 +541,12 @@ def __ixor__(self, other):
def __ior__(self, other):
return self._inplace_binary_op(other, operator.ior)

def __ilshift__(self, other):
return self._inplace_binary_op(other, operator.ilshift)

def __irshift__(self, other):
return self._inplace_binary_op(other, operator.irshift)

def _unary_op(self, f, *args, **kwargs):
raise NotImplementedError

Expand Down Expand Up @@ -540,6 +584,8 @@ def conjugate(self, *args, **kwargs):
__and__.__doc__ = operator.and_.__doc__
__xor__.__doc__ = operator.xor.__doc__
__or__.__doc__ = operator.or_.__doc__
__lshift__.__doc__ = operator.lshift.__doc__
__rshift__.__doc__ = operator.rshift.__doc__
__lt__.__doc__ = operator.lt.__doc__
__le__.__doc__ = operator.le.__doc__
__gt__.__doc__ = operator.gt.__doc__
Expand All @@ -566,6 +612,8 @@ def conjugate(self, *args, **kwargs):
__iand__.__doc__ = operator.iand.__doc__
__ixor__.__doc__ = operator.ixor.__doc__
__ior__.__doc__ = operator.ior.__doc__
__ilshift__.__doc__ = operator.ilshift.__doc__
__irshift__.__doc__ = operator.irshift.__doc__
__neg__.__doc__ = operator.neg.__doc__
__pos__.__doc__ = operator.pos.__doc__
__abs__.__doc__ = operator.abs.__doc__
Expand Down Expand Up @@ -612,6 +660,12 @@ def __xor__(self, other):
def __or__(self, other):
return self._binary_op(other, operator.or_)

def __lshift__(self, other):
return self._binary_op(other, operator.lshift)

def __rshift__(self, other):
return self._binary_op(other, operator.rshift)

def __lt__(self, other):
return self._binary_op(other, operator.lt)

Expand Down Expand Up @@ -670,6 +724,8 @@ def __ror__(self, other):
__and__.__doc__ = operator.and_.__doc__
__xor__.__doc__ = operator.xor.__doc__
__or__.__doc__ = operator.or_.__doc__
__lshift__.__doc__ = operator.lshift.__doc__
__rshift__.__doc__ = operator.rshift.__doc__
__lt__.__doc__ = operator.lt.__doc__
__le__.__doc__ = operator.le.__doc__
__gt__.__doc__ = operator.gt.__doc__
Expand Down Expand Up @@ -724,6 +780,12 @@ def __xor__(self, other):
def __or__(self, other):
return self._binary_op(other, operator.or_)

def __lshift__(self, other):
return self._binary_op(other, operator.lshift)

def __rshift__(self, other):
return self._binary_op(other, operator.rshift)

def __lt__(self, other):
return self._binary_op(other, operator.lt)

Expand Down Expand Up @@ -782,6 +844,8 @@ def __ror__(self, other):
__and__.__doc__ = operator.and_.__doc__
__xor__.__doc__ = operator.xor.__doc__
__or__.__doc__ = operator.or_.__doc__
__lshift__.__doc__ = operator.lshift.__doc__
__rshift__.__doc__ = operator.rshift.__doc__
__lt__.__doc__ = operator.lt.__doc__
__le__.__doc__ = operator.le.__doc__
__gt__.__doc__ = operator.gt.__doc__
Expand Down
50 changes: 50 additions & 0 deletions xarray/core/_typed_ops.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ class DatasetOpsMixin:
def __and__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ...
def __xor__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ...
def __or__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ...
def __lshift__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ...
def __rshift__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ...
def __lt__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ...
def __le__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ...
def __gt__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ...
Expand Down Expand Up @@ -135,6 +137,18 @@ class DataArrayOpsMixin:
@overload
def __or__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ...
@overload
def __lshift__(self, other: T_Dataset) -> T_Dataset: ...
@overload
def __lshift__(self, other: "DatasetGroupBy") -> "Dataset": ...
abrammer marked this conversation as resolved.
Show resolved Hide resolved
@overload
def __lshift__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ...
@overload
def __rshift__(self, other: T_Dataset) -> T_Dataset: ...
@overload
def __rshift__(self, other: "DatasetGroupBy") -> "Dataset": ...
@overload
def __rshift__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ...
@overload
def __lt__(self, other: T_Dataset) -> T_Dataset: ...
@overload
def __lt__(self, other: "DatasetGroupBy") -> "Dataset": ...
Expand Down Expand Up @@ -305,6 +319,18 @@ class VariableOpsMixin:
@overload
def __or__(self: T_Variable, other: VarCompatible) -> T_Variable: ...
@overload
def __lshift__(self, other: T_Dataset) -> T_Dataset: ...
@overload
def __lshift__(self, other: T_DataArray) -> T_DataArray: ...
@overload
def __lshift__(self: T_Variable, other: VarCompatible) -> T_Variable: ...
@overload
def __rshift__(self, other: T_Dataset) -> T_Dataset: ...
@overload
def __rshift__(self, other: T_DataArray) -> T_DataArray: ...
@overload
def __rshift__(self: T_Variable, other: VarCompatible) -> T_Variable: ...
@overload
def __lt__(self, other: T_Dataset) -> T_Dataset: ...
@overload
def __lt__(self, other: T_DataArray) -> T_DataArray: ...
Expand Down Expand Up @@ -475,6 +501,18 @@ class DatasetGroupByOpsMixin:
@overload
def __or__(self, other: GroupByIncompatible) -> NoReturn: ...
@overload
def __lshift__(self, other: T_Dataset) -> T_Dataset: ...
@overload
def __lshift__(self, other: "DataArray") -> "Dataset": ...
@overload
def __lshift__(self, other: GroupByIncompatible) -> NoReturn: ...
@overload
def __rshift__(self, other: T_Dataset) -> T_Dataset: ...
@overload
def __rshift__(self, other: "DataArray") -> "Dataset": ...
@overload
def __rshift__(self, other: GroupByIncompatible) -> NoReturn: ...
@overload
def __lt__(self, other: T_Dataset) -> T_Dataset: ...
@overload
def __lt__(self, other: "DataArray") -> "Dataset": ...
Expand Down Expand Up @@ -635,6 +673,18 @@ class DataArrayGroupByOpsMixin:
@overload
def __or__(self, other: GroupByIncompatible) -> NoReturn: ...
@overload
def __lshift__(self, other: T_Dataset) -> T_Dataset: ...
@overload
def __lshift__(self, other: T_DataArray) -> T_DataArray: ...
@overload
def __lshift__(self, other: GroupByIncompatible) -> NoReturn: ...
@overload
def __rshift__(self, other: T_Dataset) -> T_Dataset: ...
@overload
def __rshift__(self, other: T_DataArray) -> T_DataArray: ...
@overload
def __rshift__(self, other: GroupByIncompatible) -> NoReturn: ...
@overload
def __lt__(self, other: T_Dataset) -> T_Dataset: ...
@overload
def __lt__(self, other: T_DataArray) -> T_DataArray: ...
Expand Down
2 changes: 2 additions & 0 deletions xarray/core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
"and",
"xor",
"or",
"lshift",
"rshift",
]

# methods which pass on the numpy return value unchanged
Expand Down
13 changes: 13 additions & 0 deletions xarray/tests/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,19 @@ def test_binary_op(self):
self.assertLazyAndIdentical(u + u, v + v)
self.assertLazyAndIdentical(u[0] + u, v[0] + v)

def test_binary_op_bitshift(self) -> None:
# bit shifts only work on ints so we need to generate
# new eager and lazy vars
rng = np.random.default_rng(0)
values = rng.integers(low=-10000, high=10000, size=(4, 6))
data = da.from_array(values, chunks=(2, 2))
u = Variable(("x", "y"), values)
v = Variable(("x", "y"), data)
self.assertLazyAndIdentical(u << 2, v << 2)
self.assertLazyAndIdentical(u << 5, v << 5)
self.assertLazyAndIdentical(u >> 2, v >> 2)
self.assertLazyAndIdentical(u >> 5, v >> 5)

def test_repr(self):
expected = dedent(
"""\
Expand Down
5 changes: 5 additions & 0 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -3926,6 +3926,11 @@ def test_binary_op_propagate_indexes(self) -> None:
actual = (self.dv > 10).xindexes["x"]
assert expected is actual

# use mda for bitshift test as it's type int
actual = (self.mda << 2).xindexes["x"]
expected = self.mda.xindexes["x"]
assert expected is actual

def test_binary_op_join_setting(self) -> None:
dim = "x"
align_type: Final = "outer"
Expand Down
23 changes: 23 additions & 0 deletions xarray/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -808,6 +808,29 @@ def test_groupby_math_more() -> None:
ds + ds.groupby("time.month")


def test_groupby_math_bitshift() -> None:
ds = Dataset(
{
"x": ("level", np.ones(4, dtype=int)),
"y": ("level", np.ones(4, dtype=int) * -1),
"level": [0, 1, 2, 3],
dcherian marked this conversation as resolved.
Show resolved Hide resolved
}
)
left_expected = Dataset(
{
"x": ("level", [1, 2, 4, 8]),
"y": ("level", [-1, -2, -4, -8]),
"level": [0, 1, 2, 3],
}
)

left_actual = ds.groupby("level") << ds.level
assert_equal(left_expected, left_actual)

right_actual = left_expected.groupby("level") >> left_expected.level
assert_equal(ds, right_actual)


@pytest.mark.parametrize("use_flox", [True, False])
def test_groupby_bins_cut_kwargs(use_flox: bool) -> None:
da = xr.DataArray(np.arange(12).reshape(6, 2), dims=("x", "y"))
Expand Down
24 changes: 16 additions & 8 deletions xarray/tests/test_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,9 +344,10 @@ def test_pandas_period_index(self):
assert v[0] == pd.Period("2000", freq="B")
assert "Period('2000-01-03', 'B')" in repr(v)

def test_1d_math(self):
x = 1.0 * np.arange(5)
y = np.ones(5)
@pytest.mark.parametrize("dtype", [float, int])
def test_1d_math(self, dtype: np.typing.DTypeLike) -> None:
x = np.arange(5, dtype=dtype)
y = np.ones(5, dtype=dtype)

# should we need `.to_base_variable()`?
# probably a break that `+v` changes type?
Expand All @@ -360,11 +361,18 @@ def test_1d_math(self):
assert_identical(base_v, v + 0)
assert_identical(base_v, 0 + v)
assert_identical(base_v, v * 1)
if dtype is int:
assert_identical(base_v, v << 0)
assert_array_equal(v << 3, x << 3)
assert_array_equal(v >> 2, x >> 2)
# binary ops with numpy arrays
assert_array_equal((v * x).values, x**2)
assert_array_equal((x * v).values, x**2)
assert_array_equal((x * v).values, x**2) # type: ignore[attr-defined] # TODO: Fix mypy thinking numpy takes priority, GH7780
assert_array_equal(v - y, v - 1)
assert_array_equal(y - v, 1 - v)
if dtype is int:
assert_array_equal(v << x, x << x)
assert_array_equal(v >> x, x >> x)
# verify attributes are dropped
v2 = self.cls(["x"], x, {"units": "meters"})
with set_options(keep_attrs=False):
Expand All @@ -378,10 +386,10 @@ def test_1d_math(self):
# something complicated
assert_array_equal((v**2 * w - 1 + x).values, x**2 * y - 1 + x)
# make sure dtype is preserved (for Index objects)
assert float == (+v).dtype
assert float == (+v).values.dtype
assert float == (0 + v).dtype
assert float == (0 + v).values.dtype
assert dtype == (+v).dtype
assert dtype == (+v).values.dtype
assert dtype == (0 + v).dtype
assert dtype == (0 + v).values.dtype
# check types of returned data
assert isinstance(+v, Variable)
assert not isinstance(+v, IndexVariable)
Expand Down
4 changes: 4 additions & 0 deletions xarray/util/generate_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
("__and__", "operator.and_"),
("__xor__", "operator.xor"),
("__or__", "operator.or_"),
("__lshift__", "operator.lshift"),
("__rshift__", "operator.rshift"),
)
BINOPS_REFLEXIVE = (
("__radd__", "operator.add"),
Expand All @@ -54,6 +56,8 @@
("__iand__", "operator.iand"),
("__ixor__", "operator.ixor"),
("__ior__", "operator.ior"),
("__ilshift__", "operator.ilshift"),
("__irshift__", "operator.irshift"),
)
UNARY_OPS = (
("__neg__", "operator.neg"),
Expand Down