Skip to content

Add support for extended transforms in Lambda and SplitLambda #51

Open
@dxoigmn

Description

@dxoigmn

Right now it is not possible to use extended transforms with Lambda and SplitLambda. For example, it would be useful to do something like:

_target_: mart.transforms.SplitLambda
lambd:
  _target_: mart.transforms.Compose
  transforms:
    - _target_: mart.transforms.RandomHorizontalFlip
split_size_or_sections: 3
lambd_section: -1
dim: 0

Here is a failing example in python:

import torch
import mart

transform = mart.transforms.SplitLambda(lambd=mart.transforms.Compose(transforms=[mart.transforms.RandomHorizontalFlip()]),
                                        split_size_or_sections=3,
                                        lambd_section=-1,
                                        dim=0)
transform(tensor=torch.zeros((6, 320, 240)), target={})

If you don't pass target={}, it works as expected

It would also be nice to support original torchvision transforms:

_target_: mart.transforms.SplitLambda
lambd:
  _target_: mart.transforms.Compose
  transforms:
    - _target_: torchvision.transforms.Normalize
      mean: 0
      std: 255
split_size_or_sections: 3
lambd_section: -1
dim: 0

Here's a failing example in python:

import torch
import mart
import torchvision

transform = mart.transforms.SplitLambda(lambd=mart.transforms.Compose(transforms=[torchvision.transforms.Normalize(mean=0, std=255)]),
                                        split_size_or_sections=3,
                                        lambd_section=-1,
                                        dim=0)
transform(tensor=torch.zeros((6, 320, 240)), target={})

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions