-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Closed
Labels
featureIs an improvement or enhancementIs an improvement or enhancementhelp wantedOpen to be worked onOpen to be worked onwon't fixThis will not be worked onThis will not be worked on
Description
🚀 Feature
Motivation
NamedTuple is often used to wrap batch data, but if a class inherit NamedTuple, it can't be inherited again, so it's difficult to add some new fields. A better choice is dataclass, but it's not supported by move_data_to_device
Pitch
Alternatives
Additional context
from pytorch_lightning.utilities import move_data_to_device
import torch
from dataclasses import dataclass
from typing import NamedTuple, List
class Data_1(NamedTuple):
example_id: List[str]
x: torch.Tensor
y: torch.Tensor
@dataclass
class Data_2:
example_id: List[str]
x: torch.Tensor
y: torch.Tensor
batch_size = 5
batch_1 = Data_1(
example_id=[f"e-{i}" for i in range(batch_size)],
x = torch.rand(batch_size),
y = torch.rand(batch_size),
)
batch_2 = Data_2(
example_id=[f"e-{i}" for i in range(batch_size)],
x = torch.rand(batch_size),
y = torch.rand(batch_size),
)
device = torch.device("cuda:0")
move_data_to_device(batch=batch_1, device=device)
# Data_1(
# example_id=['e-0', 'e-1', 'e-2', 'e-3', 'e-4'],
# x=tensor([0.3385, 0.6415, 0.8117, 0.6030, 0.2551], device='cuda:0'),
# y=tensor([0.2586, 0.8260, 0.0066, 0.0321, 0.3881], device='cuda:0')
# )
move_data_to_device(batch=batch_2, device=device)
# Data_2(
# example_id=['e-0', 'e-1', 'e-2', 'e-3', 'e-4'],
# x=tensor([0.9459, 0.0063, 0.3763, 0.4537, 0.6941]),
# y=tensor([0.8080, 0.8041, 0.2999, 0.5154, 0.3706])
# )Metadata
Metadata
Assignees
Labels
featureIs an improvement or enhancementIs an improvement or enhancementhelp wantedOpen to be worked onOpen to be worked onwon't fixThis will not be worked onThis will not be worked on