forked from coreylowman/dfdx
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(wgpu): add to_dtype kernel (coreylowman#906)
* feat(wgpu): add to_dtype kernel * fix: add WebGPUNativeType * style: clippy fix --------- Co-authored-by: Corey Lowman <clowman1993@gmail.com>
- Loading branch information
1 parent
e04dd4f
commit 4722a99
Showing
6 changed files
with
228 additions
and
20 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
use crate::shapes::Unit; | ||
|
||
/// A primitive data type natively supported by WebGPU. | ||
/// | ||
/// See: https://www.w3.org/TR/WGSL/#types | ||
/// | ||
/// todo: support packed types | ||
pub trait WebgpuNativeType: Unit { | ||
/// Name of the data type in WGSL. | ||
const NAME: &'static str; | ||
} | ||
|
||
macro_rules! webgpu_type { | ||
($RustTy:ty) => { | ||
impl WebgpuNativeType for $RustTy { | ||
const NAME: &'static str = stringify!($RustTy); | ||
} | ||
}; | ||
($RustTy:ty, $WgpuTy:expr) => { | ||
impl WebgpuNativeType for $RustTy { | ||
const NAME: &'static str = $WgpuTy; | ||
} | ||
}; | ||
} | ||
|
||
/* | ||
see: | ||
- https://docs.rs/wgpu/latest/wgpu/struct.Features.html#associatedconstant.SHADER_F16 | ||
- https://docs.rs/wgpu/latest/wgpu/struct.Features.html#associatedconstant.SHADER_F64 | ||
- https://docs.rs/wgpu/latest/wgpu/struct.Features.html#associatedconstant.SHADER_I16 | ||
*/ | ||
#[cfg(feature = "f16")] | ||
webgpu_type!(half::f16, "f16"); | ||
webgpu_type!(f32); | ||
// todo: only enable when f64 feature is enabled | ||
#[cfg(feature = "f64")] | ||
webgpu_type!(f64); | ||
|
||
#[cfg(feature = "i16")] | ||
webgpu_type!(i16); | ||
webgpu_type!(i32); | ||
|
||
webgpu_type!(u32); | ||
webgpu_type!(bool); | ||
|
||
pub(crate) trait HasGlslType { | ||
const TYPE: &'static str; | ||
} | ||
|
||
impl HasGlslType for f32 { | ||
const TYPE: &'static str = "float"; | ||
} | ||
|
||
impl HasGlslType for f64 { | ||
const TYPE: &'static str = "double"; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
alias T = __SRC__; | ||
alias U = __DST__; | ||
|
||
@group(0) @binding(0) | ||
var<storage, read> in: array<T>; | ||
|
||
@group(0) @binding(1) | ||
var<storage, read_write> out: array<U>; | ||
|
||
@compute @workgroup_size(1, 1, 1) | ||
fn main( | ||
@builtin(global_invocation_id) global_id: vec3<u32> | ||
) { | ||
let i = global_id.x; | ||
out[i] = U(in[i]); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,102 @@ | ||
use crate::prelude::{Unit, Webgpu}; | ||
use crate::{ | ||
prelude::Storage, | ||
tensor::webgpu::{Webgpu, WebgpuNativeType}, | ||
tensor_ops::utilities::webgpu_kernels::webgpu_params, | ||
}; | ||
use num_traits::AsPrimitive; | ||
use wgpu; | ||
|
||
impl<E1: Unit, E2: Unit> super::ToDtypeKernel<E1, E2> for Webgpu { | ||
/// kernel template | ||
const KERNEL: &'static str = include_str!("./to_dtype.wgsl"); | ||
|
||
const LAYOUT_DESC: wgpu::BindGroupLayoutDescriptor = wgpu::BindGroupLayoutDescriptor { | ||
label: Some("to-dtype"), | ||
entries: &[ | ||
wgpu::BindGroupLayoutEntry { | ||
binding: 0, | ||
visibility: wgpu::ShaderStages::COMPUTE, | ||
ty: wgpu::BindingType::Buffer { | ||
ty: wgpu::BufferBindingType::Storage { read_only: true }, | ||
has_dynamic_offset: false, | ||
min_binding_size: None, | ||
}, | ||
count: None, | ||
}, | ||
wgpu::BindGroupLayoutEntry { | ||
binding: 1, | ||
visibility: wgpu::ShaderStages::COMPUTE, | ||
ty: wgpu::BindingType::Buffer { | ||
ty: wgpu::BufferBindingType::Storage { read_only: false }, | ||
has_dynamic_offset: false, | ||
min_binding_size: None, | ||
}, | ||
count: None, | ||
}, | ||
], | ||
}; | ||
|
||
impl<E1: WebgpuNativeType + AsPrimitive<E2>, E2: WebgpuNativeType> super::ToDtypeKernel<E1, E2> | ||
for Webgpu | ||
{ | ||
fn forward<S: crate::prelude::Shape>( | ||
inp: crate::prelude::Tensor<S, E1, Self>, | ||
) -> Result<crate::prelude::Tensor<S, E2, Self>, crate::prelude::Error> { | ||
todo!() | ||
let module_name = std::format!("convert_{}_to_{}", E1::NAME, E2::NAME); | ||
let label = Some(module_name.as_str()); | ||
let device = inp.device; | ||
|
||
let layout = device.dev.create_bind_group_layout(&LAYOUT_DESC); | ||
let shader_source: String = KERNEL | ||
.replace("__SRC__", E1::NAME) | ||
.replace("__DST__", E2::NAME); | ||
|
||
// TODO: support WGSL shaders in device shader cache | ||
let source = wgpu::ShaderSource::Wgsl(shader_source.into()); | ||
let shader_module = device | ||
.dev | ||
.create_shader_module(wgpu::ShaderModuleDescriptor { | ||
label: Some(shader_name), | ||
source, | ||
}); | ||
let pipeline_layout = device | ||
.dev | ||
.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor { | ||
label: label.clone(), | ||
bind_group_layouts: layouts, | ||
// todo: these are useful and we should use them if the adapter supports them | ||
push_constant_ranges: &push_constant_ranges, | ||
}); | ||
|
||
let pipeline = device | ||
.dev | ||
.create_compute_pipeline(&wgpu::ComputePipelineDescriptor { | ||
label: label.clone(), | ||
layout: Some(&pipeline_layout), | ||
module: &shader_module, | ||
entry_point: fn_name, | ||
}); | ||
|
||
let numel = inp.shape.num_elements(); | ||
let shape = inp.shape; | ||
let strides = shape.strides(); | ||
let output = unsafe { device.alloc_empty::<E2>(numel) }?; | ||
|
||
let params: wgpu::BindGroup = webgpu_params!(device, pipeline; inp.data, output); | ||
|
||
let _idx = device.submit_commands(label.clone(), |encoder| { | ||
let (x, y, z) = *work_groups; | ||
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { | ||
label: label.clone(), | ||
..Default::default() | ||
}); | ||
// TODO: should this be called before the pass, as the pass is created, or before submission? | ||
pass.set_pipeline(&pipeline); | ||
pass.set_bind_group(0, ¶ms, &[]); | ||
pass.dispatch_workgroups(numel as u32, 1, 1); | ||
}); | ||
|
||
// note: no need to sync here, buffer can remain on the gpu until to_array or to_vec gets called, | ||
// and those functions sync the device before mapping the buffer | ||
Ok(device.build_tensor(shape, strides, output)) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters