-
Notifications
You must be signed in to change notification settings - Fork 0
/
lib.rs
541 lines (494 loc) · 17.5 KB
/
lib.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
#![allow(clippy::cast_ptr_alignment)]
#![allow(clippy::missing_safety_doc)]
#![allow(clippy::float_cmp)]
use async_std::task::block_on;
use rayon::prelude::*;
use std::future::Future;
use std::io::Cursor;
use std::ops::Deref;
use wgpu::*;
use zerocopy::AsBytes;
pub fn addition(left: &mut [f32], right: &[f32]) {
for i in 0..left.len() {
left[i] += right[i];
}
}
pub unsafe fn addition_unchecked(left: &mut [f32], right: &[f32]) {
for i in 0..left.len() {
*left.get_unchecked_mut(i) += *right.get_unchecked(i);
}
}
pub fn addition_iterator(left: &mut [f32], right: &[f32]) {
left.iter_mut()
.zip(right.iter())
.for_each(|(left, right)| *left += *right);
}
pub fn addition_rayon(left: &mut [f32], right: &[f32]) {
left.par_iter_mut()
.zip(right.par_iter())
.for_each(|(left, right)| *left += *right);
}
pub fn create_device() -> (Device, Queue) {
let instance = Instance::new();
let adapter = block_on(instance.request_adapter(
&RequestAdapterOptions {
power_preference: PowerPreference::Default,
compatible_surface: None,
},
BackendBit::all(),
))
.unwrap();
block_on(adapter.request_device(
&DeviceDescriptor {
extensions: Extensions {
anisotropic_filtering: false,
},
limits: Limits::default(),
},
None,
))
.unwrap()
}
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub enum UploadStyle {
Mapping,
Staging,
}
bitflags::bitflags! {
pub struct AutomatedBufferUsage: u8 {
const READ = 0b01;
const WRITE = 0b10;
const ALL = Self::READ.bits | Self::WRITE.bits;
}
}
impl AutomatedBufferUsage {
pub fn into_buffer_usage(self, style: UploadStyle) -> BufferUsage {
let mut usage = BufferUsage::empty();
if self.contains(Self::READ) {
match style {
UploadStyle::Mapping => usage.insert(BufferUsage::MAP_READ),
UploadStyle::Staging => usage.insert(BufferUsage::COPY_SRC),
}
}
if self.contains(Self::WRITE) {
match style {
UploadStyle::Mapping => usage.insert(BufferUsage::MAP_WRITE),
UploadStyle::Staging => usage.insert(BufferUsage::COPY_DST),
}
}
usage
}
}
type BufferReadResult = Result<BufferReadMapping, BufferAsyncError>;
type BufferWriteResult = Result<BufferWriteMapping, BufferAsyncError>;
/// Represents either a mapping future (mapped style) or a function to create
/// a mapping future (buffered style).
enum ReadMapFn<MapFut, BufFunc> {
Mapped(MapFut),
Buffered(BufFunc),
}
impl<MapFut, BufFunc> ReadMapFn<MapFut, BufFunc>
where
MapFut: Future<Output = BufferReadResult>,
BufFunc: FnOnce() -> MapFut,
{
/// Creates the mapping future
fn prepare_future(self) -> MapFut {
match self {
ReadMapFn::Buffered(func) => func(),
ReadMapFn::Mapped(mapped) => mapped,
}
}
}
/// A buffer which automatically uses either staging buffers or direct mapping to read/write to its
/// internal buffer based on the provided [`UploadStyle`]
pub struct AutomatedBuffer {
inner: Buffer,
style: UploadStyle,
usage: AutomatedBufferUsage,
size: BufferAddress,
}
impl AutomatedBuffer {
/// Creates a new AutomatedBuffer with given settings. All operations directly
/// done on the automated buffer according to `usage` will be added to the
/// internal buffer's usage flags.
pub fn new(
device: &Device,
size: BufferAddress,
usage: AutomatedBufferUsage,
other_usages: BufferUsage,
label: Option<&str>,
style: UploadStyle,
) -> Self {
let inner = device.create_buffer(&BufferDescriptor {
size,
usage: usage.into_buffer_usage(style) | other_usages,
label,
});
Self {
inner,
style,
usage,
size,
}
}
/// Each of the two futures do different things based on the mapping style.
///
/// Mapping:
/// - 1st: No-op
/// - 2nd: Resolves the mapping
///
/// Buffered:
/// - 1st: Starts the staging buffer mapping
/// - 2nd: Resolves the mapping
///
/// This is done with assistance of a generic helper type. The data for the first
/// await is held by [`ReadMapFn`].
fn map_read<MapFut, BufFunc>(
mapping: ReadMapFn<MapFut, BufFunc>,
) -> impl Future<Output = impl Future<Output = BufferReadMapping>>
where
MapFut: Future<Output = BufferReadResult>,
BufFunc: FnOnce() -> MapFut,
{
async move {
// maps the staging buffer or passes forward the mapping of the real buffer
let future = mapping.prepare_future();
async move {
// actually resolves the mapping
future.await.unwrap()
}
}
}
/// Reads the underlying buffer using the proper read style.
///
/// This function is unusual because it returns a future which itself returns a future. We shall
/// refer to these as the First and the Second future.
///
/// This function is safe, but has the following constraints so as to not cause a panic in wgpu:
/// - Buffer usage must contain [`READ`](AutomatedBufferUsage::READ).
/// - The buffer must not be in use by any other command buffer between calling this function
/// and calling await on the Second future.
/// - The First future must be awaited _after_ `encoder`'s command buffer is submitted to the queue and _before_ device.poll is called.
/// - The Second future mut be awaited _after_ device.poll is called and the mapping is resolved.
///
/// Example:
///
/// ```ignore
/// let buffer = AutomatedBuffer::new(..);
///
/// let map_read_buf1 = buffer.read_from_buffer(&device, &mut encoder);
/// queue.submit(&[encoder.submit()]); // must happen before first await
///
/// let map_read_buf2 = map_read_buf1.await;
/// device.poll(...); // must happen before second await and after first
///
/// let mapping = map_read_buf2.await;
/// // use mapping
/// ```
pub fn read_from_buffer(
&mut self,
device: &Device,
encoder: &mut CommandEncoder,
) -> impl Future<Output = impl Future<Output = BufferReadMapping>> {
assert!(
self.usage.contains(AutomatedBufferUsage::READ),
"Must have usage READ to read from buffer. Current usage {:?}",
self.usage
);
match self.style {
UploadStyle::Mapping => {
Self::map_read(ReadMapFn::Mapped(self.inner.map_read(0, self.size)))
}
UploadStyle::Staging => {
let staging = device.create_buffer(&BufferDescriptor {
size: self.size,
usage: BufferUsage::MAP_READ | BufferUsage::COPY_DST,
label: Some("read dst buffer"),
});
encoder.copy_buffer_to_buffer(&self.inner, 0, &staging, 0, self.size);
let size = self.size;
Self::map_read(ReadMapFn::Buffered(move || staging.map_read(0, size)))
}
}
}
/// When the returned future is awaited, writes the data to the buffer if it is a mapped buffer.
/// No-op for the use of a staging buffer.
fn map_write<'a>(
data: &'a [u8],
mapping: Option<impl Future<Output = BufferWriteResult> + 'a>,
) -> impl Future<Output = ()> + 'a {
async move {
if let Some(mapping) = mapping {
mapping.await.unwrap().as_slice().copy_from_slice(data);
}
}
}
/// Writes to the underlying buffer using the proper write style.
///
/// This function is safe, but has the following constraints so as to not cause a panic in wgpu:
/// - Buffer usage must contain [`WRITE`](AutomatedBufferUsage::WRITE)
/// - The returned future must be awaited _after_ calling device.poll() to resolve it.
/// - The command buffer created by `encoder` must **not** be submitted to a queue before this future is awaited.
///
/// Example:
///
/// ```ignore
/// let buffer = AutomatedBuffer::new(..);
///
/// let map_write = buffer.write_to_buffer(&device, &mut encoder, &data);
/// device.poll(...); // must happen before await
///
/// let mapping = map_write.await; // Calling await will write to the mapping
///
/// queue.submit(&[encoder.submit()]); // must happen after await
/// ```
pub fn write_to_buffer<'a>(
&mut self,
device: &Device,
encoder: &mut CommandEncoder,
data: &'a [u8],
) -> impl Future<Output = ()> + 'a {
assert!(
self.usage.contains(AutomatedBufferUsage::WRITE),
"Must have usage WRITE to write to buffer. Current usage {:?}",
self.usage
);
match self.style {
UploadStyle::Mapping => Self::map_write(
data,
Some(self.inner.map_write(0, data.len() as BufferAddress)),
),
UploadStyle::Staging => {
let staging = device.create_buffer_with_data(data, BufferUsage::COPY_SRC);
encoder.copy_buffer_to_buffer(
&staging,
0,
&self.inner,
0,
data.len() as BufferAddress,
);
Self::map_write(data, None)
}
}
}
}
impl Deref for AutomatedBuffer {
type Target = Buffer;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
pub struct GPUAddition<'a> {
device: &'a Device,
queue: &'a Queue,
pipeline: ComputePipeline,
left_buffer: AutomatedBuffer,
right_buffer: AutomatedBuffer,
bind_group: BindGroup,
commands: Vec<CommandBuffer>,
}
impl<'a> GPUAddition<'a> {
pub async fn new(
device: &'a Device,
queue: &'a Queue,
style: UploadStyle,
size: usize,
) -> GPUAddition<'a> {
let size_bytes = size as BufferAddress * 4;
let shader_source = include_bytes!("addition.spv");
let shader_module =
device.create_shader_module(&read_spirv(Cursor::new(&shader_source[..])).unwrap());
let bind_group_layout = device.create_bind_group_layout(&BindGroupLayoutDescriptor {
bindings: &[
BindGroupLayoutEntry {
binding: 0,
visibility: ShaderStage::COMPUTE,
ty: BindingType::StorageBuffer {
dynamic: false,
readonly: false,
},
},
BindGroupLayoutEntry {
binding: 1,
visibility: ShaderStage::COMPUTE,
ty: BindingType::StorageBuffer {
dynamic: false,
readonly: true,
},
},
BindGroupLayoutEntry {
binding: 2,
visibility: ShaderStage::COMPUTE,
ty: BindingType::UniformBuffer { dynamic: false },
},
],
label: Some("bind group layout"),
});
let pipeline_layout = device.create_pipeline_layout(&PipelineLayoutDescriptor {
bind_group_layouts: &[&bind_group_layout],
});
let pipeline = device.create_compute_pipeline(&ComputePipelineDescriptor {
layout: &pipeline_layout,
compute_stage: ProgrammableStageDescriptor {
module: &shader_module,
entry_point: "main",
},
});
let left_buffer = AutomatedBuffer::new(
&device,
size_bytes,
AutomatedBufferUsage::ALL,
BufferUsage::STORAGE,
Some("left buffer"),
style,
);
let right_buffer = AutomatedBuffer::new(
&device,
size_bytes,
AutomatedBufferUsage::WRITE,
BufferUsage::STORAGE,
Some("right buffer"),
style,
);
let uniform_buffer =
device.create_buffer_with_data((size as u32).as_bytes(), BufferUsage::UNIFORM);
let bind_group = device.create_bind_group(&BindGroupDescriptor {
layout: &bind_group_layout,
bindings: &[
Binding {
binding: 0,
resource: BindingResource::Buffer(left_buffer.slice(..)),
},
Binding {
binding: 1,
resource: BindingResource::Buffer(right_buffer.slice(..)),
},
Binding {
binding: 2,
resource: BindingResource::Buffer(uniform_buffer.slice(..)),
},
],
label: Some("bind group"),
});
Self {
device,
queue,
pipeline,
left_buffer,
right_buffer,
bind_group,
commands: Vec::default(),
}
}
pub async fn set_buffers(&mut self, left: &[f32], right: &[f32]) {
let mut encoder = self
.device
.create_command_encoder(&CommandEncoderDescriptor {
label: Some("set buffers"),
});
let map_left =
self.left_buffer
.write_to_buffer(&self.device, &mut encoder, left.as_bytes());
let map_right =
self.right_buffer
.write_to_buffer(&self.device, &mut encoder, right.as_bytes());
self.device.poll(Maintain::Wait);
map_left.await;
map_right.await;
// Ensure copy takes place during run
self.commands.push(encoder.finish());
}
pub async fn run(&mut self, size: usize) -> () {
let mut encoder = self
.device
.create_command_encoder(&CommandEncoderDescriptor {
label: Some("compute encoder"),
});
let mut cpass = encoder.begin_compute_pass();
cpass.set_pipeline(&self.pipeline);
cpass.set_bind_group(0, &self.bind_group, &[]);
cpass.dispatch((size as u32 + 63) / 64, 1, 1);
drop(cpass);
self.commands.push(encoder.finish());
self.queue.submit(self.commands.drain(..));
let mut encoder = self
.device
.create_command_encoder(&CommandEncoderDescriptor {
label: Some("read mapping encoder"),
});
let map_left = self
.left_buffer
.read_from_buffer(&self.device, &mut encoder);
self.queue.submit(std::iter::once(encoder.finish()));
// remove the following lines, the leak goes away.
// calling await on map_left causes it
let map_left = map_left.await;
self.device.poll(Maintain::Wait);
map_left.await;
}
}
#[cfg(notest)]
mod test {
use crate::{create_device, GPUAddition, UploadStyle};
use async_std::task::block_on;
use itertools::{zip, Itertools};
macro_rules! addition_test {
($name:ident, $function:expr) => {
#[test]
fn $name() {
let mut left = (0..10000).map(|v| v as f32).collect_vec();
let right = (0..10000).map(|v| (v + 1) as f32).collect_vec();
let result = (0..10000).map(|v| (v + v + 1) as f32).collect_vec();
$function(&mut left, &right);
for (i, (left, result)) in zip(left, result).enumerate() {
assert_eq!(left, result, "Index {} failed", i);
}
}
};
}
addition_test!(addition, |left: &mut [f32], right: &[f32]| {
crate::addition(left, right);
});
addition_test!(addition_unchecked, |left: &mut [f32], right: &[f32]| {
unsafe { crate::addition_unchecked(left, right) };
});
addition_test!(addition_iterator, |left: &mut [f32], right: &[f32]| {
crate::addition_iterator(left, right);
});
addition_test!(addition_rayon, |left: &mut [f32], right: &[f32]| {
crate::addition_rayon(left, right);
});
addition_test!(addition_gpu_mapping, |left: &mut [f32], right: &[f32]| {
let (device, queue) = create_device();
let mut gpu = block_on(GPUAddition::new(
&device,
&queue,
UploadStyle::Mapping,
left.len(),
));
block_on(gpu.set_buffers(&left, &right));
let result_mapping = block_on(gpu.run(left.len()));
let bytes = result_mapping.as_slice();
let floats =
unsafe { std::slice::from_raw_parts(bytes.as_ptr() as *const f32, bytes.len() / 4) };
assert_eq!(left.len(), floats.len());
left.copy_from_slice(floats);
});
addition_test!(addition_gpu_staging, |left: &mut [f32], right: &[f32]| {
let (device, queue) = create_device();
let mut gpu = block_on(GPUAddition::new(
&device,
&queue,
UploadStyle::Staging,
left.len(),
));
block_on(gpu.set_buffers(&left, &right));
let result_mapping = block_on(gpu.run(left.len()));
let bytes = result_mapping.as_slice();
let floats =
unsafe { std::slice::from_raw_parts(bytes.as_ptr() as *const f32, bytes.len() / 4) };
assert_eq!(left.len(), floats.len());
left.copy_from_slice(floats);
});
}