Skip to content

Commit

Permalink
ValidationError-ify ShaderModule (vulkano-rs#2268)
Browse files Browse the repository at this point in the history
* ValidationError-ify `ShaderModule`

* Typo

* Remove leftover comments

* slice, not Iterator

* Update vulkano/src/shader/mod.rs

Co-authored-by: marc0246 <40955683+marc0246@users.noreply.github.com>

* Update vulkano/src/shader/mod.rs

Co-authored-by: marc0246 <40955683+marc0246@users.noreply.github.com>

---------

Co-authored-by: marc0246 <40955683+marc0246@users.noreply.github.com>
  • Loading branch information
2 people authored and hakolao committed Feb 20, 2024
1 parent fca970a commit 2b08260
Show file tree
Hide file tree
Showing 10 changed files with 637 additions and 517 deletions.
53 changes: 39 additions & 14 deletions examples/src/bin/runtime-shader/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
//
// Vulkano uses shaderc to build your shaders internally.

use std::{fs::File, io::Read, sync::Arc};
use std::{fs::File, io::Read, path::Path, sync::Arc};
use vulkano::{
buffer::{Buffer, BufferContents, BufferCreateInfo, BufferUsage},
command_buffer::{
Expand Down Expand Up @@ -49,7 +49,7 @@ use vulkano::{
GraphicsPipeline, PipelineLayout, PipelineShaderStageCreateInfo,
},
render_pass::{Framebuffer, FramebufferCreateInfo, RenderPass, Subpass},
shader::ShaderModule,
shader::{ShaderModule, ShaderModuleCreateInfo},
swapchain::{
acquire_next_image, AcquireError, Surface, Swapchain, SwapchainCreateInfo,
SwapchainPresentInfo,
Expand Down Expand Up @@ -178,26 +178,22 @@ fn main() {

let graphics_pipeline = {
let vs = {
let mut f = File::open("src/bin/runtime-shader/vert.spv").expect(
"can't find file `src/bin/runtime-shader/vert.spv`, this example needs to be run from \
the root of the example crate",
);
let mut v = vec![];
f.read_to_end(&mut v).unwrap();
let code = read_spirv_words_from_file("src/bin/runtime-shader/vert.spv");

// Create a ShaderModule on a device the same Shader::load does it.
// NOTE: You will have to verify correctness of the data by yourself!
let module = unsafe { ShaderModule::from_bytes(device.clone(), &v).unwrap() };
let module = unsafe {
ShaderModule::new(device.clone(), ShaderModuleCreateInfo::new(&code)).unwrap()
};
module.entry_point("main").unwrap()
};

let fs = {
let mut f = File::open("src/bin/runtime-shader/frag.spv")
.expect("can't find file `src/bin/runtime-shader/frag.spv`");
let mut v = vec![];
f.read_to_end(&mut v).unwrap();
let code = read_spirv_words_from_file("src/bin/runtime-shader/frag.spv");

let module = unsafe { ShaderModule::from_bytes(device.clone(), &v).unwrap() };
let module = unsafe {
ShaderModule::new(device.clone(), ShaderModuleCreateInfo::new(&code)).unwrap()
};
module.entry_point("main").unwrap()
};

Expand Down Expand Up @@ -427,3 +423,32 @@ fn window_size_dependent_setup(
})
.collect::<Vec<_>>()
}

fn read_spirv_words_from_file(path: impl AsRef<Path>) -> Vec<u32> {
// Read the file.
let path = path.as_ref();
let mut bytes = vec![];
let mut file = File::open(path).unwrap_or_else(|err| {
panic!(
"can't open file `{}`: {}.\n\
Note: this example needs to be run from the root of the example crate",
path.display(),
err,
)
});
file.read_to_end(&mut bytes).unwrap();

// Convert the bytes to words.
// SPIR-V is defined to be always little-endian, so this may need an endianness conversion.
assert!(
bytes.len() % 4 == 0,
"file `{}` does not contain a whole number of SPIR-V words",
path.display(),
);

// TODO: Use `slice::array_chunks` once it's stable.
bytes
.chunks_exact(4)
.map(|chunk| u32::from_le_bytes(chunk.try_into().unwrap()))
.collect()
}
8 changes: 4 additions & 4 deletions vulkano-shaders/src/codegen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -275,20 +275,20 @@ pub(super) fn reflect(
device: ::std::sync::Arc<::vulkano::device::Device>,
) -> ::std::result::Result<
::std::sync::Arc<::vulkano::shader::ShaderModule>,
::vulkano::shader::ShaderModuleCreationError,
::vulkano::Validated<::vulkano::VulkanError>,
> {
let _bytes = ( #( #include_bytes ),* );

static WORDS: &[u32] = &[ #( #words ),* ];

unsafe {
::vulkano::shader::ShaderModule::from_words_with_data(
::vulkano::shader::ShaderModule::new_with_data(
device,
WORDS,
::vulkano::shader::ShaderModuleCreateInfo::new(&WORDS),
[ #( #entry_points ),* ],
#spirv_version,
[ #( #spirv_capabilities ),* ],
[ #( #spirv_extensions ),* ],
[ #( #entry_points ),* ],
)
}
}
Expand Down
15 changes: 7 additions & 8 deletions vulkano-shaders/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@
//!
//! The macro generates the following items of interest:
//!
//! - The `load` constructor. This function takes an `Arc<Device>`, calls
//! [`ShaderModule::from_words_with_data`] with the passed-in device and the shader data provided
//! via the macro, and returns `Result<Arc<ShaderModule>, ShaderModuleCreationError>`.
//! Before doing so, it loops through every capability instruction in the shader data,
//! - The `load` constructor. This function takes an `Arc<Device>`, constructs a
//! [`ShaderModule`] with the passed-in device and the shader data provided
//! via the macro, and returns `Result<Arc<ShaderModule>, Validated<VulkanError>>`.
//! Before doing so, it checks every capability instruction in the shader data,
//! verifying that the passed-in `Device` has the appropriate features enabled.
//! - If the `shaders` option is used, then instead of one `load` constructor, there is one for
//! each shader. They are named based on the provided names, `load_first`, `load_second` etc.
Expand All @@ -50,8 +50,7 @@
//! ```
//! # fn main() {}
//! # use std::sync::Arc;
//! # use vulkano::shader::{ShaderModuleCreationError, ShaderModule};
//! # use vulkano::device::Device;
//! # use vulkano::{device::Device, shader::ShaderModule, Validated, VulkanError};
//! #
//! # mod vs {
//! # vulkano_shaders::shader!{
Expand All @@ -75,7 +74,7 @@
//! }
//!
//! impl Shaders {
//! pub fn load(device: Arc<Device>) -> Result<Self, ShaderModuleCreationError> {
//! pub fn load(device: Arc<Device>) -> Result<Self, Validated<VulkanError>> {
//! Ok(Self {
//! vs: vs::load(device)?,
//! })
Expand Down Expand Up @@ -208,7 +207,7 @@
//!
//! [`cargo-env-vars`]: https://doc.rust-lang.org/cargo/reference/environment-variables.html
//! [cargo-expand]: https://github.com/dtolnay/cargo-expand
//! [`ShaderModule::from_words_with_data`]: vulkano::shader::ShaderModule::from_words_with_data
//! [`ShaderModule`]: vulkano::shader::ShaderModule
//! [pipeline]: vulkano::pipeline
//! [`set_target_env`]: shaderc::CompileOptions::set_target_env
//! [`set_target_spirv`]: shaderc::CompileOptions::set_target_spirv
Expand Down
37 changes: 6 additions & 31 deletions vulkano/autogen/extensions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@
// notice may not be copied, modified, or distributed except
// according to those terms.

use super::{write_file, IndexMap, VkRegistryData};
use super::{write_file, IndexMap, RequiresOneOf, VkRegistryData};
use heck::ToSnakeCase;
use once_cell::sync::Lazy;
use proc_macro2::{Ident, Literal, TokenStream};
use quote::{format_ident, quote};
use regex::Regex;
use std::{cmp::min, fmt::Write as _, ops::BitOrAssign};
use std::fmt::Write as _;
use vk_parse::Extension;

// This is not included in vk.xml, so it's added here manually
Expand Down Expand Up @@ -49,35 +49,6 @@ struct ExtensionsMember {
status: Option<ExtensionStatus>,
}

#[derive(Clone, Debug, Default, PartialEq, Eq)]
pub struct RequiresOneOf {
pub api_version: Option<(u32, u32)>,
pub device_extensions: Vec<String>,
pub instance_extensions: Vec<String>,
}

impl BitOrAssign<&Self> for RequiresOneOf {
fn bitor_assign(&mut self, rhs: &Self) {
self.api_version = match (self.api_version, rhs.api_version) {
(None, None) => None,
(None, Some(x)) | (Some(x), None) => Some(x),
(Some(lhs), Some(rhs)) => Some(min(lhs, rhs)),
};

for rhs_ext in &rhs.device_extensions {
if !self.device_extensions.contains(rhs_ext) {
self.device_extensions.push(rhs_ext.to_owned());
}
}

for rhs_ext in &rhs.instance_extensions {
if !self.instance_extensions.contains(rhs_ext) {
self.instance_extensions.push(rhs_ext.to_owned());
}
}
}
}

#[derive(Clone, Debug)]
enum ExtensionStatus {
PromotedTo(Requires),
Expand Down Expand Up @@ -129,6 +100,7 @@ fn device_extensions_output(members: &[ExtensionsMember]) -> TokenStream {
api_version,
device_extensions,
instance_extensions,
features: _,
}| {
(device_extensions.is_empty()
&& (api_version.is_some() || !instance_extensions.is_empty()))
Expand Down Expand Up @@ -204,6 +176,7 @@ fn device_extensions_output(members: &[ExtensionsMember]) -> TokenStream {
api_version,
device_extensions,
instance_extensions: _,
features: _,
}| {
(!device_extensions.is_empty()).then(|| {
let condition_items = api_version
Expand Down Expand Up @@ -310,6 +283,7 @@ fn instance_extensions_output(members: &[ExtensionsMember]) -> TokenStream {
api_version,
device_extensions: _,
instance_extensions,
features: _,
}| {
api_version.filter(|_| instance_extensions.is_empty()).map(|(major, minor)| {
let version = format_ident!("V{}_{}", major, minor);
Expand Down Expand Up @@ -363,6 +337,7 @@ fn instance_extensions_output(members: &[ExtensionsMember]) -> TokenStream {
api_version,
device_extensions: _,
instance_extensions,
features: _,
}| {
(!instance_extensions.is_empty()).then(|| {
let condition_items = api_version
Expand Down
3 changes: 2 additions & 1 deletion vulkano/autogen/formats.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
// notice may not be copied, modified, or distributed except
// according to those terms.

use super::{extensions::RequiresOneOf, write_file, IndexMap, VkRegistryData};
use super::{write_file, IndexMap, RequiresOneOf, VkRegistryData};
use heck::ToSnakeCase;
use once_cell::sync::Lazy;
use proc_macro2::{Ident, Literal, TokenStream};
Expand Down Expand Up @@ -271,6 +271,7 @@ fn formats_output(members: &[FormatMember]) -> TokenStream {
api_version,
device_extensions,
instance_extensions,
features: _,
}| {
let condition_items = (api_version.iter().map(|(major, minor)| {
let version = format_ident!("V{}_{}", major, minor);
Expand Down
54 changes: 54 additions & 0 deletions vulkano/autogen/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@ use ahash::HashMap;
use once_cell::sync::Lazy;
use regex::Regex;
use std::{
cmp::min,
env,
fmt::Display,
fs::File,
io::{BufWriter, Write},
ops::BitOrAssign,
path::Path,
process::Command,
};
Expand Down Expand Up @@ -451,3 +453,55 @@ fn suffix_key(name: &str) -> u32 {
0
}
}

#[derive(Clone, Debug, Default, PartialEq, Eq)]
pub struct RequiresOneOf {
pub api_version: Option<(u32, u32)>,
pub device_extensions: Vec<String>,
pub instance_extensions: Vec<String>,
pub features: Vec<String>,
}

impl RequiresOneOf {
pub fn is_empty(&self) -> bool {
let Self {
api_version,
device_extensions,
instance_extensions,
features,
} = self;

api_version.is_none()
&& device_extensions.is_empty()
&& instance_extensions.is_empty()
&& features.is_empty()
}
}

impl BitOrAssign<&Self> for RequiresOneOf {
fn bitor_assign(&mut self, rhs: &Self) {
self.api_version = match (self.api_version, rhs.api_version) {
(None, None) => None,
(None, Some(x)) | (Some(x), None) => Some(x),
(Some(lhs), Some(rhs)) => Some(min(lhs, rhs)),
};

for rhs in &rhs.device_extensions {
if !self.device_extensions.contains(rhs) {
self.device_extensions.push(rhs.to_owned());
}
}

for rhs in &rhs.instance_extensions {
if !self.instance_extensions.contains(rhs) {
self.instance_extensions.push(rhs.to_owned());
}
}

for rhs in &rhs.features {
if !self.features.contains(rhs) {
self.features.push(rhs.to_owned());
}
}
}
}
Loading

0 comments on commit 2b08260

Please sign in to comment.