Skip to content

Commit

Permalink
PytorchDataset: remove internal tqdm bar (#692)
Browse files Browse the repository at this point in the history
  • Loading branch information
skshetry authored Dec 12, 2024
1 parent b67d599 commit e9c9f39
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 21 deletions.
33 changes: 17 additions & 16 deletions examples/get_started/torch-loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision.transforms import v2
from tqdm import tqdm

from datachain import C, DataChain
from datachain.torch import label_to_int
Expand All @@ -23,7 +24,7 @@
# Define transformation for data preprocessing
transform = v2.Compose(
[
v2.ToTensor(),
v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)]),
v2.Resize((64, 64)),
v2.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
Expand Down Expand Up @@ -77,18 +78,18 @@ def forward(self, x):

# Train the model
for epoch in range(int(NUM_EPOCHS)):
for i, data in enumerate(train_loader):
inputs, labels = data
optimizer.zero_grad()

# Forward pass
outputs = model(inputs)
loss = criterion(outputs, labels)

# Backward pass and optimize
loss.backward()
optimizer.step()

print(f"[{epoch + 1}, {i + 1:5d}] loss: {loss.item():.3f}")

print("Finished Training")
with tqdm(
train_loader, desc=f"epoch {epoch + 1}/{NUM_EPOCHS}", unit="batch"
) as loader:
for data in loader:
inputs, labels = data
optimizer.zero_grad()

# Forward pass
outputs = model(inputs)
loss = criterion(outputs, labels)

# Backward pass and optimize
loss.backward()
optimizer.step()
loader.set_postfix(loss=loss.item())
6 changes: 1 addition & 5 deletions src/datachain/lib/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from torch.distributed import get_rank, get_world_size
from torch.utils.data import IterableDataset, get_worker_info
from torchvision.transforms import v2
from tqdm import tqdm

from datachain import Session
from datachain.asyn import AsyncMapper
Expand Down Expand Up @@ -112,10 +111,7 @@ def __iter__(self) -> Iterator[Any]:
from datachain.lib.udf import _prefetch_input

rows = AsyncMapper(_prefetch_input, rows, workers=self.prefetch).iterate()

desc = f"Parsed PyTorch dataset for rank={total_rank} worker"
with tqdm(rows, desc=desc, unit=" rows", position=total_rank) as rows_it:
yield from map(self._process_row, rows_it)
yield from map(self._process_row, rows)

def _process_row(self, row_features):
row = []
Expand Down

0 comments on commit e9c9f39

Please sign in to comment.