-
Notifications
You must be signed in to change notification settings - Fork 169
Description
🐛 Describe the bug
I found that the batcher fails to work after shuffler/
Example:
from torchdata.datapipes.map import SequenceWrapper, Batcher, Shuffler
dp = SequenceWrapper(range(10))
dp = Shuffler(dp)
b_dp = Batcher(dp, batch_size=3, drop_last=False)
print("Length of batcher:", len(b_dp))
print("Error occured when:", b_dp[len(b_dp) - 1])
The error is produced by the error handler inside the Batcher.
I change the __getitem__
to the following (add one more error in exceptations) to fix my problem:
def __getitem__(self, index) -> DataChunk:
batch: List = []
indices = range(index * self.batch_size, (index + 1) * self.batch_size)
try:
for i in indices:
batch.append(self.datapipe[i])
return self.wrapper_class(batch)
except (IndexError, KeyError):
if not self.drop_last and len(batch) > 0:
return self.wrapper_class(batch)
else:
raise IndexError(f"Index {index} is out of bound.")
Versions
[conda] pytorch-lightning 1.5.10 pypi_0 pypi
[conda] pytorch-ranger 0.1.1 pypi_0 pypi
[conda] torch 1.12.1+cu116 pypi_0 pypi
[conda] torch-optimizer 0.3.0 pypi_0 pypi
[conda] torch-stoi 0.1.2 pypi_0 pypi
[conda] torchaudio 0.12.1+cu116 pypi_0 pypi
[conda] torchdata 0.4.1 pypi_0 pypi
[conda] torchmetrics 0.8.2 pypi_0 pypi
[conda] torchvision 0.13.1+cu116 pypi_0 pypi