-
Notifications
You must be signed in to change notification settings - Fork 39
Test communication ops for wrap lowering #1972
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
test/optimize_comm.jl
Outdated
| end | ||
|
|
||
| function wrap(x) | ||
| return Reactant.Ops.@opcall wrap(x, 7, 7; dimension=1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you may want to change this to 2
86eedf3 to
f28f9f7
Compare
|
This test is supposed to fail. It fails locally with |
|
@jumerckx I ran this on hydra (with cpu) and these seemed to pass. |
|
Hmm... On Hydra, cpu with 8 forced devices and IFRT, the code contains all-gathers when I run it: ENV["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"
using Reactant, Test
Reactant.set_default_backend("cpu")
const addressable_devices = Reactant.addressable_devices()
@assert length(addressable_devices) == 8
function wrap(x)
return Reactant.Ops.@opcall wrap(x, 2, 2; dimension=1)
end
mesh = Sharding.Mesh(Reactant.devices(), (:x,))
sharding = Sharding.NamedSharding(mesh, (:x,))
x = Reactant.to_rarray(rand(8192); sharding)
@assert x isa ConcreteIFRTArray
hlo = repr(@code_xla wrap(x))
@test !contains(hlo, "all-to-all")
@test !contains(hlo, "all-gather")
@test contains(hlo, "collective-permute") |
|
I was trying inside of docker (via act) and it doesn't fail. let me investigate a bit more |
|
|
i htink I see the issue |
|
Aaah! |
|
@jumerckx try: |
|
hopefully this should work independent of device ccount [for large device counts] |
That fails as well, which is good (?) |
|
yeah sorry I meant "this should hopefully be a test that confirms failure on a larger device set" |
|
But it's strange that @avik-pal got the tests to pass with ndevices >= 8, how many devices does CI have? |
|
CI forces 12 devices |
|
That doesn't make sense then, why did it pass CI before?! |
|
nvm, I see the input size is changed as well |
|
do we expect EnzymeAD/Enzyme-JAX#1779 to fix this? |
|
it should, if not we have to figure out why |
|
|
Okay that's very weird, doing the debug logs can we see what's happening? I wonder if someone is messing up our rotates. We may need to put an optimization barrier on them |
|
Are you sure the hlo is with the right optimizations enabled vs disabled. Add select fusion and concat are both weird names to have in the IR |
|
Oh I see what is happening, disable the rotate to pad |
|
We should leave rotate as a concat and make sure it properly hits the rotate handler in xla |
get_optimize_comms_passes(compile_options.optimize_communications) = ["enzyme-hlo-generate-td{patterns=lower_rotate;concat_to_onedim_dus;concat_to_onedim_dusslice;concatreshape_to_onedim_dus}", "transform-interpreter", "enzyme-hlo-remove-transform", "enzyme-hlo-generate-td{patterns=reshape_to_broadcast}", "transform-interpreter", "enzyme-hlo-remove-transform", "optimize-communication{periodic_concat=0 rotate_comm=0 rotate_to_pad_comm=0 wrap_comm=0 extend_comm=0 dus_to_pad_manual_comp_comm=0 dus_to_pad_comm=0 concat_two_operands_comm=0 concat_to_pad_comm=1 extend_to_pad_comm=0 extend_to_pad_comm2=1 wrap_to_pad_comm=1 wrap_to_rotate=1}", "enzyme-hlo-generate-td{patterns=lower_rotate;lower_wrap;lower_extend}", "transform-interpreter", "enzyme-hlo-remove-transform", "optimize-communication{periodic_concat=0 rotate_comm=0 rotate_to_pad_comm=0 wrap_comm=0 extend_comm=0 dus_to_pad_manual_comp_comm=0 dus_to_pad_comm=0 concat_two_operands_comm=0 concat_to_pad_comm=1 extend_to_pad_comm=0 extend_to_pad_comm2=1 wrap_to_pad_comm=1 wrap_to_rotate=1}"] |
|
We also should disable the other wrap one |
ENTRY %main.0_spmd (param: f64[192]) -> f64[577] {
%partition-id = u32[] partition-id()
%param = f64[192]{0} parameter(0), sharding={devices=[12]<=[12]}, metadata={op_name="arg1 (path=(:args, 1))"}
%constant.12 = f64[] constant(0)
%constant.67 = s32[12]{0} constant({...})
%constant.71 = s32[12]{0} constant({...})
%wrapped_slice = f64[1]{0} fusion(%param), kind=kLoop, calls=%wrapped_slice_computation, metadata={op_name="wrap/wrap" stack_frame_id=1}
%wrapped_slice.1 = f64[9]{0} fusion(%param), kind=kLoop, calls=%wrapped_slice_computation.1, metadata={op_name="wrap/wrap" stack_frame_id=1}
%wrapped_broadcast = f64[4]{0} fusion(%constant.12), kind=kLoop, calls=%wrapped_broadcast_computation, metadata={op_name="wrap/wrap" stack_frame_id=1}
%collective-permute = f64[1]{0} collective-permute(%wrapped_slice), channel_id=1, source_target_pairs={{0,1},{1,2},{2,3},{3,4},{4,5},{5,6},{6,7},{7,8},{8,9},{9,10},{10,11}}, metadata={op_name="wrap/wrap" stack_frame_id=1}
%collective-permute.1 = f64[9]{0} collective-permute(%wrapped_slice.1), channel_id=2, source_target_pairs={{1,0},{2,1},{3,2},{4,3},{5,4},{6,5},{7,6},{8,7},{9,8},{10,9},{11,10}}, metadata={op_name="wrap/wrap" stack_frame_id=1}
%copy.5 = f64[4]{0} copy(%wrapped_broadcast)
%copy.4 = f64[4]{0} copy(%wrapped_broadcast)
%concatenate_pad_fusion = f64[204]{0} fusion(%collective-permute, %param, %collective-permute.1), kind=kLoop, calls=%fused_computation.19, metadata={op_name="wrap/wrap" stack_frame_id=1}
%dynamic-slice_slice_fusion.12 = f64[1]{0} fusion(%concatenate_pad_fusion, %partition-id), kind=kLoop, calls=%fused_computation.18, metadata={op_name="wrap/wrap" stack_frame_id=1}
%dynamic-slice_slice_fusion = f64[1]{0} fusion(%concatenate_pad_fusion, %partition-id), kind=kLoop, calls=%fused_computation.3, metadata={op_name="wrap/wrap" stack_frame_id=1}
%dynamic-slice_slice_fusion.1 = f64[1]{0} fusion(%concatenate_pad_fusion, %partition-id), kind=kLoop, calls=%fused_computation.4, metadata={op_name="wrap/wrap" stack_frame_id=1}
%dynamic-slice_slice_fusion.2 = f64[1]{0} fusion(%concatenate_pad_fusion, %partition-id), kind=kLoop, calls=%fused_computation.5, metadata={op_name="wrap/wrap" stack_frame_id=1}
%dynamic-slice_slice_fusion.4 = f64[3]{0} fusion(%concatenate_pad_fusion, %partition-id), kind=kLoop, calls=%fused_computation.8, metadata={op_name="wrap/wrap" stack_frame_id=1}
%dynamic-slice_slice_fusion.5 = f64[7]{0} fusion(%concatenate_pad_fusion, %partition-id), kind=kLoop, calls=%fused_computation.9, metadata={op_name="wrap/wrap" stack_frame_id=1}
%dynamic-slice_slice_fusion.8 = f64[11]{0} fusion(%concatenate_pad_fusion, %partition-id), kind=kLoop, calls=%fused_computation.13, metadata={op_name="wrap/wrap" stack_frame_id=1}
%dynamic-slice_slice_fusion.9 = f64[1]{0} fusion(%concatenate_pad_fusion, %partition-id), kind=kLoop, calls=%fused_computation.15, metadata={op_name="wrap/wrap" stack_frame_id=1}
%dynamic-slice_slice_fusion.10 = f64[1]{0} fusion(%concatenate_pad_fusion, %partition-id), kind=kLoop, calls=%fused_computation.16, metadata={op_name="wrap/wrap" stack_frame_id=1}
%dynamic-slice_slice_fusion.11 = f64[1]{0} fusion(%concatenate_pad_fusion, %partition-id), kind=kLoop, calls=%fused_computation.17, metadata={op_name="wrap/wrap" stack_frame_id=1}
%collective-permute.5 = f64[1]{0} collective-permute(%dynamic-slice_slice_fusion.12), channel_id=6, source_target_pairs={{11,3}}, metadata={op_name="wrap/wrap" stack_frame_id=1}
%collective-permute.12 = f64[1]{0} collective-permute(%dynamic-slice_slice_fusion), channel_id=14, source_target_pairs={{0,1}}, metadata={op_name="wrap/wrap" stack_frame_id=1}
%collective-permute.13 = f64[1]{0} collective-permute(%dynamic-slice_slice_fusion.1), channel_id=15, source_target_pairs={{0,2}}, metadata={op_name="wrap/wrap" stack_frame_id=1}
%collective-permute.14 = f64[1]{0} collective-permute(%dynamic-slice_slice_fusion.2), channel_id=16, source_target_pairs={{0,3}}, metadata={op_name="wrap/wrap" stack_frame_id=1}
%collective-permute.10 = f64[3]{0} collective-permute(%dynamic-slice_slice_fusion.4), channel_id=12, source_target_pairs={{1,0},{2,1},{3,2},{4,3},{5,4},{6,5},{7,6},{8,7},{9,8},{10,9},{11,10}}, metadata={op_name="wrap/wrap" stack_frame_id=1}
%collective-permute.9 = f64[7]{0} collective-permute(%dynamic-slice_slice_fusion.5), channel_id=11, source_target_pairs={{0,1},{1,2},{2,3},{3,4},{4,5},{5,6},{6,7},{7,8},{8,9},{9,10},{10,11}}, metadata={op_name="wrap/wrap" stack_frame_id=1}
%collective-permute.6 = f64[11]{0} collective-permute(%dynamic-slice_slice_fusion.8), channel_id=8, source_target_pairs={{0,1},{1,2},{2,3},{3,4},{4,5},{5,6},{6,7},{7,8},{8,9},{9,10},{10,11}}, metadata={op_name="wrap/wrap" stack_frame_id=1}
%collective-permute.2 = f64[1]{0} collective-permute(%dynamic-slice_slice_fusion.9), channel_id=3, source_target_pairs={{11,0}}, metadata={op_name="wrap/wrap" stack_frame_id=1}
%collective-permute.3 = f64[1]{0} collective-permute(%dynamic-slice_slice_fusion.10), channel_id=4, source_target_pairs={{11,1}}, metadata={op_name="wrap/wrap" stack_frame_id=1}
%collective-permute.4 = f64[1]{0} collective-permute(%dynamic-slice_slice_fusion.11), channel_id=5, source_target_pairs={{11,2}}, metadata={op_name="wrap/wrap" stack_frame_id=1}
%select_dynamic-update-slice_fusion = f64[4]{0} fusion(%copy.4, %collective-permute.14, %collective-permute.13, %collective-permute.12, %concatenate_pad_fusion, /*index=5*/%partition-id), kind=kLoop, calls=%fused_computation.2, metadata={op_name="wrap/wrap" stack_frame_id=1}
%dynamic-slice_concatenate_fusion = f64[203]{0} fusion(%collective-permute.9, %collective-permute.10, %concatenate_pad_fusion, %partition-id), kind=kLoop, calls=%fused_computation.7, metadata={op_name="wrap/wrap" stack_frame_id=1}
%slice_concatenate_fusion = f64[203]{0} fusion(%collective-permute.6, %concatenate_pad_fusion, %partition-id), kind=kLoop, calls=%fused_computation.12, metadata={op_name="wrap/wrap" stack_frame_id=1}
%select_dynamic-update-slice_fusion.1 = f64[4]{0} fusion(%copy.5, %collective-permute.5, %collective-permute.4, %collective-permute.3, %collective-permute.2, /*index=5*/%partition-id), kind=kLoop, calls=%fused_computation.14, metadata={op_name="wrap/wrap" stack_frame_id=1}
%dynamic-slice_slice_fusion.3 = f64[11]{0} fusion(%dynamic-slice_concatenate_fusion, %partition-id), kind=kLoop, calls=%fused_computation.6, metadata={op_name="wrap/wrap" stack_frame_id=1}
%dynamic-slice_slice_fusion.7 = f64[3]{0} fusion(%slice_concatenate_fusion, %partition-id), kind=kLoop, calls=%fused_computation.11, metadata={op_name="wrap/wrap" stack_frame_id=1}
%dynamic-slice_slice_fusion.6 = f64[7]{0} fusion(%slice_concatenate_fusion, %partition-id), kind=kLoop, calls=%fused_computation.10, metadata={op_name="wrap/wrap" stack_frame_id=1}
%all-reduce.2 = (f64[4]{0}, f64[4]{0}) all-reduce(%select_dynamic-update-slice_fusion, %select_dynamic-update-slice_fusion.1), channel_id=17, replica_groups=[1,12]<=[12], use_global_device_ids=true, to_apply=%add.1.clone, metadata={op_name="wrap/wrap" stack_frame_id=1}
%collective-permute.11 = f64[11]{0} collective-permute(%dynamic-slice_slice_fusion.3), channel_id=13, source_target_pairs={{1,0},{2,1},{3,2},{4,3},{5,4},{6,5},{7,6},{8,7},{9,8},{10,9},{11,10}}, metadata={op_name="wrap/wrap" stack_frame_id=1}
%collective-permute.7 = f64[3]{0} collective-permute(%dynamic-slice_slice_fusion.7), channel_id=9, source_target_pairs={{0,1},{1,2},{2,3},{3,4},{4,5},{5,6},{6,7},{7,8},{8,9},{9,10},{10,11}}, metadata={op_name="wrap/wrap" stack_frame_id=1}
%collective-permute.8 = f64[7]{0} collective-permute(%dynamic-slice_slice_fusion.6), channel_id=10, source_target_pairs={{1,0},{2,1},{3,2},{4,3},{5,4},{6,5},{7,6},{8,7},{9,8},{10,9},{11,10}}, metadata={op_name="wrap/wrap" stack_frame_id=1}
%get-tuple-element.3 = f64[4]{0} get-tuple-element(%all-reduce.2), index=0, metadata={op_name="wrap/wrap" stack_frame_id=1}
%get-tuple-element.4 = f64[4]{0} get-tuple-element(%all-reduce.2), index=1, metadata={op_name="wrap/wrap" stack_frame_id=1}
%add_select_fusion = f64[193]{0} fusion(%get-tuple-element.3, %collective-permute.11, %dynamic-slice_concatenate_fusion, %collective-permute.7, %collective-permute.8, /*index=5*/%slice_concatenate_fusion, %get-tuple-element.4, %concatenate_pad_fusion, %partition-id), kind=kLoop, calls=%fused_computation.1, metadata={op_name="wrap/wrap" stack_frame_id=1}
%wrapped_slice.2 = f64[6]{0} fusion(%add_select_fusion), kind=kLoop, calls=%wrapped_slice_computation.2
%collective-permute.15 = f64[6]{0} collective-permute(%wrapped_slice.2), channel_id=18, source_target_pairs={{2,3},{3,4},{4,5},{5,6},{6,7},{7,8},{8,9},{9,10},{10,11}}
%bitcast_dynamic-slice_fusion = f64[193]{0} fusion(%constant.71, %add_select_fusion, %collective-permute.15, %constant.67, %partition-id), kind=kLoop, calls=%fused_computation
%all-gather = f64[579]{0} all-gather(%bitcast_dynamic-slice_fusion), channel_id=19, replica_groups=[4,3]<=[12], dimensions={0}, use_global_device_ids=true
ROOT %wrapped_slice.3 = f64[577]{0} fusion(%all-gather), kind=kLoop, calls=%wrapped_slice_computation.3
} |
|
Yeah the concat pad fusion makes me skeptical, try the optimization barrier when lowering the rotate. If not we need to define a rotate lowering directly to the rotate custom call from xla |
No description provided.