Skip to content

Removing graph breaks in transforms #8056

Open
@NicolasHug

Description

@NicolasHug

This issue tracks progress on graph breaks removal for the v2 transforms.
Restricting to pure tensors input (images) for now, we can figure out the TVTensors and arbitrary structures later.

Kernels

The low-levels kernels are almost all fine. Only 4 kernels are problematic.

import torch
from torchvision.transforms import v2
import torchvision.transforms.v2.functional as F

img = torch.rand(3, 256, 256)

# These kernels don't have graph breaks
# -------------------------------------
# torch.compile(F.get_dimensions_image, fullgraph=True)(img)
# torch.compile(F.get_num_channels_image, fullgraph=True)(img)
# torch.compile(F.get_size_image, fullgraph=True)(img)
# torch.compile(F.erase_image, fullgraph=True)(img, 0, 0, 10, 10, v=torch.tensor(0.5))
# torch.compile(F.adjust_brightness_image, fullgraph=True)(img, .5)
# torch.compile(F.adjust_contrast_image, fullgraph=True)(img, .5)
# torch.compile(F.adjust_gamma_image, fullgraph=True)(img, .5)
# torch.compile(F.adjust_hue_image, fullgraph=True)(img, .5)
# torch.compile(F.adjust_saturation_image, fullgraph=True)(img, .5)
# torch.compile(F.adjust_sharpness_image, fullgraph=True)(img, .5)
# torch.compile(F.autocontrast_image, fullgraph=True)(img)
# torch.compile(F.invert_image, fullgraph=True)(img)
# torch.compile(F.permute_channels_image, fullgraph=True)(img, [2, 1, 0])
# torch.compile(F.posterize_image, fullgraph=True)(img, bits=3)
# torch.compile(F.rgb_to_grayscale_image, fullgraph=True)(img)
# torch.compile(F.solarize_image, fullgraph=True)(img, .4)
# torch.compile(F.affine_image, fullgraph=True)(img, angle=20, translate=[1, 4], scale=1.3, shear=[0, 0])
# torch.compile(F.center_crop_image, fullgraph=True)(img, output_size=(223, 223))
# torch.compile(F.crop_image, fullgraph=True)(img, 0, 10, 10, 10)
# torch.compile(F.elastic_image, fullgraph=True)(img, displacement=torch.randn(1, *img.shape[-2:], 2))
# torch.compile(F.five_crop_image, fullgraph=True)(img, size=(223, 224))
# torch.compile(F.horizontal_flip_image, fullgraph=True)(img)
# torch.compile(F.pad_image, fullgraph=True)(img, [2, 2, 2, 2])
# torch.compile(F.rotate_image, fullgraph=True)(img, angle=30)
# torch.compile(F.ten_crop_image, fullgraph=True)(img, size=(223, 224))
# torch.compile(F.vertical_flip_image, fullgraph=True)(img)
# torch.compile(F.gaussian_blur_image, fullgraph=True)(img, kernel_size=3)
# torch.compile(F.normalize_image, fullgraph=True)(img, mean=0, std=1)
# torch.compile(to_dtype_image, fullgraph=True)(img, dtype=torch.uint8, scale=True)


# These ones have breaks

# torch.compile(F.perspective_image, fullgraph=False)(img, None, None, coefficients=torch.rand(8))
# torch.compile(F.resize_image, fullgraph=False)(img, size=(223, 223))
# torch.compile(F.resized_crop_image, fullgraph=False)(img, 0, 12, 10, 34, (223, 223))

# This one doesn't even compile
# torch.compile(F.equalize_image, fullgraph=False)(img) 

Weird thing: resize_image and resized_crop_image both break on

if (interpolation == InterpolationMode.BILINEAR and "AVX2" in torch.backends.cpu.get_cpu_capability()) or (
, but when calling them both consecutively, one of them starts breaking on
if image.is_contiguous(memory_format=torch.channels_last) and image.shape[0] == 1 and numel != strides[0]:
as well. I have no idea why.

Functionals

As @pmeier noted offline the functionals break on

registry = _KERNEL_REGISTRY.get(functional)

which, technically, can probably be avoided since the dict entry should be constant across one execution (we still need to make sure it won't affect custom kernels that users register, or whether it changes something if we eventually want to allow users to override our default kernels)

TODO: figure out whether the call to log_api_usage_once() introduces a break.

Transforms

The transforms also break where the functionals break.
On top of that the random transforms seem to break on the call to if rand() < self.p although I don't see those breaks when using TORCH_LOGS="graph_breaks", I only see them when using _dynamo.explain(). And _dynamo.explain() in turn doesn't show the graph breaks that happens on the _KERNEL_REGISTRY. 🤷‍♂️

TODO: figure out which one we should trust, and also assess the rest of the transforms more systematically with a script similar to the one above.

CC @pmeier @vfdev-5

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions