Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

vello_shaders crate for AOT compilation and lightweight integration with external renderer projects #265

Merged
merged 12 commits into from
Mar 29, 2023
Merged
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
resolver = "2"

members = [
"crates/shaders",

"integrations/vello_svg",

"examples/headless",
Expand Down
19 changes: 19 additions & 0 deletions crates/shaders/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
[package]
name = "vello_shaders"
version = "0.1.0"
edition = "2021"

[features]
default = ["compile", "wgsl", "msl"]
compile = ["naga", "thiserror"]
wgsl = []
msl = []

[dependencies]
naga = { git = "https://github.com/gfx-rs/naga", rev = "53d62b9", features = ["wgsl-in", "msl-out", "validate"], optional = true }
thiserror = { version = "1.0.40", optional = true }

[build-dependencies]
naga = { git = "https://github.com/gfx-rs/naga", rev = "53d62b9", features = ["wgsl-in", "msl-out", "validate"] }
thiserror = "1.0.40"

7 changes: 7 additions & 0 deletions crates/shaders/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
The `vello_shaders` crate provides a utility library to integrate the Vello shader modules into any
renderer project. The crate provides the necessary metadata to construct the individual compute
pipelines on any GPU API while leaving the responsibility of all API interactions (such as
resource management and command encoding) up to the client.

The shaders can be pre-compiled to any target shading language at build time based on feature flags.
Currently only WGSL and Metal Shading Language are supported.
107 changes: 107 additions & 0 deletions crates/shaders/build.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
// Copyright 2023 The Vello authors
// SPDX-License-Identifier: Apache-2.0 OR MIT

#[path = "src/compile/mod.rs"]
mod compile;
#[path = "src/types.rs"]
mod types;

use std::env;
use std::fmt::Write;
use std::path::Path;

use compile::ShaderInfo;

fn main() {
let out_dir = env::var_os("OUT_DIR").unwrap();
let dest_path = Path::new(&out_dir).join("shaders.rs");
let mut shaders = compile::ShaderInfo::from_dir("../../shader");
// Drop the HashMap and sort by name so that we get deterministic order.
let mut shaders = shaders.drain().collect::<Vec<_>>();
shaders.sort_by(|x, y| x.0.cmp(&y.0));
let mut buf = String::default();
write_types(&mut buf, &shaders).unwrap();
if cfg!(feature = "wgsl") {
write_shaders(&mut buf, "wgsl", &shaders, |info| {
info.source.as_bytes().to_owned()
})
.unwrap();
}
if cfg!(feature = "msl") {
write_shaders(&mut buf, "msl", &shaders, |info| {
compile::msl::translate(info).unwrap().as_bytes().to_owned()
})
.unwrap();
}
std::fs::write(&dest_path, &buf).unwrap();
println!("cargo:rerun-if-changed=../shader");
}

fn write_types(buf: &mut String, shaders: &[(String, ShaderInfo)]) -> Result<(), std::fmt::Error> {
writeln!(buf, "pub struct Shaders<'a> {{")?;
for (name, _) in shaders {
writeln!(buf, " pub {name}: ComputeShader<'a>,")?;
}
writeln!(buf, "}}")?;
writeln!(buf, "pub struct Pipelines<T> {{")?;
for (name, _) in shaders {
writeln!(buf, " pub {name}: T,")?;
}
writeln!(buf, "}}")?;
writeln!(buf, "impl<T> Pipelines<T> {{")?;
writeln!(buf, " pub fn from_shaders<H: PipelineHost<ComputePipeline = T>>(shaders: &Shaders, device: &H::Device, host: &mut H) -> Result<Self, H::Error> {{")?;
writeln!(buf, " Ok(Self {{")?;
for (name, _) in shaders {
writeln!(
buf,
" {name}: host.new_compute_pipeline(device, &shaders.{name})?,"
)?;
}
writeln!(buf, " }})")?;
writeln!(buf, " }}")?;
writeln!(buf, "}}")?;
Ok(())
}

fn write_shaders(
buf: &mut String,
mod_name: &str,
shaders: &[(String, ShaderInfo)],
translate: impl Fn(&ShaderInfo) -> Vec<u8>,
) -> Result<(), std::fmt::Error> {
writeln!(buf, "pub mod {mod_name} {{")?;
writeln!(buf, " use super::*;")?;
writeln!(buf, " use BindType::*;")?;
writeln!(buf, " pub const SHADERS: Shaders<'static> = Shaders {{")?;
for (name, info) in shaders {
let bind_tys = info
.bindings
.iter()
.map(|binding| binding.ty)
.collect::<Vec<_>>();
let wg_bufs = &info.workgroup_buffers;
let source = translate(info);
writeln!(buf, " {name}: ComputeShader {{")?;
writeln!(buf, " name: Cow::Borrowed({:?}),", name)?;
writeln!(
buf,
" code: Cow::Borrowed(&{:?}),",
source.as_slice()
)?;
writeln!(
buf,
" workgroup_size: {:?},",
info.workgroup_size
)?;
writeln!(buf, " bindings: Cow::Borrowed(&{:?}),", bind_tys)?;
writeln!(
buf,
" workgroup_buffers: Cow::Borrowed(&{:?}),",
wg_bufs
)?;
writeln!(buf, " }},")?;
}
writeln!(buf, " }};")?;
writeln!(buf, "}}")?;
Ok(())
}
197 changes: 197 additions & 0 deletions crates/shaders/src/compile/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
// Copyright 2023 The Vello authors
// SPDX-License-Identifier: Apache-2.0 OR MIT

use {
naga::{
front::wgsl,
valid::{Capabilities, ModuleInfo, ValidationError, ValidationFlags},
AddressSpace, ArraySize, ConstantInner, ImageClass, Module, ScalarValue, StorageAccess,
WithSpan,
},
std::{
collections::{HashMap, HashSet},
path::Path,
},
thiserror::Error,
};

pub mod permutations;
pub mod preprocess;

pub mod msl;

use crate::types::{BindType, BindingInfo, WorkgroupBufferInfo};

#[derive(Error, Debug)]
pub enum Error {
#[error("failed to parse shader: {0}")]
Parse(#[from] wgsl::ParseError),

#[error("failed to validate shader: {0}")]
Validate(#[from] WithSpan<ValidationError>),

#[error("missing entry point function")]
EntryPointNotFound,
}

#[derive(Debug)]
pub struct ShaderInfo {
pub source: String,
pub module: Module,
pub module_info: ModuleInfo,
pub workgroup_size: [u32; 3],
pub bindings: Vec<BindingInfo>,
pub workgroup_buffers: Vec<WorkgroupBufferInfo>,
}

impl ShaderInfo {
pub fn new(source: String, entry_point: &str) -> Result<ShaderInfo, Error> {
let module = wgsl::parse_str(&source)?;
let module_info = naga::valid::Validator::new(
ValidationFlags::all() & !ValidationFlags::CONTROL_FLOW_UNIFORMITY,
Capabilities::all(),
)
.validate(&module)?;
let (entry_index, entry) = module
.entry_points
.iter()
.enumerate()
.find(|(_, entry)| entry.name.as_str() == entry_point)
.ok_or(Error::EntryPointNotFound)?;
let mut bindings = vec![];
let mut workgroup_buffers = vec![];
let mut wg_buffer_idx = 0;
let entry_info = module_info.get_entry_point(entry_index);
for (var_handle, var) in module.global_variables.iter() {
if entry_info[var_handle].is_empty() {
continue;
}
let binding_ty = match module.types[var.ty].inner {
naga::TypeInner::BindingArray { base, .. } => &module.types[base].inner,
ref ty => ty,
};
let Some(binding) = &var.binding else {
if var.space == AddressSpace::WorkGroup {
let index = wg_buffer_idx;
wg_buffer_idx += 1;
let size_in_bytes = match binding_ty {
naga::TypeInner::Array {
size: ArraySize::Constant(const_handle),
stride,
..
} => {
let size: u32 = match module.constants[*const_handle].inner {
ConstantInner::Scalar { value, width: _ } => match value {
ScalarValue::Uint(value) => value.try_into().unwrap(),
ScalarValue::Sint(value) => value.try_into().unwrap(),
_ => continue,
},
ConstantInner::Composite { .. } => continue,
};
size * stride
},
naga::TypeInner::Struct { span, .. } => *span,
naga::TypeInner::Scalar { width, ..} => *width as u32,
naga::TypeInner::Vector { width, ..} => *width as u32,
naga::TypeInner::Matrix { width, ..} => *width as u32,
naga::TypeInner::Atomic { width, ..} => *width as u32,
_ => {
// Not a valid workgroup variable type. At least not one that is used
// in our shaders.
continue;
}
};
workgroup_buffers.push(WorkgroupBufferInfo {
size_in_bytes,
index,
});
}
continue;
};
let mut resource = BindingInfo {
name: var.name.clone(),
location: (binding.group, binding.binding),
ty: BindType::Buffer,
};
if let naga::TypeInner::Image { class, .. } = &binding_ty {
resource.ty = BindType::ImageRead;
if let ImageClass::Storage { access, .. } = class {
if access.contains(StorageAccess::STORE) {
resource.ty = BindType::Image;
}
}
} else {
resource.ty = BindType::BufReadOnly;
match var.space {
AddressSpace::Storage { access } => {
if access.contains(StorageAccess::STORE) {
resource.ty = BindType::Buffer;
}
}
AddressSpace::Uniform => {
resource.ty = BindType::Uniform;
}
_ => {}
}
}
bindings.push(resource);
}
bindings.sort_by_key(|res| res.location);
let workgroup_size = entry.workgroup_size;
Ok(ShaderInfo {
source,
module,
module_info,
workgroup_size,
bindings,
workgroup_buffers,
})
}

pub fn from_dir(shader_dir: impl AsRef<Path>) -> HashMap<String, Self> {
use std::fs;
let shader_dir = shader_dir.as_ref();
let permutation_map = if let Ok(permutations_source) =
std::fs::read_to_string(shader_dir.join("permutations"))
{
permutations::parse(&permutations_source)
} else {
Default::default()
};
println!("{:?}", permutation_map);
let imports = preprocess::get_imports(shader_dir);
let mut info = HashMap::default();
let mut defines = HashSet::default();
defines.insert("full".to_string());
for entry in shader_dir
.read_dir()
.expect("Can read shader import directory")
{
let entry = entry.expect("Can continue reading shader import directory");
if entry.file_type().unwrap().is_file() {
let file_name = entry.file_name();
if let Some(name) = file_name.to_str() {
let suffix = ".wgsl";
if let Some(shader_name) = name.strip_suffix(suffix) {
let contents = fs::read_to_string(shader_dir.join(&file_name))
.expect("Could read shader {shader_name} contents");
if let Some(permutations) = permutation_map.get(shader_name) {
for permutation in permutations {
let mut defines = defines.clone();
defines.extend(permutation.defines.iter().cloned());
let source = preprocess::preprocess(&contents, &defines, &imports);
let shader_info = Self::new(source.clone(), "main").unwrap();
info.insert(permutation.name.clone(), shader_info);
}
} else {
let source = preprocess::preprocess(&contents, &defines, &imports);
let shader_info = Self::new(source.clone(), "main").unwrap();
info.insert(shader_name.to_string(), shader_info);
}
}
}
}
}
info
}
}
56 changes: 56 additions & 0 deletions crates/shaders/src/compile/msl.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// Copyright 2023 The Vello authors
// SPDX-License-Identifier: Apache-2.0 OR MIT

use naga::back::msl;

use super::{BindType, ShaderInfo};

pub fn translate(shader: &ShaderInfo) -> Result<String, msl::Error> {
let mut map = msl::EntryPointResourceMap::default();
let mut buffer_index = 0u8;
let mut image_index = 0u8;
let mut binding_map = msl::BindingMap::default();
for resource in &shader.bindings {
let binding = naga::ResourceBinding {
group: resource.location.0,
binding: resource.location.1,
};
let mut target = msl::BindTarget::default();
match resource.ty {
BindType::Buffer | BindType::BufReadOnly | BindType::Uniform => {
target.buffer = Some(buffer_index);
buffer_index += 1;
}
BindType::Image | BindType::ImageRead => {
target.texture = Some(image_index);
image_index += 1;
}
}
target.mutable = resource.ty.is_mutable();
binding_map.insert(binding, target);
}
map.insert(
"main".to_string(),
msl::EntryPointResources {
resources: binding_map,
push_constant_buffer: None,
sizes_buffer: Some(30),
},
);
let options = msl::Options {
lang_version: (2, 0),
per_entry_point_map: map,
inline_samplers: vec![],
spirv_cross_compatibility: false,
fake_missing_bindings: false,
bounds_check_policies: naga::proc::BoundsCheckPolicies::default(),
zero_initialize_workgroup_memory: false,
};
let (source, _) = msl::write_string(
&shader.module,
&shader.module_info,
&options,
&msl::PipelineOptions::default(),
)?;
Ok(source)
}
Loading