Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding func_to_get_labels argument to DatasetTok #80

Merged
merged 2 commits into from
Oct 10, 2023
Merged

Conversation

Natooz
Copy link
Owner

@Natooz Natooz commented Oct 5, 2023

Following #78, this PR adds the ability to read labels from DatasetTok with a func_to_get_labels provided method.

@leleogere you can review if desired before merging the PR. Otherwise I'll merge in a few days :)

@codecov
Copy link

codecov bot commented Oct 5, 2023

Codecov Report

All modified lines are covered by tests ✅

Comparison is base (01cdac0) 90.29% compared to head (a0bebfa) 90.42%.

Additional details and impacted files
@@            Coverage Diff             @@
##             main      #80      +/-   ##
==========================================
+ Coverage   90.29%   90.42%   +0.13%     
==========================================
  Files          31       31              
  Lines        4305     4334      +29     
==========================================
+ Hits         3887     3919      +32     
+ Misses        418      415       -3     
Files Coverage Δ
miditok/pytorch_data/datasets.py 100.00% <100.00%> (+3.19%) ⬆️
tests/test_pytorch_data_loading.py 86.74% <100.00%> (+1.61%) ⬆️

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Contributor

@leleogere leleogere left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have not tried it on my data yet, but at first glance it looks good!

miditok/pytorch_data/datasets.py Outdated Show resolved Hide resolved
miditok/pytorch_data/datasets.py Outdated Show resolved Hide resolved
miditok/pytorch_data/datasets.py Show resolved Hide resolved
@leleogere
Copy link
Contributor

leleogere commented Oct 6, 2023

Do not merge this immediately, I am getting an error when using the DatasetTok with the DataCollator when trying to torch.stack the labels. I'm going to investigate a bit more.

@leleogere
Copy link
Contributor

leleogere commented Oct 6, 2023

The issue is in the last line of https://github.com/Natooz/MidiTok/blob/main/miditok/pytorch_data/collators.py#L78-L82

        if y is not None:
            if isinstance(y[0], LongTensor):
                y = _pad_batch(y, self.labels_pad_idx, self.pad_on_left)
            else:  # classification
                y = torch.stack(y)

The torch.stack is expecting a list of tensors, but it gets a list of int. The labels should be turned into tensors at some point:

y = torch.stack([torch.tensor(l) for l in y])

I'm not sure that this is the right place to do that but this change makes it work.

However, this might cause an issue when training models as I think most models usually use a labels vector of shape (batch_size, 1), and the y vector would currently be shaped as (batch_size,). More generally, I think y should be allowed to be a vector, so that we can train models with OHE labels for example (shape (batch_size, N)).

@Natooz Natooz closed this Oct 6, 2023
@Natooz Natooz deleted the dataset-tok-labels branch October 6, 2023 13:14
@Natooz Natooz restored the dataset-tok-labels branch October 6, 2023 13:14
@Natooz
Copy link
Owner Author

Natooz commented Oct 6, 2023

(my apologies, wrong operation)

I'm working on passing the labels as tensors

Indeed the data loader will output a batch with labels with shape (batch_size). Now, I have mostly worked with PyTorch models (and loss functions) and it is the expected format. So I'm not sure giving them as one-hot tensors would be very useful.

@Natooz Natooz reopened this Oct 6, 2023
@Natooz Natooz merged commit 7726d7d into main Oct 10, 2023
14 checks passed
@leleogere
Copy link
Contributor

I did not have the time to have a look earlier, but I think there is still an issue with the DataCollator:

Traceback (most recent call last):
  File "/home/gerel/Documents/GrooveMIDI/data_preprocessing.py", line 133, in <module>
    model.run_training(
  File "/home/gerel/Documents/GrooveMIDI/models.py", line 60, in run_training
    writer.add_graph(self, next(iter(train_loader))['input_ids'])
                           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/gerel/anaconda3/envs/groove/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 630, in __next__
    data = self._next_data()
           ^^^^^^^^^^^^^^^^^
  File "/home/gerel/anaconda3/envs/groove/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 674, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/gerel/anaconda3/envs/groove/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py", line 54, in fetch
    return self.collate_fn(data)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/gerel/anaconda3/envs/groove/lib/python3.11/site-packages/miditok/pytorch_data/collators.py", line 80, in __call__
    y = _pad_batch(y, self.labels_pad_idx, self.pad_on_left)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/gerel/anaconda3/envs/groove/lib/python3.11/site-packages/miditok/pytorch_data/collators.py", line 155, in _pad_batch
    length_of_first = batch[0].size(0)
                      ^^^^^^^^^^^^^^^^
IndexError: Dimension specified as 0 but tensor has no dimensions

The problem comes from the line length_of_first = batch[0].size(0) in _pad_batch. Indeed, the y variable is a list of Tensors of shape y[0].shape = torch.Size([]), meaning that we can't access batch[0].size(0).

@Natooz
Copy link
Owner Author

Natooz commented Oct 11, 2023

Thanks for pointing it, I'll take a look tonight and hopefully fix this

@Natooz
Copy link
Owner Author

Natooz commented Oct 12, 2023

It should be fixed now! (Sorry for being late 😅)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants