Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batch: don't create new objects on getitem #1086

Closed
MischaPanch opened this issue Apr 3, 2024 · 15 comments · Fixed by #1098
Closed

Batch: don't create new objects on getitem #1086

MischaPanch opened this issue Apr 3, 2024 · 15 comments · Fixed by #1098
Assignees
Labels
Batch and Buffer Improvements in internal data structures, temporary label breaking changes Changes in public interfaces. Includes small changes or changes in keys refactoring No change to functionality

Comments

@MischaPanch
Copy link
Collaborator

MischaPanch commented Apr 3, 2024

Currently Batch.__getitem__ will always create a new object. This is counterintuitive and destroys equality checks. E.g.,

b = Batch(...)
id1 = id(b[0]) 
id2 = id(b[0])

will result in id1 != id2, which leads to b[0] == b[0] being False

Related to #922

@MischaPanch MischaPanch added refactoring No change to functionality breaking changes Changes in public interfaces. Includes small changes or changes in keys Batch and Buffer Improvements in internal data structures, temporary label labels Apr 3, 2024
@dantp-ai
Copy link
Contributor

dantp-ai commented Apr 4, 2024

I can look into this. I think the example above has a small typo. It should be id2 = id(b[0]).

@MischaPanch
Copy link
Collaborator Author

Yes, you're right about the typo.

From all batch issues this might be the hardest one. I'm not sure how it can be solved at all, tbh.

@dantp-ai
Copy link
Contributor

dantp-ai commented Apr 4, 2024

Interesting:

>>> b = Batch(a=Batch(a=[1, 2, 3]))
>>> id1 = id(b[1])
>>> id2 = id(b[1])
>>> id1 == id2
True
>>> b[1]
Batch(
    a: Batch(
           a: 2,
       ),
)

@MischaPanch
Copy link
Collaborator Author

Even more confusing, since for batches with only subbatches getitem does work as expected, but if a sequence is involved it creates a new object:

b = Batch(a=[1, 2, 3])
b[0] == b[0]
>>> False

@MischaPanch
Copy link
Collaborator Author

Note that if there is a solution, it should also work for slices. Right now

b[:2] == b[:2]
>>> False

One idea: we likely can't make it return the same object, but we could add __eq__ to batch to at least not have the misleading euqalities. This would actually be almost trivial to do! You could just compare the sorted wrapped __dict__

@dantp-ai
Copy link
Contributor

dantp-ai commented Apr 4, 2024

One idea: we likely can't make it return the same object,

Yes, seems to be quite involving at this point. I wonder how torch is able to do it:

>>> a = torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8 , 9]])
>>> id1 = id(a[0, :2])
>>> id2 = id(a[0, :2])
>>> id1 == id2
True
>>> a[0, :2]
tensor([1., 2.])

but we could add eq to batch to at least not have the misleading equalities

Yes, this sounds good. I'll try this out. I don't think it would hurt later if we do find a solution for the object equality.

@MischaPanch
Copy link
Collaborator Author

Yes, seems to be quite involving at this point. I wonder how torch is able to do it:
Torch an numpy can provide "views" of the arrays, so the id is the same.

As I found out just now, python's own list actually cannot do this, so

l = [1, 2, 3]
id(l[:2]) == id(l[:2])
>>> False

Since Batch is more of a python object than an array, it's fine if we don't do better than list :). Let's just implement __eq__ properly and resolve this issue

@MischaPanch
Copy link
Collaborator Author

Huh, actually, I was slightly wrong but in a weird way. There seems some magic happening when a var is assigned to id of a list view.. Anyhow, the id of python list slices is not completely fixed

MischaPanch pushed a commit that referenced this issue Apr 16, 2024
Closes: #1086

### Api Extensions

- Batch received new method: `to_numpy_`. #1098
- `to_dict` in Batch supports also non-recursive conversion. #1098
- Batch `__eq__` now implemented, semantic equality check of batches is
now possible. #1098

### Breaking Changes

- The method `to_numpy` in `data.utils.batch.Batch` is not in-place
anymore. Instead, a new method `to_numpy_` does the conversion in-place.
#1098
@MischaPanch
Copy link
Collaborator Author

For reference: the objects returned on getitem still have different ids. This issue was resolved by implementing __eq__ on batch, which permits a meaningful comparison of the returned objects

@maxhuettenrauch
Copy link
Collaborator

maxhuettenrauch commented Jun 7, 2024

I just had the case where I wanted to compare two batches that contained torch distributions logged during the training process. This comparison fails with a TypeError: iteration over a 0-d tensor. Should __eq__ also work for non array/tensor data?

@dantp-ai
Copy link
Contributor

dantp-ai commented Jun 10, 2024

Thx for spotting it! It should indeed work. There are some tests that cover this, but as I was digging into it I noticed that it fails for some other cases, e.g:

In [41]: b1 = Batch(a={"b": 1})

In [42]: b2 = Batch(a={"b": 1})

In [43]: b1 == b2
Out[43]: True

In [44]: b2 = Batch(a={"c": 2})

In [45]: b1 == b2
Out[45]: False

In [46]: b2 = Batch(b={"c": 2})

In [47]: b1 == b2
Out[47]: False

In [48]: b2 = Batch(b={"b": 1})

In [49]: b1 == b2
Out[49]: False

In [50]: b2 = Batch(a={"b": 10})

In [51]: b1 == b2
...
    682 """
    683 Default compare if `iterable_compare_func` is not provided.
    684 This will compare in sequence order.
    685 """
    686 if t1_from_index is None:
    687     return [((i, i), (x, y)) for i, (x, y) in enumerate(
--> 688         zip_longest(
    689             level.t1, level.t2, fillvalue=ListItemRemovedOrAdded))]
    690 else:
    691     t1_chunk = level.t1[t1_from_index:t1_to_index]

TypeError: iteration over a 0-d array

I will look into it asap. I apologize for the inconvenience.

EDIT:

@dantp-ai
Copy link
Contributor

@maxhuettenrauch So far it seems that the issue is when dealing with zero-dimensional arrays.

To remain flexible wrt to DeepDiff's, I suggest that we perform an additional processing step in Batch.__eq__ that is using the convenient numpy method numpy.atleast_1d to recursively convert in the Batch any scalar inputs to 1-dimensional arrays (while preserving all other arrays).

@MischaPanch
Copy link
Collaborator Author

In the last months I implemented a lot of helper things that also could help with this issue. Gonna open a PR tomorrow and assign you two as reviewers

@dantp-ai
Copy link
Contributor

@MischaPanch Should I go ahead with the proposal above? Or does one of your helper methods already cover this edge case?

@dantp-ai
Copy link
Contributor

dantp-ai commented Aug 1, 2024

@MischaPanch I experimented today with the new Batch API (#1181), specifically Batch.apply_values_transform which I can use with np.atleast_1d to transform any of the 0-dimensional arrays to 1-dimension. Then DeepDiff should work fine with this edge case when checking for batch equality.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Batch and Buffer Improvements in internal data structures, temporary label breaking changes Changes in public interfaces. Includes small changes or changes in keys refactoring No change to functionality
Projects
Development

Successfully merging a pull request may close this issue.

3 participants