Skip to content

Add dataclass support in move_data_to_device #7878

@dalek-who

Description

@dalek-who

🚀 Feature

Motivation

NamedTuple is often used to wrap batch data, but if a class inherit NamedTuple, it can't be inherited again, so it's difficult to add some new fields. A better choice is dataclass, but it's not supported by move_data_to_device

Pitch

Alternatives

Additional context

from pytorch_lightning.utilities import move_data_to_device
import torch
from dataclasses import dataclass
from typing import NamedTuple, List

class Data_1(NamedTuple):
    example_id: List[str]
    x: torch.Tensor
    y: torch.Tensor

@dataclass
class Data_2:
    example_id: List[str]
    x: torch.Tensor
    y: torch.Tensor

batch_size = 5

batch_1 = Data_1(
    example_id=[f"e-{i}" for i in range(batch_size)],
    x = torch.rand(batch_size),
    y = torch.rand(batch_size),
)

batch_2 = Data_2(
    example_id=[f"e-{i}" for i in range(batch_size)],
    x = torch.rand(batch_size),
    y = torch.rand(batch_size),
)

device = torch.device("cuda:0")

move_data_to_device(batch=batch_1, device=device)
# Data_1(
# example_id=['e-0', 'e-1', 'e-2', 'e-3', 'e-4'], 
# x=tensor([0.3385, 0.6415, 0.8117, 0.6030, 0.2551], device='cuda:0'), 
# y=tensor([0.2586, 0.8260, 0.0066, 0.0321, 0.3881], device='cuda:0')
# )

move_data_to_device(batch=batch_2, device=device)
# Data_2(
# example_id=['e-0', 'e-1', 'e-2', 'e-3', 'e-4'], 
# x=tensor([0.9459, 0.0063, 0.3763, 0.4537, 0.6941]), 
# y=tensor([0.8080, 0.8041, 0.2999, 0.5154, 0.3706])
# )

Metadata

Metadata

Assignees

No one assigned

    Labels

    featureIs an improvement or enhancementhelp wantedOpen to be worked onwon't fixThis will not be worked on

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions