Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Jan 7, 2025
1 parent fda901c commit 95205d9
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 12 deletions.
16 changes: 8 additions & 8 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,7 +611,7 @@ def _quick_set(swap_dict, swap_td):
else:
return TensorDict._new_unsafe(_swap, batch_size=torch.Size(()))

def __ne__(self, other: object) -> T | bool:
def __ne__(self, other: Any) -> T | bool:
if is_tensorclass(other):
return other != self
if isinstance(other, (dict,)):
Expand All @@ -635,7 +635,7 @@ def __ne__(self, other: object) -> T | bool:
)
return True

def __xor__(self, other: object) -> T | bool:
def __xor__(self, other: Any) -> T | bool:
if is_tensorclass(other):
return other ^ self
if isinstance(other, (dict,)):
Expand All @@ -659,7 +659,7 @@ def __xor__(self, other: object) -> T | bool:
)
return True

def __or__(self, other: object) -> T | bool:
def __or__(self, other: Any) -> T | bool:
if is_tensorclass(other):
return other | self
if isinstance(other, (dict,)):
Expand All @@ -683,7 +683,7 @@ def __or__(self, other: object) -> T | bool:
)
return False

def __eq__(self, other: object) -> T | bool:
def __eq__(self, other: Any) -> T | bool:
if is_tensorclass(other):
return other == self
if isinstance(other, (dict,)):
Expand All @@ -705,7 +705,7 @@ def __eq__(self, other: object) -> T | bool:
)
return False

def __ge__(self, other: object) -> T | bool:
def __ge__(self, other: Any) -> T | bool:
if is_tensorclass(other):
return other <= self
if isinstance(other, (dict,)):
Expand All @@ -727,7 +727,7 @@ def __ge__(self, other: object) -> T | bool:
)
return False

def __gt__(self, other: object) -> T | bool:
def __gt__(self, other: Any) -> T | bool:
if is_tensorclass(other):
return other < self
if isinstance(other, (dict,)):
Expand All @@ -749,7 +749,7 @@ def __gt__(self, other: object) -> T | bool:
)
return False

def __le__(self, other: object) -> T | bool:
def __le__(self, other: Any) -> T | bool:
if is_tensorclass(other):
return other >= self
if isinstance(other, (dict,)):
Expand All @@ -771,7 +771,7 @@ def __le__(self, other: object) -> T | bool:
)
return False

def __lt__(self, other: object) -> T | bool:
def __lt__(self, other: Any) -> T | bool:
if is_tensorclass(other):
return other > self
if isinstance(other, (dict,)):
Expand Down
75 changes: 71 additions & 4 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
convert_ellipsis_to_idx,
DeviceType,
erase_cache,
expand_as_right,
implement_for,
IndexType,
infer_size_impl,
Expand Down Expand Up @@ -10513,7 +10514,14 @@ def clamp_max_(self, other: TensorDictBase | torch.Tensor) -> T:
else:
vals = self._values_list(True, True)
other_val = other
torch._foreach_clamp_max_(vals, other_val)
try:
torch._foreach_clamp_max_(vals, other_val)
except RuntimeError as err:
if "isDifferentiableType" in str(err):
raise RuntimeError(
"Attempted to execute _foreach_clamp_max_ with a differentiable tensor. "
"Use `td.apply(lambda x: x.clamp_max_(val)` instead."
)
return self

def clamp_max(
Expand Down Expand Up @@ -10547,7 +10555,14 @@ def clamp_max(
keys = new_keys
else:
other_val = other
vals = torch._foreach_clamp_max(vals, other_val)
try:
vals = torch._foreach_clamp_max(vals, other_val)
except RuntimeError as err:
if "isDifferentiableType" in str(err):
raise RuntimeError(
"Attempted to execute _foreach_clamp_max with a differentiable tensor. "
"Use `td.apply(lambda x: x.clamp_max(val)` instead."
)
items = dict(zip(keys, vals))

def pop(name, val):
Expand Down Expand Up @@ -10579,7 +10594,15 @@ def clamp_min_(self, other: TensorDictBase | torch.Tensor) -> T:
else:
vals = self._values_list(True, True)
other_val = other
torch._foreach_clamp_min_(vals, other_val)
try:
torch._foreach_clamp_min_(vals, other_val)
except RuntimeError as err:
if "isDifferentiableType" in str(err):
raise RuntimeError(
"Attempted to execute _foreach_clamp_min_ with a differentiable tensor. "
"Use `td.apply(lambda x: x.clamp_min_(val)` instead."
)

return self

def clamp_min(
Expand Down Expand Up @@ -10612,7 +10635,15 @@ def clamp_min(
keys = new_keys
else:
other_val = other
vals = torch._foreach_clamp_min(vals, other_val)
try:
vals = torch._foreach_clamp_min(vals, other_val)
except RuntimeError as err:
if "isDifferentiableType" in str(err):
raise RuntimeError(
"Attempted to execute _foreach_clamp_min with a differentiable tensor. "
"Use `td.apply(lambda x: x.clamp_min(val)` instead."
)

items = dict(zip(keys, vals))

def pop(name, val):
Expand All @@ -10631,6 +10662,42 @@ def pop(name, val):
result.update(items)
return result

def clamp(self, min=None, max=None, *, out=None): # noqa: W605
r"""Clamps all elements in :attr:`self` into the range `[` :attr:`min`, :attr:`max` `]`.

Letting min_value and max_value be :attr:`min` and :attr:`max`, respectively, this returns:

.. math::
y_i = \min(\max(x_i, \text{min\_value}_i), \text{max\_value}_i)

If :attr:`min` is ``None``, there is no lower bound.
Or, if :attr:`max` is ``None`` there is no upper bound.

.. note::
If :attr:`min` is greater than :attr:`max` :func:`torch.clamp(..., min, max) <torch.clamp>`
sets all elements in :attr:`input` to the value of :attr:`max`.

"""
if min is None:
if out is not None:
raise ValueError(
"clamp() with min/max=None isn't implemented with specified output."
)
return self.clamp_max(max)
if max is None:
if out is not None:
raise ValueError(
"clamp() with min/max=None isn't implemented with specified output."
)
return self.clamp_min(min)
if out is None:
return self._fast_apply(lambda x: x.clamp(min, max))
result = self._fast_apply(
lambda x, y: x.clamp(min, max, out=y), out, default=None
)
with out.unlock_() if out.is_locked else contextlib.nullcontext():
return out.update(result)

def pow_(self, other: TensorDictBase | torch.Tensor) -> T:
"""In-place version of :meth:`~.pow`.

Expand Down
12 changes: 12 additions & 0 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -3995,6 +3995,18 @@ def test_chunk(self, td_name, device, dim, chunks):
assert sum([_td.shape[dim] for _td in td_chunks]) == td.shape[dim]
assert (torch.cat(td_chunks, dim) == td).all()

def test_clamp(self, td_name, device):
td = getattr(self, td_name)(device)
tdc = td.clamp(-1, 1)
assert (tdc <= 1).all()
assert (tdc >= -1).all()
if td.requires_grad:
td = td.detach()
tdc = td.clamp(None, 1)
assert (tdc <= 1).all()
tdc = td.clamp(-1)
assert (tdc >= -1).all()

def test_clear(self, td_name, device):
td = getattr(self, td_name)(device)
with td.unlock_():
Expand Down

0 comments on commit 95205d9

Please sign in to comment.