Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

General shader improvements, specifically targeting rust-gpu #2482

Merged
merged 9 commits into from
Mar 3, 2024
151 changes: 131 additions & 20 deletions vulkano-shaders/src/codegen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use crate::{
};
use heck::ToSnakeCase;
use proc_macro2::TokenStream;
use quote::{format_ident, quote};
pub use shaderc::{CompilationArtifact, IncludeType, ResolvedInclude, ShaderKind};
use shaderc::{CompileOptions, Compiler, EnvVersion, TargetEnv};
use std::{
Expand Down Expand Up @@ -262,9 +263,18 @@ pub(super) fn reflect(
#[cfg(test)]
mod tests {
use super::*;
use proc_macro2::Span;
use quote::ToTokens;
use shaderc::SpirvVersion;
use syn::{File, Item};
use vulkano::shader::reflect;

fn spv_to_words(data: &[u8]) -> Vec<u32> {
data.chunks(4)
.map(|c| u32::from_le_bytes([c[0], c[1], c[2], c[3]]))
.collect()
}

fn convert_paths(root_path: &Path, paths: &[PathBuf]) -> Vec<String> {
paths
.iter()
Expand All @@ -274,17 +284,28 @@ mod tests {

#[test]
fn spirv_parse() {
let data = include_bytes!("../tests/frag.spv");
let insts: Vec<_> = data
.chunks(4)
.map(|c| {
((c[3] as u32) << 24) | ((c[2] as u32) << 16) | ((c[1] as u32) << 8) | c[0] as u32
})
.collect();

let insts = spv_to_words(include_bytes!("../tests/frag.spv"));
Spirv::new(&insts).unwrap();
}

#[test]
fn spirv_reflect() {
let insts = spv_to_words(include_bytes!("../tests/frag.spv"));

let mut type_registry = TypeRegistry::default();
let (_shader_code, _structs) = reflect(
&MacroInput::empty(),
LitStr::new("../tests/frag.spv", Span::call_site()),
String::new(),
&insts,
Vec::new(),
&mut type_registry,
)
.expect("reflecting spv failed");

assert_eq!(_structs.to_string(), "", "No structs should be generated");
}

#[test]
fn include_resolution() {
let root_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
Expand Down Expand Up @@ -536,14 +557,8 @@ mod tests {
/// ```
#[test]
fn descriptor_calculation_with_multiple_entrypoints() {
let data = include_bytes!("../tests/multiple_entrypoints.spv");
let instructions: Vec<u32> = data
.chunks(4)
.map(|c| {
((c[3] as u32) << 24) | ((c[2] as u32) << 16) | ((c[1] as u32) << 8) | c[0] as u32
})
.collect();
let spirv = Spirv::new(&instructions).unwrap();
let insts = spv_to_words(include_bytes!("../tests/multiple_entrypoints.spv"));
let spirv = Spirv::new(&insts).unwrap();

let mut descriptors = Vec::new();
for (_, info) in reflect::entry_points(&spirv) {
Expand Down Expand Up @@ -578,8 +593,52 @@ mod tests {
}

#[test]
fn descriptor_calculation_with_multiple_functions() {
let (comp, _) = compile(
fn reflect_descriptor_calculation_with_multiple_entrypoints() {
let insts = spv_to_words(include_bytes!("../tests/multiple_entrypoints.spv"));

let mut type_registry = TypeRegistry::default();
let (_shader_code, _structs) = reflect(
&MacroInput::empty(),
LitStr::new("../tests/multiple_entrypoints.spv", Span::call_site()),
String::new(),
&insts,
Vec::new(),
&mut type_registry,
)
.expect("reflecting spv failed");

let structs = _structs.to_string();
assert_ne!(structs, "", "Has some structs");

let file: File = syn::parse2(_structs).unwrap();
let structs: Vec<_> = file
.items
.iter()
.filter_map(|item| {
if let Item::Struct(s) = item {
Some(s)
} else {
None
}
})
.collect();

let buffer = structs.iter().find(|s| s.ident == "Buffer").unwrap();
assert_eq!(
buffer.fields.to_token_stream().to_string(),
quote!({pub data: u32,}).to_string()
);

let uniform = structs.iter().find(|s| s.ident == "Uniform").unwrap();
assert_eq!(
uniform.fields.to_token_stream().to_string(),
quote!({pub data: u32,}).to_string()
);
}

fn descriptor_calculation_with_multiple_functions_shader() -> (CompilationArtifact, Vec<String>)
{
compile(
&MacroInput {
spirv_version: Some(SpirvVersion::V1_6),
vulkan_version: Some(EnvVersion::Vulkan1_3),
Expand Down Expand Up @@ -615,8 +674,13 @@ mod tests {
"#,
ShaderKind::Vertex,
)
.unwrap();
let spirv = Spirv::new(comp.as_binary()).unwrap();
.unwrap()
}

#[test]
fn descriptor_calculation_with_multiple_functions() {
let (artifact, _) = descriptor_calculation_with_multiple_functions_shader();
let spirv = Spirv::new(artifact.as_binary()).unwrap();

if let Some((_, info)) = reflect::entry_points(&spirv).next() {
let mut bindings = Vec::new();
Expand All @@ -634,4 +698,51 @@ mod tests {
}
panic!("could not find entrypoint");
}

#[test]
fn reflect_descriptor_calculation_with_multiple_functions() {
let (artifact, _) = descriptor_calculation_with_multiple_functions_shader();

let mut type_registry = TypeRegistry::default();
let (_shader_code, _structs) = reflect(
&MacroInput::empty(),
LitStr::new(
"descriptor_calculation_with_multiple_functions_shader",
Span::call_site(),
),
String::new(),
artifact.as_binary(),
Vec::new(),
&mut type_registry,
)
.expect("reflecting spv failed");

let structs = _structs.to_string();
assert_ne!(structs, "", "Has some structs");

let file: File = syn::parse2(_structs).unwrap();
let structs: Vec<_> = file
.items
.iter()
.filter_map(|item| {
if let Item::Struct(s) = item {
Some(s)
} else {
None
}
})
.collect();

let buffer = structs.iter().find(|s| s.ident == "Buffer").unwrap();
assert_eq!(
buffer.fields.to_token_stream().to_string(),
quote!({pub data: [f32; 3usize],}).to_string()
);

let uniform = structs.iter().find(|s| s.ident == "Uniform").unwrap();
assert_eq!(
uniform.fields.to_token_stream().to_string(),
quote!({pub data: f32,}).to_string()
);
}
}
67 changes: 46 additions & 21 deletions vulkano-shaders/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
//! appropriate features enabled.
//! - If the `shaders` option is used, then instead of one `load` constructor, there is one for
//! each shader. They are named based on the provided names, `load_first`, `load_second` etc.
//! - A Rust struct translated from each struct contained in the shader data. By default each
//! - A Rust struct translated from each struct contained in the shader data. By default, each
//! structure has a `Clone` and a `Copy` implementation. This behavior could be customized
//! through the `custom_derives` macro option (see below for details). Each struct also has an
//! implementation of [`BufferContents`], so that it can be read from/written to a buffer.
Expand Down Expand Up @@ -122,8 +122,8 @@
//! ## `bytes: "..."`
//!
//! Provides the path to precompiled SPIR-V bytecode, relative to your `Cargo.toml`. Cannot be used
//! in conjunction with the `src` or `path` field. This allows using shaders compiled through a
//! separate build system.
//! in conjunction with the `src` or `path` field, and may also not specify a shader `ty` type.
//! This allows using shaders compiled through a separate build system.
//!
//! ## `root_path_env: "..."`
//!
Expand All @@ -143,7 +143,7 @@
//!
//! With these options the user can compile several shaders in a single macro invocation. Each
//! entry key will be the suffix of the generated `load` function (`load_first` in this case).
//! However all other Rust structs translated from the shader source will be shared between
//! However, all other Rust structs translated from the shader source will be shared between
//! shaders. The macro checks that the source structs with the same names between different shaders
//! have the same declaration signature, and throws a compile-time error if they don't.
//!
Expand Down Expand Up @@ -172,14 +172,21 @@
//! The generated code must be supported by the device at runtime. If not, then an error will be
//! returned when calling `load`.
//!
//! ## `generate_structs: true`
//!
//! Generate rust structs that represent the structs contained in the shader. They all implement
//! [`BufferContents`], which allows then to be passed to the shader, without having to worry about
//! the layout of the struct manually. However, some use-cases, such as Rust-GPU, may not have any
//! use for such structs, and may choose to disable them.
//!
//! ## `custom_derives: [Clone, Default, PartialEq, ...]`
//!
//! Extends the list of derive macros that are added to the `derive` attribute of Rust structs that
//! represent shader structs.
//!
//! By default each generated struct has a derive for `Clone` and `Copy`. If the struct has unsized
//! members none of the derives are applied on the struct, except [`BufferContents`], which is
//! always derived.
//! By default, each generated struct derives `Clone` and `Copy`. If the struct has unsized members
//! none of the derives are applied on the struct, except [`BufferContents`], which is always
//! derived.
//!
//! ## `linalg_type: "..."`
//!
Expand Down Expand Up @@ -221,26 +228,24 @@
#![allow(clippy::needless_borrowed_reference)]
#![warn(rust_2018_idioms, rust_2021_compatibility)]

#[macro_use]
extern crate quote;
#[macro_use]
extern crate syn;

use crate::codegen::ShaderKind;
use ahash::HashMap;
use proc_macro2::{Span, TokenStream};
use quote::quote;
use shaderc::{EnvVersion, SpirvVersion};
use std::{
env, fs, mem,
path::{Path, PathBuf},
};
use structs::TypeRegistry;
use syn::{
braced, bracketed, parenthesized,
parse::{Parse, ParseStream, Result},
Error, Ident, LitBool, LitStr, Path as SynPath,
parse_macro_input, parse_quote, Error, Ident, LitBool, LitStr, Path as SynPath, Token,
};

mod codegen;
mod rust_gpu;
mod structs;

#[proc_macro]
Expand Down Expand Up @@ -286,9 +291,14 @@ fn shader_inner(mut input: MacroInput) -> Result<TokenStream> {
for (name, (shader_kind, source_kind)) in shaders {
let (code, types) = match source_kind {
SourceKind::Src(source) => {
let (artifact, includes) =
codegen::compile(&input, None, root_path, &source.value(), shader_kind)
.map_err(|err| Error::new_spanned(&source, err))?;
let (artifact, includes) = codegen::compile(
&input,
None,
root_path,
&source.value(),
shader_kind.unwrap(),
)
.map_err(|err| Error::new_spanned(&source, err))?;

let words = artifact.as_binary();

Expand All @@ -313,7 +323,7 @@ fn shader_inner(mut input: MacroInput) -> Result<TokenStream> {
Some(path.value()),
root_path,
&source_code,
shader_kind,
shader_kind.unwrap(),
)
.map_err(|err| Error::new_spanned(&path, err))?;

Expand Down Expand Up @@ -371,9 +381,10 @@ struct MacroInput {
root_path_env: Option<LitStr>,
include_directories: Vec<PathBuf>,
macro_defines: Vec<(String, String)>,
shaders: HashMap<String, (ShaderKind, SourceKind)>,
shaders: HashMap<String, (Option<ShaderKind>, SourceKind)>,
spirv_version: Option<SpirvVersion>,
vulkan_version: Option<EnvVersion>,
generate_structs: bool,
custom_derives: Vec<SynPath>,
linalg_type: LinAlgType,
dump: LitBool,
Expand All @@ -389,6 +400,7 @@ impl MacroInput {
shaders: HashMap::default(),
vulkan_version: None,
spirv_version: None,
generate_structs: true,
custom_derives: Vec::new(),
linalg_type: LinAlgType::default(),
dump: LitBool::new(false, Span::call_site()),
Expand All @@ -406,6 +418,7 @@ impl Parse for MacroInput {
let mut shaders = HashMap::default();
let mut vulkan_version = None;
let mut spirv_version = None;
let mut generate_structs = true;
let mut custom_derives = None;
let mut linalg_type = None;
let mut dump = None;
Expand Down Expand Up @@ -643,6 +656,10 @@ impl Parse for MacroInput {
),
});
}
"generate_structs" => {
let lit = input.parse::<LitBool>()?;
generate_structs = lit.value;
marc0246 marked this conversation as resolved.
Show resolved Hide resolved
}
"custom_derives" => {
let in_brackets;
bracketed!(in_brackets in input);
Expand Down Expand Up @@ -696,8 +713,8 @@ impl Parse for MacroInput {
field => bail!(
field_ident,
"expected `bytes`, `src`, `path`, `ty`, `shaders`, `define`, `include`, \
`vulkan_version`, `spirv_version`, `custom_derives`, `linalg_type` or `dump` \
as a field, found `{field}`",
`vulkan_version`, `spirv_version`, `generate_structs`, `custom_derives`, \
`linalg_type` or `dump` as a field, found `{field}`",
),
}

Expand All @@ -711,6 +728,13 @@ impl Parse for MacroInput {
}

match shaders.get("") {
// if source is bytes, the shader type should not be declared
Some((None, Some(SourceKind::Bytes(_)))) => {}
Some((_, Some(SourceKind::Bytes(_)))) => {
bail!(
r#"one may not specify a shader type when including precompiled SPIR-V binaries. Please remove the `ty:` declaration"#
);
}
Some((None, _)) => {
bail!(r#"please specify the type of the shader e.g. `ty: "vertex"`"#);
}
Expand All @@ -727,11 +751,12 @@ impl Parse for MacroInput {
shaders: shaders
.into_iter()
.map(|(key, (shader_kind, shader_source))| {
(key, (shader_kind.unwrap(), shader_source.unwrap()))
(key, (shader_kind, shader_source.unwrap()))
})
.collect(),
vulkan_version,
spirv_version,
generate_structs,
custom_derives: custom_derives.unwrap_or_else(|| {
vec![
parse_quote! { ::std::clone::Clone },
Expand Down
Loading
Loading