36
36
@nospecialize (obj:: AbstractArray{T} ), field, val
37
37
) where {T}
38
38
ancestor_obj = ancestor (obj)
39
- if isbitstype (T) || ancestor_obj isa RArray
40
- if val isa XLA. AsyncBuffer
41
- if Reactant. Sharding. is_sharded (ancestor_obj)
42
- error (" `val` can't be a buffer if `obj` is sharded" )
43
- else
44
- return Base. setfield! (obj, field, (val,))
45
- end
46
- end
47
- return Base. setfield! (obj, field, val)
48
- end
39
+ (isbitstype (T) || ancestor_obj isa RArray) && return Base. setfield! (obj, field, val)
49
40
return Base. setindex! (obj, val, field)
50
41
end
51
42
@@ -75,40 +66,48 @@ function create_result(
75
66
return Expr (:new , T, elems... )
76
67
end
77
68
69
+ function __reconstruct_shardinfo (path, path_to_shard_info, sharding_mesh)
70
+ device_to_array_slices, partition_spec = path_to_shard_info[path]
71
+ delete! (path_to_shard_info, path)
72
+ sharding = Reactant. Sharding. NamedSharding (sharding_mesh, partition_spec)
73
+ return Reactant. Sharding. ShardInfo (sharding, device_to_array_slices)
74
+ end
75
+
78
76
function create_result (
79
- tocopy:: ConcreteRNumber{T} , path, result_stores, path_to_shard_info, sharding_mesh
80
- ) where {T}
77
+ tocopy:: ConcreteRNumber{T,D,S } , path, result_stores, path_to_shard_info, sharding_mesh
78
+ ) where {T,D,S }
81
79
if haskey (result_stores, path)
82
80
restore = result_stores[path]
83
81
delete! (result_stores, path)
84
- return :(ConcreteRNumber {$T} ($ restore))
82
+ if path_to_shard_info != = nothing # restore sharding
83
+ sharding = __reconstruct_shardinfo (path, path_to_shard_info, sharding_mesh)
84
+ return :(ConcreteRNumber {$T,length($(restore)),$(typeof(sharding))} (
85
+ ($ (restore). .. ,), $ sharding
86
+ ))
87
+ else
88
+ return :(ConcreteRNumber {$T} ($ restore))
89
+ end
90
+ end
91
+
92
+ if path_to_shard_info != = nothing # restore sharding
93
+ sharding = __reconstruct_shardinfo (path, path_to_shard_info, sharding_mesh)
94
+ return :(ConcreteRNumber {$T,length($(tocopy.data)),$(typeof(sharding))} (
95
+ ($ (tocopy. data... ,)), $ sharding
96
+ ))
85
97
end
86
98
# We will set the data for this later
87
99
return :(ConcreteRNumber {$T} ($ (tocopy. data)))
88
100
end
89
101
90
- function __construct_sharding_for_carray (
91
- :: ConcreteRArray{T,N,D,S} , path, _, path_to_shard_info, sharding_mesh
92
- ) where {T,N,D,S}
93
- device_to_array_slices, partition_spec = path_to_shard_info[path]
94
- delete! (path_to_shard_info, path)
95
- sharding = Reactant. Sharding. NamedSharding (sharding_mesh, partition_spec)
96
- return Reactant. Sharding. FinalizedNamedSharding {typeof(sharding),ndims(sharding_mesh)} (
97
- sharding, device_to_array_slices
98
- )
99
- end
100
-
101
102
function create_result (
102
103
tocopy:: ConcreteRArray{T,N,D,S} , path, result_stores, path_to_shard_info, sharding_mesh
103
104
) where {T,N,D,S}
104
105
if haskey (result_stores, path)
105
106
restore = result_stores[path]
106
107
delete! (result_stores, path)
107
108
if path_to_shard_info != = nothing # restore sharding
108
- sharding = __construct_sharding_for_carray (
109
- tocopy, path, result_stores, path_to_shard_info, sharding_mesh
110
- )
111
- return :(ConcreteRArray {$T,$N,$(ndims(sharding_mesh)),$(typeof(sharding))} (
109
+ sharding = __reconstruct_shardinfo (path, path_to_shard_info, sharding_mesh)
110
+ return :(ConcreteRArray {$T,$N,length($(restore)),$(typeof(sharding))} (
112
111
($ (restore). .. ,), $ (tocopy. shape), $ sharding
113
112
))
114
113
else
@@ -117,10 +116,8 @@ function create_result(
117
116
end
118
117
119
118
if path_to_shard_info != = nothing # restore sharding
120
- sharding = __construct_sharding_for_carray (
121
- tocopy, path, result_stores, path_to_shard_info, sharding_mesh
122
- )
123
- return :(ConcreteRArray {$T,$N,$(ndims(sharding_mesh)),$(typeof(sharding))} (
119
+ sharding = __reconstruct_shardinfo (path, path_to_shard_info, sharding_mesh)
120
+ return :(ConcreteRArray {$T,$N,length($(tocopy.data)),$(typeof(sharding))} (
124
121
($ (tocopy. data). .. ,), $ (tocopy. shape), $ sharding
125
122
))
126
123
end
@@ -365,6 +362,7 @@ function optimization_passes(; no_nan::Bool=false, sroa::Bool=false, inline::Boo
365
362
" binary_op_transpose_simplify_or" ,
366
363
" binary_op_transpose_simplify_and" ,
367
364
" binary_op_transpose_simplify_xor" ,
365
+ " associative_binary_op_reordering<1>" ,
368
366
" transpose_unary_transpose_abs" ,
369
367
" transpose_unary_transpose_neg" ,
370
368
" transpose_unary_transpose_sqrt" ,
@@ -380,12 +378,15 @@ function optimization_passes(; no_nan::Bool=false, sroa::Bool=false, inline::Boo
380
378
" transpose_unary_transpose_sine" ,
381
379
" transpose_unary_transpose_tanh" ,
382
380
" transpose_broadcast_in_dim_to_broadcast_in_dim<16>" ,
381
+ " scatter_indices_are_unique" ,
382
+ " transpose_reduce_simplify" ,
383
383
" replace_neg_add_with_subtract" ,
384
384
" log_const_prop<1>" ,
385
385
" log_plus_one_const_prop<1>" ,
386
386
" binop_const_simplify" ,
387
387
" transpose_broadcast_in_dim_to_broadcast_in_dim" ,
388
388
" not_select_simplify" ,
389
+ " scatter_update_computation_const_prop" ,
389
390
" common_compare_expression_rewrite" ,
390
391
" compare_select_simplify" ,
391
392
" while_simplify<1>" ,
@@ -794,10 +795,12 @@ function compile_mlir!(
794
795
results = [MLIR. IR. operand (ret, i) for i in 1 : MLIR. IR. noperands (ret)]
795
796
nresults = MLIR. IR. Value[]
796
797
linear_results2 = TracedType[]
798
+ results_mask = falses (length (results))
797
799
for (i, op) in enumerate (results)
798
800
if ! MLIR. IR. is_block_arg (op)
799
801
push! (nresults, op)
800
802
push! (linear_results2, linear_results[i])
803
+ results_mask[i] = true
801
804
continue
802
805
end
803
806
push! (preserved_args, (linear_results[i], MLIR. IR. block_arg_num (op)))
@@ -812,11 +815,18 @@ function compile_mlir!(
812
815
813
816
out_tys2 = [MLIR. IR. type (a) for a in nresults]
814
817
818
+ res_attrs = MLIR. IR. attr (compiled_f, " res_attrs" )
819
+ if res_attrs isa MLIR. IR. Attribute
820
+ res_attrs = [
821
+ res_attrs[i - 1 ] for (i, present) in enumerate (results_mask) if present
822
+ ]
823
+ end
824
+
815
825
func3 = MLIR. Dialects. func. func_ (;
816
826
sym_name= " main" ,
817
827
function_type= MLIR. IR. FunctionType (in_tys, out_tys2),
818
828
arg_attrs= MLIR. IR. attr (compiled_f, " arg_attrs" ),
819
- res_attrs= MLIR . IR . attr (compiled_f, " res_attrs " ) ,
829
+ res_attrs,
820
830
no_inline= MLIR. IR. attr (compiled_f, " no_inline" ),
821
831
body= MLIR. IR. Region (),
822
832
)
@@ -837,7 +847,6 @@ function compile_mlir!(
837
847
linear_args,
838
848
in_tys,
839
849
linear_results2,
840
- mlir_fn_res. linear_result_shard_info,
841
850
mlir_fn_res. num_partitions,
842
851
mlir_fn_res. num_replicas,
843
852
mlir_fn_res. is_sharded,
@@ -862,6 +871,22 @@ macro code_hlo(args...)
862
871
$ (first)($ (compiled))))
863
872
end
864
873
874
+ """
875
+ @code_mhlo [optimize = ...] [no_nan = <true/false>] f(args...)
876
+
877
+ Similar to `@code_hlo`, but prints the module after running the XLA compiler.
878
+ """
879
+ macro code_mhlo (args... )
880
+ default_options = Dict {Symbol,Any} (
881
+ :optimize => true , :no_nan => false , :client => nothing
882
+ )
883
+ compile_expr, (; compiled) = compile_call_expr (
884
+ __module__, compile_xla, default_options, args...
885
+ )
886
+ return esc (:($ (compile_expr);
887
+ $ (first)($ (compiled))))
888
+ end
889
+
865
890
"""
866
891
@compile [optimize = ...] [no_nan = <true/false>] [sync = <true/false>] f(args...)
867
892
"""
@@ -998,7 +1023,7 @@ function codegen_flatten!(
998
1023
999
1024
if is_sharded
1000
1025
carg = inv_seen_args[arg]
1001
- if carg isa ConcreteRArray && Reactant. Sharding. is_sharded (carg)
1026
+ if Reactant. Sharding. is_sharded (carg)
1002
1027
for j in 1 : length (mesh)
1003
1028
sbuf = Symbol (:sbuf_ , i, " _" , j)
1004
1029
push! (flatten_names, sbuf)
@@ -1007,17 +1032,11 @@ function codegen_flatten!(
1007
1032
else
1008
1033
# Warn here first and then replicate the input across all devices on the
1009
1034
# mesh
1010
- if carg isa ConcreteRArray
1011
- @warn " Input $carg is not sharded, replicating across all devices. It \
1012
- is recommended to replicate the input across all devices on the \
1013
- mesh manually using `Reactant.Sharding.NamedSharding`" maxlog = 1
1014
- end
1035
+ @warn " Input $carg is not sharded, replicating across all devices. It \
1036
+ is recommended to replicate the input across all devices on the \
1037
+ mesh manually using `Reactant.Sharding.NamedSharding`" maxlog = 1
1015
1038
buf = Symbol (:buf_ , i)
1016
- if carg isa ConcreteRArray
1017
- push! (flatten_code, :($ buf = XLA. synced_buffer (only ($ usbuf))))
1018
- else
1019
- push! (flatten_code, :($ buf = XLA. synced_buffer ($ usbuf)))
1020
- end
1039
+ push! (flatten_code, :($ buf = XLA. synced_buffer (only ($ usbuf))))
1021
1040
for j in 1 : length (mesh)
1022
1041
device_id = mesh. device_ids[j]
1023
1042
device_ordinal = XLA. device_ordinal (client, device_id)
@@ -1030,9 +1049,7 @@ function codegen_flatten!(
1030
1049
else
1031
1050
sbuf = Symbol (:sbuf_ , i)
1032
1051
push! (flatten_names, sbuf)
1033
- if arg isa TracedRNumber
1034
- push! (flatten_code, :($ sbuf = XLA. synced_buffer ($ usbuf)))
1035
- elseif arg isa TracedRArray
1052
+ if arg isa TracedRArray || arg isa TracedRNumber
1036
1053
push! (flatten_code, :($ sbuf = only (XLA. synced_buffer ($ usbuf))))
1037
1054
else
1038
1055
error (" Unsupported type $(typeof (arg)) " )
@@ -1061,7 +1078,6 @@ function codegen_unflatten!(
1061
1078
concrete_result,
1062
1079
result_stores,
1063
1080
path_to_shard_info,
1064
- is_sharded:: Bool ,
1065
1081
linear_result_shard_info,
1066
1082
sharding_mesh,
1067
1083
)
@@ -1369,26 +1385,28 @@ function compile_xla(f, args; client=nothing, kwargs...)
1369
1385
mlir_fn_res. is_sharded,
1370
1386
)
1371
1387
1372
- mlir_fn_res. num_partitions > 1 && (device = nothing )
1373
-
1374
1388
# Attach a name, and partitioning attributes to the module
1375
1389
__add_mhlo_attributes_and_name! (
1376
1390
mod, f; mlir_fn_res. num_partitions, mlir_fn_res. num_replicas
1377
1391
)
1378
1392
1379
1393
# compile MLIR module to XLA executable
1380
- is_sharded = mlir_fn_res. num_partitions > 1
1381
- if is_sharded
1382
- # mesh_shape = collect(Int64, size(mlir_fn_res.sharding_mesh))
1383
- mesh_ids = collect (Int64, vec (mlir_fn_res. sharding_mesh. device_ids))
1394
+ mlir_fn_res. is_sharded && (device = nothing )
1395
+ mesh_ids = if mlir_fn_res. is_sharded
1396
+ collect (Int64, mlir_fn_res. sharding_mesh. device_ids)
1384
1397
else
1385
- # mesh_shape = Int64[]
1386
- mesh_ids = Int64[]
1398
+ Int64[]
1387
1399
end
1388
- # exec = XLA.Compile(client, device, mod; is_sharded, mesh_ids, mesh_shape)
1389
- exec = XLA. Compile (client, device, mod; is_sharded, mesh_ids)
1400
+ exec = XLA. Compile (
1401
+ client,
1402
+ device,
1403
+ mod;
1404
+ num_results= length (mlir_fn_res. linear_results),
1405
+ mlir_fn_res. is_sharded,
1406
+ mesh_ids,
1407
+ )
1390
1408
1391
- return exec, mlir_fn_res, device, client
1409
+ return mod, exec, mlir_fn_res, device, client
1392
1410
finally
1393
1411
MLIR. IR. deactivate! (ctx)
1394
1412
end
@@ -1398,7 +1416,7 @@ function compile_xla(f, args; client=nothing, kwargs...)
1398
1416
end
1399
1417
1400
1418
function compile (f, args; sync= false , kwargs... )
1401
- exec, mlir_fn_res, device, client = compile_xla (f, args; kwargs... )
1419
+ _, exec, mlir_fn_res, device, client = compile_xla (f, args; kwargs... )
1402
1420
(; linear_args, seen_args, linear_results, preserved_args, concrete_result) =
1403
1421
mlir_fn_res
1404
1422
@@ -1408,11 +1426,7 @@ function compile(f, args; sync=false, kwargs...)
1408
1426
end
1409
1427
1410
1428
result_stores = Dict {Tuple,Symbol} ()
1411
- path_to_shard_info = if mlir_fn_res. is_sharded
1412
- Dict{Tuple,Tuple{Array{Vector{UnitRange{Int}}},Tuple}}()
1413
- else
1414
- nothing
1415
- end
1429
+ path_to_shard_info = mlir_fn_res. is_sharded ? Dict {Tuple,Tuple} () : nothing
1416
1430
1417
1431
# generate Julia `Thunk` code
1418
1432
flatten_arg_names, flatten_code = codegen_flatten! (
@@ -1431,9 +1445,25 @@ function compile(f, args; sync=false, kwargs...)
1431
1445
donated_args_mask,
1432
1446
length (linear_results),
1433
1447
mlir_fn_res. is_sharded,
1434
- mlir_fn_res. is_sharded ? vec (mlir_fn_res. sharding_mesh. device_ids) : Int64[],
1448
+ if mlir_fn_res. is_sharded
1449
+ collect (Int64, mlir_fn_res. sharding_mesh. device_ids)
1450
+ else
1451
+ Int64[]
1452
+ end ,
1435
1453
)
1436
1454
1455
+ linear_result_shard_info = if mlir_fn_res. is_sharded
1456
+ # Generate a tuple of DeviceToArraySlices and PartitionSpecs
1457
+ output_shardings = XLA. get_output_shardings (exec)
1458
+ XLA. compute_array_indices_and_partition_spec .(
1459
+ output_shardings,
1460
+ size .(mlir_fn_res. linear_results),
1461
+ (mlir_fn_res. sharding_mesh,),
1462
+ )
1463
+ else
1464
+ ntuple (Returns (nothing ), length (linear_results))
1465
+ end
1466
+
1437
1467
unflatten_code = codegen_unflatten! (
1438
1468
linear_args,
1439
1469
preserved_args,
@@ -1442,8 +1472,7 @@ function compile(f, args; sync=false, kwargs...)
1442
1472
concrete_result,
1443
1473
result_stores,
1444
1474
path_to_shard_info,
1445
- mlir_fn_res. is_sharded,
1446
- mlir_fn_res. linear_result_shard_info,
1475
+ linear_result_shard_info,
1447
1476
mlir_fn_res. sharding_mesh,
1448
1477
)
1449
1478
0 commit comments