Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

flash_attention: support also cross attention. #8427

Merged

Conversation

dudulightricks
Copy link
Contributor

In case that q and kv have different shapes (cross attention) flash attention with spmd fails since it does not support it.

@dudulightricks
Copy link
Contributor Author

@JackCaoG

@JackCaoG JackCaoG added the tpuci label Nov 28, 2024
@dudulightricks dudulightricks force-pushed the bug-fix/support-cross-flash-attn branch 3 times, most recently from 840e3fe to 1bc1fb4 Compare December 1, 2024 10:44
In case that q and kv have different shapes (cross attention) flash
attention with spmd fails since it does not support it.
@dudulightricks dudulightricks force-pushed the bug-fix/support-cross-flash-attn branch from 1bc1fb4 to 9e1e24b Compare December 1, 2024 10:47
@dudulightricks
Copy link
Contributor Author

dudulightricks commented Dec 1, 2024

Rerun the tests when you can. @JackCaoG

@dudulightricks
Copy link
Contributor Author

dudulightricks commented Dec 2, 2024

@JackCaoG when using XLA_DISABLE_FUNCTIONALIZATION=1, the flash attention backward tests are failing with error (unrelated to this PR specifically):


ERROR: test_flash_attention_backward_spmd_data_parallel (__main__.PallasTest)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/opt/xla/test/test_pallas_spmd.py", line 86, in test_flash_attention_backward_spmd_data_parallel
    loss.backward()
  File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/_tensor.py", line 624, in backward
    torch.autograd.backward(
  File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/autograd/__init__.py", line 347, in backward
    _engine_run_backward(
  File "/home/ubuntu/.local/lib/python3.10/site-packages/torch/autograd/graph.py", line 825, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: Bad StatusOr access: INTERNAL: RET_CHECK failure (third_party/tensorflow/compiler/xla/service/sharding_propagation.cc:1466) instruction->has_sharding() Sharding instruction must have a sharding attribute

@JackCaoG
Copy link
Collaborator

JackCaoG commented Dec 2, 2024

hmm, not the first time I saw this issue, let me see if I can do anything before I return my laptop...

@JackCaoG
Copy link
Collaborator

JackCaoG commented Dec 2, 2024

I can repo the issue with

XLA_DISABLE_FUNCTIONALIZATION=1, python test/test_pallas_spmd.py 

======================================================================
ERROR: test_flash_attention_backward_spmd_data_parallel (__main__.PallasTest)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/workspaces/dk3/pytorch/xla/test/test_pallas_spmd.py", line 85, in test_flash_attention_backward_spmd_data_parallel
    loss.backward()
  File "/workspaces/dk3/pytorch/torch/_tensor.py", line 626, in backward
    torch.autograd.backward(
  File "/workspaces/dk3/pytorch/torch/autograd/__init__.py", line 347, in backward
    _engine_run_backward(
  File "/workspaces/dk3/pytorch/torch/autograd/graph.py", line 825, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: Bad StatusOr access: INTERNAL: RET_CHECK failure (third_party/tensorflow/compiler/xla/service/sharding_propagation.cc:1464) instruction->has_sharding() Sharding instruction must have a sharding attribute

so this has nothing to do with segment ID.

I check XLA it is from
https://github.com/openxla/xla/blob/89358a8c397a9c6fc3ec5e4faad17d11da946a94/xla/service/sharding_propagation.cc#L1463-L1466
and this happens during compilation. The issue seems to be there is a HLO has the custom call Sharding but does not have sharding..

@JackCaoG
Copy link
Collaborator

JackCaoG commented Dec 2, 2024

there is something more fundemenrtal, it seems like during the backward there is a graph break, I saw

  %custom-call.58 = f32[1,2,128,8]{3,2,1,0} custom-call(f32[1,2,128,8]{3,2,1,0} %custom-call.57, f32[1,2,128,8]{3,2,1,0} %custom-call.56, f32[1,2,128,8]{3,2,1,0} %custom-call.55, f32[1,2,128,128]{3,2,1,0} %custom-call.54, f32[1,2,128,128]{3,2,1,0} %custom-call.44, /*index=5*/f32[1,2,128,8]{3,2,1,0} %custom-call.34, f32[1,2,128,128]{3,2,1,0} %custom-call.32), custom_call_target="tpu_custom_call", operand_layout_constraints={f32[1,2,128,8]{3,2,1,0}, f32[1,2,128,8]{3,2,1,0}, f32[1,2,128,8]{3,2,1,0}, f32[1,2,128,128]{3,2,1,0}, f32[1,2,128,128]{3,2,1,0}, f32[1,2,128,8]{3,2,1,0}, f32[1,2,128,128]{3,2,1,0}}, metadata={op_type="xla__tpu_custom_call" op_name="xla__tpu_custom_call" source_file="/workspaces/dk3/pytorch/xla/torch_xla/experimental/custom_kernel.py" source_line=423}, backend_config={"custom_call_config": {"body": "TUzvUgFNTElSMjAuMC4wZ2l0AAE9CQEDBQcBAwkDKQsNDxETFRcZGx0fISMlJykrLS8xA3YEAgQdAfkHFwsXFwsLCwsLCwsPCxsLExMTExMTE5MXDw8PEw8LEw8PkxMLCyMPCxcTDx8PxQ8PCwsLCwsLCwsTCwsTCwsLExMTCxcTLwsLCwsTCxMTDwsLExMTCxcPCxMPExsLD0MLGwuFC5MLCwsLKxsLGwsbCxsLGwsbCxsLGwsbBQ2NYXmRKgIqAgH5GxsbGxsbGw8TExcXCxcXFwsXFxcLFxcLGwsPExcLFwsXExMXFxMTExcbCw8TExcTExcTExcTExcTExcTExcTExcfExMXCwsPExcLDxcLExMXFx8LExMXDxMXDxMTFw8TFw8PExcXHy8TC1MTExMXDxMXExcfUw8TFxcLFwcFWVkBHQ8HHxsHKw8vIwsnI0MvAooYHwMDF7YCBTMDAxcmAhWCAo4CBTUFNwU5BTsFPQ0bBT8dPakFQQMDMgP+AwVDHRu6Ah2HxgIdh9ICHRveAh0b6gIdG/YCHRsCAyMNCUEBAAAAAAAAAAEAAAAAAAAAgAAAAAAAAAAIAAAAAAAAABU2AkICHX9HHX8/ERMAFZYDCR09OQVFFcoDMR09sx09tyMNCUEBAAAAAAAAAAEAAAAAAAAAgAAAAAAAAACAAAAAAAAAABV2AjEFRwVJAwUeA08iA08RAQUFSx1WA1oDFYYDCR0bOQMFobIDo6UdG6lhZmZpbmVfbWFwPChkMCwgZDEsIGQyLCBkMykgLT4gKGQwLCBkMSwgZDIsIGQzKT4AEQ0BEQ0FBU0FTwVRBVMFVQVXBVkFWx11KgIFXQVfAwN7XwVhBWMFZQMDe2EdfgIJHUmiAgVnAwMXDgMdSxIDAwmPAgKR/ZM3lTcFaQVrBW0FbxUqAwkFcRVqAwkdS3YDHUs5BXMFdSMBAQEddboDFdIDPwV3AwMX4gMdsbMFeRXmAz8dsbcV7gNHAwW7vQ0fBXsRDQ0DD8HDD8XJy81fz2EN0dPVBX0BCfn5+f8NGWFmZmluZV9tYXA8KGQwLCBkMSkgLT4gKGQwLCBkMSk+AAV/Iw0JQQEAAAAAAAAAAgAAAAAAAAABAAAAAAAAAAEAAAAAAAAABYEFgwWFBYcBEdfb3+Pn6+/zAwUR2RMvCWMDBRHdEy8JZQMFEeETLwlnAwUR5RNFCWkDBRHpE0UJawMFEe0TLwltAwUR8RNFCW8DBRH1Ey8JcQMFDxUNYyN0cHUuZGltZW5zaW9uX3NlbWFudGljczxwYXJhbGxlbD4AI3RwdS5tZW1vcnlfc3BhY2U8dm1lbT4AI3RwdS5jb250cmFjdF9wcmVjaXNpb248ZnAzMj4AI3RwdS5kaW1lbnNpb25fc2VtYW50aWNzPGFyYml0cmFyeT4AI3RwdS5kb3RfZGltZW5zaW9uX251bWJlcnM8WzFdLCBbMV0sIFswXSwgWzBdLCBbMCwgMCwgMSwgMF0sIFtdLCBbXT4AI3RwdS5kb3RfZGltZW5zaW9uX251bWJlcnM8WzFdLCBbMF0sIFswXSwgWzFdLCBbMCwgMCwgMSwgMV0sIFtdLCBbXT4AAwUPFQ1lAwUPFQ1nAwUPFQ1pAwUPFQ1rAwUPFQ1tAwUPFQ1vAwUPFQ1xEQEBFS4CMR0fMgIXBZISAR06Aj4CBYkXBR4XARVGAlICHUoCTgIFixd3NgIBFVYCYgIdWgJeAgWNF3cCBgEdZgJqAgWPF24CzgQBBZEdfUcdH3oCFwWWEgEFkx2GAooCBZUXBUYUARWSAjEdH5YCFwVCFAEDAxeeAhEBAgQVpgIJHQuqAhcFphIBAwOyAvoDBZcRAwEVvgIJHQvCAhcFqhIBFcoCCR0LzgIXBa4SARXWAgkdC9oCFwW6EgEV4gIJHQvmAhcFxhIBFe4CCR0L8gIXBcoSARX6AgkdC/4CFwXOEgEVBgMJHQsKAxcF0hIBJQUJAAAAABUWAwkdCxoDFwXaEgEFmQWbHVGXHQsuAxcFhhMBBZ0dmZcdPgNCAwWfFUYDCR0LSgMXBYITAQMDF1IDEwkQAADgDwWhFV4DCR0LYgMXBZITAR1Rmx0LbgMXBY4TAR1JmxV6AwkdC34DFwWqEwEdUVUdC4oDFwXCEwEdmVUdSVUdC5oDFwUCFAEDAxeiAyUHCQAAAAADCY8GApH9kzeVNx2uAzkFoyMBCSEBAAAAAQAAAAIAAAAAAAAAAwMXTxW+AzEdH8IDFwViFAEdfT8dH84DFwVmFAEdq9YDFwVqFAEDBaHeA6OlIwEJIQEAAAABAAAABAAAAAAAAAATCQEdq+oDFwVuFAEd8gP2AwWlFwWaEgEjYXJpdGgub3ZlcmZsb3c8bm9uZT4AI2FyaXRoLmZhc3RtYXRoPG5vbmU+AAECAgMnBQIEAgQJJwUCBCEJCxf7CQUFAgQhCV0BAgQX+wkFBQIEAgQJXScJBQUCBCEJAQknCQUFAgQCBAkX+wUCBCEJxwUbAQEBAQsLCw8PCw8LFwEFCQEBAQEJAQEBAQQ+EgUBEQG5BwMBJQcRAb8HA7umAhsBAQEBAQEBAQsBCwELAQ8BDwELAQ8BCwEXAQMDcwcDAQ0Hc3kDEwUHGxkGcgIDAQMdAwMzBwMBDQczgQMTBR8hGxQzAyMJAwsdAwO1rQMJFwa1AwcDuwMDQwMDAwMDQwMDAwUGQwMHBxm/wQ8FQ1kJvRm/wREAMwMBBREAMwMDgwcDAQMDhZoCAwEjB4WuAgMBBSUnAwMhAwMDAwMhAwMDAwMhAwMDAwMhAwMDBQYhAxELCSstLzELBiEDBwMzAwMjAwMDAwMjAwMDHQYjAwMDKQMDIwMDAwUGIwMRCws3OTs9CwYjAwcDPwMDJQMDAwMDJQMDAx0GJQMDAykDAyUDAwMFBiUDEQsNQ0VHSQsGJQMHA0sDAycDAwMDAycDAwMDAycDAwMDAycDAwMFBicDFQsPT1FTVQsGJwMFA1cDAykDAwMDAykDAwMDAykDAwMDAykDAwMFBikDFQsRW11fYQsGKQMFA2MDAysDAwMDAysDAwMDAysDAwMDAysDAwMFBisDEQsTZ2lrbQsGKwMHA28DAy0DAwMDAy0DAwMDAy0DAwMDAy0DAwMFBi0DFQsVc3V3eQsGLQMFA3sDA4uJAwUTB4uNAwUHNUF/FQcmA00DBQNlHwc2Ax0DBQWBgyUHOgMdAwUDhQMDU04DAwkXBlMDBQOJJwdTHQMFBYtZFQdmA00DBQONIQdyAx0DBQWHjwMDnYkDBRMHnY0DBQdxTZMVB4IDTQMFA30fB44DHQMFBZWXIQeSAx0DBQWZkQMDVwMDAwMDVwMDAwUGVwMHBxmdnwMDn54DAwcTB5+mAwMHB5tBoykHqgMdAwcFoaUDAzsDAwMDAzsDAwMFBjsDBwcZqasPBTtZCacZqasDA4O2AwMBAwOnBwMBDQeneQMTBQexGQbGAwMBA7MDAzUHAwENBzWBAxMFtbcbFDUDuQkDH0kDA1sDAwMDA1sDAwMFBlsDBwcZu70DAxkDAwMDAxkDAwMDAxkDAwMDAxkDAwMFBhkDEQsXwcPFxwsGGQMHA8kLBhkDEQO/DwUZ2gMNzRfBw8XHAwOvrQMJFwavAwcDzwMDQQMDAwMDQQMDAwUGQQMHBxnT1Q8FQVkJ0RnT1REANQMBBREANQkAAQcRAfcHAw0PCQEBAQEBAQEBAwMBBwMBAwMBBwMBCQQBCQEDBQkHEQEKAgcDDQ8JAQEBAQEBAQEDAwEHAwEDAwEHAwEJBAEJAQMHCQcRAQ4CBwMNDwkBAQEBAQEBAQMDAQcDAQMDAQcDAQkEAQkBAwcJBxEBEgIHAw0PCQEBAQEBAQEBAwMBBwMBAwMBBwMBCQQBCQEDBQkHEQEWAgcDDQ8JAQEBAQEBAQEDAwEHAwEDAwEHAwEJBAEJAQMFCQcRARoCBwMNDwkBAQEBAQEBAQMDAQcDAQMDAQcDAQkEAQkBAwUJBxEBHgIHAw0PCQEBAQEBAQEBAwMBBwMBAwMBBwMBCQQBCQEDBQkHEQEiAgcDDQ8JAQEBAQEBAQEDAwEHAwEDAwEHAwEJBAEJAQMFCQYDAQUBADoTpycLCwsTDRUdCQ1nDRMbMR0LIyEjKS0lJxEpCx0dFSUbDS0ViQkZGRkZGRkZGREbCw03Cw0dJR0TC7cXFxMXFxcjDxkjFxcVIxclGRUZHw8NCR0RYnVpbHRpbgBzdGFibGVfbW9zYWljAHRwdQBhcml0aABtb2R1bGUAYXJpdGguY29uc3RhbnQAdmVjdG9yLmxvYWQAZnVuYy5mdW5jAGZ1bmMucmV0dXJuAHZlY3Rvci5zaGFwZV9jYXN0AGFyaXRoLmNtcGkAdHB1LnZlY3Rvcl9zdG9yZQBzY2YueWllbGQAdHB1Lm1hdG11bAB0cHUucmVwZWF0AHZlY3Rvci5icm9hZGNhc3QAYXJpdGguZXh0dWkAc2NmLmlmAGFyaXRoLmluZGV4X2Nhc3QAYXJpdGguc3ViZgBhcml0aC5tdWxmAGFyaXRoLm11bGkAbWF0aC5leHAAYXJpdGguZGl2ZgBhcml0aC5hZGRmAC91c3IvbG9jYWwvbGliL3B5dGhvbjMuMTAvc2l0ZS1wYWNrYWdlcy9qYXgvZXhwZXJpbWVudGFsL3BhbGxhcy9vcHMvdHB1L2ZsYXNoX2F0dGVudGlvbi5weQBib2R5AHN5bV9uYW1lAGZ1bmN0aW9uX3R5cGUAdHJhbnNmb3JtX2luZGljZXMAd2luZG93X2JvdW5kcwB2YWx1ZQAvZ2V0AF9mbGFzaF9hdHRlbnRpb25fZHFfa2VybmVsAC9zd2FwAC9tdWwAL2RvdF9nZW5lcmFsAC9yZXBlYXQAdHJhbnNmb3JtXzAAdHJhbnNmb3JtXzEAdHJhbnNmb3JtXzIAdHJhbnNmb3JtXzMAdHJhbnNmb3JtXzQAdHJhbnNmb3JtXzUAdHJhbnNmb3JtXzYAdHJhbnNmb3JtXzcAL2VxAC93b3Jrc3BhY2VzL2RrMy9weXRvcmNoL3hsYS90b3JjaF94bGEvZXhwZXJpbWVudGFsL2N1c3RvbV9rZXJuZWwucHkAcHJlZGljYXRlAC9jb252ZXJ0X2VsZW1lbnRfdHlwZQAvY29uZAAvbWFza2VkX2xvYWQAZGltZW5zaW9uX251bWJlcnMAcHJlY2lzaW9uAHRyYW5zcG9zZV9saHMAdHJhbnNwb3NlX3JocwAvc3ViAG9wZXJhbmRTZWdtZW50U2l6ZXMAc3RyaWRlcwBlbmRfb2Zfa3Zfc2VxdWVuY2UAL2Jyb2FkY2FzdF9pbl9kaW0Ac3RhYmxlX21vc2FpYy52ZXJzaW9uAGRpbWVuc2lvbl9zZW1hbnRpY3MAaXRlcmF0aW9uX2JvdW5kcwBzY2FsYXJfcHJlZmV0Y2gAc2NyYXRjaF9vcGVyYW5kcwBtYWluAHdpbmRvd19wYXJhbXMAX2ZsYXNoX2F0dGVudGlvbl9id2RfZHEAdHJhY2VfcGFsbGFzAGJhY2t3YXJkAGFwcGx5AC93b3Jrc3BhY2VzL2RrMy9weXRvcmNoL3RvcmNoL2F1dG9ncmFkL2Z1bmN0aW9uLnB5AC9zY2FuAHJ1bgBvdmVyZmxvd0ZsYWdzAGRpbWVuc2lvbgB0aW1lcwBmYXN0bWF0aAAvZXhwAC9kaXYAL2FkZABzdGFydF9uZXdfc2VxdWVuY2UA", "serialization_format": 1, "needs_layout_passes": true}}
  %custom-call.59 = f32[1,2,128,8]{3,2,1,0} custom-call(f32[1,2,128,8]{3,2,1,0} %custom-call.58), custom_call_target="Sharding", sharding={manual}, metadata={op_type="xla__custom_sharding" op_name="xla__custom_sharding" source_file="/workspaces/dk3/pytorch/xla/torch_xla/distributed/spmd/xla_sharding.py" source_line=537}
  %custom-call.60 = f32[4,2,128,8]{3,2,1,0} custom-call(f32[1,2,128,8]{3,2,1,0} %custom-call.59), custom_call_target="SPMDShardToFullShape", sharding={devices=[4,1,1,1]0,1,2,3}, metadata={op_type="xla__custom_sharding" op_name="xla__custom_sharding" source_file="/workspaces/dk3/pytorch/xla/torch_xla/distributed/spmd/xla_sharding.py" source_line=570}
  ROOT %tuple.61 = (f32[4,2,128,8]{3,2,1,0}) tuple(f32[4,2,128,8]{3,2,1,0} %custom-call.60)
}

before any mark_step, this is the cause of the crash above.. I tried to revert this pr but I still see the same graph break/ Need to first figure out where this is coming from.

@JackCaoG
Copy link
Collaborator

JackCaoG commented Dec 2, 2024

ok I found an even better repo, if I just mark_step in the fwd it will crash too

HloModule IrToHlo.45, entry_computation_layout={(f32[4,2,128,8]{2,3,0,1}, f32[4,2,128,8]{2,3,0,1}, f32[4,2,128,8]{2,3,0,1})->(f32[4,2,128,8]{3,2,1,0}, f32[4,2,128]{2,1,0}, f32[4,2,128]{2,1,0}, f32[])}

%AddComputation.39 (x.40: f32[], y.41: f32[]) -> f32[] {
  %x.40 = f32[] parameter(0)
  %y.41 = f32[] parameter(1)
  ROOT %add.42 = f32[] add(f32[] %x.40, f32[] %y.41)
}

ENTRY %IrToHlo.45 (p0.1: f32[4,2,128,8], p1.3: f32[4,2,128,8], p2.5: f32[4,2,128,8]) -> (f32[4,2,128,8], f32[4,2,128], f32[4,2,128], f32[]) {
  %constant.38 = s32[] constant(8192), metadata={op_type="aten__sum" op_name="aten__sum" source_file="/workspaces/dk3/pytorch/xla/test/test_pallas_spmd.py" source_line=84}
  %p2.5 = f32[4,2,128,8]{2,3,0,1} parameter(2), sharding={devices=[4,1,1,1]0,1,2,3}, metadata={op_type="xla__device_data" op_name="xla__device_data" source_file="/workspaces/dk3/pytorch/xla/torch_xla/distributed/spmd/xla_sharding.py" source_line=625}
  %custom-call.6 = f32[1,2,128,8]{3,2,1,0} custom-call(f32[4,2,128,8]{2,3,0,1} %p2.5), custom_call_target="SPMDFullToShardShape", sharding={manual}, metadata={op_type="xla__custom_sharding" op_name="xla__custom_sharding" source_file="/workspaces/dk3/pytorch/xla/torch_xla/distributed/spmd/xla_sharding.py" source_line=555}
  %p1.3 = f32[4,2,128,8]{2,3,0,1} parameter(1), sharding={devices=[4,1,1,1]0,1,2,3}, metadata={op_type="xla__device_data" op_name="xla__device_data" source_file="/workspaces/dk3/pytorch/xla/torch_xla/distributed/spmd/xla_sharding.py" source_line=625}
  %custom-call.4 = f32[1,2,128,8]{3,2,1,0} custom-call(f32[4,2,128,8]{2,3,0,1} %p1.3), custom_call_target="SPMDFullToShardShape", sharding={manual}, metadata={op_type="xla__custom_sharding" op_name="xla__custom_sharding" source_file="/workspaces/dk3/pytorch/xla/torch_xla/distributed/spmd/xla_sharding.py" source_line=555}
  %p0.1 = f32[4,2,128,8]{2,3,0,1} parameter(0), sharding={devices=[4,1,1,1]0,1,2,3}, metadata={op_type="xla__device_data" op_name="xla__device_data" source_file="/workspaces/dk3/pytorch/xla/torch_xla/distributed/spmd/xla_sharding.py" source_line=625}
  %custom-call.2 = f32[1,2,128,8]{3,2,1,0} custom-call(f32[4,2,128,8]{2,3,0,1} %p0.1), custom_call_target="SPMDFullToShardShape", sharding={manual}, metadata={op_type="xla__custom_sharding" op_name="xla__custom_sharding" source_file="/workspaces/dk3/pytorch/xla/torch_xla/distributed/spmd/xla_sharding.py" source_line=555}
  %custom-call.7 = (f32[1,2,128,8]{3,2,1,0}, f32[1,2,128,128]{3,2,1,0}, f32[1,2,128,128]{3,2,1,0}) custom-call(f32[1,2,128,8]{3,2,1,0} %custom-call.6, f32[1,2,128,8]{3,2,1,0} %custom-call.4, f32[1,2,128,8]{3,2,1,0} %custom-call.2), custom_call_target="tpu_custom_call", operand_layout_constraints={f32[1,2,128,8]{3,2,1,0}, f32[1,2,128,8]{3,2,1,0}, f32[1,2,128,8]{3,2,1,0}}, metadata={op_type="xla__tpu_custom_call" op_name="xla__tpu_custom_call" source_file="/workspaces/dk3/pytorch/xla/torch_xla/experimental/custom_kernel.py" source_line=309}, backend_config={"custom_call_config": {"body": "TUzvUgFNTElSMjAuMC4wZ2l0AAEvCwEDBQcJAQMLAxkNDxETFRcZGx0fISMDYgPiAh8B+QcTCxMLDwsPDw8LCwsLCw8TE5MPCwsLGwsrC8UPCwsLCwuTCwsTCwsLCwsTEwsLMxMTExMXDxMPEw8TGwsPQwsbCwuTCwsLCyMbCxsLGwsbCxsLGwsbGxsbGxsPDw8PFw8XDw8LFw8PCxcPDwsXDw8LFwsPDwsXDw8LExcFDY1heZEqAioCAWcLFxMTFxcfExMXLxcfCxMXHw8LExcLEwsXHwsTFx8PCxMTFxMXC1MLExMXExMXFx8TFy8HBV1JCQNZAR8PBx8rBxsPFyMvGyc3LwsCghMfAwMduQUlAwMdtwUnFcHFBSkdMW0dMXEdMXUFKwUtBS8NGwUxHSu7HSsSAh0rtgIjDQlBAQAAAAAAAAABAAAAAAAAAIAAAAAAAAAACAAAAAAAAAARHQAFMwU1BTcDA14C3gIFOQMFngKiAqYCqgIFO2FmZmluZV9tYXA8KGQwLCBkMSwgZDIsIGQzKSAtPiAoZDAsIGQxLCBkMiwgZDMpPgARDQEFPQU/BUEFQwVFIw0JQQEAAAAAAAAAAQAAAAAAAACAAAAAAAAAAIAAAAAAAAAABUcFSR1NJgIFSwVNBU8FUQVTHT4CWRVCAgsFVQVXIw0DEQEAAAAAAAAAHVICYxVWAgsdcgJnFXYCCx2GAooCHSltFZYCCx0pcRWuAgsdTXUVygILAwV5ew01BVkRDQ0DD3+BFYOFh4k5izkNjY+RBVsBCfn5+f8NGQVdIw0JQQEAAAAAAAAAAgAAAAAAAAABAAAAAAAAAAEAAAAAAAAABV8FYQVjBWUBDZOXm5+jpwMFF5UZJQk7AwUXmRklCT0DBRedGSUJPwMFF6EZJQlBAwUXpRlFCUMDBRepGUUJRwMFFRsNOwMFFRsNPQMFFRsNPwMFFRsNQQMFFRsNQwMFFRsNRxEBAREDARW9Cx0JvxcFCggBHTXDFwVOBQEVx80dycsFZxcF/gsBFc/VHdHTBWkXLTYCARXX3R3Z2wVrFy12BAEV3+cd4eMFbRfl/ggBBW8V6e8d6+0FcRcttgcBFfH3HfP1BXMXSacBHQoCDgIjdHB1LmRpbWVuc2lvbl9zZW1hbnRpY3M8cGFyYWxsZWw+ACN0cHUubWVtb3J5X3NwYWNlPHZtZW0+ACN0cHUuY29udHJhY3RfcHJlY2lzaW9uPGZwMzI+ACN0cHUuZGltZW5zaW9uX3NlbWFudGljczxhcmJpdHJhcnk+ACN0cHUuZG90X2RpbWVuc2lvbl9udW1iZXJzPFsxXSwgWzFdLCBbMF0sIFswXSwgWzAsIDAsIDEsIDBdLCBbXSwgW10+ACN0cHUuZG90X2RpbWVuc2lvbl9udW1iZXJzPFsxXSwgWzBdLCBbMF0sIFsxXSwgWzAsIDAsIDEsIDFdLCBbXSwgW10+AAV1F0mCBQEVFgILHQkaAhcFDggBAwMdIgIlBQkAAAAAFSoCCx0JLgIXBRIIAQMJTwICUf1TJ1UnAwMdOgIlDwkAAID/BXcdCUYCFwWmCAEDBVvWAl1fHSlZBXkdCVoCFwWqCAEFex1mAmMFfQMDHW4CJQ8JAAAAAAV/HQl6AhcFrggBAwVb2gJdXx0pZwWBFY4CCx0JkgIXBbIIAR0JmgIXBb4IAQWDIwEJIQEAAAABAAAABAAAAAAAAAAFhSMBAQEdCbICFwXGCAEVugILHQm+AhcFzggBAwMdxgIlCwkAAAAAHQnOAhcF0ggBAwlPBgJR/VMnVScjdmVjdG9yLmtpbmQ8bWF4aW11bWY+ACN2ZWN0b3Iua2luZDxhZGQ+ACNhcml0aC5mYXN0bWF0aDxub25lPgABAgIDJwUCBAIECRf7CQUFAgQhCTcLJwUCBCEJAQIEJwMCBAknCQUFAgQhCRf7CQUFAgQCBAk3JwUCBAUJJwkFBQIEAgQJBRUBAQEBBwcHBxMTAQUJAQEBAQkBAQEBAQkEUgsFAREBdwcDAR0HEQF9BwOJ+xUBAQEBAQEBAQcBBwEHAQcBEwETAQMDHwMDAwMDHwMDAwMDHwMDAwMDHwMDAwsGHwMRCwkVFxkbBQYfAwsDHQMDIQMDAwMDIQMDAwMDIQMDAwMDIQMDAwsGIQMRCwshIyUnBQYhAwsDKQMDSx4CAwURB0syAgMFBx8rLQMDVzYCAw8TB1dKAgMPBS8xBQZOAgMVAzMNBmEDBQM1FQdhLwMFBS83FwdiAi8DBQM5AwNlagIDDxMHZX4CAw8FOz0FBoICAxUDPw0GaQMFA0EZB2kvAwUFO0MFBmsDFQM1DQZrAwUDRwMDDwMDAwMDDwMDAwMDDwMDAwMDDwMDAwsGDwMXCxNLTU9RBQYPAwUDUwUGDwMXA0kPBQ8zDVcTS01PUQUGbwMVA0ENBm8DBQNZAwMRAwMDAwMRAwMDAwMRAwMDAwMRAwMDCwYRAxcLEV1fYWMFBhEDBQNlBQYRAxcDWw8FETMNaRFdX2FjAwMjAwMDAwMjAwMDAwMjAwMDAwMjAwMDCwYjAxELDWttb3EFBiMDCwNzAwNzwgIDCxEHc9ICAwsHRXV3AwMTAwMDAwMTAwMDAwMTAwMDAwMTAwMDCwYTAxELD3t9f4EFBhMDCwODBQYTAxEDeQ8FEzMNhw97fX+BCQABBxEBqwcDDQ8JAQEBAQEBAQEDAwEHAwEDAwEHAwEJBAEJAQMFCQcRAa0HAw0PCQEBAQEBAQEBAwMBBwMBAwMBBwMBCQQBCQEDBwkHEQGvBwMNDwkBAQEBAQEBAQMDAQcDAQMDAQcDAQkEAQkBAwcJBxEBsQcDDQ8JAQEBAQEBAQEDAwEHAwEDAwEHAwEJBAEJAQMFCQcRAbMHAw0PCQEBAQEBAQEBAwMBBwMBAwMBBwMBCQQBCQEDBQkHEQG1BwMNDwkBAQEBAQEBAQMDAQcDAQMDAQcDAQkEAQkBAwUJBgMBBQEAZhKHESkLGQsTCxkTYyFnDREbLR0LIyEjKS0fCx0dFSUbaxkZGRkZGTENiQslDR0lHRNjtxcTFy8XIyMZGRUlHw8NDwkdEWJ1aWx0aW4Ac3RhYmxlX21vc2FpYwB0cHUAdmVjdG9yAGFyaXRoAG1vZHVsZQBhcml0aC5jb25zdGFudAB2ZWN0b3Iuc2hhcGVfY2FzdABmdW5jLmZ1bmMAZnVuYy5yZXR1cm4AdmVjdG9yLmxvYWQAdmVjdG9yLmJyb2FkY2FzdAB0cHUudmVjdG9yX3N0b3JlAHRwdS5tYXRtdWwAdmVjdG9yLm11bHRpX3JlZHVjdGlvbgBhcml0aC5zdWJmAG1hdGguZXhwAGFyaXRoLmRpdmYAL3Vzci9sb2NhbC9saWIvcHl0aG9uMy4xMC9zaXRlLXBhY2thZ2VzL2pheC9leHBlcmltZW50YWwvcGFsbGFzL29wcy90cHUvZmxhc2hfYXR0ZW50aW9uLnB5AF9mbGFzaF9hdHRlbnRpb25fa2VybmVsX3NpbmdsZV9iYXRjaF9zaW5nbGVfc3RlcABzeW1fbmFtZQBmdW5jdGlvbl90eXBlAHRyYW5zZm9ybV9pbmRpY2VzAHdpbmRvd19ib3VuZHMAdmFsdWUAL2Jyb2FkY2FzdF9pbl9kaW0AL2dldAAvd29ya3NwYWNlcy9kazMvcHl0b3JjaC94bGEvdG9yY2hfeGxhL2V4cGVyaW1lbnRhbC9jdXN0b21fa2VybmVsLnB5AC9zd2FwAF9mbGFzaF9hdHRlbnRpb25fa2VybmVsAHRyYW5zZm9ybV8wAHRyYW5zZm9ybV8xAHRyYW5zZm9ybV8yAHRyYW5zZm9ybV8zAHRyYW5zZm9ybV80AHRyYW5zZm9ybV81AC93b3Jrc3BhY2VzL2RrMy9weXRvcmNoL3hsYS90ZXN0L3Rlc3RfcGFsbGFzX3NwbWQucHkAL2RvdF9nZW5lcmFsAGRpbWVuc2lvbl9udW1iZXJzAHByZWNpc2lvbgB0cmFuc3Bvc2VfbGhzAHRyYW5zcG9zZV9yaHMAa2luZAByZWR1Y3Rpb25fZGltcwBzdGFibGVfbW9zYWljLnZlcnNpb24AZGltZW5zaW9uX3NlbWFudGljcwBpdGVyYXRpb25fYm91bmRzAHNjYWxhcl9wcmVmZXRjaABzY3JhdGNoX29wZXJhbmRzAG1haW4Ad2luZG93X3BhcmFtcwBfZmxhc2hfYXR0ZW50aW9uX2ltcGwAdHJhY2VfcGFsbGFzAGZvcndhcmQAYXBwbHkAL3dvcmtzcGFjZXMvZGszL3B5dG9yY2gvdG9yY2gvYXV0b2dyYWQvZnVuY3Rpb24ucHkAZmxhc2hfYXR0ZW50aW9uAHRlc3RfZmxhc2hfYXR0ZW50aW9uX2JhY2t3YXJkX3NwbWRfZGF0YV9wYXJhbGxlbAA8bW9kdWxlPgAvcmVkdWNlX21heAAvc3ViAGZhc3RtYXRoAC9leHAAL3JlZHVjZV9zdW0AL2RpdgBvcGVyYW5kU2VnbWVudFNpemVzAHN0cmlkZXMA", "cost_estimate": {"flops": 1114112, "transcendentals": 32768, "bytes_accessed": 294912}, "serialization_format": 1, "needs_layout_passes": true}}
  %get-tuple-element.8 = f32[1,2,128,8]{3,2,1,0} get-tuple-element((f32[1,2,128,8]{3,2,1,0}, f32[1,2,128,128]{3,2,1,0}, f32[1,2,128,128]{3,2,1,0}) %custom-call.7), index=0, metadata={op_type="xla__tpu_custom_call" op_name="xla__tpu_custom_call" source_file="/workspaces/dk3/pytorch/xla/torch_xla/experimental/custom_kernel.py" source_line=309}
  %custom-call.11 = f32[1,2,128,8]{3,2,1,0} custom-call(f32[1,2,128,8]{3,2,1,0} %get-tuple-element.8), custom_call_target="Sharding", sharding={manual}, metadata={op_type="xla__custom_sharding" op_name="xla__custom_sharding" source_file="/workspaces/dk3/pytorch/xla/torch_xla/distributed/spmd/xla_sharding.py" source_line=537}
  %custom-call.12 = f32[4,2,128,8]{3,2,1,0} custom-call(f32[1,2,128,8]{3,2,1,0} %custom-call.11), custom_call_target="SPMDShardToFullShape", sharding={devices=[4,1,1,1]0,1,2,3}, metadata={op_type="xla__custom_sharding" op_name="xla__custom_sharding" source_file="/workspaces/dk3/pytorch/xla/torch_xla/distributed/spmd/xla_sharding.py" source_line=570}
  %get-tuple-element.9 = f32[1,2,128,128]{3,2,1,0} get-tuple-element((f32[1,2,128,8]{3,2,1,0}, f32[1,2,128,128]{3,2,1,0}, f32[1,2,128,128]{3,2,1,0}) %custom-call.7), index=1, metadata={op_type="xla__tpu_custom_call" op_name="xla__tpu_custom_call" source_file="/workspaces/dk3/pytorch/xla/torch_xla/experimental/custom_kernel.py" source_line=309}
  %slice.13 = f32[1,2,128,1]{3,2,1,0} slice(f32[1,2,128,128]{3,2,1,0} %get-tuple-element.9), slice={[0:1], [0:2], [0:128], [0:1]}, metadata={op_type="xla__generic_slice" op_name="xla__generic_slice" source_file="/workspaces/dk3/pytorch/xla/torch_xla/distributed/spmd/xla_sharding.py" source_line=537}
  %reshape.14 = f32[1,2,128]{2,1,0} reshape(f32[1,2,128,1]{3,2,1,0} %slice.13), metadata={op_type="aten__view" op_name="aten__view" source_file="/workspaces/dk3/pytorch/xla/torch_xla/distributed/spmd/xla_sharding.py" source_line=537}
  %custom-call.15 = f32[1,2,128]{2,1,0} custom-call(f32[1,2,128]{2,1,0} %reshape.14), custom_call_target="Sharding", metadata={op_type="xla__custom_sharding" op_name="xla__custom_sharding" source_file="/workspaces/dk3/pytorch/xla/torch_xla/distributed/spmd/xla_sharding.py" source_line=537}
  %reshape.16 = f32[1,2,128,1]{3,2,1,0} reshape(f32[1,2,128]{2,1,0} %custom-call.15), metadata={op_type="aten__view" op_name="aten__view" source_file="/workspaces/dk3/pytorch/xla/torch_xla/distributed/spmd/xla_sharding.py" source_line=537}
  %constant.17 = s64[] constant(0), metadata={op_type="xla__update_slice" op_name="xla__update_slice" source_file="/workspaces/dk3/pytorch/xla/torch_xla/distributed/spmd/xla_sharding.py" source_line=537}
  %constant.18 = s64[] constant(0), metadata={op_type="xla__update_slice" op_name="xla__update_slice" source_file="/workspaces/dk3/pytorch/xla/torch_xla/distributed/spmd/xla_sharding.py" source_line=537}
  %constant.19 = s64[] constant(0), metadata={op_type="xla__update_slice" op_name="xla__update_slice" source_file="/workspaces/dk3/pytorch/xla/torch_xla/distributed/spmd/xla_sharding.py" source_line=537}
  %constant.20 = s64[] constant(0), metadata={op_type="xla__update_slice" op_name="xla__update_slice" source_file="/workspaces/dk3/pytorch/xla/torch_xla/distributed/spmd/xla_sharding.py" source_line=537}
  %dynamic-update-slice.21 = f32[1,2,128,128]{3,2,1,0} dynamic-update-slice(f32[1,2,128,128]{3,2,1,0} %get-tuple-element.9, f32[1,2,128,1]{3,2,1,0} %reshape.16, s64[] %constant.17, s64[] %constant.18, s64[] %constant.19, /*index=5*/s64[] %constant.20), metadata={op_type="xla__update_slice" op_name="xla__update_slice" source_file="/workspaces/dk3/pytorch/xla/torch_xla/distributed/spmd/xla_sharding.py" source_line=537}
  %slice.22 = f32[1,2,128,1]{3,2,1,0} slice(f32[1,2,128,128]{3,2,1,0} %dynamic-update-slice.21), slice={[0:1], [0:2], [0:128], [0:1]}, metadata={op_type="xla__generic_slice" op_name="xla__generic_slice" source_file="/workspaces/dk3/pytorch/xla/torch_xla/distributed/spmd/xla_sharding.py" source_line=537}
  %reshape.23 = f32[1,2,128]{2,1,0} reshape(f32[1,2,128,1]{3,2,1,0} %slice.22), sharding={manual}, metadata={op_type="aten__view" op_name="aten__view" source_file="/workspaces/dk3/pytorch/xla/torch_xla/distributed/spmd/xla_sharding.py" source_line=537}
  %custom-call.24 = f32[4,2,128]{2,1,0} custom-call(f32[1,2,128]{2,1,0} %reshape.23), custom_call_target="SPMDShardToFullShape", sharding={devices=[4,1,1]0,1,2,3}, metadata={op_type="xla__custom_sharding" op_name="xla__custom_sharding" source_file="/workspaces/dk3/pytorch/xla/torch_xla/distributed/spmd/xla_sharding.py" source_line=570}
  %get-tuple-element.10 = f32[1,2,128,128]{3,2,1,0} get-tuple-element((f32[1,2,128,8]{3,2,1,0}, f32[1,2,128,128]{3,2,1,0}, f32[1,2,128,128]{3,2,1,0}) %custom-call.7), index=2, metadata={op_type="xla__tpu_custom_call" op_name="xla__tpu_custom_call" source_file="/workspaces/dk3/pytorch/xla/torch_xla/experimental/custom_kernel.py" source_line=309}
  %slice.25 = f32[1,2,128,1]{3,2,1,0} slice(f32[1,2,128,128]{3,2,1,0} %get-tuple-element.10), slice={[0:1], [0:2], [0:128], [0:1]}, metadata={op_type="xla__generic_slice" op_name="xla__generic_slice" source_file="/workspaces/dk3/pytorch/xla/torch_xla/distributed/spmd/xla_sharding.py" source_line=537}
  %reshape.26 = f32[1,2,128]{2,1,0} reshape(f32[1,2,128,1]{3,2,1,0} %slice.25), metadata={op_type="aten__view" op_name="aten__view" source_file="/workspaces/dk3/pytorch/xla/torch_xla/distributed/spmd/xla_sharding.py" source_line=537}
  %custom-call.27 = f32[1,2,128]{2,1,0} custom-call(f32[1,2,128]{2,1,0} %reshape.26), custom_call_target="Sharding", metadata={op_type="xla__custom_sharding" op_name="xla__custom_sharding" source_file="/workspaces/dk3/pytorch/xla/torch_xla/distributed/spmd/xla_sharding.py" source_line=537}
  %reshape.28 = f32[1,2,128,1]{3,2,1,0} reshape(f32[1,2,128]{2,1,0} %custom-call.27), metadata={op_type="aten__view" op_name="aten__view" source_file="/workspaces/dk3/pytorch/xla/torch_xla/distributed/spmd/xla_sharding.py" source_line=537}
  %constant.29 = s64[] constant(0), metadata={op_type="xla__update_slice" op_name="xla__update_slice" source_file="/workspaces/dk3/pytorch/xla/torch_xla/distributed/spmd/xla_sharding.py" source_line=537}
  %constant.30 = s64[] constant(0), metadata={op_type="xla__update_slice" op_name="xla__update_slice" source_file="/workspaces/dk3/pytorch/xla/torch_xla/distributed/spmd/xla_sharding.py" source_line=537}
  %constant.31 = s64[] constant(0), metadata={op_type="xla__update_slice" op_name="xla__update_slice" source_file="/workspaces/dk3/pytorch/xla/torch_xla/distributed/spmd/xla_sharding.py" source_line=537}
  %constant.32 = s64[] constant(0), metadata={op_type="xla__update_slice" op_name="xla__update_slice" source_file="/workspaces/dk3/pytorch/xla/torch_xla/distributed/spmd/xla_sharding.py" source_line=537}
  %dynamic-update-slice.33 = f32[1,2,128,128]{3,2,1,0} dynamic-update-slice(f32[1,2,128,128]{3,2,1,0} %get-tuple-element.10, f32[1,2,128,1]{3,2,1,0} %reshape.28, s64[] %constant.29, s64[] %constant.30, s64[] %constant.31, /*index=5*/s64[] %constant.32), metadata={op_type="xla__update_slice" op_name="xla__update_slice" source_file="/workspaces/dk3/pytorch/xla/torch_xla/distributed/spmd/xla_sharding.py" source_line=537}
  %slice.34 = f32[1,2,128,1]{3,2,1,0} slice(f32[1,2,128,128]{3,2,1,0} %dynamic-update-slice.33), slice={[0:1], [0:2], [0:128], [0:1]}, metadata={op_type="xla__generic_slice" op_name="xla__generic_slice" source_file="/workspaces/dk3/pytorch/xla/torch_xla/distributed/spmd/xla_sharding.py" source_line=537}
  %reshape.35 = f32[1,2,128]{2,1,0} reshape(f32[1,2,128,1]{3,2,1,0} %slice.34), sharding={manual}, metadata={op_type="aten__view" op_name="aten__view" source_file="/workspaces/dk3/pytorch/xla/torch_xla/distributed/spmd/xla_sharding.py" source_line=537}
  %custom-call.36 = f32[4,2,128]{2,1,0} custom-call(f32[1,2,128]{2,1,0} %reshape.35), custom_call_target="SPMDShardToFullShape", sharding={devices=[4,1,1]0,1,2,3}, metadata={op_type="xla__custom_sharding" op_name="xla__custom_sharding" source_file="/workspaces/dk3/pytorch/xla/torch_xla/distributed/spmd/xla_sharding.py" source_line=570}
  %constant.37 = f32[] constant(0), metadata={op_type="aten__sum" op_name="aten__sum" source_file="/workspaces/dk3/pytorch/xla/test/test_pallas_spmd.py" source_line=84}
  %reduce.43 = f32[] reduce(f32[4,2,128,8]{3,2,1,0} %custom-call.12, f32[] %constant.37), dimensions={0,1,2,3}, to_apply=%AddComputation.39, metadata={op_type="aten__sum" op_name="aten__sum" source_file="/workspaces/dk3/pytorch/xla/test/test_pallas_spmd.py" source_line=84}
  ROOT %tuple.44 = (f32[4,2,128,8]{3,2,1,0}, f32[4,2,128]{2,1,0}, f32[4,2,128]{2,1,0}, f32[]) tuple(f32[4,2,128,8]{3,2,1,0} %custom-call.12, f32[4,2,128]{2,1,0} %custom-call.24, f32[4,2,128]{2,1,0} %custom-call.36, f32[] %reduce.43)
}

with functionizaiton

HloModule IrToHlo.29, entry_computation_layout={(f32[4,2,128,8]{2,3,0,1}, f32[4,2,128,8]{2,3,0,1}, f32[4,2,128,8]{2,3,0,1})->(f32[4,2,128,8]{3,2,1,0}, f32[4,2,128]{2,1,0}, f32[4,2,128]{2,1,0}, f32[])}

%AddComputation.23 (x.24: f32[], y.25: f32[]) -> f32[] {
  %x.24 = f32[] parameter(0)
  %y.25 = f32[] parameter(1)
  ROOT %add.26 = f32[] add(f32[] %x.24, f32[] %y.25)
}

ENTRY %IrToHlo.29 (p0.1: f32[4,2,128,8], p1.3: f32[4,2,128,8], p2.5: f32[4,2,128,8]) -> (f32[4,2,128,8], f32[4,2,128], f32[4,2,128], f32[]) {
  %constant.22 = s32[] constant(8192), metadata={op_type="aten__sum" op_name="aten__sum" source_file="/workspaces/dk3/pytorch/xla/test/test_pallas_spmd.py" source_line=84}
  %p2.5 = f32[4,2,128,8]{2,3,0,1} parameter(2), sharding={devices=[4,1,1,1]0,1,2,3}, metadata={op_type="xla__device_data" op_name="xla__device_data" source_file="/workspaces/dk3/pytorch/xla/torch_xla/distributed/spmd/xla_sharding.py" source_line=625}
  %custom-call.6 = f32[1,2,128,8]{3,2,1,0} custom-call(f32[4,2,128,8]{2,3,0,1} %p2.5), custom_call_target="SPMDFullToShardShape", sharding={manual}, metadata={op_type="xla__custom_sharding" op_name="xla__custom_sharding" source_file="/workspaces/dk3/pytorch/xla/torch_xla/distributed/spmd/xla_sharding.py" source_line=555}
  %p1.3 = f32[4,2,128,8]{2,3,0,1} parameter(1), sharding={devices=[4,1,1,1]0,1,2,3}, metadata={op_type="xla__device_data" op_name="xla__device_data" source_file="/workspaces/dk3/pytorch/xla/torch_xla/distributed/spmd/xla_sharding.py" source_line=625}
  %custom-call.4 = f32[1,2,128,8]{3,2,1,0} custom-call(f32[4,2,128,8]{2,3,0,1} %p1.3), custom_call_target="SPMDFullToShardShape", sharding={manual}, metadata={op_type="xla__custom_sharding" op_name="xla__custom_sharding" source_file="/workspaces/dk3/pytorch/xla/torch_xla/distributed/spmd/xla_sharding.py" source_line=555}
  %p0.1 = f32[4,2,128,8]{2,3,0,1} parameter(0), sharding={devices=[4,1,1,1]0,1,2,3}, metadata={op_type="xla__device_data" op_name="xla__device_data" source_file="/workspaces/dk3/pytorch/xla/torch_xla/distributed/spmd/xla_sharding.py" source_line=625}
  %custom-call.2 = f32[1,2,128,8]{3,2,1,0} custom-call(f32[4,2,128,8]{2,3,0,1} %p0.1), custom_call_target="SPMDFullToShardShape", sharding={manual}, metadata={op_type="xla__custom_sharding" op_name="xla__custom_sharding" source_file="/workspaces/dk3/pytorch/xla/torch_xla/distributed/spmd/xla_sharding.py" source_line=555}
  %custom-call.7 = (f32[1,2,128,8]{3,2,1,0}, f32[1,2,128,128]{3,2,1,0}, f32[1,2,128,128]{3,2,1,0}) custom-call(f32[1,2,128,8]{3,2,1,0} %custom-call.6, f32[1,2,128,8]{3,2,1,0} %custom-call.4, f32[1,2,128,8]{3,2,1,0} %custom-call.2), custom_call_target="tpu_custom_call", operand_layout_constraints={f32[1,2,128,8]{3,2,1,0}, f32[1,2,128,8]{3,2,1,0}, f32[1,2,128,8]{3,2,1,0}}, metadata={op_type="xla__tpu_custom_call" op_name="xla__tpu_custom_call" source_file="/workspaces/dk3/pytorch/xla/torch_xla/experimental/custom_kernel.py" source_line=309}, backend_config={"custom_call_config": {"body": "TUzvUgFNTElSMjAuMC4wZ2l0AAEvCwEDBQcJAQMLAxkNDxETFRcZGx0fISMDYgPiAh8B+QcTCxMLDwsPDw8LCwsLCw8TE5MPCwsLGwsrC8UPCwsLCwuTCwsTCwsLCwsTEwsLMxMTExMXDxMPEw8TGwsPQwsbCwuTCwsLCyMbCxsLGwsbCxsLGwsbGxsbGxsPDw8PFw8XDw8LFw8PCxcPDwsXDw8LFwsPDwsXDw8LExcFDY1heZEqAioCAWcLFxMTFxcfExMXLxcfCxMXHw8LExcLEwsXHwsTFx8PCxMTFxMXC1MLExMXExMXFx8TFy8HBV1JCQNZAR8PBx8rBxsPFyMvGyc3LwsCghMfAwMduQUlAwMdtwUnFcHFBSkdMW0dMXEdMXUFKwUtBS8NGwUxHSu7HSsSAh0rtgIjDQlBAQAAAAAAAAABAAAAAAAAAIAAAAAAAAAACAAAAAAAAAARHQAFMwU1BTcDA14C3gIFOQMFngKiAqYCqgIFO2FmZmluZV9tYXA8KGQwLCBkMSwgZDIsIGQzKSAtPiAoZDAsIGQxLCBkMiwgZDMpPgARDQEFPQU/BUEFQwVFIw0JQQEAAAAAAAAAAQAAAAAAAACAAAAAAAAAAIAAAAAAAAAABUcFSR1NJgIFSwVNBU8FUQVTHT4CWRVCAgsFVQVXIw0DEQEAAAAAAAAAHVICYxVWAgsdcgJnFXYCCx2GAooCHSltFZYCCx0pcRWuAgsdTXUVygILAwV5ew01BVkRDQ0DD3+BFYOFh4k5izkNjY+RBVsBCfn5+f8NGQVdIw0JQQEAAAAAAAAAAgAAAAAAAAABAAAAAAAAAAEAAAAAAAAABV8FYQVjBWUBDZOXm5+jpwMFF5UZJQk7AwUXmRklCT0DBRedGSUJPwMFF6EZJQlBAwUXpRlFCUMDBRepGUUJRwMFFRsNOwMFFRsNPQMFFRsNPwMFFRsNQQMFFRsNQwMFFRsNRxEBAREDARW9Cx0JvxcFCggBHTXDFwVOBQEVx80dycsFZxcF/gsBFc/VHdHTBWkXLTYCARXX3R3Z2wVrFy12BAEV3+cd4eMFbRfl/ggBBW8V6e8d6+0FcRcttgcBFfH3HfP1BXMXSacBHQoCDgIjdHB1LmRpbWVuc2lvbl9zZW1hbnRpY3M8cGFyYWxsZWw+ACN0cHUubWVtb3J5X3NwYWNlPHZtZW0+ACN0cHUuY29udHJhY3RfcHJlY2lzaW9uPGZwMzI+ACN0cHUuZGltZW5zaW9uX3NlbWFudGljczxhcmJpdHJhcnk+ACN0cHUuZG90X2RpbWVuc2lvbl9udW1iZXJzPFsxXSwgWzFdLCBbMF0sIFswXSwgWzAsIDAsIDEsIDBdLCBbXSwgW10+ACN0cHUuZG90X2RpbWVuc2lvbl9udW1iZXJzPFsxXSwgWzBdLCBbMF0sIFsxXSwgWzAsIDAsIDEsIDFdLCBbXSwgW10+AAV1F0mCBQEVFgILHQkaAhcFDggBAwMdIgIlBQkAAAAAFSoCCx0JLgIXBRIIAQMJTwICUf1TJ1UnAwMdOgIlDwkAAID/BXcdCUYCFwWmCAEDBVvWAl1fHSlZBXkdCVoCFwWqCAEFex1mAmMFfQMDHW4CJQ8JAAAAAAV/HQl6AhcFrggBAwVb2gJdXx0pZwWBFY4CCx0JkgIXBbIIAR0JmgIXBb4IAQWDIwEJIQEAAAABAAAABAAAAAAAAAAFhSMBAQEdCbICFwXGCAEVugILHQm+AhcFzggBAwMdxgIlCwkAAAAAHQnOAhcF0ggBAwlPBgJR/VMnVScjdmVjdG9yLmtpbmQ8bWF4aW11bWY+ACN2ZWN0b3Iua2luZDxhZGQ+ACNhcml0aC5mYXN0bWF0aDxub25lPgABAgIDJwUCBAIECRf7CQUFAgQhCTcLJwUCBCEJAQIEJwMCBAknCQUFAgQhCRf7CQUFAgQCBAk3JwUCBAUJJwkFBQIEAgQJBRUBAQEBBwcHBxMTAQUJAQEBAQkBAQEBAQkEUgsFAREBdwcDAR0HEQF9BwOJ+xUBAQEBAQEBAQcBBwEHAQcBEwETAQMDHwMDAwMDHwMDAwMDHwMDAwMDHwMDAwsGHwMRCwkVFxkbBQYfAwsDHQMDIQMDAwMDIQMDAwMDIQMDAwMDIQMDAwsGIQMRCwshIyUnBQYhAwsDKQMDSx4CAwURB0syAgMFBx8rLQMDVzYCAw8TB1dKAgMPBS8xBQZOAgMVAzMNBmEDBQM1FQdhLwMFBS83FwdiAi8DBQM5AwNlagIDDxMHZX4CAw8FOz0FBoICAxUDPw0GaQMFA0EZB2kvAwUFO0MFBmsDFQM1DQZrAwUDRwMDDwMDAwMDDwMDAwMDDwMDAwMDDwMDAwsGDwMXCxNLTU9RBQYPAwUDUwUGDwMXA0kPBQ8zDVcTS01PUQUGbwMVA0ENBm8DBQNZAwMRAwMDAwMRAwMDAwMRAwMDAwMRAwMDCwYRAxcLEV1fYWMFBhEDBQNlBQYRAxcDWw8FETMNaRFdX2FjAwMjAwMDAwMjAwMDAwMjAwMDAwMjAwMDCwYjAxELDWttb3EFBiMDCwNzAwNzwgIDCxEHc9ICAwsHRXV3AwMTAwMDAwMTAwMDAwMTAwMDAwMTAwMDCwYTAxELD3t9f4EFBhMDCwODBQYTAxEDeQ8FEzMNhw97fX+BCQABBxEBqwcDDQ8JAQEBAQEBAQEDAwEHAwEDAwEHAwEJBAEJAQMFCQcRAa0HAw0PCQEBAQEBAQEBAwMBBwMBAwMBBwMBCQQBCQEDBwkHEQGvBwMNDwkBAQEBAQEBAQMDAQcDAQMDAQcDAQkEAQkBAwcJBxEBsQcDDQ8JAQEBAQEBAQEDAwEHAwEDAwEHAwEJBAEJAQMFCQcRAbMHAw0PCQEBAQEBAQEBAwMBBwMBAwMBBwMBCQQBCQEDBQkHEQG1BwMNDwkBAQEBAQEBAQMDAQcDAQMDAQcDAQkEAQkBAwUJBgMBBQEAZhKHESkLGQsTCxkTYyFnDREbLR0LIyEjKS0fCx0dFSUbaxkZGRkZGTENiQslDR0lHRNjtxcTFy8XIyMZGRUlHw8NDwkdEWJ1aWx0aW4Ac3RhYmxlX21vc2FpYwB0cHUAdmVjdG9yAGFyaXRoAG1vZHVsZQBhcml0aC5jb25zdGFudAB2ZWN0b3Iuc2hhcGVfY2FzdABmdW5jLmZ1bmMAZnVuYy5yZXR1cm4AdmVjdG9yLmxvYWQAdmVjdG9yLmJyb2FkY2FzdAB0cHUudmVjdG9yX3N0b3JlAHRwdS5tYXRtdWwAdmVjdG9yLm11bHRpX3JlZHVjdGlvbgBhcml0aC5zdWJmAG1hdGguZXhwAGFyaXRoLmRpdmYAL3Vzci9sb2NhbC9saWIvcHl0aG9uMy4xMC9zaXRlLXBhY2thZ2VzL2pheC9leHBlcmltZW50YWwvcGFsbGFzL29wcy90cHUvZmxhc2hfYXR0ZW50aW9uLnB5AF9mbGFzaF9hdHRlbnRpb25fa2VybmVsX3NpbmdsZV9iYXRjaF9zaW5nbGVfc3RlcABzeW1fbmFtZQBmdW5jdGlvbl90eXBlAHRyYW5zZm9ybV9pbmRpY2VzAHdpbmRvd19ib3VuZHMAdmFsdWUAL2Jyb2FkY2FzdF9pbl9kaW0AL2dldAAvd29ya3NwYWNlcy9kazMvcHl0b3JjaC94bGEvdG9yY2hfeGxhL2V4cGVyaW1lbnRhbC9jdXN0b21fa2VybmVsLnB5AC9zd2FwAF9mbGFzaF9hdHRlbnRpb25fa2VybmVsAHRyYW5zZm9ybV8wAHRyYW5zZm9ybV8xAHRyYW5zZm9ybV8yAHRyYW5zZm9ybV8zAHRyYW5zZm9ybV80AHRyYW5zZm9ybV81AC93b3Jrc3BhY2VzL2RrMy9weXRvcmNoL3hsYS90ZXN0L3Rlc3RfcGFsbGFzX3NwbWQucHkAL2RvdF9nZW5lcmFsAGRpbWVuc2lvbl9udW1iZXJzAHByZWNpc2lvbgB0cmFuc3Bvc2VfbGhzAHRyYW5zcG9zZV9yaHMAa2luZAByZWR1Y3Rpb25fZGltcwBzdGFibGVfbW9zYWljLnZlcnNpb24AZGltZW5zaW9uX3NlbWFudGljcwBpdGVyYXRpb25fYm91bmRzAHNjYWxhcl9wcmVmZXRjaABzY3JhdGNoX29wZXJhbmRzAG1haW4Ad2luZG93X3BhcmFtcwBfZmxhc2hfYXR0ZW50aW9uX2ltcGwAdHJhY2VfcGFsbGFzAGZvcndhcmQAYXBwbHkAL3dvcmtzcGFjZXMvZGszL3B5dG9yY2gvdG9yY2gvYXV0b2dyYWQvZnVuY3Rpb24ucHkAZmxhc2hfYXR0ZW50aW9uAHRlc3RfZmxhc2hfYXR0ZW50aW9uX2JhY2t3YXJkX3NwbWRfZGF0YV9wYXJhbGxlbAA8bW9kdWxlPgAvcmVkdWNlX21heAAvc3ViAGZhc3RtYXRoAC9leHAAL3JlZHVjZV9zdW0AL2RpdgBvcGVyYW5kU2VnbWVudFNpemVzAHN0cmlkZXMA", "cost_estimate": {"flops": 1114112, "transcendentals": 32768, "bytes_accessed": 294912}, "serialization_format": 1, "needs_layout_passes": true}}
  %get-tuple-element.8 = f32[1,2,128,8]{3,2,1,0} get-tuple-element((f32[1,2,128,8]{3,2,1,0}, f32[1,2,128,128]{3,2,1,0}, f32[1,2,128,128]{3,2,1,0}) %custom-call.7), index=0, metadata={op_type="xla__tpu_custom_call" op_name="xla__tpu_custom_call" source_file="/workspaces/dk3/pytorch/xla/torch_xla/experimental/custom_kernel.py" source_line=309}
  %custom-call.11 = f32[1,2,128,8]{3,2,1,0} custom-call(f32[1,2,128,8]{3,2,1,0} %get-tuple-element.8), custom_call_target="Sharding", sharding={manual}, metadata={op_type="xla__custom_sharding" op_name="xla__custom_sharding" source_file="/workspaces/dk3/pytorch/xla/torch_xla/distributed/spmd/xla_sharding.py" source_line=537}
  %custom-call.12 = f32[4,2,128,8]{3,2,1,0} custom-call(f32[1,2,128,8]{3,2,1,0} %custom-call.11), custom_call_target="SPMDShardToFullShape", sharding={devices=[4,1,1,1]0,1,2,3}, metadata={op_type="xla__custom_sharding" op_name="xla__custom_sharding" source_file="/workspaces/dk3/pytorch/xla/torch_xla/distributed/spmd/xla_sharding.py" source_line=570}
  %get-tuple-element.9 = f32[1,2,128,128]{3,2,1,0} get-tuple-element((f32[1,2,128,8]{3,2,1,0}, f32[1,2,128,128]{3,2,1,0}, f32[1,2,128,128]{3,2,1,0}) %custom-call.7), index=1, metadata={op_type="xla__tpu_custom_call" op_name="xla__tpu_custom_call" source_file="/workspaces/dk3/pytorch/xla/torch_xla/experimental/custom_kernel.py" source_line=309}
  %slice.13 = f32[1,2,128,1]{3,2,1,0} slice(f32[1,2,128,128]{3,2,1,0} %get-tuple-element.9), slice={[0:1], [0:2], [0:128], [0:1]}, metadata={op_type="xla__generic_slice" op_name="xla__generic_slice" source_file="/workspaces/dk3/pytorch/xla/torch_xla/experimental/custom_kernel.py" source_line=319}
  %reshape.14 = f32[1,2,128]{2,1,0} reshape(f32[1,2,128,1]{3,2,1,0} %slice.13), metadata={op_type="aten__view" op_name="aten__view" source_file="/workspaces/dk3/pytorch/xla/torch_xla/experimental/custom_kernel.py" source_line=319}
  %custom-call.15 = f32[1,2,128]{2,1,0} custom-call(f32[1,2,128]{2,1,0} %reshape.14), custom_call_target="Sharding", sharding={manual}, metadata={op_type="xla__custom_sharding" op_name="xla__custom_sharding" source_file="/workspaces/dk3/pytorch/xla/torch_xla/distributed/spmd/xla_sharding.py" source_line=537}
  %custom-call.16 = f32[4,2,128]{2,1,0} custom-call(f32[1,2,128]{2,1,0} %custom-call.15), custom_call_target="SPMDShardToFullShape", sharding={devices=[4,1,1]0,1,2,3}, metadata={op_type="xla__custom_sharding" op_name="xla__custom_sharding" source_file="/workspaces/dk3/pytorch/xla/torch_xla/distributed/spmd/xla_sharding.py" source_line=570}
  %get-tuple-element.10 = f32[1,2,128,128]{3,2,1,0} get-tuple-element((f32[1,2,128,8]{3,2,1,0}, f32[1,2,128,128]{3,2,1,0}, f32[1,2,128,128]{3,2,1,0}) %custom-call.7), index=2, metadata={op_type="xla__tpu_custom_call" op_name="xla__tpu_custom_call" source_file="/workspaces/dk3/pytorch/xla/torch_xla/experimental/custom_kernel.py" source_line=309}
  %slice.17 = f32[1,2,128,1]{3,2,1,0} slice(f32[1,2,128,128]{3,2,1,0} %get-tuple-element.10), slice={[0:1], [0:2], [0:128], [0:1]}, metadata={op_type="xla__generic_slice" op_name="xla__generic_slice" source_file="/workspaces/dk3/pytorch/xla/torch_xla/experimental/custom_kernel.py" source_line=319}
  %reshape.18 = f32[1,2,128]{2,1,0} reshape(f32[1,2,128,1]{3,2,1,0} %slice.17), metadata={op_type="aten__view" op_name="aten__view" source_file="/workspaces/dk3/pytorch/xla/torch_xla/experimental/custom_kernel.py" source_line=319}
  %custom-call.19 = f32[1,2,128]{2,1,0} custom-call(f32[1,2,128]{2,1,0} %reshape.18), custom_call_target="Sharding", sharding={manual}, metadata={op_type="xla__custom_sharding" op_name="xla__custom_sharding" source_file="/workspaces/dk3/pytorch/xla/torch_xla/distributed/spmd/xla_sharding.py" source_line=537}
  %custom-call.20 = f32[4,2,128]{2,1,0} custom-call(f32[1,2,128]{2,1,0} %custom-call.19), custom_call_target="SPMDShardToFullShape", sharding={devices=[4,1,1]0,1,2,3}, metadata={op_type="xla__custom_sharding" op_name="xla__custom_sharding" source_file="/workspaces/dk3/pytorch/xla/torch_xla/distributed/spmd/xla_sharding.py" source_line=570}
  %constant.21 = f32[] constant(0), metadata={op_type="aten__sum" op_name="aten__sum" source_file="/workspaces/dk3/pytorch/xla/test/test_pallas_spmd.py" source_line=84}
  %reduce.27 = f32[] reduce(f32[4,2,128,8]{3,2,1,0} %custom-call.12, f32[] %constant.21), dimensions={0,1,2,3}, to_apply=%AddComputation.23, metadata={op_type="aten__sum" op_name="aten__sum" source_file="/workspaces/dk3/pytorch/xla/test/test_pallas_spmd.py" source_line=84}
  ROOT %tuple.28 = (f32[4,2,128,8]{3,2,1,0}, f32[4,2,128]{2,1,0}, f32[4,2,128]{2,1,0}, f32[]) tuple(f32[4,2,128,8]{3,2,1,0} %custom-call.12, f32[4,2,128]{2,1,0} %custom-call.16, f32[4,2,128]{2,1,0} %custom-call.20, f32[] %reduce.27)
}

@JackCaoG
Copy link
Collaborator

JackCaoG commented Dec 2, 2024

Ok I narrowed it down to this line

l, m = (v[..., 0] for v in aux[-2:])

if I disable the functionizaiton and print l it will crash right away. This also make sense as it is trying to do some view op and likely functionizaiton has some issue..

@JackCaoG
Copy link
Collaborator

JackCaoG commented Dec 2, 2024

here is my diff.

diff --git a/torch_xla/experimental/custom_kernel.py b/torch_xla/experimental/custom_kernel.py
index 5e30ffba2..09cbc755e 100644
--- a/torch_xla/experimental/custom_kernel.py
+++ b/torch_xla/experimental/custom_kernel.py
@@ -317,6 +317,8 @@ class FlashAttention(torch.autograd.Function):
         return o
       o, *aux = o
       l, m = (v[..., 0] for v in aux[-2:])
+      print(f"before view {torch_xla._XLAC._get_xla_tensor_debug_info(l)}")
+
 
     # SPMD integration
     if partition_spec is not None:
@@ -324,6 +326,11 @@ class FlashAttention(torch.autograd.Function):
           o, partition_spec, ctx.full_shape, mesh=mesh).global_tensor
       l = xs.disable_manual_sharding(
           l, partition_spec[0:3], ctx.full_shape[0:3], mesh=mesh).global_tensor
+      print(f"after view {torch_xla._XLAC._get_xla_tensor_debug_info(l)}")    
+      breakpoint()
+      print(torch_xla._XLAC._get_xla_tensors_hlo([l]))
+      # this will crash without functionization
+      print(l)
       m = xs.disable_manual_sharding(
           m, partition_spec[0:3], ctx.full_shape[0:3], mesh=mesh).global_tensor

when I compare the HLO printed, I found out that in the failing case there is a weird HLO

  %custom-call.13 = f32[1,2,128]{2,1,0} custom-call(f32[1,2,128]{2,1,0} %reshape.12), custom_call_target="Sharding", metadata={op_type="xla__custom_sharding" op_name="xla__custom_sharding" source_file="/workspaces/dk3/pytorch/xla/torch_xla/distributed/spmd/xla_sharding.py" source_line=537}

....



  %reshape.21 = f32[1,2,128]{2,1,0} reshape(f32[1,2,128,1]{3,2,1,0} %slice.20), sharding={manual}, metadata={op_type="aten__view" op_name="aten__view" source_file="/workspaces/dk3/pytorch/xla/torch_xla/distributed/spmd/xla_sharding.py" source_line=537}

where the passing case the custom_call_target="Sharding" and sharding={manual} is in the same HLO


  %custom-call.13 = f32[1,2,128]{2,1,0} custom-call(f32[1,2,128]{2,1,0} %reshape.12), custom_call_target="Sharding", sharding={manual}, metadata={op_type="xla__custom_sharding" op_name="xla__custom_sharding"
image I think this is where the issue is from

@JackCaoG
Copy link
Collaborator

JackCaoG commented Dec 2, 2024

ok I think the issue is in disable_manual_sharding , more specifically

t = _mark_manual_sharding(unwrap_sharded_tensor(t))

the _mark_manual_sharding should annotate that current tensor is sharded with manual so we can bring it to the right shape. However with XLA_DISABLE_FUNCTIONIZATION=1, the pytorch/xla's view logic kicks in and apply some of the view replay logic and move the sharding away from the custom_call = Sharding HLO.

IR

  %8 = f32[1,2,128]{2,1,0} aten::view(%7), location=forward@custom_kernel.py:320, xla_shape=f32[1,2,128]{2,1,0}
  %9 = f32[1,2,128]{2,1,0} xla::custom_sharding(%8), location=_mark_manual_sharding@xla_sharding.py:537, xla_shape=f32[1,2,128]{2,1,0}

next step is to look into where view logic kick in. @miladm I will let you find someone to follow up.

@dudulightricks
Copy link
Contributor Author

dudulightricks commented Dec 10, 2024

@JackCaoG Hey! Can you merge this PR? is it fine?
Also, is there any progress with the issue with XLA_DISABLE_FUNCTIONIZATION?

@JackCaoG
Copy link
Collaborator

I am no longer with the team, you should check with @miladm

@JackCaoG JackCaoG merged commit 5790092 into pytorch:master Dec 10, 2024
12 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants