Skip to content

Conversation

@jumerckx
Copy link
Collaborator

No description provided.

end

function wrap(x)
return Reactant.Ops.@opcall wrap(x, 7, 7; dimension=1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you may want to change this to 2

@jumerckx
Copy link
Collaborator Author

This test is supposed to fail. It fails locally with xla_force_host_platform_device_count=8. I disabled CI jobs for testing here f28f9f7 (#1972)
But that presumably disabled the job I'm looking for.
Which job does the multi-device tests? (@avik-pal)

Base automatically changed from jm/wrap_dims_fix to main December 15, 2025 04:30
@avik-pal
Copy link
Collaborator

@jumerckx I ran this on hydra (with cpu) and these seemed to pass.

@jumerckx
Copy link
Collaborator Author

Hmm... On Hydra, cpu with 8 forced devices and IFRT, the code contains all-gathers when I run it:

ENV["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"
using Reactant, Test

Reactant.set_default_backend("cpu")

const addressable_devices = Reactant.addressable_devices()
@assert length(addressable_devices) == 8

function wrap(x)
    return Reactant.Ops.@opcall wrap(x, 2, 2; dimension=1)
end

mesh = Sharding.Mesh(Reactant.devices(), (:x,))
sharding = Sharding.NamedSharding(mesh, (:x,))

x = Reactant.to_rarray(rand(8192); sharding)
@assert x isa ConcreteIFRTArray

hlo = repr(@code_xla wrap(x))

@test !contains(hlo, "all-to-all")
@test !contains(hlo, "all-gather")
@test contains(hlo, "collective-permute")
  Expression: !(contains(hlo, "all-gather"))
   Evaluated: !(contains("HloModule reactant_wrap, is_scheduled=true, entry_computation_layout={(f64[1024]{0})->f64[2049]{0}}, num_partitions=8\n\nFileNames\n1 \"Untitled-2\"\n\nFunctionNames\n1 \"wrap/wrap\"\n\nFileLocations\n1 {file_name_id=1 function_name_id=1 line=10 end_line=10 column=0 end_column=0}\n\nStackFrames\n1 {file_location_id=1 parent_frame_id=1}\n\n\n%fused_computation (param_0.1: s32[8], param_1.4: f64[1025], param_2.3: f64[3], param_3.4: s32[8], param_4.6: u32[]) -> f64[1025] {\n  %param_3.4 = s32[8]{0} parameter(3)\n  %param_4.6 = u32[] parameter(4)\n  %convert.1 = s32[] convert(%param_4.6)\n  %dynamic-slice.24 = s32[1]{0} dynamic-slice(%param_3.4, %convert.1), dynamic_slice_sizes={1}\n  %constant.50 = s32[1]{0} constant({1}), metadata={op_name=\"wrap/wrap\" stack_frame_id=1}\n  %compare.7 = pred[1]{0} compare(%dynamic-slice.24, %constant.50), direction=EQ\n  %bitcast.5 = pred[] bitcast(%compare.7)\n  %broadcast.14 = pred[1025]{0} broadcast(%bitcast.5), dimensions={}\n  %param_2.3 = f64[3]{0} parameter(2)\n  %constant.49 = f64[] constant(0)\n  %pad.4 = f64[1025]{0} pad(%param_2.3, %constant.49), padding=1022_0\n  %param_1.4 = f64[1025]{0} parameter(1)\n  %select.5 = f64[1025]{0} select(%broadcast.14, %pad.4, %param_1.4)\n  %slice.20 = f64[1024]{0} slice(%param_1.4), slice={[0:1024]}\n  %concatenate.3 = f64[2049]{0} concatenate(%select.5, %slice.20), dimensions={0}\n  %param_0.1 = s32[8]{0} parameter(0)\n  %dynamic-slice.23 = s32[1]{0} dynamic-slice(%param_0.1, %convert.1), dynamic_slice_sizes={1}\n  %bitcast.4 = s32[] bitcast(%dynamic-slice.23)\n  ROOT %dynamic-slice.22 = f64[1025]{0} dynamic-slice(%concatenate.3, %bitcast.4), dynamic_slice_sizes={1025}\n}\n\n%fused_computation.1 (param_0.4: f64[2], param_1.11: f64[1], param_2.9: f64[1024], param_3.10: f64[5], param_4.13: f64[2], param_5.8: u32[]) -> f64[1025] {\n  %param_4.13 = f64[2]{0} parameter(4)\n  %constant.52 = f64[] constant(0)\n  %pad.7 = f64[8200]{0} pad(%param_4.13, %constant.52), padding=0_8198, metadata={op_name=\"wrap/wrap\" stack_frame_id=1}\n  %param_5.8 = u32[] parameter(5)\n  %convert.2 = s32[] convert(%param_5.8)\n  %constant.51 = s32[] constant(1025), metadata={op_name=\"wrap/wrap\" stack_frame_id=1}\n  %multiply.6 = s32[] multiply(%convert.2, %constant.51), metadata={op_name=\"wrap/wrap\" stack_frame_id=1}\n  %dynamic-slice.27 = f64[1025]{0} dynamic-slice(%pad.7, %multiply.6), dynamic_slice_sizes={1025}, metadata={op_name=\"wrap/wrap\" stack_frame_id=1}\n  %param_1.11 = f64[1]{0} parameter(1)\n  %param_2.9 = f64[1024]{0} parameter(2)\n  %param_3.10 = f64[5]{0} parameter(3)\n  %concatenate.4 = f64[1030]{0} concatenate(%param_1.11, %param_2.9, %param_3.10), dimensions={0}, metadata={op_name=\"wrap/wrap\" stack_frame_id=1}\n  %pad.6 = f64[1032]{0} pad(%concatenate.4, %constant.52), padding=1_1, metadata={op_name=\"wrap/wrap\" stack_frame_id=1}\n  %dynamic-slice.26 = f64[1025]{0} dynamic-slice(%pad.6, %convert.2), dynamic_slice_sizes={1025}, metadata={op_name=\"wrap/wrap\" stack_frame_id=1}\n  %wrap.18 = f64[1025]{0} add(%dynamic-slice.27, %dynamic-slice.26), metadata={op_name=\"wrap/wrap\" stack_frame_id=1}\n  %param_0.4 = f64[2]{0} parameter(0)\n  %pad.5 = f64[8200]{0} pad(%param_0.4, %constant.52), padding=8194_4, metadata={op_name=\"wrap/wrap\" stack_frame_id=1}\n  %dynamic-slice.25 = f64[1025]{0} dynamic-slice(%pad.5, %multiply.6), dynamic_slice_sizes={1025}, metadata={op_name=\"wrap/wrap\" stack_frame_id=1}\n  ROOT %wrap.17 = f64[1025]{0} add(%wrap.18, %dynamic-slice.25), metadata={op_name=\"wrap/wrap\" stack_frame_id=1}\n}\n\n%fused_computation.2 (param_0.5: f64[2], param_1.14: f64[1], param_2.13: f64[1024], param_3.16: u32[]) -> f64[2] {\n  %param_0.5 = f64[2]{0} parameter(0)\n  %param_3.16 = u32[] parameter(3)\n  %convert.3 = s32[] convert(%param_3.16)\n  %constant.56 = s32[] constant(2), metadata={op_name=\"wrap/wrap\" stack_frame_id=1}\n  %compare.8 = pred[] compare(%convert.3, %constant.56), direction=LT, metadata={op_name=\"wrap/wrap\" stack_frame_id=1}\n  %bitcast.6 = pred[1]{0} bitcast(%compare.8), metadata={op_name=\"wrap/wrap\" stack_frame_id=1}\n  %param_1.14 = f64[1]{0} parameter(1)\n  %param_2.13 = f64[1024]{0} parameter(2)\n  %slice.21 = f64[1]{0} slice(%param_2.13), slice={[0:1]}, metadata={op_name=\"wrap/wrap\" stack_frame_id=1}\n  %concatenate.5 = f64[2]{0} concatenate(%param_1.14, %slice.21), dimensions={0}, metadata={op_name=\"wrap/wrap\" stack_frame_id=1}\n  %constant.53 = s32[] constant(-1), metadata={op_name=\"wrap/wrap\" stack_frame_id=1}\n  %multiply.7 = s32[] multiply(%convert.3, %constant.53), metadata={op_name=\"wrap/wrap\" stack_frame_id=1}\n  %constant.55 = s32[] constant(1), metadata={op_name=\"wrap/wrap\" stack_frame_id=1}\n  %add.8 = s32[] add(%multiply.7, %constant.55), metadata={op_name=\"wrap/wrap\" stack_frame_id=1}\n  %dynamic-slice.28 = f64[1]{0} dynamic-slice(%concatenate.5, %add.8), dynamic_slice_sizes={1}, metadata={op_name=\"wrap/wrap\" stack_frame_id=1}\n  %constant.54 = f64[1]{0} constant({0}), metadata={op_name=\"wrap/wrap\" stack_frame_id=1}\n  %select.6 = f64[1]{0} select(%bitcast.6, %dynamic-slice.28, %constant.54), metadata={op_name=\"wrap/wrap\" stack_frame_id=1}\n  ROOT %dynamic-update-slice.2 = f64[2]{0} dynamic-update-slice(%param_0.5, %select.6, %convert.3), metadata={op_name=\"wrap/wrap\" stack_frame_id=1}\n}\n\n%fused_computation.3 (param_0.6: f64[2], param_1.17: f64[1], param_2.17: f64[1], param_3.21: u32[]) -> f64[2] {\n  %param_0.6 = f64[2]{0} parameter(0)\n  %param_3.21 = u32[] parameter(3)\n  %convert.4 = s32[] convert(%param_3.21)\n  %constant.61 = s32[] constant(2), metadata={op_name=\"wrap/wrap\" stack_frame_id=1}\n  %compare.10 = pred[] compare(%convert.4, %constant.61), direction=LT, metadata={op_name=\"wrap/wrap\" stack_frame_id=1}\n  %bitcast.8 = pred[1]{0} bitcast(%compare.10), metadata={op_name=\"wrap/wrap\" stack_frame_id=1}\n  %constant.60 = s32[] constant(0), metadata={op_name=\"wrap/wrap\" stack_frame_id=1}\n  %constant.59 = s32[] constant(-2), metadata={op_name=\"wrap/wrap\" stack_frame_id=1}\n  %add.9 = s32[] add(%convert.4, %constant.59), metadata={op_name=\"wrap/wrap\" stack_frame_id=1}\n  %select.9 = s32[] select(%compare.10, %convert.4, %add.9), metadata={op_name=\"wrap/wrap\" stack_frame_id=1}\n  %clamp.2 = s32[] clamp(%constant.60, %select.9, %constant.61), metadata={op_name=\"wrap/wrap\" stack_frame_id=1}\n  %constant.58 = s32[] constant(1), metadata={op_name=\"wrap/wrap\" stack_frame_id=1}\n  %compare.9 = pred[] compare(%clamp.2, %constant.58), direction=EQ, metadata={op_name=\"wrap/wrap\" stack_frame_id=1}\n  %bitcast.7 = pred[1]{0} bitcast(%compare.9), metadata={op_name=\"wrap/wrap\" stack_frame_id=1}\n  %param_1.17 = f64[1]{0} parameter(1)\n  %param_2.17 = f64[1]{0} parameter(2)\n  %select.8 = f64[1]{0} select(%bitcast.7, %param_1.17, %param_2.17), metadata={op_name=\"wrap/wrap\" stack_frame_id=1}\n  %constant.57 = f64[1]{0} constant({0}), metadata={op_name=\"wrap/wrap\" stack_frame_id=1}\n  %select.7 = f64[1]{0} select(%bitcast.8, %select.8, %constant.57), metadata={op_name=\"wrap/wrap\" stack_frame_id=1}\n  ROOT %dynamic-update-slice.3 = f64[2]{0} dynamic-update-slice(%param_0.6, %select.7, %convert.4), metadata={op_name=\"wrap/wrap\" stack_frame_id=1}\n}\n\n%wrapped_broadcast_computation (param_0.7: f64[]) -> f64[2] {\n  %param_0.7 = f64[] parameter(0)\n  ROOT %broadcast.15 = f64[2]{0} broadcast(%param_0.7), dimensions={}, metadata={op_name=\"wrap/wrap\" stack_frame_id=1}\n}\n\n%wrapped_slice_computation (param_0.8: f64[1024]) -> f64[1] {\n  %param_0.8 = f64[1024]{0} parameter(0)\n  ROOT %slice.22 = f64[1]{0} slice(%param_0.8), slice={[1:2]}, metadata={op_name=\"wrap/wrap\" stack_frame_id=1}\n}\n\n%wrapped_slice_computation.1 (param_0.9: f64[1024]) -> f64[1] {\n  %param_0.9 = f64[1024]{0} parameter(0)\n  ROOT %slice.23 = f64[1]{0} slice(%param_0.9), slice={[1023:1024]}, metadata={op_name=\"wrap/wrap\" stack_frame_id=1}\n}\n\n%wrapped_slice_computation.2 (param_0.10: f64[1024]) -> f64[5] {\n  %param_0.10 = f64[1024]{0} parameter(0)\n  ROOT %slice.24 = f64[5]{0} slice(%param_0.10), slice={[0:5]}, metadata={op_name=\"wrap/wrap\" stack_frame_id=1}\n}\n\n%wrapped_slice_computation.3 (param_0.11: f64[1024]) -> f64[1] {\n  %param_0.11 = f64[1024]{0} parameter(0)\n  ROOT %slice.25 = f64[1]{0} slice(%param_0.11), slice={[1022:1023]}, metadata={op_name=\"wrap/wrap\" stack_frame_id=1}\n}\n\n%wrapped_slice_computation.4 (param_0.12: f64[1025]) -> f64[3] {\n  %param_0.12 = f64[1025]{0} parameter(0)\n  ROOT %slice.26 = f64[3]{0} slice(%param_0.12), slice={[1022:1025]}\n}\n\n%wrapped_slice_computation.5 (param_0.13: f64[2050]) -> f64[2049] {\n  %param_0.13 = f64[2050]{0} parameter(0)\n  ROOT %slice.27 = f64[2049]{0} slice(%param_0.13), slice={[0:2049]}\n}\n\n%add.1.clone (x.3: f64[], y.3: f64[]) -> f64[] {\n  %x.3 = f64[] parameter(0)\n  %y.3 = f64[] parameter(1)\n  ROOT %add.3 = f64[] add(%x.3, %y.3)\n}\n\nENTRY %main.0_spmd (param: f64[1024]) -> f64[2049] {\n  %partition-id = u32[] partition-id()\n  %param = f64[1024]{0} parameter(0), sharding={devices=[8]<=[8]}, metadata={op_name=\"arg1 (path=(:args, 1))\"}\n  %constant.8 = f64[] constant(0)\n  %constant.35 = s32[8]{0} constant({0, 0, 1, 1, 1, 1, 1, 1})\n  %constant.39 = s32[8]{0} constant({0, 0, 1024, 1024, 1023, 1023, 1022, 1022})\n  %wrapped_slice.3 = f64[1]{0} fusion(%param), kind=kLoop, calls=%wrapped_slice_computation.3, metadata={op_name=\"wrap/wrap\" stack_frame_id=1}\n  %wrapped_slice = f64[1]{0} fusion(%param), kind=kLoop, calls=%wrapped_slice_computation, metadata={op_name=\"wrap/wrap\" stack_frame_id=1}\n  %wrapped_slice.2 = f64[5]{0} fusion(%param), kind=kLoop, calls=%wrapped_slice_computation.2, metadata={op_name=\"wrap/wrap\" stack_frame_id=1}\n  %wrapped_slice.1 = f64[1]{0} fusion(%param), kind=kLoop, calls=%wrapped_slice_computation.1, metadata={op_name=\"wrap/wrap\" stack_frame_id=1}\n  %wrapped_broadcast = f64[2]{0} fusion(%constant.8), kind=kLoop, calls=%wrapped_broadcast_computation, metadata={op_name=\"wrap/wrap\" stack_frame_id=1}\n  %collective-permute = f64[1]{0} collective-permute(%wrapped_slice.3), channel_id=1, source_target_pairs={{7,0}}, metadata={op_name=\"wrap/wrap\" stack_frame_id=1}\n  %collective-permute.4 = f64[1]{0} collective-permute(%wrapped_slice), channel_id=6, source_target_pairs={{0,1}}, metadata={op_name=\"wrap/wrap\" stack_frame_id=1}\n  %collective-permute.3 = f64[5]{0} collective-permute(%wrapped_slice.2), channel_id=5, source_target_pairs={{1,0},{2,1},{3,2},{4,3},{5,4},{6,5},{7,6}}, metadata={op_name=\"wrap/wrap\" stack_frame_id=1}\n  %collective-permute.1 = f64[1]{0} collective-permute(%wrapped_slice.1), channel_id=2, source_target_pairs={{7,1}}, metadata={op_name=\"wrap/wrap\" stack_frame_id=1}\n  %collective-permute.2 = f64[1]{0} collective-permute(%wrapped_slice.1), channel_id=4, source_target_pairs={{0,1},{1,2},{2,3},{3,4},{4,5},{5,6},{6,7}}, metadata={op_name=\"wrap/wrap\" stack_frame_id=1}\n  %copy.3 = f64[2]{0} copy(%wrapped_broadcast)\n  %copy.2 = f64[2]{0} copy(%wrapped_broadcast)\n  %select_dynamic-update-slice_fusion.1 = f64[2]{0} fusion(%copy.3, %collective-permute.1, %collective-permute, %partition-id), kind=kLoop, calls=%fused_computation.3, metadata={op_name=\"wrap/wrap\" stack_frame_id=1}\n  %select_dynamic-update-slice_fusion = f64[2]{0} fusion(%copy.2, %collective-permute.4, %param, %partition-id), kind=kLoop, calls=%fused_computation.2, metadata={op_name=\"wrap/wrap\" stack_frame_id=1}\n  %all-reduce.2 = (f64[2]{0}, f64[2]{0}) all-reduce(%select_dynamic-update-slice_fusion, %select_dynamic-update-slice_fusion.1), channel_id=7, replica_groups=[1,8]<=[8], use_global_device_ids=true, to_apply=%add.1.clone, metadata={op_name=\"wrap/wrap\" stack_frame_id=1}\n  %get-tuple-element.3 = f64[2]{0} get-tuple-element(%all-reduce.2), index=0, metadata={op_name=\"wrap/wrap\" stack_frame_id=1}\n  %get-tuple-element.4 = f64[2]{0} get-tuple-element(%all-reduce.2), index=1, metadata={op_name=\"wrap/wrap\" stack_frame_id=1}\n  %dynamic-slice_add_fusion = f64[1025]{0} fusion(%get-tuple-element.3, %collective-permute.2, %param, %collective-permute.3, %get-tuple-element.4, /*index=5*/%partition-id), kind=kLoop, calls=%fused_computation.1, metadata={op_name=\"wrap/wrap\" stack_frame_id=1}\n  %wrapped_slice.4 = f64[3]{0} fusion(%dynamic-slice_add_fusion), kind=kLoop, calls=%wrapped_slice_computation.4\n  %collective-permute.5 = f64[3]{0} collective-permute(%wrapped_slice.4), channel_id=8, source_target_pairs={{1,2},{2,3},{3,4},{4,5},{5,6},{6,7}}\n  %bitcast_dynamic-slice_fusion = f64[1025]{0} fusion(%constant.39, %dynamic-slice_add_fusion, %collective-permute.5, %constant.35, %partition-id), kind=kLoop, calls=%fused_computation\n  %all-gather = f64[2050]{0} all-gather(%bitcast_dynamic-slice_fusion), channel_id=9, replica_groups=[4,2]<=[8], dimensions={0}, use_global_device_ids=true\n  ROOT %wrapped_slice.5 = f64[2049]{0} fusion(%all-gather), kind=kLoop, calls=%wrapped_slice_computation.5\n}\n\n", "all-gather"))
(Reactant) pkg> st
Project Reactant v0.2.184
Status `/mnt/jumerckx/Reactant2/Project.toml`
  [79e6a3ab] Adapt v4.4.0
  [fa961155] CEnum v0.5.0
  [4e289a0a] EnumX v1.0.5
  [7da242da] Enzyme v0.13.109
  [f151be2c] EnzymeCore v0.8.17
  [d9f16b24] Functors v0.5.2
  [46192b85] GPUArraysCore v0.2.0
  [cd3eb016] HTTP v1.10.19
  [929cbde3] LLVM v9.4.4
  [bac558e1] OrderedCollections v1.8.1
⌅ [aea7be01] PrecompileTools v1.2.1
  [21216c6a] Preferences v1.5.0
  [a3311ec8] ReactantCore v0.1.16 `lib/ReactantCore`
  [7e506255] ScopedValues v1.5.0
  [6c6a2e73] Scratch v1.3.0
  [1d63c593] LLVMOpenMP_jll v18.1.8+0
  [0192cb87] Reactant_jll v0.0.275+1
  [f43a241f] Downloads v1.6.0
  [8f399da3] Libdl v1.11.0
  [37e2e46d] LinearAlgebra v1.11.0
  [9a3f8284] Random v1.11.0
  [6462fe0b] Sockets v1.11.0
  [3f19e933] p7zip_jll v17.4.0+2

@avik-pal
Copy link
Collaborator

I was trying inside of docker (via act) and it doesn't fail. let me investigate a bit more

@avik-pal
Copy link
Collaborator

| HloModule reactant_wrap, is_scheduled=true, input_output_alias={ {1}: (0, {}, may-alias) }, entry_computation_layout={(f64[683]{0})->(f64[683]{0}, f64[683]{0})}, num_partitions=12                                                                                                                                                                                                                                                                                                                                                                                                                                                                      
|                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                          
| FileNames                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                
| 1 "/home/avik-pal/reactant/Reactant.jl/test/optimize_comm.jl"                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            
|                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                          
| FunctionNames                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            
| 1 "wrap/wrap"                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            
|                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                          
| FileLocations                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            
| 1 {file_name_id=1 function_name_id=1 line=26 end_line=26 column=0 end_column=0}                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                          
|                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                          
| StackFrames                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                              
| 1 {file_location_id=1 parent_frame_id=1}                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                 
|                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                          
|                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                          
| %fused_computation (param_0.1: f64[683], param_1.4: u32[]) -> f64[683] {                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                 
|   %iota.6 = s32[683]{0} iota(), iota_dimension=0, metadata={op_name="wrap/wrap" stack_frame_id=1}                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        
|   %param_1.4 = u32[] parameter(1)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        
|   %convert.1 = s32[] convert(%param_1.4), metadata={op_name="wrap/wrap" stack_frame_id=1}                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                
|   %constant.48 = s32[] constant(683), metadata={op_name="wrap/wrap" stack_frame_id=1}                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    
|   %multiply.8 = s32[] multiply(%convert.1, %constant.48), metadata={op_name="wrap/wrap" stack_frame_id=1}                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                
|   %broadcast.20 = s32[683]{0} broadcast(%multiply.8), dimensions={}, metadata={op_name="wrap/wrap" stack_frame_id=1}                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                     
|   %add.10 = s32[683]{0} add(%iota.6, %broadcast.20), metadata={op_name="wrap/wrap" stack_frame_id=1}                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                     
|   %constant.47 = s32[] constant(8192), metadata={op_name="pad.1"}                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        
|   %broadcast.18 = s32[683]{0} broadcast(%constant.47), dimensions={}, metadata={op_name="pad.1"}                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                         
|   %compare.8 = pred[683]{0} compare(%add.10, %broadcast.18), direction=LT, metadata={op_name="pad.1"}                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    
|   %param_0.1 = f64[683]{0} parameter(0)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  
|   %constant.49 = f64[] constant(0)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       
|   %broadcast.19 = f64[683]{0} broadcast(%constant.49), dimensions={}, metadata={op_name="wrap/wrap" stack_frame_id=1}                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    
|   ROOT %select.6 = f64[683]{0} select(%compare.8, %param_0.1, %broadcast.19), metadata={op_name="pad.1"}                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                 
| }                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        
|                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                          
| %fused_computation.1 (param_0.4: f64[2], param_1.10: f64[2], param_2.11: f64[683], param_3.9: f64[2], param_4.7: u32[]) -> f64[683] {                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    
|   %param_3.9 = f64[2]{0} parameter(3)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    
|   %constant.53 = f64[] constant(0)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       
|   %pad.5 = f64[8196]{0} pad(%param_3.9, %constant.53), padding=0_8194, metadata={op_name="wrap/wrap" stack_frame_id=1}                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                   
|   %param_4.7 = u32[] parameter(4)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        
|   %convert.2 = s32[] convert(%param_4.7), metadata={op_name="wrap/wrap" stack_frame_id=1}                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                
|   %constant.51 = s32[] constant(683), metadata={op_name="wrap/wrap" stack_frame_id=1}                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    
|   %multiply.9 = s32[] multiply(%convert.2, %constant.51), metadata={op_name="wrap/wrap" stack_frame_id=1}                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                
|   %dynamic-slice.18 = f64[683]{0} dynamic-slice(%pad.5, %multiply.9), dynamic_slice_sizes={683}, metadata={op_name="wrap/wrap" stack_frame_id=1}                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                         
|   %iota.7 = s32[683]{0} iota(), iota_dimension=0, metadata={op_name="wrap/wrap" stack_frame_id=1}                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        
|   %broadcast.24 = s32[683]{0} broadcast(%multiply.9), dimensions={}, metadata={op_name="wrap/wrap" stack_frame_id=1}                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                     
|   %add.11 = s32[683]{0} add(%iota.7, %broadcast.24), metadata={op_name="wrap/wrap" stack_frame_id=1}                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                     
|   %constant.52 = s32[] constant(2), metadata={op_name="wrap/wrap" stack_frame_id=1}                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                      
|   %broadcast.23 = s32[683]{0} broadcast(%constant.52), dimensions={}, metadata={op_name="wrap/wrap" stack_frame_id=1}                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    
|   %compare.10 = pred[683]{0} compare(%add.11, %broadcast.23), direction=GE, metadata={op_name="wrap/wrap" stack_frame_id=1}                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                              
|   %constant.50 = s32[] constant(8194), metadata={op_name="wrap/wrap" stack_frame_id=1}                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                   
|   %broadcast.22 = s32[683]{0} broadcast(%constant.50), dimensions={}, metadata={op_name="wrap/wrap" stack_frame_id=1}                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    
|   %compare.9 = pred[683]{0} compare(%add.11, %broadcast.22), direction=LT, metadata={op_name="wrap/wrap" stack_frame_id=1}                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                               
|   %and.1 = pred[683]{0} and(%compare.10, %compare.9), metadata={op_name="wrap/wrap" stack_frame_id=1}                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    
|   %param_1.10 = f64[2]{0} parameter(1)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                   
|   %param_2.11 = f64[683]{0} parameter(2)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                 
|   %slice.14 = f64[681]{0} slice(%param_2.11), slice={[0:681]}, metadata={op_name="wrap/wrap" stack_frame_id=1}                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                           
|   %concatenate.2 = f64[683]{0} concatenate(%param_1.10, %slice.14), dimensions={0}, metadata={op_name="wrap/wrap" stack_frame_id=1}                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                      
|   %broadcast.21 = f64[683]{0} broadcast(%constant.53), dimensions={}, metadata={op_name="wrap/wrap" stack_frame_id=1}                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    
|   %select.7 = f64[683]{0} select(%and.1, %concatenate.2, %broadcast.21), metadata={op_name="wrap/wrap" stack_frame_id=1}                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                 
|   %wrap.18 = f64[683]{0} add(%dynamic-slice.18, %select.7), metadata={op_name="wrap/wrap" stack_frame_id=1}                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                              
|   %param_0.4 = f64[2]{0} parameter(0)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    
|   %pad.4 = f64[8196]{0} pad(%param_0.4, %constant.53), padding=8194_0, metadata={op_name="wrap/wrap" stack_frame_id=1}                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                   
|   %dynamic-slice.17 = f64[683]{0} dynamic-slice(%pad.4, %multiply.9), dynamic_slice_sizes={683}, metadata={op_name="wrap/wrap" stack_frame_id=1}                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                         
|   ROOT %wrap.17 = f64[683]{0} add(%wrap.18, %dynamic-slice.17), metadata={op_name="wrap/wrap" stack_frame_id=1}                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                          
| }                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        
|                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                          
| %fused_computation.2 (param_0.5: f64[2], param_1.13: f64[1], param_2.15: f64[683], param_3.15: u32[]) -> f64[2] {                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        
|   %param_0.5 = f64[2]{0} parameter(0)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    
|   %param_3.15 = u32[] parameter(3)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       
|   %convert.3 = s32[] convert(%param_3.15), metadata={op_name="wrap/wrap" stack_frame_id=1}                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                               
|   %constant.57 = s32[] constant(2), metadata={op_name="wrap/wrap" stack_frame_id=1}                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                      
|   %compare.11 = pred[] compare(%convert.3, %constant.57), direction=LT, metadata={op_name="wrap/wrap" stack_frame_id=1}                                                                                                                                                                                                    
|   %bitcast.2 = pred[1]{0} bitcast(%compare.11), metadata={op_name="wrap/wrap" stack_frame_id=1}                                                                                                                                                                                                                            
|   %param_1.13 = f64[1]{0} parameter(1)                                                                                                                                                                                                                                                                                     
|   %param_2.15 = f64[683]{0} parameter(2)                                                                                                                                                                                                                                                                                   
|   %slice.15 = f64[1]{0} slice(%param_2.15), slice={[0:1]}, metadata={op_name="wrap/wrap" stack_frame_id=1}                                                                                                                                                                                                                 
|   %concatenate.3 = f64[2]{0} concatenate(%param_1.13, %slice.15), dimensions={0}, metadata={op_name="wrap/wrap" stack_frame_id=1}                                                                                                                                                                                          
|   %constant.54 = s32[] constant(-1), metadata={op_name="wrap/wrap" stack_frame_id=1}                                                                                                                                                                                                                                       
|   %multiply.10 = s32[] multiply(%convert.3, %constant.54), metadata={op_name="wrap/wrap" stack_frame_id=1}                                                                                                                                                                                                                 
|   %constant.56 = s32[] constant(1), metadata={op_name="wrap/wrap" stack_frame_id=1}                                                                                                                                                                                                                                        
|   %add.12 = s32[] add(%multiply.10, %constant.56), metadata={op_name="wrap/wrap" stack_frame_id=1}                                                                                                                                                                                                                         
|   %dynamic-slice.19 = f64[1]{0} dynamic-slice(%concatenate.3, %add.12), dynamic_slice_sizes={1}, metadata={op_name="wrap/wrap" stack_frame_id=1}                                                                                                                                                                           
|   %constant.55 = f64[1]{0} constant({0}), metadata={op_name="wrap/wrap" stack_frame_id=1}                                                                                                                                                                                                                                  
|   %select.8 = f64[1]{0} select(%bitcast.2, %dynamic-slice.19, %constant.55), metadata={op_name="wrap/wrap" stack_frame_id=1}                                                                                                                                                                                               
|   ROOT %dynamic-update-slice.2 = f64[2]{0} dynamic-update-slice(%param_0.5, %select.8, %convert.3), metadata={op_name="wrap/wrap" stack_frame_id=1}                                                                                                                                                                        
| }                                                                                                                                                                                                                                                                                                                          
|                                                                                                                                                                                                                                                                                                                            
| %fused_computation.3 (param_0.6: f64[2], param_1.16: f64[1], param_2.19: f64[1], param_3.20: u32[]) -> f64[2] {                                                                                                                                                                                                            
|   %param_0.6 = f64[2]{0} parameter(0)                                                                                                                                                                                                                                                                                      
|   %param_3.20 = u32[] parameter(3)                                                                                                                                                                                                                                                                                         
|   %convert.4 = s32[] convert(%param_3.20), metadata={op_name="wrap/wrap" stack_frame_id=1}                                                                                                                                                                                                                                 
|   %constant.62 = s32[] constant(2), metadata={op_name="wrap/wrap" stack_frame_id=1}                                                                                                                                                                                                                                        
|   %compare.13 = pred[] compare(%convert.4, %constant.62), direction=LT, metadata={op_name="wrap/wrap" stack_frame_id=1}                                                                                                                                                                                                    
|   %bitcast.4 = pred[1]{0} bitcast(%compare.13), metadata={op_name="wrap/wrap" stack_frame_id=1}                                                                                                                                                                                                                            
|   %constant.61 = s32[] constant(0), metadata={op_name="wrap/wrap" stack_frame_id=1}                                                                                                                                                                                                                                        
|   %constant.60 = s32[] constant(-2), metadata={op_name="wrap/wrap" stack_frame_id=1}                                                                                                                                                                                                                                       
|   %add.13 = s32[] add(%convert.4, %constant.60), metadata={op_name="wrap/wrap" stack_frame_id=1}                                                                                                                                                                                                                           
|   %select.11 = s32[] select(%compare.13, %convert.4, %add.13), metadata={op_name="wrap/wrap" stack_frame_id=1}                                                                                                                                                                                                             
|   %clamp.2 = s32[] clamp(%constant.61, %select.11, %constant.62), metadata={op_name="wrap/wrap" stack_frame_id=1}                                                                                                                                                                                                          
|   %constant.59 = s32[] constant(1), metadata={op_name="wrap/wrap" stack_frame_id=1}                                                                                                                                                                                                                                        
|   %compare.12 = pred[] compare(%clamp.2, %constant.59), direction=EQ, metadata={op_name="wrap/wrap" stack_frame_id=1}                                                                                                                                                                                                      
|   %bitcast.3 = pred[1]{0} bitcast(%compare.12), metadata={op_name="wrap/wrap" stack_frame_id=1}                                                                                                                                                                                                                            
|   %param_1.16 = f64[1]{0} parameter(1)                                                                                                                                                                                                                                                                                     
|   %param_2.19 = f64[1]{0} parameter(2)                                                                                                                                                                                                                                                                                     
|   %select.10 = f64[1]{0} select(%bitcast.3, %param_1.16, %param_2.19), metadata={op_name="wrap/wrap" stack_frame_id=1}                                                                                                                                                                                                     
|   %constant.58 = f64[1]{0} constant({0}), metadata={op_name="wrap/wrap" stack_frame_id=1}                                                                                                                                                                                                                                  
|   %select.9 = f64[1]{0} select(%bitcast.4, %select.10, %constant.58), metadata={op_name="wrap/wrap" stack_frame_id=1}                                                                                                                                                                                                      
|   ROOT %dynamic-update-slice.3 = f64[2]{0} dynamic-update-slice(%param_0.6, %select.9, %convert.4), metadata={op_name="wrap/wrap" stack_frame_id=1}                                                                                                                                                                        
| }                                                                                                                                                                                                                                                                                                                          
|                                                                                                                                                                                                                                                                                                                            
| %wrapped_broadcast_computation (param_0.7: f64[]) -> f64[2] {                                                                                                                                                                                                                                                              
|   %param_0.7 = f64[] parameter(0)                                                                                                                                                                                                                                                                                          
|   ROOT %broadcast.25 = f64[2]{0} broadcast(%param_0.7), dimensions={}, metadata={op_name="wrap/wrap" stack_frame_id=1}         
|   ROOT %broadcast.25 = f64[2]{0} broadcast(%param_0.7), dimensions={}, metadata={op_name="wrap/wrap" stack_frame_id=1}                                                                                                                                                                                                     
| }                                                                                                                                                                                                                                                                                                                          
|                                                                                                                                                             
| %wrapped_slice_computation (param_0.8: f64[683]) -> f64[1] {                                                                                                
|   %param_0.8 = f64[683]{0} parameter(0)                                                                                                                     
|   ROOT %slice.16 = f64[1]{0} slice(%param_0.8), slice={[1:2]}, metadata={op_name="wrap/wrap" stack_frame_id=1}                                                                                                                                                                                                             
| }                                                                                                                                                           
|                                                                                                                                                             
| %wrapped_slice_computation.1 (param_0.9: f64[683]) -> f64[2] {                                                                                              
|   %param_0.9 = f64[683]{0} parameter(0)                                                                                                                     
|   ROOT %slice.17 = f64[2]{0} slice(%param_0.9), slice={[681:683]}, metadata={op_name="wrap/wrap" stack_frame_id=1}                                                                                                                                                                                                         
| }                                                                                                                                                           
|                                                                                                                                                             
| %wrapped_slice_computation.2 (param_0.10: f64[683]) -> f64[1] {                                                                                             
|   %param_0.10 = f64[683]{0} parameter(0)                                                                                                                    
|   ROOT %slice.18 = f64[1]{0} slice(%param_0.10), slice={[678:679]}, metadata={op_name="wrap/wrap" stack_frame_id=1}                                                                                                                                                                                                        
| }                                                                                                                                                           
|                                                                                                                                                             
| %wrapped_slice_computation.3 (param_0.11: f64[683]) -> f64[1] {                                                                                             
|   %param_0.11 = f64[683]{0} parameter(0)                                                                                                                    
|   ROOT %slice.19 = f64[1]{0} slice(%param_0.11), slice={[677:678]}, metadata={op_name="wrap/wrap" stack_frame_id=1}                                                                                                                                                                                                        
| }                                                                                                                                                           
|                                                                                                                                                             
| %add.1.clone (x.3: f64[], y.3: f64[]) -> f64[] {                                                                                                            
|   %x.3 = f64[] parameter(0)                                                                                                                                 
|   %y.3 = f64[] parameter(1)                                                                                                                                 
|   ROOT %add.3 = f64[] add(%x.3, %y.3)                                                                                                                       
| }                                                                                                                                                           
|                                                                                                                                                             
| ENTRY %main.0_spmd (param: f64[683]) -> (f64[683], f64[683]) {                                                                                              
|   %partition-id = u32[] partition-id(), metadata={op_name="wrap/wrap" stack_frame_id=1}                                                                     
|   %param = f64[683]{0} parameter(0), sharding={devices=[12]<=[12]}, metadata={op_name="arg1 (path=(:args, 1))"}                                                                                                                                                                                                            
|   %constant.11 = f64[] constant(0)                                                                                                                          
|   %copy.5 = f64[683]{0} copy(%param)                                                                                                                        
|   %wrapped_broadcast = f64[2]{0} fusion(%constant.11), kind=kLoop, calls=%wrapped_broadcast_computation, metadata={op_name="wrap/wrap" stack_frame_id=1}                                                                                                                                                                   
|   %compare_select_fusion = f64[683]{0} fusion(%copy.5, %partition-id), kind=kLoop, calls=%fused_computation, metadata={op_name="pad.1"}                                                                                                                                                                                    
|   %wrapped_slice.3 = f64[1]{0} fusion(%copy.5), kind=kLoop, calls=%wrapped_slice_computation.3, metadata={op_name="wrap/wrap" stack_frame_id=1}                                                                                                                                                                            
|   %wrapped_slice.1 = f64[2]{0} fusion(%copy.5), kind=kLoop, calls=%wrapped_slice_computation.1, metadata={op_name="wrap/wrap" stack_frame_id=1}                                                                                                                                                                            
|   %wrapped_slice.2 = f64[1]{0} fusion(%copy.5), kind=kLoop, calls=%wrapped_slice_computation.2, metadata={op_name="wrap/wrap" stack_frame_id=1}                                                                                                                                                                            
|   %wrapped_slice = f64[1]{0} fusion(%copy.5), kind=kLoop, calls=%wrapped_slice_computation, metadata={op_name="wrap/wrap" stack_frame_id=1}                                                                                                                                                                                
|   %copy.4 = f64[2]{0} copy(%wrapped_broadcast)                                                                                                              
|   %copy.3 = f64[2]{0} copy(%wrapped_broadcast)                                                                                                              
|   %collective-permute = f64[1]{0} collective-permute(%wrapped_slice.3), channel_id=1, source_target_pairs={{11,0}}, metadata={op_name="wrap/wrap" stack_frame_id=1}                                                                                                                                                        
|   %collective-permute.2 = f64[2]{0} collective-permute(%wrapped_slice.1), channel_id=4, source_target_pairs={{0,1},{1,2},{2,3},{3,4},{4,5},{5,6},{6,7},{7,8},{8,9},{9,10},{10,11}}, metadata={op_name="wrap/wrap" stack_frame_id=1}                                                                                        
|   %collective-permute.1 = f64[1]{0} collective-permute(%wrapped_slice.2), channel_id=2, source_target_pairs={{11,1}}, metadata={op_name="wrap/wrap" stack_frame_id=1}                                                                                                                                                      
|   %collective-permute.3 = f64[1]{0} collective-permute(%wrapped_slice), channel_id=5, source_target_pairs={{0,1}}, metadata={op_name="wrap/wrap" stack_frame_id=1}                                                                                                                                                         
|   %select_dynamic-update-slice_fusion.1 = f64[2]{0} fusion(%copy.4, %collective-permute.1, %collective-permute, %partition-id), kind=kLoop, calls=%fused_computation.3, metadata={op_name="wrap/wrap" stack_frame_id=1}                                                                                                    
|   %select_dynamic-update-slice_fusion = f64[2]{0} fusion(%copy.3, %collective-permute.3, %copy.5, %partition-id), kind=kLoop, calls=%fused_computation.2, metadata={op_name="wrap/wrap" stack_frame_id=1}                                                                                                                  
|   %all-reduce.2 = (f64[2]{0}, f64[2]{0}) all-reduce(%select_dynamic-update-slice_fusion, %select_dynamic-update-slice_fusion.1), channel_id=6, replica_groups=[1,12]<=[12], use_global_device_ids=true, to_apply=%add.1.clone, metadata={op_name="wrap/wrap" stack_frame_id=1}                                             
|   %get-tuple-element = f64[2]{0} get-tuple-element(%all-reduce.2), index=0, metadata={op_name="wrap/wrap" stack_frame_id=1}                                                                                                                                                                                                
|   %get-tuple-element.1 = f64[2]{0} get-tuple-element(%all-reduce.2), index=1, metadata={op_name="wrap/wrap" stack_frame_id=1}                                                                                                                                                                                              
|   %dynamic-slice_add_fusion = f64[683]{0} fusion(%get-tuple-element, %collective-permute.2, %copy.5, %get-tuple-element.1, %partition-id), kind=kLoop, calls=%fused_computation.1, metadata={op_name="wrap/wrap" stack_frame_id=1}                                                                                         
|   ROOT %tuple.3 = (f64[683]{0}, f64[683]{0}) tuple(%dynamic-slice_add_fusion, %compare_select_fusion)                                                                                                                                                                                                                      
| }                                                                                                                                                           
|       

@avik-pal
Copy link
Collaborator

i htink I see the issue

@jumerckx
Copy link
Collaborator Author

Aaah!

@wsmoses
Copy link
Member

wsmoses commented Dec 15, 2025

@jumerckx try:

ENV["XLA_FLAGS"] = "--xla_force_host_platform_device_count=12"
using Reactant, Test

Reactant.set_default_backend("cpu")

const addressable_devices = Reactant.addressable_devices()
@assert length(addressable_devices) == 12

function wrap(x)
    return Reactant.Ops.@opcall wrap(x, 2, 2; dimension=1)
end

mesh = Sharding.Mesh(Reactant.devices(), (:x,))
sharding = Sharding.NamedSharding(mesh, (:x,))

x = Reactant.to_rarray(rand(192 * length(addressable_devices)); sharding)
@assert x isa ConcreteIFRTArray

hlo = repr(@code_xla wrap(x))

@test !contains(hlo, "all-to-all")
@test !contains(hlo, "all-gather")
@test contains(hlo, "collective-permute")

@wsmoses
Copy link
Member

wsmoses commented Dec 15, 2025

hopefully this should work independent of device ccount [for large device counts]

@jumerckx
Copy link
Collaborator Author

jumerckx commented Dec 15, 2025

@jumerckx try:
...

That fails as well, which is good (?)

@wsmoses
Copy link
Member

wsmoses commented Dec 15, 2025

yeah sorry I meant "this should hopefully be a test that confirms failure on a larger device set"

@jumerckx
Copy link
Collaborator Author

But it's strange that @avik-pal got the tests to pass with ndevices >= 8, how many devices does CI have?

@avik-pal
Copy link
Collaborator

CI forces 12 devices

@jumerckx
Copy link
Collaborator Author

That doesn't make sense then, why did it pass CI before?!

@jumerckx
Copy link
Collaborator Author

nvm, I see the input size is changed as well

@avik-pal
Copy link
Collaborator

do we expect EnzymeAD/Enzyme-JAX#1779 to fix this?

@wsmoses
Copy link
Member

wsmoses commented Dec 16, 2025

@avik-pal
Copy link
Collaborator

module @reactant_wrap attributes {mhlo.num_partitions = 12 : i64, mhlo.num_replicas = 1 : i64} {
  sdy.mesh @mesh = <["x"=12]>
  func.func @main(%arg0: tensor<2304xf64> {enzymexla.memory_effects = [], sdy.sharding = #sdy.sharding<@mesh, [{"x"}]>}) -> (tensor<2308xf64> {sdy.sharding = #sdy.sharding<@mesh, [{"x":(1)4}]>}) attributes {enzymexla.memory_effects = []} {
    %c = stablehlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}]>]>} dense<2306> : tensor<2308xi32>
    %c_0 = stablehlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}]>]>} dense<2> : tensor<2308xi32>
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<f64>
    %0 = stablehlo.pad %arg0, %cst, low = [2], high = [2], interior = [0] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}]>]>} : (tensor<2304xf64>, tensor<f64>) -> tensor<2308xf64>
    %1 = stablehlo.slice %0 [4:2308] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}]>]>} : (tensor<2308xf64>) -> tensor<2304xf64>
    %2 = stablehlo.slice %0 [0:4] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}]>]>} : (tensor<2308xf64>) -> tensor<4xf64>
    %3 = stablehlo.pad %1, %cst, low = [0], high = [4], interior = [0] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}]>]>} : (tensor<2304xf64>, tensor<f64>) -> tensor<2308xf64>
    %4 = stablehlo.pad %2, %cst, low = [2304], high = [0], interior = [0] : (tensor<4xf64>, tensor<f64>) -> tensor<2308xf64>
    %5 = stablehlo.add %3, %4 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}]>]>} : tensor<2308xf64>
    %6 = stablehlo.slice %0 [2304:2308] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}]>]>} : (tensor<2308xf64>) -> tensor<4xf64>
    %7 = stablehlo.slice %0 [0:2304] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}]>]>} : (tensor<2308xf64>) -> tensor<2304xf64>
    %8 = stablehlo.pad %6, %cst, low = [0], high = [2304], interior = [0] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}]>]>} : (tensor<4xf64>, tensor<f64>) -> tensor<2308xf64>
    %9 = stablehlo.pad %7, %cst, low = [4], high = [0], interior = [0] : (tensor<2304xf64>, tensor<f64>) -> tensor<2308xf64>
    %10 = stablehlo.add %8, %9 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}]>]>} : tensor<2308xf64>
    %11 = stablehlo.iota dim = 0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}]>]>} : tensor<2308xi32>
    %12 = stablehlo.compare  LT, %11, %c_0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}]>]>} : (tensor<2308xi32>, tensor<2308xi32>) -> tensor<2308xi1>
    %13 = stablehlo.compare  LT, %11, %c {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}]>]>} : (tensor<2308xi32>, tensor<2308xi32>) -> tensor<2308xi1>
    %14 = stablehlo.select %12, %10, %0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}]>]>} : tensor<2308xi1>, tensor<2308xf64>
    %15 = stablehlo.select %13, %14, %5 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}]>]>} : tensor<2308xi1>, tensor<2308xf64>
    return %15 : tensor<2308xf64>
  }
}
ENTRY %main.0_spmd (param: f64[192]) -> f64[577] {
  %partition-id = u32[] partition-id()
  %param = f64[192]{0} parameter(0), sharding={devices=[12]<=[12]}, metadata={op_name="arg1 (path=(:args, 1))"}
  %constant.12 = f64[] constant(0)
  %constant.67 = s32[12]{0} constant({...})
  %constant.71 = s32[12]{0} constant({...})
  %wrapped_slice = f64[1]{0} fusion(%param), kind=kLoop, calls=%wrapped_slice_computation, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %wrapped_slice.1 = f64[9]{0} fusion(%param), kind=kLoop, calls=%wrapped_slice_computation.1, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %wrapped_broadcast = f64[4]{0} fusion(%constant.12), kind=kLoop, calls=%wrapped_broadcast_computation, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %collective-permute = f64[1]{0} collective-permute(%wrapped_slice), channel_id=1, source_target_pairs={{0,1},{1,2},{2,3},{3,4},{4,5},{5,6},{6,7},{7,8},{8,9},{9,10},{10,11}}, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %collective-permute.1 = f64[9]{0} collective-permute(%wrapped_slice.1), channel_id=2, source_target_pairs={{1,0},{2,1},{3,2},{4,3},{5,4},{6,5},{7,6},{8,7},{9,8},{10,9},{11,10}}, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %copy.5 = f64[4]{0} copy(%wrapped_broadcast)
  %copy.4 = f64[4]{0} copy(%wrapped_broadcast)
  %concatenate_pad_fusion = f64[204]{0} fusion(%collective-permute, %param, %collective-permute.1), kind=kLoop, calls=%fused_computation.19, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %dynamic-slice_slice_fusion.12 = f64[1]{0} fusion(%concatenate_pad_fusion, %partition-id), kind=kLoop, calls=%fused_computation.18, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %dynamic-slice_slice_fusion = f64[1]{0} fusion(%concatenate_pad_fusion, %partition-id), kind=kLoop, calls=%fused_computation.3, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %dynamic-slice_slice_fusion.1 = f64[1]{0} fusion(%concatenate_pad_fusion, %partition-id), kind=kLoop, calls=%fused_computation.4, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %dynamic-slice_slice_fusion.2 = f64[1]{0} fusion(%concatenate_pad_fusion, %partition-id), kind=kLoop, calls=%fused_computation.5, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %dynamic-slice_slice_fusion.4 = f64[3]{0} fusion(%concatenate_pad_fusion, %partition-id), kind=kLoop, calls=%fused_computation.8, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %dynamic-slice_slice_fusion.5 = f64[7]{0} fusion(%concatenate_pad_fusion, %partition-id), kind=kLoop, calls=%fused_computation.9, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %dynamic-slice_slice_fusion.8 = f64[11]{0} fusion(%concatenate_pad_fusion, %partition-id), kind=kLoop, calls=%fused_computation.13, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %dynamic-slice_slice_fusion.9 = f64[1]{0} fusion(%concatenate_pad_fusion, %partition-id), kind=kLoop, calls=%fused_computation.15, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %dynamic-slice_slice_fusion.10 = f64[1]{0} fusion(%concatenate_pad_fusion, %partition-id), kind=kLoop, calls=%fused_computation.16, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %dynamic-slice_slice_fusion.11 = f64[1]{0} fusion(%concatenate_pad_fusion, %partition-id), kind=kLoop, calls=%fused_computation.17, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %collective-permute.5 = f64[1]{0} collective-permute(%dynamic-slice_slice_fusion.12), channel_id=6, source_target_pairs={{11,3}}, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %collective-permute.12 = f64[1]{0} collective-permute(%dynamic-slice_slice_fusion), channel_id=14, source_target_pairs={{0,1}}, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %collective-permute.13 = f64[1]{0} collective-permute(%dynamic-slice_slice_fusion.1), channel_id=15, source_target_pairs={{0,2}}, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %collective-permute.14 = f64[1]{0} collective-permute(%dynamic-slice_slice_fusion.2), channel_id=16, source_target_pairs={{0,3}}, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %collective-permute.10 = f64[3]{0} collective-permute(%dynamic-slice_slice_fusion.4), channel_id=12, source_target_pairs={{1,0},{2,1},{3,2},{4,3},{5,4},{6,5},{7,6},{8,7},{9,8},{10,9},{11,10}}, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %collective-permute.9 = f64[7]{0} collective-permute(%dynamic-slice_slice_fusion.5), channel_id=11, source_target_pairs={{0,1},{1,2},{2,3},{3,4},{4,5},{5,6},{6,7},{7,8},{8,9},{9,10},{10,11}}, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %collective-permute.6 = f64[11]{0} collective-permute(%dynamic-slice_slice_fusion.8), channel_id=8, source_target_pairs={{0,1},{1,2},{2,3},{3,4},{4,5},{5,6},{6,7},{7,8},{8,9},{9,10},{10,11}}, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %collective-permute.2 = f64[1]{0} collective-permute(%dynamic-slice_slice_fusion.9), channel_id=3, source_target_pairs={{11,0}}, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %collective-permute.3 = f64[1]{0} collective-permute(%dynamic-slice_slice_fusion.10), channel_id=4, source_target_pairs={{11,1}}, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %collective-permute.4 = f64[1]{0} collective-permute(%dynamic-slice_slice_fusion.11), channel_id=5, source_target_pairs={{11,2}}, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %select_dynamic-update-slice_fusion = f64[4]{0} fusion(%copy.4, %collective-permute.14, %collective-permute.13, %collective-permute.12, %concatenate_pad_fusion, /*index=5*/%partition-id), kind=kLoop, calls=%fused_computation.2, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %dynamic-slice_concatenate_fusion = f64[203]{0} fusion(%collective-permute.9, %collective-permute.10, %concatenate_pad_fusion, %partition-id), kind=kLoop, calls=%fused_computation.7, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %slice_concatenate_fusion = f64[203]{0} fusion(%collective-permute.6, %concatenate_pad_fusion, %partition-id), kind=kLoop, calls=%fused_computation.12, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %select_dynamic-update-slice_fusion.1 = f64[4]{0} fusion(%copy.5, %collective-permute.5, %collective-permute.4, %collective-permute.3, %collective-permute.2, /*index=5*/%partition-id), kind=kLoop, calls=%fused_computation.14, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %dynamic-slice_slice_fusion.3 = f64[11]{0} fusion(%dynamic-slice_concatenate_fusion, %partition-id), kind=kLoop, calls=%fused_computation.6, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %dynamic-slice_slice_fusion.7 = f64[3]{0} fusion(%slice_concatenate_fusion, %partition-id), kind=kLoop, calls=%fused_computation.11, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %dynamic-slice_slice_fusion.6 = f64[7]{0} fusion(%slice_concatenate_fusion, %partition-id), kind=kLoop, calls=%fused_computation.10, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %all-reduce.2 = (f64[4]{0}, f64[4]{0}) all-reduce(%select_dynamic-update-slice_fusion, %select_dynamic-update-slice_fusion.1), channel_id=17, replica_groups=[1,12]<=[12], use_global_device_ids=true, to_apply=%add.1.clone, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %collective-permute.11 = f64[11]{0} collective-permute(%dynamic-slice_slice_fusion.3), channel_id=13, source_target_pairs={{1,0},{2,1},{3,2},{4,3},{5,4},{6,5},{7,6},{8,7},{9,8},{10,9},{11,10}}, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %collective-permute.7 = f64[3]{0} collective-permute(%dynamic-slice_slice_fusion.7), channel_id=9, source_target_pairs={{0,1},{1,2},{2,3},{3,4},{4,5},{5,6},{6,7},{7,8},{8,9},{9,10},{10,11}}, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %collective-permute.8 = f64[7]{0} collective-permute(%dynamic-slice_slice_fusion.6), channel_id=10, source_target_pairs={{1,0},{2,1},{3,2},{4,3},{5,4},{6,5},{7,6},{8,7},{9,8},{10,9},{11,10}}, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %get-tuple-element.3 = f64[4]{0} get-tuple-element(%all-reduce.2), index=0, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %get-tuple-element.4 = f64[4]{0} get-tuple-element(%all-reduce.2), index=1, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %add_select_fusion = f64[193]{0} fusion(%get-tuple-element.3, %collective-permute.11, %dynamic-slice_concatenate_fusion, %collective-permute.7, %collective-permute.8, /*index=5*/%slice_concatenate_fusion, %get-tuple-element.4, %concatenate_pad_fusion, %partition-id), kind=kLoop, calls=%fused_computation.1, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %wrapped_slice.2 = f64[6]{0} fusion(%add_select_fusion), kind=kLoop, calls=%wrapped_slice_computation.2
  %collective-permute.15 = f64[6]{0} collective-permute(%wrapped_slice.2), channel_id=18, source_target_pairs={{2,3},{3,4},{4,5},{5,6},{6,7},{7,8},{8,9},{9,10},{10,11}}
  %bitcast_dynamic-slice_fusion = f64[193]{0} fusion(%constant.71, %add_select_fusion, %collective-permute.15, %constant.67, %partition-id), kind=kLoop, calls=%fused_computation
  %all-gather = f64[579]{0} all-gather(%bitcast_dynamic-slice_fusion), channel_id=19, replica_groups=[4,3]<=[12], dimensions={0}, use_global_device_ids=true
  ROOT %wrapped_slice.3 = f64[577]{0} fusion(%all-gather), kind=kLoop, calls=%wrapped_slice_computation.3
}

@wsmoses
Copy link
Member

wsmoses commented Dec 16, 2025

Okay that's very weird, doing the debug logs can we see what's happening?

I wonder if someone is messing up our rotates. We may need to put an optimization barrier on them

@wsmoses
Copy link
Member

wsmoses commented Dec 16, 2025

Are you sure the hlo is with the right optimizations enabled vs disabled. Add select fusion and concat are both weird names to have in the IR

@wsmoses
Copy link
Member

wsmoses commented Dec 16, 2025

Oh I see what is happening, disable the rotate to pad

@wsmoses
Copy link
Member

wsmoses commented Dec 16, 2025

We should leave rotate as a concat and make sure it properly hits the rotate handler in xla

@avik-pal
Copy link
Collaborator

ENTRY %main.0_spmd (param: f64[192]) -> f64[577] {
  %partition-id = u32[] partition-id()
  %param = f64[192]{0} parameter(0), sharding={devices=[12]<=[12]}, metadata={op_name="arg1 (path=(:args, 1))"}
  %constant.8 = f64[] constant(0)
  %constant.35 = s32[12]{0} constant({...})
  %constant.39 = s32[12]{0} constant({...})
  %wrapped_slice.3 = f64[1]{0} fusion(%param), kind=kLoop, calls=%wrapped_slice_computation.3, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %wrapped_slice = f64[1]{0} fusion(%param), kind=kLoop, calls=%wrapped_slice_computation, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %wrapped_slice.2 = f64[9]{0} fusion(%param), kind=kLoop, calls=%wrapped_slice_computation.2, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %wrapped_slice.1 = f64[1]{0} fusion(%param), kind=kLoop, calls=%wrapped_slice_computation.1, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %wrapped_broadcast = f64[2]{0} fusion(%constant.8), kind=kLoop, calls=%wrapped_broadcast_computation, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %collective-permute = f64[1]{0} collective-permute(%wrapped_slice.3), channel_id=1, source_target_pairs={{11,0}}, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %collective-permute.4 = f64[1]{0} collective-permute(%wrapped_slice), channel_id=6, source_target_pairs={{0,1}}, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %collective-permute.3 = f64[9]{0} collective-permute(%wrapped_slice.2), channel_id=5, source_target_pairs={{1,0},{2,1},{3,2},{4,3},{5,4},{6,5},{7,6},{8,7},{9,8},{10,9},{11,10}}, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %collective-permute.1 = f64[1]{0} collective-permute(%wrapped_slice.1), channel_id=2, source_target_pairs={{11,1}}, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %collective-permute.2 = f64[1]{0} collective-permute(%wrapped_slice.1), channel_id=4, source_target_pairs={{0,1},{1,2},{2,3},{3,4},{4,5},{5,6},{6,7},{7,8},{8,9},{9,10},{10,11}}, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %copy.3 = f64[2]{0} copy(%wrapped_broadcast)
  %copy.2 = f64[2]{0} copy(%wrapped_broadcast)
  %select_dynamic-update-slice_fusion.1 = f64[2]{0} fusion(%copy.3, %collective-permute.1, %collective-permute, %partition-id), kind=kLoop, calls=%fused_computation.3, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %select_dynamic-update-slice_fusion = f64[2]{0} fusion(%copy.2, %collective-permute.4, %param, %partition-id), kind=kLoop, calls=%fused_computation.2, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %all-reduce.2 = (f64[2]{0}, f64[2]{0}) all-reduce(%select_dynamic-update-slice_fusion, %select_dynamic-update-slice_fusion.1), channel_id=7, replica_groups=[1,12]<=[12], use_global_device_ids=true, to_apply=%add.1.clone, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %get-tuple-element.3 = f64[2]{0} get-tuple-element(%all-reduce.2), index=0, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %get-tuple-element.4 = f64[2]{0} get-tuple-element(%all-reduce.2), index=1, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %dynamic-slice_add_fusion = f64[193]{0} fusion(%get-tuple-element.3, %collective-permute.2, %param, %collective-permute.3, %get-tuple-element.4, /*index=5*/%partition-id), kind=kLoop, calls=%fused_computation.1, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %wrapped_slice.4 = f64[6]{0} fusion(%dynamic-slice_add_fusion), kind=kLoop, calls=%wrapped_slice_computation.4
  %collective-permute.5 = f64[6]{0} collective-permute(%wrapped_slice.4), channel_id=8, source_target_pairs={{2,3},{3,4},{4,5},{5,6},{6,7},{7,8},{8,9},{9,10},{10,11}}
  %bitcast_dynamic-slice_fusion = f64[193]{0} fusion(%constant.39, %dynamic-slice_add_fusion, %collective-permute.5, %constant.35, %partition-id), kind=kLoop, calls=%fused_computation
  %all-gather = f64[579]{0} all-gather(%bitcast_dynamic-slice_fusion), channel_id=9, replica_groups=[4,3]<=[12], dimensions={0}, use_global_device_ids=true
  ROOT %wrapped_slice.5 = f64[577]{0} fusion(%all-gather), kind=kLoop, calls=%wrapped_slice_computation.5
}

get_optimize_comms_passes(compile_options.optimize_communications) = ["enzyme-hlo-generate-td{patterns=lower_rotate;concat_to_onedim_dus;concat_to_onedim_dusslice;concatreshape_to_onedim_dus}", "transform-interpreter", "enzyme-hlo-remove-transform", "enzyme-hlo-generate-td{patterns=reshape_to_broadcast}", "transform-interpreter", "enzyme-hlo-remove-transform", "optimize-communication{periodic_concat=0 rotate_comm=0 rotate_to_pad_comm=0 wrap_comm=0 extend_comm=0 dus_to_pad_manual_comp_comm=0 dus_to_pad_comm=0 concat_two_operands_comm=0 concat_to_pad_comm=1 extend_to_pad_comm=0 extend_to_pad_comm2=1 wrap_to_pad_comm=1 wrap_to_rotate=1}", "enzyme-hlo-generate-td{patterns=lower_rotate;lower_wrap;lower_extend}", "transform-interpreter", "enzyme-hlo-remove-transform", "optimize-communication{periodic_concat=0 rotate_comm=0 rotate_to_pad_comm=0 wrap_comm=0 extend_comm=0 dus_to_pad_manual_comp_comm=0 dus_to_pad_comm=0 concat_two_operands_comm=0 concat_to_pad_comm=1 extend_to_pad_comm=0 extend_to_pad_comm2=1 wrap_to_pad_comm=1 wrap_to_rotate=1}"]

module @reactant_wrap attributes {mhlo.num_partitions = 12 : i64, mhlo.num_replicas = 1 : i64} {
  sdy.mesh @mesh = <["x"=12]>
  func.func @main(%arg0: tensor<2304xf64> {enzymexla.memory_effects = [], sdy.sharding = #sdy.sharding<@mesh, [{"x"}]>}) -> (tensor<2308xf64> {sdy.sharding = #sdy.sharding<@mesh, [{"x":(1)4}]>}) attributes {enzymexla.memory_effects = []} {
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<f64>
    %0 = stablehlo.slice %arg0 [0:2] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}]>]>} : (tensor<2304xf64>) -> tensor<2xf64>
    %1 = stablehlo.slice %arg0 [2302:2304] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}]>]>} : (tensor<2304xf64>) -> tensor<2xf64>
    %2 = stablehlo.pad %0, %cst, low = [2306], high = [0], interior = [0] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}]>]>} : (tensor<2xf64>, tensor<f64>) -> tensor<2308xf64>
    %3 = stablehlo.pad %1, %cst, low = [0], high = [2306], interior = [0] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}]>]>} : (tensor<2xf64>, tensor<f64>) -> tensor<2308xf64>
    %4 = stablehlo.pad %arg0, %cst, low = [2], high = [2], interior = [0] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}]>]>} : (tensor<2304xf64>, tensor<f64>) -> tensor<2308xf64>
    %5 = stablehlo.add %3, %4 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}]>]>} : tensor<2308xf64>
    %6 = stablehlo.add %5, %2 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}]>]>} : tensor<2308xf64>
    return %6 : tensor<2308xf64>
  }
}

@wsmoses
Copy link
Member

wsmoses commented Dec 16, 2025

We also should disable the other wrap one

@avik-pal
Copy link
Collaborator

ENTRY %main.0_spmd (param: f64[192]) -> f64[577] {
  %partition-id = u32[] partition-id()
  %param = f64[192]{0} parameter(0), sharding={devices=[12]<=[12]}, metadata={op_name="arg1 (path=(:args, 1))"}
  %constant.12 = f64[] constant(0)
  %constant.67 = s32[12]{0} constant({...})
  %constant.71 = s32[12]{0} constant({...})
  %wrapped_slice = f64[1]{0} fusion(%param), kind=kLoop, calls=%wrapped_slice_computation, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %wrapped_slice.1 = f64[9]{0} fusion(%param), kind=kLoop, calls=%wrapped_slice_computation.1, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %wrapped_broadcast = f64[4]{0} fusion(%constant.12), kind=kLoop, calls=%wrapped_broadcast_computation, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %collective-permute = f64[1]{0} collective-permute(%wrapped_slice), channel_id=1, source_target_pairs={{0,1},{1,2},{2,3},{3,4},{4,5},{5,6},{6,7},{7,8},{8,9},{9,10},{10,11}}, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %collective-permute.1 = f64[9]{0} collective-permute(%wrapped_slice.1), channel_id=2, source_target_pairs={{1,0},{2,1},{3,2},{4,3},{5,4},{6,5},{7,6},{8,7},{9,8},{10,9},{11,10}}, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %copy.5 = f64[4]{0} copy(%wrapped_broadcast)
  %copy.4 = f64[4]{0} copy(%wrapped_broadcast)
  %concatenate_pad_fusion = f64[204]{0} fusion(%collective-permute, %param, %collective-permute.1), kind=kLoop, calls=%fused_computation.19, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %dynamic-slice_slice_fusion.12 = f64[1]{0} fusion(%concatenate_pad_fusion, %partition-id), kind=kLoop, calls=%fused_computation.18, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %dynamic-slice_slice_fusion = f64[1]{0} fusion(%concatenate_pad_fusion, %partition-id), kind=kLoop, calls=%fused_computation.3, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %dynamic-slice_slice_fusion.1 = f64[1]{0} fusion(%concatenate_pad_fusion, %partition-id), kind=kLoop, calls=%fused_computation.4, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %dynamic-slice_slice_fusion.2 = f64[1]{0} fusion(%concatenate_pad_fusion, %partition-id), kind=kLoop, calls=%fused_computation.5, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %dynamic-slice_slice_fusion.4 = f64[3]{0} fusion(%concatenate_pad_fusion, %partition-id), kind=kLoop, calls=%fused_computation.8, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %dynamic-slice_slice_fusion.5 = f64[7]{0} fusion(%concatenate_pad_fusion, %partition-id), kind=kLoop, calls=%fused_computation.9, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %dynamic-slice_slice_fusion.8 = f64[11]{0} fusion(%concatenate_pad_fusion, %partition-id), kind=kLoop, calls=%fused_computation.13, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %dynamic-slice_slice_fusion.9 = f64[1]{0} fusion(%concatenate_pad_fusion, %partition-id), kind=kLoop, calls=%fused_computation.15, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %dynamic-slice_slice_fusion.10 = f64[1]{0} fusion(%concatenate_pad_fusion, %partition-id), kind=kLoop, calls=%fused_computation.16, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %dynamic-slice_slice_fusion.11 = f64[1]{0} fusion(%concatenate_pad_fusion, %partition-id), kind=kLoop, calls=%fused_computation.17, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %collective-permute.5 = f64[1]{0} collective-permute(%dynamic-slice_slice_fusion.12), channel_id=6, source_target_pairs={{11,3}}, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %collective-permute.12 = f64[1]{0} collective-permute(%dynamic-slice_slice_fusion), channel_id=14, source_target_pairs={{0,1}}, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %collective-permute.13 = f64[1]{0} collective-permute(%dynamic-slice_slice_fusion.1), channel_id=15, source_target_pairs={{0,2}}, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %collective-permute.14 = f64[1]{0} collective-permute(%dynamic-slice_slice_fusion.2), channel_id=16, source_target_pairs={{0,3}}, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %collective-permute.10 = f64[3]{0} collective-permute(%dynamic-slice_slice_fusion.4), channel_id=12, source_target_pairs={{1,0},{2,1},{3,2},{4,3},{5,4},{6,5},{7,6},{8,7},{9,8},{10,9},{11,10}}, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %collective-permute.9 = f64[7]{0} collective-permute(%dynamic-slice_slice_fusion.5), channel_id=11, source_target_pairs={{0,1},{1,2},{2,3},{3,4},{4,5},{5,6},{6,7},{7,8},{8,9},{9,10},{10,11}}, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %collective-permute.6 = f64[11]{0} collective-permute(%dynamic-slice_slice_fusion.8), channel_id=8, source_target_pairs={{0,1},{1,2},{2,3},{3,4},{4,5},{5,6},{6,7},{7,8},{8,9},{9,10},{10,11}}, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %collective-permute.2 = f64[1]{0} collective-permute(%dynamic-slice_slice_fusion.9), channel_id=3, source_target_pairs={{11,0}}, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %collective-permute.3 = f64[1]{0} collective-permute(%dynamic-slice_slice_fusion.10), channel_id=4, source_target_pairs={{11,1}}, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %collective-permute.4 = f64[1]{0} collective-permute(%dynamic-slice_slice_fusion.11), channel_id=5, source_target_pairs={{11,2}}, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %select_dynamic-update-slice_fusion = f64[4]{0} fusion(%copy.4, %collective-permute.14, %collective-permute.13, %collective-permute.12, %concatenate_pad_fusion, /*index=5*/%partition-id), kind=kLoop, calls=%fused_computation.2, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %dynamic-slice_concatenate_fusion = f64[203]{0} fusion(%collective-permute.9, %collective-permute.10, %concatenate_pad_fusion, %partition-id), kind=kLoop, calls=%fused_computation.7, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %slice_concatenate_fusion = f64[203]{0} fusion(%collective-permute.6, %concatenate_pad_fusion, %partition-id), kind=kLoop, calls=%fused_computation.12, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %select_dynamic-update-slice_fusion.1 = f64[4]{0} fusion(%copy.5, %collective-permute.5, %collective-permute.4, %collective-permute.3, %collective-permute.2, /*index=5*/%partition-id), kind=kLoop, calls=%fused_computation.14, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %dynamic-slice_slice_fusion.3 = f64[11]{0} fusion(%dynamic-slice_concatenate_fusion, %partition-id), kind=kLoop, calls=%fused_computation.6, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %dynamic-slice_slice_fusion.7 = f64[3]{0} fusion(%slice_concatenate_fusion, %partition-id), kind=kLoop, calls=%fused_computation.11, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %dynamic-slice_slice_fusion.6 = f64[7]{0} fusion(%slice_concatenate_fusion, %partition-id), kind=kLoop, calls=%fused_computation.10, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %all-reduce.2 = (f64[4]{0}, f64[4]{0}) all-reduce(%select_dynamic-update-slice_fusion, %select_dynamic-update-slice_fusion.1), channel_id=17, replica_groups=[1,12]<=[12], use_global_device_ids=true, to_apply=%add.1.clone, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %collective-permute.11 = f64[11]{0} collective-permute(%dynamic-slice_slice_fusion.3), channel_id=13, source_target_pairs={{1,0},{2,1},{3,2},{4,3},{5,4},{6,5},{7,6},{8,7},{9,8},{10,9},{11,10}}, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %collective-permute.7 = f64[3]{0} collective-permute(%dynamic-slice_slice_fusion.7), channel_id=9, source_target_pairs={{0,1},{1,2},{2,3},{3,4},{4,5},{5,6},{6,7},{7,8},{8,9},{9,10},{10,11}}, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %collective-permute.8 = f64[7]{0} collective-permute(%dynamic-slice_slice_fusion.6), channel_id=10, source_target_pairs={{1,0},{2,1},{3,2},{4,3},{5,4},{6,5},{7,6},{8,7},{9,8},{10,9},{11,10}}, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %get-tuple-element.3 = f64[4]{0} get-tuple-element(%all-reduce.2), index=0, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %get-tuple-element.4 = f64[4]{0} get-tuple-element(%all-reduce.2), index=1, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %add_select_fusion = f64[193]{0} fusion(%get-tuple-element.3, %collective-permute.11, %dynamic-slice_concatenate_fusion, %collective-permute.7, %collective-permute.8, /*index=5*/%slice_concatenate_fusion, %get-tuple-element.4, %concatenate_pad_fusion, %partition-id), kind=kLoop, calls=%fused_computation.1, metadata={op_name="wrap/wrap" stack_frame_id=1}
  %wrapped_slice.2 = f64[6]{0} fusion(%add_select_fusion), kind=kLoop, calls=%wrapped_slice_computation.2
  %collective-permute.15 = f64[6]{0} collective-permute(%wrapped_slice.2), channel_id=18, source_target_pairs={{2,3},{3,4},{4,5},{5,6},{6,7},{7,8},{8,9},{9,10},{10,11}}
  %bitcast_dynamic-slice_fusion = f64[193]{0} fusion(%constant.71, %add_select_fusion, %collective-permute.15, %constant.67, %partition-id), kind=kLoop, calls=%fused_computation
  %all-gather = f64[579]{0} all-gather(%bitcast_dynamic-slice_fusion), channel_id=19, replica_groups=[4,3]<=[12], dimensions={0}, use_global_device_ids=true
  ROOT %wrapped_slice.3 = f64[577]{0} fusion(%all-gather), kind=kLoop, calls=%wrapped_slice_computation.3
}

@wsmoses
Copy link
Member

wsmoses commented Dec 16, 2025

Yeah the concat pad fusion makes me skeptical, try the optimization barrier when lowering the rotate. If not we need to define a rotate lowering directly to the rotate custom call from xla

@wsmoses
Copy link
Member

wsmoses commented Dec 16, 2025

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants