Description
This issue tracks progress on graph breaks removal for the v2 transforms.
Restricting to pure tensors input (images) for now, we can figure out the TVTensors and arbitrary structures later.
Kernels
The low-levels kernels are almost all fine. Only 4 kernels are problematic.
import torch
from torchvision.transforms import v2
import torchvision.transforms.v2.functional as F
img = torch.rand(3, 256, 256)
# These kernels don't have graph breaks
# -------------------------------------
# torch.compile(F.get_dimensions_image, fullgraph=True)(img)
# torch.compile(F.get_num_channels_image, fullgraph=True)(img)
# torch.compile(F.get_size_image, fullgraph=True)(img)
# torch.compile(F.erase_image, fullgraph=True)(img, 0, 0, 10, 10, v=torch.tensor(0.5))
# torch.compile(F.adjust_brightness_image, fullgraph=True)(img, .5)
# torch.compile(F.adjust_contrast_image, fullgraph=True)(img, .5)
# torch.compile(F.adjust_gamma_image, fullgraph=True)(img, .5)
# torch.compile(F.adjust_hue_image, fullgraph=True)(img, .5)
# torch.compile(F.adjust_saturation_image, fullgraph=True)(img, .5)
# torch.compile(F.adjust_sharpness_image, fullgraph=True)(img, .5)
# torch.compile(F.autocontrast_image, fullgraph=True)(img)
# torch.compile(F.invert_image, fullgraph=True)(img)
# torch.compile(F.permute_channels_image, fullgraph=True)(img, [2, 1, 0])
# torch.compile(F.posterize_image, fullgraph=True)(img, bits=3)
# torch.compile(F.rgb_to_grayscale_image, fullgraph=True)(img)
# torch.compile(F.solarize_image, fullgraph=True)(img, .4)
# torch.compile(F.affine_image, fullgraph=True)(img, angle=20, translate=[1, 4], scale=1.3, shear=[0, 0])
# torch.compile(F.center_crop_image, fullgraph=True)(img, output_size=(223, 223))
# torch.compile(F.crop_image, fullgraph=True)(img, 0, 10, 10, 10)
# torch.compile(F.elastic_image, fullgraph=True)(img, displacement=torch.randn(1, *img.shape[-2:], 2))
# torch.compile(F.five_crop_image, fullgraph=True)(img, size=(223, 224))
# torch.compile(F.horizontal_flip_image, fullgraph=True)(img)
# torch.compile(F.pad_image, fullgraph=True)(img, [2, 2, 2, 2])
# torch.compile(F.rotate_image, fullgraph=True)(img, angle=30)
# torch.compile(F.ten_crop_image, fullgraph=True)(img, size=(223, 224))
# torch.compile(F.vertical_flip_image, fullgraph=True)(img)
# torch.compile(F.gaussian_blur_image, fullgraph=True)(img, kernel_size=3)
# torch.compile(F.normalize_image, fullgraph=True)(img, mean=0, std=1)
# torch.compile(to_dtype_image, fullgraph=True)(img, dtype=torch.uint8, scale=True)
# These ones have breaks
# torch.compile(F.perspective_image, fullgraph=False)(img, None, None, coefficients=torch.rand(8))
# torch.compile(F.resize_image, fullgraph=False)(img, size=(223, 223))
# torch.compile(F.resized_crop_image, fullgraph=False)(img, 0, 12, 10, 34, (223, 223))
# This one doesn't even compile
# torch.compile(F.equalize_image, fullgraph=False)(img)
Weird thing: resize_image
and resized_crop_image
both break on
Functionals
As @pmeier noted offline the functionals break on
which, technically, can probably be avoided since the dict entry should be constant across one execution (we still need to make sure it won't affect custom kernels that users register, or whether it changes something if we eventually want to allow users to override our default kernels)
TODO: figure out whether the call to log_api_usage_once()
introduces a break.
Transforms
The transforms also break where the functionals break.
On top of that the random transforms seem to break on the call to if rand() < self.p
although I don't see those breaks when using TORCH_LOGS="graph_breaks"
, I only see them when using _dynamo.explain()
. And _dynamo.explain()
in turn doesn't show the graph breaks that happens on the _KERNEL_REGISTRY
. 🤷♂️
TODO: figure out which one we should trust, and also assess the rest of the transforms more systematically with a script similar to the one above.