From e9c9f39e4c590742956da4a55b4a4ffd6bc246f4 Mon Sep 17 00:00:00 2001 From: skshetry <18718008+skshetry@users.noreply.github.com> Date: Thu, 12 Dec 2024 06:58:40 +0545 Subject: [PATCH] PytorchDataset: remove internal tqdm bar (#692) --- examples/get_started/torch-loader.py | 33 ++++++++++++++-------------- src/datachain/lib/pytorch.py | 6 +---- 2 files changed, 18 insertions(+), 21 deletions(-) diff --git a/examples/get_started/torch-loader.py b/examples/get_started/torch-loader.py index a8bb43c39..19ca484ca 100644 --- a/examples/get_started/torch-loader.py +++ b/examples/get_started/torch-loader.py @@ -13,6 +13,7 @@ from torch import nn, optim from torch.utils.data import DataLoader from torchvision.transforms import v2 +from tqdm import tqdm from datachain import C, DataChain from datachain.torch import label_to_int @@ -23,7 +24,7 @@ # Define transformation for data preprocessing transform = v2.Compose( [ - v2.ToTensor(), + v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)]), v2.Resize((64, 64)), v2.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ] @@ -77,18 +78,18 @@ def forward(self, x): # Train the model for epoch in range(int(NUM_EPOCHS)): - for i, data in enumerate(train_loader): - inputs, labels = data - optimizer.zero_grad() - - # Forward pass - outputs = model(inputs) - loss = criterion(outputs, labels) - - # Backward pass and optimize - loss.backward() - optimizer.step() - - print(f"[{epoch + 1}, {i + 1:5d}] loss: {loss.item():.3f}") - - print("Finished Training") + with tqdm( + train_loader, desc=f"epoch {epoch + 1}/{NUM_EPOCHS}", unit="batch" + ) as loader: + for data in loader: + inputs, labels = data + optimizer.zero_grad() + + # Forward pass + outputs = model(inputs) + loss = criterion(outputs, labels) + + # Backward pass and optimize + loss.backward() + optimizer.step() + loader.set_postfix(loss=loss.item()) diff --git a/src/datachain/lib/pytorch.py b/src/datachain/lib/pytorch.py index 0e1de153d..e85fb0aae 100644 --- a/src/datachain/lib/pytorch.py +++ b/src/datachain/lib/pytorch.py @@ -7,7 +7,6 @@ from torch.distributed import get_rank, get_world_size from torch.utils.data import IterableDataset, get_worker_info from torchvision.transforms import v2 -from tqdm import tqdm from datachain import Session from datachain.asyn import AsyncMapper @@ -112,10 +111,7 @@ def __iter__(self) -> Iterator[Any]: from datachain.lib.udf import _prefetch_input rows = AsyncMapper(_prefetch_input, rows, workers=self.prefetch).iterate() - - desc = f"Parsed PyTorch dataset for rank={total_rank} worker" - with tqdm(rows, desc=desc, unit=" rows", position=total_rank) as rows_it: - yield from map(self._process_row, rows_it) + yield from map(self._process_row, rows) def _process_row(self, row_features): row = []