@@ -18,16 +18,14 @@ use super::*;
1818use anyhow:: Result ;
1919use nixl_sys:: { MemoryRegion , NixlDescriptor , OptArgs , XferDescList , XferOp } ;
2020use std:: future:: { poll_fn, Future } ;
21- use std:: ops:: Range ;
2221use std:: task:: Poll ;
2322
24- /// Copy a block from a source to a destination using CUDA memcpy
25- pub fn write_block_to < ' a , Source , Destination > (
26- src : & ' a Source ,
27- dst : & ' a mut Destination ,
28- ctx : Arc < TransferContext > ,
29- notify : Option < String > ,
30- ) -> Result < Box < dyn Future < Output = ( ) > + Send + Sync + Unpin > >
23+ fn append_xfer_request < Source , Destination > (
24+ src : & Arc < Source > ,
25+ dst : & mut Destination ,
26+ src_dl : & mut XferDescList ,
27+ dst_dl : & mut XferDescList ,
28+ ) -> Result < ( ) >
3129where
3230 Source : BlockDataProvider ,
3331 Destination : BlockDataProviderMut ,
3634 let dst_data = dst. block_data_mut ( private:: PrivateToken ) ;
3735
3836 if src_data. is_fully_contiguous ( ) && dst_data. is_fully_contiguous ( ) {
39- // Keep the arc to use in the returned future.
40- let nixl_agent_arc = ctx. as_ref ( ) . nixl_agent ( ) ;
41-
42- let nixl_agent = nixl_agent_arc
43- . as_ref ( )
44- . as_ref ( )
45- . expect ( "NIXL agent not found" ) ;
46-
47- let mut src_dl = XferDescList :: new ( src_data. storage_type ( ) . nixl_mem_type ( ) ) ?;
48- let mut dst_dl = XferDescList :: new ( dst_data. storage_type ( ) . nixl_mem_type ( ) ) ?;
49-
5037 let src_desc = src_data. block_view ( ) ?. as_nixl_descriptor ( ) ;
5138 let dst_desc = dst_data. block_view_mut ( ) ?. as_nixl_descriptor_mut ( ) ;
5239
@@ -64,121 +51,99 @@ where
6451 ) ?;
6552 }
6653
67- let xfer_req = nixl_agent
68- . create_xfer_req ( XferOp :: Write , & src_dl, & dst_dl, & nixl_agent. name ( ) , None )
69- . unwrap ( ) ;
70-
71- let mut xfer_args = OptArgs :: new ( ) ?;
72-
73- if let Some ( notify) = notify {
74- xfer_args. set_has_notification ( true ) ?;
75- xfer_args. set_notification_message ( notify. as_bytes ( ) ) ?;
76- }
77-
78- let _ = nixl_agent. post_xfer_req ( & xfer_req, Some ( & xfer_args) ) ?;
79-
80- // Return a future that completes when the transfer is complete.
81- // TODO: How efficient is this? Can we do better?
82- Ok ( Box :: new ( poll_fn ( move |_cx| {
83- let nixl_agent = nixl_agent_arc
84- . as_ref ( )
85- . as_ref ( )
86- . expect ( "NIXL agent not found" ) ;
87-
88- // The nixl agent returns true if the transfer is still in progress.
89- if !nixl_agent. get_xfer_status ( & xfer_req) . unwrap ( ) {
90- Poll :: Ready ( ( ) )
91- } else {
92- Poll :: Pending
93- }
94- } ) ) )
54+ Ok ( ( ) )
9555 } else {
9656 assert_eq ! ( src_data. num_layers( ) , dst_data. num_layers( ) ) ;
97- write_layers_to ( 0 ..src_data. num_layers ( ) , src, dst, ctx, notify)
57+ for layer_idx in 0 ..src_data. num_layers ( ) {
58+ for outer_idx in 0 ..src_data. num_outer_dims ( ) {
59+ let src_view = src_data. layer_view ( layer_idx, outer_idx) ?;
60+ let mut dst_view = dst_data. layer_view_mut ( layer_idx, outer_idx) ?;
61+
62+ debug_assert_eq ! ( src_view. size( ) , dst_view. size( ) ) ;
63+
64+ let src_desc = src_view. as_nixl_descriptor ( ) ;
65+ let dst_desc = dst_view. as_nixl_descriptor_mut ( ) ;
66+
67+ unsafe {
68+ src_dl. add_desc (
69+ src_desc. as_ptr ( ) as usize ,
70+ src_desc. size ( ) ,
71+ src_desc. device_id ( ) ,
72+ ) ?;
73+
74+ dst_dl. add_desc (
75+ dst_desc. as_ptr ( ) as usize ,
76+ dst_desc. size ( ) ,
77+ dst_desc. device_id ( ) ,
78+ ) ?;
79+ }
80+ }
81+ }
82+ Ok ( ( ) )
9883 }
9984}
10085
101- /// Copy a range of layers from a source to a destination using CUDA memcpy
102- pub fn write_layers_to < ' a , Source , Destination > (
103- layer_range : Range < usize > ,
104- src : & ' a Source ,
105- dst : & ' a mut Destination ,
86+ /// Copy a block from a source to a destination using CUDA memcpy
87+ pub fn write_blocks_to < Source , Destination > (
88+ src : & [ Arc < Source > ] ,
89+ dst : & mut [ Destination ] ,
10690 ctx : Arc < TransferContext > ,
10791 notify : Option < String > ,
10892) -> Result < Box < dyn Future < Output = ( ) > + Send + Sync + Unpin > >
10993where
11094 Source : BlockDataProvider ,
11195 Destination : BlockDataProviderMut ,
11296{
113- let src_data = src. block_data ( private:: PrivateToken ) ;
114- let dst_data = dst. block_data_mut ( private:: PrivateToken ) ;
97+ if src. is_empty ( ) || dst. is_empty ( ) {
98+ return Ok ( Box :: new ( std:: future:: ready ( ( ) ) ) ) ;
99+ }
100+ assert_eq ! ( src. len( ) , dst. len( ) ) ;
115101
116102 let nixl_agent_arc = ctx. as_ref ( ) . nixl_agent ( ) ;
117103 let nixl_agent = nixl_agent_arc
118104 . as_ref ( )
119105 . as_ref ( )
120106 . expect ( "NIXL agent not found" ) ;
121107
122- let remote_worker_id = dst_data. worker_id . to_string ( ) ;
123- let mut src_dl = XferDescList :: new ( src_data. storage_type ( ) . nixl_mem_type ( ) ) ?;
124- let mut dst_dl = XferDescList :: new ( dst_data. storage_type ( ) . nixl_mem_type ( ) ) ?;
125-
126- // #[cfg(debug_assertions)]
127- // {
128- // let expected_strategy = <<Source as BlockDataProvider>::StorageType as WriteToStrategy<
129- // Destination::StorageType,
130- // >>::write_to_strategy();
131- // assert_eq!(strategy, expected_strategy);
132- // }
133-
134- for layer_idx in layer_range {
135- for outer_idx in 0 ..src_data. num_outer_dims ( ) {
136- let src_view = src_data. layer_view ( layer_idx, outer_idx) ?;
137- let mut dst_view = dst_data. layer_view_mut ( layer_idx, outer_idx) ?;
138-
139- debug_assert_eq ! ( src_view. size( ) , dst_view. size( ) ) ;
140-
141- let src_desc = src_view. as_nixl_descriptor ( ) ;
142- let dst_desc = dst_view. as_nixl_descriptor_mut ( ) ;
143-
144- unsafe {
145- src_dl. add_desc (
146- src_desc. as_ptr ( ) as usize ,
147- src_desc. size ( ) ,
148- src_desc. device_id ( ) ,
149- ) ?;
150-
151- dst_dl. add_desc (
152- dst_desc. as_ptr ( ) as usize ,
153- dst_desc. size ( ) ,
154- dst_desc. device_id ( ) ,
155- ) ?;
156- }
157- }
108+ let src_mem_type = src
109+ . first ( )
110+ . unwrap ( )
111+ . block_data ( private:: PrivateToken )
112+ . storage_type ( )
113+ . nixl_mem_type ( ) ;
114+ let dst_mem_type = dst
115+ . first ( )
116+ . unwrap ( )
117+ . block_data ( private:: PrivateToken )
118+ . storage_type ( )
119+ . nixl_mem_type ( ) ;
120+
121+ let mut src_dl = XferDescList :: new ( src_mem_type) ?;
122+ let mut dst_dl = XferDescList :: new ( dst_mem_type) ?;
123+
124+ for ( src, dst) in src. iter ( ) . zip ( dst. iter_mut ( ) ) {
125+ append_xfer_request ( src, dst, & mut src_dl, & mut dst_dl) ?;
158126 }
159127
128+ let xfer_req =
129+ nixl_agent. create_xfer_req ( XferOp :: Write , & src_dl, & dst_dl, & nixl_agent. name ( ) , None ) ?;
130+
160131 let mut xfer_args = OptArgs :: new ( ) ?;
161132
162133 if let Some ( notify) = notify {
163134 xfer_args. set_has_notification ( true ) ?;
164135 xfer_args. set_notification_message ( notify. as_bytes ( ) ) ?;
165136 }
166137
167- let xfer_req = nixl_agent. create_xfer_req (
168- XferOp :: Write ,
169- & src_dl,
170- & dst_dl,
171- & remote_worker_id,
172- Some ( & xfer_args) ,
173- ) ?;
174-
175138 let _ = nixl_agent. post_xfer_req ( & xfer_req, Some ( & xfer_args) ) ?;
176139
177140 Ok ( Box :: new ( poll_fn ( move |_cx| {
178141 let nixl_agent = nixl_agent_arc
179142 . as_ref ( )
180143 . as_ref ( )
181144 . expect ( "NIXL agent not found" ) ;
145+
146+ // The nixl agent returns true if the transfer is still in progress.
182147 if !nixl_agent. get_xfer_status ( & xfer_req) . unwrap ( ) {
183148 Poll :: Ready ( ( ) )
184149 } else {
0 commit comments