Skip to content

Commit

Permalink
[Performance] Minor efficiency improvements (#703)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Mar 24, 2024
1 parent 8ea3f13 commit 82e83af
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 2 deletions.
8 changes: 8 additions & 0 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,14 @@ def __eq__(self, other: object) -> T | bool:
keys1 = set(self.keys())
keys2 = set(other.keys())
if len(keys1.difference(keys2)) or len(keys1) != len(keys2):
keys1 = sorted(
keys1,
key=lambda key: "".join(key) if isinstance(key, tuple) else key,
)
keys2 = sorted(
keys2,
key=lambda key: "".join(key) if isinstance(key, tuple) else key,
)
raise KeyError(f"keys in tensordicts mismatch, got {keys1} and {keys2}")
d = {}
for key, item1 in self.items():
Expand Down
5 changes: 3 additions & 2 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5252,8 +5252,9 @@ def _register_tensor_class(cls):


def _is_tensor_collection(datatype):
out = _TENSOR_COLLECTION_MEMO.get(datatype, None)
if out is None:
try:
out = _TENSOR_COLLECTION_MEMO[datatype]
except KeyError:
if issubclass(datatype, TensorDictBase):
out = True
elif _is_tensorclass(datatype):
Expand Down

0 comments on commit 82e83af

Please sign in to comment.