Skip to content

Commit

Permalink
Shader Processor: process imported shader (#3290)
Browse files Browse the repository at this point in the history
# Objective

- I want to be able to use `#ifdef` and other processor directives in an imported shader

## Solution

- Process imported shader strings


Co-authored-by: François <8672791+mockersf@users.noreply.github.com>
  • Loading branch information
mockersf and mockersf committed Dec 22, 2021
1 parent b5d7ff2 commit a3c53e6
Showing 1 changed file with 157 additions and 36 deletions.
193 changes: 157 additions & 36 deletions crates/bevy_render/src/render_resource/shader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -360,16 +360,16 @@ impl ShaderProcessor {
}
};

let shader_defs = HashSet::<String>::from_iter(shader_defs.iter().cloned());
let shader_defs_unique = HashSet::<String>::from_iter(shader_defs.iter().cloned());
let mut scopes = vec![true];
let mut final_string = String::new();
for line in shader_str.split('\n') {
if let Some(cap) = self.ifdef_regex.captures(line) {
let def = cap.get(1).unwrap();
scopes.push(*scopes.last().unwrap() && shader_defs.contains(def.as_str()));
scopes.push(*scopes.last().unwrap() && shader_defs_unique.contains(def.as_str()));
} else if let Some(cap) = self.ifndef_regex.captures(line) {
let def = cap.get(1).unwrap();
scopes.push(*scopes.last().unwrap() && !shader_defs.contains(def.as_str()));
scopes.push(*scopes.last().unwrap() && !shader_defs_unique.contains(def.as_str()));
} else if self.else_regex.is_match(line) {
let mut is_parent_scope_truthy = true;
if scopes.len() > 1 {
Expand All @@ -388,19 +388,32 @@ impl ShaderProcessor {
.captures(line)
{
let import = ShaderImport::AssetPath(cap.get(1).unwrap().as_str().to_string());
apply_import(import_handles, shaders, &import, shader, &mut final_string)?;
self.apply_import(
import_handles,
shaders,
&import,
shader,
shader_defs,
&mut final_string,
)?;
} else if let Some(cap) = SHADER_IMPORT_PROCESSOR
.import_custom_path_regex
.captures(line)
{
let import = ShaderImport::Custom(cap.get(1).unwrap().as_str().to_string());
apply_import(import_handles, shaders, &import, shader, &mut final_string)?;
self.apply_import(
import_handles,
shaders,
&import,
shader,
shader_defs,
&mut final_string,
)?;
} else if *scopes.last().unwrap() {
final_string.push_str(line);
final_string.push('\n');
}
}

final_string.pop();

if scopes.len() != 1 {
Expand All @@ -417,45 +430,51 @@ impl ShaderProcessor {
}
}
}
}

fn apply_import(
import_handles: &HashMap<ShaderImport, Handle<Shader>>,
shaders: &HashMap<Handle<Shader>, Shader>,
import: &ShaderImport,
shader: &Shader,
final_string: &mut String,
) -> Result<(), ProcessShaderError> {
let imported_shader = import_handles
.get(import)
.and_then(|handle| shaders.get(handle))
.ok_or_else(|| ProcessShaderError::UnresolvedImport(import.clone()))?;
match &shader.source {
Source::Wgsl(_) => {
if let Source::Wgsl(import_source) = &imported_shader.source {
final_string.push_str(import_source);
} else {
return Err(ProcessShaderError::MismatchedImportFormat(import.clone()));
fn apply_import(
&self,
import_handles: &HashMap<ShaderImport, Handle<Shader>>,
shaders: &HashMap<Handle<Shader>, Shader>,
import: &ShaderImport,
shader: &Shader,
shader_defs: &[String],
final_string: &mut String,
) -> Result<(), ProcessShaderError> {
let imported_shader = import_handles
.get(import)
.and_then(|handle| shaders.get(handle))
.ok_or_else(|| ProcessShaderError::UnresolvedImport(import.clone()))?;
let imported_processed =
self.process(imported_shader, shader_defs, shaders, import_handles)?;

match &shader.source {
Source::Wgsl(_) => {
if let ProcessedShader::Wgsl(import_source) = &imported_processed {
final_string.push_str(import_source);
} else {
return Err(ProcessShaderError::MismatchedImportFormat(import.clone()));
}
}
}
Source::Glsl(_, _) => {
if let Source::Glsl(import_source, _) = &imported_shader.source {
final_string.push_str(import_source);
} else {
return Err(ProcessShaderError::MismatchedImportFormat(import.clone()));
Source::Glsl(_, _) => {
if let ProcessedShader::Glsl(import_source, _) = &imported_processed {
final_string.push_str(import_source);
} else {
return Err(ProcessShaderError::MismatchedImportFormat(import.clone()));
}
}
Source::SpirV(_) => {
return Err(ProcessShaderError::ShaderFormatDoesNotSupportImports);
}
}
Source::SpirV(_) => {
return Err(ProcessShaderError::ShaderFormatDoesNotSupportImports);
}
}

Ok(())
Ok(())
}
}

#[cfg(test)]
mod tests {
use bevy_asset::Handle;
use bevy_asset::{Handle, HandleUntyped};
use bevy_reflect::TypeUuid;
use bevy_utils::HashMap;
use naga::ShaderStage;

Expand Down Expand Up @@ -1081,4 +1100,106 @@ fn vertex(
.unwrap();
assert_eq!(result.get_wgsl_source().unwrap(), EXPECTED);
}

#[test]
fn process_import_ifdef() {
#[rustfmt::skip]
const FOO: &str = r"
#ifdef IMPORT_MISSING
fn in_import_missing() { }
#endif
#ifdef IMPORT_PRESENT
fn in_import_present() { }
#endif
";
#[rustfmt::skip]
const INPUT: &str = r"
#import FOO
#ifdef MAIN_MISSING
fn in_main_missing() { }
#endif
#ifdef MAIN_PRESENT
fn in_main_present() { }
#endif
";
#[rustfmt::skip]
const EXPECTED: &str = r"
fn in_import_present() { }
fn in_main_present() { }
";
let processor = ShaderProcessor::default();
let mut shaders = HashMap::default();
let mut import_handles = HashMap::default();
let foo_handle = Handle::<Shader>::default();
shaders.insert(foo_handle.clone_weak(), Shader::from_wgsl(FOO));
import_handles.insert(
ShaderImport::Custom("FOO".to_string()),
foo_handle.clone_weak(),
);
let result = processor
.process(
&Shader::from_wgsl(INPUT),
&["MAIN_PRESENT".to_string(), "IMPORT_PRESENT".to_string()],
&shaders,
&import_handles,
)
.unwrap();
assert_eq!(result.get_wgsl_source().unwrap(), EXPECTED);
}

#[test]
fn process_import_in_import() {
#[rustfmt::skip]
const BAR: &str = r"
#ifdef DEEP
fn inner_import() { }
#endif
";
const FOO: &str = r"
#import BAR
fn import() { }
";
#[rustfmt::skip]
const INPUT: &str = r"
#import FOO
fn in_main() { }
";
#[rustfmt::skip]
const EXPECTED: &str = r"
fn inner_import() { }
fn import() { }
fn in_main() { }
";
let processor = ShaderProcessor::default();
let mut shaders = HashMap::default();
let mut import_handles = HashMap::default();
{
let bar_handle = Handle::<Shader>::default();
shaders.insert(bar_handle.clone_weak(), Shader::from_wgsl(BAR));
import_handles.insert(
ShaderImport::Custom("BAR".to_string()),
bar_handle.clone_weak(),
);
}
{
let foo_handle = HandleUntyped::weak_from_u64(Shader::TYPE_UUID, 1).typed();
shaders.insert(foo_handle.clone_weak(), Shader::from_wgsl(FOO));
import_handles.insert(
ShaderImport::Custom("FOO".to_string()),
foo_handle.clone_weak(),
);
}
let result = processor
.process(
&Shader::from_wgsl(INPUT),
&["DEEP".to_string()],
&shaders,
&import_handles,
)
.unwrap();
assert_eq!(result.get_wgsl_source().unwrap(), EXPECTED);
}
}

0 comments on commit a3c53e6

Please sign in to comment.