diff --git a/src/compose/error.rs b/src/compose/error.rs index 255cf1d..112d080 100644 --- a/src/compose/error.rs +++ b/src/compose/error.rs @@ -1,4 +1,4 @@ -use std::ops::Range; +use std::{borrow::Cow, collections::HashMap, ops::Range}; use codespan_reporting::{ diagnostic::{Diagnostic, Label}, @@ -8,12 +8,19 @@ use codespan_reporting::{ use thiserror::Error; use tracing::trace; -use super::Composer; +use super::{ + preprocess::{PreprocessOutput, PreprocessorMetaData}, + Composer, ShaderDefValue, +}; use crate::{compose::SPAN_SHIFT, redirect::RedirectError}; #[derive(Debug)] pub enum ErrSource { - Module(String, usize), + Module { + name: String, + offset: usize, + defs: HashMap, + }, Constructing { path: String, source: String, @@ -24,21 +31,38 @@ pub enum ErrSource { impl ErrSource { pub fn path<'a>(&'a self, composer: &'a Composer) -> &'a String { match self { - ErrSource::Module(c, _) => &composer.module_sets.get(c).unwrap().file_path, + ErrSource::Module { name, .. } => &composer.module_sets.get(name).unwrap().file_path, ErrSource::Constructing { path, .. } => path, } } - pub fn source<'a>(&'a self, composer: &'a Composer) -> &'a String { + pub fn source<'a>(&'a self, composer: &'a Composer) -> Cow<'a, String> { match self { - ErrSource::Module(c, _) => &composer.module_sets.get(c).unwrap().substituted_source, - ErrSource::Constructing { source, .. } => source, + ErrSource::Module { name, defs, .. } => { + let raw_source = &composer.module_sets.get(name).unwrap().sanitized_source; + let Ok(PreprocessOutput { + preprocessed_source: source, + meta: PreprocessorMetaData { imports, .. }, + }) = composer + .preprocessor + .preprocess(raw_source, defs, composer.validate) + else { + return Default::default() + }; + + let Ok(source) = composer + .substitute_shader_string(&source, &imports) + else { return Default::default() }; + + Cow::Owned(source) + } + ErrSource::Constructing { source, .. } => Cow::Borrowed(source), } } pub fn offset(&self) -> usize { match self { - ErrSource::Module(_, offset) | ErrSource::Constructing { offset, .. } => *offset, + ErrSource::Module { offset, .. } | ErrSource::Constructing { offset, .. } => *offset, } } } @@ -159,7 +183,7 @@ impl ComposerError { ..((rng.end & ((1 << SPAN_SHIFT) - 1)).saturating_sub(source_offset)) }; - let files = SimpleFile::new(path, source); + let files = SimpleFile::new(path, source.as_str()); let config = term::Config::default(); #[cfg(test)] let mut writer = term::termcolor::NoColor::new(Vec::new()); diff --git a/src/compose/mod.rs b/src/compose/mod.rs index 666ead7..0543a10 100644 --- a/src/compose/mod.rs +++ b/src/compose/mod.rs @@ -243,8 +243,8 @@ pub struct ComposableModule { #[derive(Debug)] pub struct ComposableModuleDefinition { pub name: String, - // shader text - pub substituted_source: String, + // shader text (with auto bindings replaced - we do this on module add as we only want to do it once to avoid burning slots) + pub sanitized_source: String, // language pub language: ShaderLanguage, // source path for error display @@ -421,21 +421,45 @@ impl Composer { undecor.to_string() } - fn sanitize_and_substitute_shader_string( - &mut self, - source: &str, - imports: &[ImportDefWithOffset], - ) -> Result { + fn sanitize_and_set_auto_bindings(&mut self, source: &str) -> String { let mut substituted_source = source.replace("\r\n", "\n").replace('\r', "\n"); if !substituted_source.ends_with('\n') { substituted_source.push('\n'); } + // replace @binding(auto) with an incrementing index + struct AutoBindingReplacer<'a> { + auto: &'a mut u32, + } + + impl<'a> regex::Replacer for AutoBindingReplacer<'a> { + fn replace_append(&mut self, _: ®ex::Captures<'_>, dst: &mut String) { + dst.push_str(&format!("@binding({})", self.auto)); + *self.auto += 1; + } + } + + let substituted_source = self.auto_binding_regex.replace_all( + &substituted_source, + AutoBindingReplacer { + auto: &mut self.auto_binding_index, + }, + ); + + substituted_source.into_owned() + } + + fn substitute_shader_string( + &self, + source: &str, + imports: &[ImportDefWithOffset], + ) -> Result { // sort imports by decreasing length so we don't accidentally replace substrings of a longer import let mut imports = imports.to_vec(); imports.sort_by_key(|import| usize::MAX - import.definition.as_name().len()); let mut imported_items = HashMap::new(); + let mut substituted_source = source.to_owned(); for import in imports { match import.definition.items { @@ -515,26 +539,7 @@ impl Composer { } substituted_source = item_substituted_source; - // replace @binding(auto) with an incrementing index - struct AutoBindingReplacer<'a> { - auto: &'a mut u32, - } - - impl<'a> regex::Replacer for AutoBindingReplacer<'a> { - fn replace_append(&mut self, _: ®ex::Captures<'_>, dst: &mut String) { - dst.push_str(&format!("@binding({})", self.auto)); - *self.auto += 1; - } - } - - let substituted_source = self.auto_binding_regex.replace_all( - &substituted_source, - AutoBindingReplacer { - auto: &mut self.auto_binding_index, - }, - ); - - Ok(substituted_source.into_owned()) + Ok(substituted_source) } fn naga_to_string( @@ -699,7 +704,11 @@ impl Composer { .naga_to_string(&mut header_module.into(), language, name) .map_err(|inner| ComposerError { inner, - source: ErrSource::Module(name.to_owned(), 0), + source: ErrSource::Module { + name: name.to_owned(), + offset: 0, + defs: shader_defs.clone(), + }, })?; module_string.push_str(&composed_header); @@ -719,7 +728,11 @@ impl Composer { debug!("full err'd source file: \n---\n{}\n---", module_string); ComposerError { inner: ComposerErrorInner::WgslParseError(e), - source: ErrSource::Module(name.to_owned(), start_offset), + source: ErrSource::Module { + name: name.to_owned(), + offset: start_offset, + defs: shader_defs.clone(), + }, } })?, ShaderLanguage::Glsl => naga::front::glsl::Frontend::default() @@ -734,7 +747,11 @@ impl Composer { debug!("full err'd source file: \n---\n{}\n---", module_string); ComposerError { inner: ComposerErrorInner::GlslParseError(e), - source: ErrSource::Module(name.to_owned(), start_offset), + source: ErrSource::Module { + name: name.to_owned(), + offset: start_offset, + defs: shader_defs.clone(), + }, } })?, }; @@ -877,7 +894,7 @@ impl Composer { // - record any types/vars/constants/functions that are defined within this module // - build headers for each supported language fn create_composable_module( - &self, + &mut self, module_definition: &ComposableModuleDefinition, module_decoration: String, shader_defs: &HashMap, @@ -887,7 +904,11 @@ impl Composer { let wrap_err = |inner: ComposerErrorInner| -> ComposerError { ComposerError { inner, - source: ErrSource::Module(module_definition.name.to_owned(), 0), + source: ErrSource::Module { + name: module_definition.name.to_owned(), + offset: 0, + defs: shader_defs.clone(), + }, } }; @@ -897,12 +918,16 @@ impl Composer { } = self .preprocessor .preprocess( - &module_definition.substituted_source, + &module_definition.sanitized_source, shader_defs, self.validate, ) .map_err(wrap_err)?; + let source = self + .substitute_shader_string(&source, &imports) + .map_err(wrap_err)?; + let mut imports: Vec<_> = imports .into_iter() .map(|import_with_offset| import_with_offset.definition) @@ -1029,7 +1054,11 @@ impl Composer { let wrap_err = |inner: ComposerErrorInner| -> ComposerError { ComposerError { inner, - source: ErrSource::Module(module_definition.name.to_owned(), start_offset), + source: ErrSource::Module { + name: module_definition.name.to_owned(), + offset: start_offset, + defs: shader_defs.clone(), + }, } }; @@ -1334,10 +1363,14 @@ impl Composer { ) -> Result { let imports = self .preprocessor - .preprocess(&module_set.substituted_source, shader_defs, self.validate) + .preprocess(&module_set.sanitized_source, shader_defs, self.validate) .map_err(|inner| ComposerError { inner, - source: ErrSource::Module(module_set.name.to_owned(), 0), + source: ErrSource::Module { + name: module_set.name.to_owned(), + offset: 0, + defs: shader_defs.clone(), + }, })? .meta .imports; @@ -1461,6 +1494,8 @@ impl Composer { }); } + let substituted_source = self.sanitize_and_set_auto_bindings(source); + let ( PreprocessorMetaData { name: module_name, @@ -1469,7 +1504,7 @@ impl Composer { _, ) = self .preprocessor - .get_preprocessor_metadata(source, false) + .get_preprocessor_metadata(&substituted_source, false) .map_err(|inner| ComposerError { inner, source: ErrSource::Constructing { @@ -1508,17 +1543,6 @@ impl Composer { }), ); - let substituted_source = self - .sanitize_and_substitute_shader_string(source, &imports) - .map_err(|e| ComposerError { - inner: e, - source: ErrSource::Constructing { - path: file_path.to_owned(), - source: source.to_owned(), - offset: 0, - }, - })?; - let mut effective_defs = HashSet::new(); for import in &imports { // we require modules already added so that we can capture the shader_defs that may impact us by impacting our dependencies @@ -1532,7 +1556,7 @@ impl Composer { ), source: ErrSource::Constructing { path: file_path.to_owned(), - source: substituted_source.clone(), + source: substituted_source.to_owned(), offset: 0, }, })?; @@ -1557,7 +1581,7 @@ impl Composer { let module_set = ComposableModuleDefinition { name: module_name.clone(), - substituted_source, + sanitized_source: substituted_source, file_path: file_path.to_owned(), language, effective_defs: effective_defs.into_iter().collect(), @@ -1610,14 +1634,16 @@ impl Composer { additional_imports, } = desc; + let sanitized_source = self.sanitize_and_set_auto_bindings(source); + let (_, defines) = self .preprocessor - .get_preprocessor_metadata(source, true) + .get_preprocessor_metadata(&sanitized_source, true) .map_err(|inner| ComposerError { inner, source: ErrSource::Constructing { path: file_path.to_owned(), - source: source.to_owned(), + source: sanitized_source.to_owned(), offset: 0, }, })?; @@ -1628,24 +1654,24 @@ impl Composer { meta: PreprocessorMetaData { name, imports }, } = self .preprocessor - .preprocess(source, &shader_defs, false) + .preprocess(&sanitized_source, &shader_defs, false) .map_err(|inner| ComposerError { inner, source: ErrSource::Constructing { path: file_path.to_owned(), - source: source.to_owned(), + source: sanitized_source.to_owned(), offset: 0, }, })?; let name = name.unwrap_or_default(); let substituted_source = self - .sanitize_and_substitute_shader_string(source, &imports) + .substitute_shader_string(&sanitized_source, &imports) .map_err(|inner| ComposerError { inner, source: ErrSource::Constructing { path: file_path.to_owned(), - source: source.to_owned(), + source: sanitized_source.to_owned(), offset: 0, }, })?; @@ -1667,7 +1693,7 @@ impl Composer { }, source: ErrSource::Constructing { path: file_path.to_owned(), - source: substituted_source, + source: sanitized_source.to_owned(), offset: 0, }, }); @@ -1679,7 +1705,7 @@ impl Composer { inner: ComposerErrorInner::ImportNotFound(import_name.clone(), offset), source: ErrSource::Constructing { path: file_path.to_owned(), - source: substituted_source, + source: sanitized_source, offset: 0, }, }); @@ -1693,7 +1719,7 @@ impl Composer { let definition = ComposableModuleDefinition { name, - substituted_source, + sanitized_source: substituted_source, language: shader_type.into(), file_path: file_path.to_owned(), module_index: 0, @@ -1711,7 +1737,7 @@ impl Composer { inner: e.inner, source: ErrSource::Constructing { path: definition.file_path.to_owned(), - source: definition.substituted_source.to_owned(), + source: definition.sanitized_source.to_owned(), offset: e.source.offset(), }, })?; @@ -1804,7 +1830,7 @@ impl Composer { match module_index { 0 => ErrSource::Constructing { path: file_path.to_owned(), - source: definition.substituted_source, + source: definition.sanitized_source, offset: composable.start_offset, }, _ => { @@ -1817,7 +1843,11 @@ impl Composer { .get_module(&shader_defs) .unwrap() .start_offset; - ErrSource::Module(module_name, offset) + ErrSource::Module { + name: module_name, + offset, + defs: shader_defs.clone(), + } } } } diff --git a/src/compose/test.rs b/src/compose/test.rs index 5555ac1..f64c47d 100644 --- a/src/compose/test.rs +++ b/src/compose/test.rs @@ -1008,6 +1008,81 @@ mod test { output_eq!(wgsl, "tests/expected/item_sub_point.txt"); } + #[test] + fn conditional_import() { + let mut composer = Composer::default(); + + composer + .add_composable_module(ComposableModuleDescriptor { + source: include_str!("tests/conditional_import/mod_a.wgsl"), + file_path: "tests/conditional_import/mod_a.wgsl", + ..Default::default() + }) + .unwrap(); + composer + .add_composable_module(ComposableModuleDescriptor { + source: include_str!("tests/conditional_import/mod_b.wgsl"), + file_path: "tests/conditional_import/mod_b.wgsl", + ..Default::default() + }) + .unwrap(); + + let module_a = composer + .make_naga_module(NagaModuleDescriptor { + source: include_str!("tests/conditional_import/top.wgsl"), + file_path: "tests/conditional_import/top.wgsl", + shader_defs: HashMap::from_iter([("USE_A".to_owned(), ShaderDefValue::Bool(true))]), + ..Default::default() + }) + .unwrap(); + + let info = naga::valid::Validator::new( + naga::valid::ValidationFlags::all(), + naga::valid::Capabilities::default(), + ) + .validate(&module_a) + .unwrap(); + let wgsl = naga::back::wgsl::write_string( + &module_a, + &info, + naga::back::wgsl::WriterFlags::EXPLICIT_TYPES, + ) + .unwrap(); + + // let mut f = std::fs::File::create("conditional_import_a.txt").unwrap(); + // f.write_all(wgsl.as_bytes()).unwrap(); + // drop(f); + + output_eq!(wgsl, "tests/expected/conditional_import_a.txt"); + + let module_b = composer + .make_naga_module(NagaModuleDescriptor { + source: include_str!("tests/conditional_import/top.wgsl"), + file_path: "tests/conditional_import/top.wgsl", + ..Default::default() + }) + .unwrap(); + + let info = naga::valid::Validator::new( + naga::valid::ValidationFlags::all(), + naga::valid::Capabilities::default(), + ) + .validate(&module_b) + .unwrap(); + let wgsl = naga::back::wgsl::write_string( + &module_b, + &info, + naga::back::wgsl::WriterFlags::EXPLICIT_TYPES, + ) + .unwrap(); + + // let mut f = std::fs::File::create("conditional_import_b.txt").unwrap(); + // f.write_all(wgsl.as_bytes()).unwrap(); + // drop(f); + + output_eq!(wgsl, "tests/expected/conditional_import_b.txt"); + } + // actually run a shader and extract the result // needs the composer to contain a module called "test_module", with a function called "entry_point" returning an f32. fn test_shader(composer: &mut Composer) -> f32 { diff --git a/src/compose/tests/conditional_import/mod_a.wgsl b/src/compose/tests/conditional_import/mod_a.wgsl new file mode 100644 index 0000000..054e49f --- /dev/null +++ b/src/compose/tests/conditional_import/mod_a.wgsl @@ -0,0 +1,3 @@ +#define_import_path a + +const C: u32 = 1u; \ No newline at end of file diff --git a/src/compose/tests/conditional_import/mod_b.wgsl b/src/compose/tests/conditional_import/mod_b.wgsl new file mode 100644 index 0000000..9caaac0 --- /dev/null +++ b/src/compose/tests/conditional_import/mod_b.wgsl @@ -0,0 +1,3 @@ +#define_import_path b + +const C: u32 = 2u; \ No newline at end of file diff --git a/src/compose/tests/conditional_import/top.wgsl b/src/compose/tests/conditional_import/top.wgsl new file mode 100644 index 0000000..8182f45 --- /dev/null +++ b/src/compose/tests/conditional_import/top.wgsl @@ -0,0 +1,9 @@ +#ifdef USE_A + #import a C +#else + #import b C +#endif + +fn main() -> u32 { + return C; +} \ No newline at end of file diff --git a/src/compose/tests/expected/conditional_import_a.txt b/src/compose/tests/expected/conditional_import_a.txt new file mode 100644 index 0000000..2f41734 --- /dev/null +++ b/src/compose/tests/expected/conditional_import_a.txt @@ -0,0 +1,6 @@ +const _naga_oil_mod_ME_memberC: u32 = 1u; + +fn main() -> u32 { + return _naga_oil_mod_ME_memberC; +} + diff --git a/src/compose/tests/expected/conditional_import_b.txt b/src/compose/tests/expected/conditional_import_b.txt new file mode 100644 index 0000000..31a7a56 --- /dev/null +++ b/src/compose/tests/expected/conditional_import_b.txt @@ -0,0 +1,6 @@ +const _naga_oil_mod_MI_memberC: u32 = 2u; + +fn main() -> u32 { + return _naga_oil_mod_MI_memberC; +} +