@@ -2,47 +2,91 @@ mutable struct Array <: XLA.AbstractBuffer
22 buffer:: Ptr{Cvoid}
33
44 function Array (buffer:: Ptr{Cvoid} )
5- return finalizer (free_ifrt_array, new (buffer))
5+ # return finalizer(free_ifrt_array, new(buffer))
6+ return new (buffer)
67 end
78end
89
9- function Array (client:: Client , array:: Base.Array{T,N} , device:: Device ) where {T,N}
10+ function Array (
11+ client:: Client ,
12+ array:: Base.Array{T,N} ,
13+ device:: Device ,
14+ memory_kind:: AbstractString = string (convert (MemoryKind, XLA. default_memory (device))),
15+ ) where {T,N}
1016 sizear = collect (Int64, reverse (size (array)))
1117 buffer = GC. @preserve array sizear begin
1218 @ccall MLIR. API. mlir_c. ifrt_client_make_single_shard_array_from_host_buffer (
1319 client. client:: Ptr{Cvoid} ,
14- pointer ( array) :: Ptr{T} ,
20+ array:: Ptr{T} ,
1521 XLA. primitive_type (T):: UInt64 ,
1622 N:: Csize_t ,
17- pointer ( sizear) :: Ptr{Int64} ,
23+ sizear:: Ptr{Int64} ,
1824 0 :: Cint , # kAlwaysCopy
1925 device. device:: Ptr{Cvoid} ,
20- string (convert (MemoryKind, XLA . default_memory (device)) ):: Cstring ,
26+ string (memory_kind ):: Cstring ,
2127 ):: Ptr{Cvoid}
2228 end
2329 return Array (buffer)
2430end
2531
26- function Array (client:: Client , array:: Base.Array{T,N} , sharding:: HloSharding ) where {T,N}
27- return Array (client, array, convert (Sharding, sharding))
28- end
29-
30- function Array (client:: Client , array:: Base.Array{T,N} , sharding:: Sharding ) where {T,N}
32+ function Array (
33+ client:: Client , array:: Base.Array{T,N} , sharding:: Sharding , logical_device_ids
34+ ) where {T,N}
3135 sizear = collect (Int64, reverse (size (array)))
32- buffer = GC. @preserve array sizear begin
33- @ccall MLIR. API. mlir_c. ifrt_client_make_array_from_host_buffer (
36+
37+ if is_single_device_sharding (sharding) || is_fully_replicated (sharding)
38+ buffer = GC. @preserve array sizear begin
39+ @ccall MLIR. API. mlir_c. ifrt_client_make_array_from_host_buffer (
40+ client. client:: Ptr{Cvoid} ,
41+ array:: Ptr{T} ,
42+ XLA. primitive_type (T):: Cint ,
43+ N:: Csize_t ,
44+ sizear:: Ptr{Int64} ,
45+ sharding. ptr:: Ptr{Cvoid} ,
46+ 0 :: Cint , # kAlwaysCopy
47+ ):: Ptr{Cvoid}
48+ end
49+ return Array (buffer)
50+ end
51+
52+ all_devices = XLA. devices (sharding)
53+ array_slices = XLA. sharding_to_concrete_array_indices (
54+ convert (XLA. HloSharding, sharding), size (array), logical_device_ids
55+ )
56+ array_shape = collect (Int64, reverse (size (array)))
57+ arrays_list = [
58+ Array (client, array[slice... ], device). buffer for
59+ (device, slice) in zip (all_devices, array_slices) if XLA. is_addressable (device)
60+ ]
61+
62+ buffer = GC. @preserve client arrays_list array_shape sharding begin
63+ @ccall MLIR. API. mlir_c. ifrt_client_assemble_array_from_single_shards (
3464 client. client:: Ptr{Cvoid} ,
35- pointer (array):: Ptr{T} ,
36- XLA. primitive_type (T):: Cint ,
37- N:: Csize_t ,
38- pointer (sizear):: Ptr{Int64} ,
65+ Int32 (length (array_shape)):: Int32 ,
66+ array_shape:: Ptr{Int64} ,
3967 sharding. ptr:: Ptr{Cvoid} ,
40- 0 :: Cint , # kAlwaysCopy
68+ Int32 (length (arrays_list)):: Int32 ,
69+ arrays_list:: Ptr{Ptr{Cvoid}} ,
70+ 2 :: Cint , # kDonateInput
4171 ):: Ptr{Cvoid}
4272 end
73+
4374 return Array (buffer)
4475end
4576
77+ function Array (client:: Client , array:: Base.Array{T,N} , sharding) where {T,N}
78+ @assert sharding isa Reactant. Sharding. AbstractSharding
79+ if ! (sharding isa Reactant. Sharding. HloSharding)
80+ sharding = convert (Reactant. Sharding. HloSharding, sharding)
81+ end
82+
83+ (; hlo_sharding, mesh) = sharding
84+ devices = XLA. get_device .((client,), mesh. device_ids)
85+ ifrt_sharding = Sharding ([devices... ], hlo_sharding)
86+
87+ return Array (client, array, ifrt_sharding, mesh. logical_device_ids)
88+ end
89+
4690@inline function free_ifrt_array (buffer:: Array )
4791 sbuffer = buffer. buffer
4892 if sbuffer != C_NULL
0 commit comments