Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for extended transforms in Lambda and SplitLambda #51

Open
dxoigmn opened this issue Jan 24, 2023 · 0 comments
Open

Add support for extended transforms in Lambda and SplitLambda #51

dxoigmn opened this issue Jan 24, 2023 · 0 comments

Comments

@dxoigmn
Copy link
Contributor

dxoigmn commented Jan 24, 2023

Right now it is not possible to use extended transforms with Lambda and SplitLambda. For example, it would be useful to do something like:

_target_: mart.transforms.SplitLambda
lambd:
  _target_: mart.transforms.Compose
  transforms:
    - _target_: mart.transforms.RandomHorizontalFlip
split_size_or_sections: 3
lambd_section: -1
dim: 0

Here is a failing example in python:

import torch
import mart

transform = mart.transforms.SplitLambda(lambd=mart.transforms.Compose(transforms=[mart.transforms.RandomHorizontalFlip()]),
                                        split_size_or_sections=3,
                                        lambd_section=-1,
                                        dim=0)
transform(tensor=torch.zeros((6, 320, 240)), target={})

If you don't pass target={}, it works as expected

It would also be nice to support original torchvision transforms:

_target_: mart.transforms.SplitLambda
lambd:
  _target_: mart.transforms.Compose
  transforms:
    - _target_: torchvision.transforms.Normalize
      mean: 0
      std: 255
split_size_or_sections: 3
lambd_section: -1
dim: 0

Here's a failing example in python:

import torch
import mart
import torchvision

transform = mart.transforms.SplitLambda(lambd=mart.transforms.Compose(transforms=[torchvision.transforms.Normalize(mean=0, std=255)]),
                                        split_size_or_sections=3,
                                        lambd_section=-1,
                                        dim=0)
transform(tensor=torch.zeros((6, 320, 240)), target={})
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant