-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Removing graph breaks in transforms #8056
Comments
I've run a few quick benchmarks whether or not it is useful to compile kernels in the first place. I've used a simple classification pipeline (random_resized_crop, horizontal_flip, to_dtype, normalize) and pure tensor input:
The slowdown in the functionals stems from the graph break mentioned of # kernel = _get_kernel(horizontal_flip, type(inpt))
kernel = horizontal_flip_image we get the following results
Meaning, if we can somehow resolve the graph break, compiling the functionals will net us the same speedup as compiling the kernels directly. Note that this for now only applies to pure tensors and thus image only pipelines. |
I'll be working on this item:
=> PR on pytorch: pytorch/pytorch#112753 |
Description: - Fixed cat uint8 lowering Otherwise, it gives the following issue on the repro code: ```python def func(x): batch_shape = x.shape[:1] out = torch.cat([x.new_zeros(1).expand(batch_shape + (1,)), x], dim=-1) return out cfunc = torch.compile(func) x = torch.randint(0, 256, size=(3, 255), dtype=torch.uint8) out = cfunc(x) ``` Error message: ``` File "/pytorch/torch/_inductor/lowering.py", line 1037, in <genexpr> if all(len(input.layout.size) == 4 for input in inputs): File "/pytorch/torch/_inductor/ir.py", line 5795, in __getattr__ fn = getattr(self.data, name) torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised: LoweringException: AttributeError: 'ExpandView' object has no attribute 'layout' target: aten.cat.default args[0]: [TensorBox( ExpandView(data=StorageBox( ComputedBuffer(name='buf0', layout=FlexibleLayout('cpu', torch.uint8, size=[1], stride=[1]), data=Pointwise( 'cpu', torch.uint8, def inner_fn(index): _ = index tmp0 = ops.constant(0, torch.uint8) return tmp0 , ranges=[1], origin_node=full, origins={full} )) ), size=[3, 1]) ), TensorBox(StorageBox( InputBuffer(name='arg0_1', layout=FixedLayout('cpu', torch.uint8, size=[3, 255], stride=[255, 1])) ))] args[1]: 1 Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information ``` Context: compiling is not working for torchvision's `F.equalize` op: pytorch/vision#8056 Pull Request resolved: #112753 Approved by: https://github.com/peterbell10
Description: - Fixed cat uint8 lowering Otherwise, it gives the following issue on the repro code: ```python def func(x): batch_shape = x.shape[:1] out = torch.cat([x.new_zeros(1).expand(batch_shape + (1,)), x], dim=-1) return out cfunc = torch.compile(func) x = torch.randint(0, 256, size=(3, 255), dtype=torch.uint8) out = cfunc(x) ``` Error message: ``` File "/pytorch/torch/_inductor/lowering.py", line 1037, in <genexpr> if all(len(input.layout.size) == 4 for input in inputs): File "/pytorch/torch/_inductor/ir.py", line 5795, in __getattr__ fn = getattr(self.data, name) torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised: LoweringException: AttributeError: 'ExpandView' object has no attribute 'layout' target: aten.cat.default args[0]: [TensorBox( ExpandView(data=StorageBox( ComputedBuffer(name='buf0', layout=FlexibleLayout('cpu', torch.uint8, size=[1], stride=[1]), data=Pointwise( 'cpu', torch.uint8, def inner_fn(index): _ = index tmp0 = ops.constant(0, torch.uint8) return tmp0 , ranges=[1], origin_node=full, origins={full} )) ), size=[3, 1]) ), TensorBox(StorageBox( InputBuffer(name='arg0_1', layout=FixedLayout('cpu', torch.uint8, size=[3, 255], stride=[255, 1])) ))] args[1]: 1 Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information ``` Context: compiling is not working for torchvision's `F.equalize` op: pytorch/vision#8056 Pull Request resolved: pytorch#112753 Approved by: https://github.com/peterbell10
EDIT: Wrong conclusion:
|
you can see that likely a tensor image made its way into a bounding box kernel. What exactly are you testing there? That bounding box / mask inputs work properly on a compiled functional? |
Well, I was running tests from #8092 and it is partially my fault as I was running dispatched functions on tensors instead of subclasses... Now, the problem is with recursive error due |
There are two sources of graph breaks in the way we currently dispatch:
Apart from that, nothing needs to change on our side. Dynamo is fine with all the other things we worried about, i.e. global dicts, MRO traversal, ... 🎉 I've reran my benchmark with fixes for the points above and this is what I got out:
I've re-run it a couple of times and the 10µs gap between compiled kernels and functionals is reproducable. Meaning the compiled functionals don't fully get to the same level as the kernels, but they still outperform their eager counterpart. |
One thing that I noticed while playing around with the benchmarks is that dynamo does not give us a strict improvement for individual ops.
|
Note that what's going to be great for torchvision is that I expect pretty much any combination of transformation to be fused into one kernel. There is where the main speed-ups will be coming from. To this end, it'd be useful to try to benchmark through a list of transformation applied one after the other. As I told victor, I expect these wins to heavily overweight the slight regression in resize and flips. On a different note, I'd expect the |
Thanks a lot for this great investigation Philip. @lezcano I tend to have a different intuition from yours: if |
Few other findings on failing tests when kernels are compiled with variable input shape: https://gist.github.com/vfdev-5/5b2733b5641d08c6889a17eda6267aba (logs contain 32k lines totally, so browser may stuck for few seconds on loading...) |
Description: - Fixed cat uint8 lowering Otherwise, it gives the following issue on the repro code: ```python def func(x): batch_shape = x.shape[:1] out = torch.cat([x.new_zeros(1).expand(batch_shape + (1,)), x], dim=-1) return out cfunc = torch.compile(func) x = torch.randint(0, 256, size=(3, 255), dtype=torch.uint8) out = cfunc(x) ``` Error message: ``` File "/pytorch/torch/_inductor/lowering.py", line 1037, in <genexpr> if all(len(input.layout.size) == 4 for input in inputs): File "/pytorch/torch/_inductor/ir.py", line 5795, in __getattr__ fn = getattr(self.data, name) torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised: LoweringException: AttributeError: 'ExpandView' object has no attribute 'layout' target: aten.cat.default args[0]: [TensorBox( ExpandView(data=StorageBox( ComputedBuffer(name='buf0', layout=FlexibleLayout('cpu', torch.uint8, size=[1], stride=[1]), data=Pointwise( 'cpu', torch.uint8, def inner_fn(index): _ = index tmp0 = ops.constant(0, torch.uint8) return tmp0 , ranges=[1], origin_node=full, origins={full} )) ), size=[3, 1]) ), TensorBox(StorageBox( InputBuffer(name='arg0_1', layout=FixedLayout('cpu', torch.uint8, size=[3, 255], stride=[255, 1])) ))] args[1]: 1 Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information ``` Context: compiling is not working for torchvision's `F.equalize` op: pytorch/vision#8056 Pull Request resolved: pytorch#112753 Approved by: https://github.com/peterbell10
This issue tracks progress on graph breaks removal for the v2 transforms.
Restricting to pure tensors input (images) for now, we can figure out the TVTensors and arbitrary structures later.
Kernels
The low-levels kernels are almost all fine. Only 4 kernels are problematic.
Weird thing:
resize_image
andresized_crop_image
both break onvision/torchvision/transforms/v2/functional/_geometry.py
Line 228 in 68161e9
vision/torchvision/transforms/v2/functional/_geometry.py
Line 234 in 68161e9
Functionals
As @pmeier noted offline the functionals break on
vision/torchvision/transforms/v2/functional/_utils.py
Line 99 in 68161e9
which, technically, can probably be avoided since the dict entry should be constant across one execution (we still need to make sure it won't affect custom kernels that users register, or whether it changes something if we eventually want to allow users to override our default kernels)
TODO: figure out whether the call to
log_api_usage_once()
introduces a break.Transforms
The transforms also break where the functionals break.
On top of that the random transforms seem to break on the call to
if rand() < self.p
although I don't see those breaks when usingTORCH_LOGS="graph_breaks"
, I only see them when using_dynamo.explain()
. And_dynamo.explain()
in turn doesn't show the graph breaks that happens on the_KERNEL_REGISTRY
. 🤷♂️TODO: figure out which one we should trust, and also assess the rest of the transforms more systematically with a script similar to the one above.
CC @pmeier @vfdev-5
The text was updated successfully, but these errors were encountered: