Skip to content

Commit

Permalink
[wgsl-out] Write correct scalar kind when width != 4 (#1514)
Browse files Browse the repository at this point in the history
* [wgsl-out] Write correct scalar kind when width != 4

* slight refactoring

* Also handle matrix scalar widths

* Fix formatting
  • Loading branch information
fintelia authored Dec 12, 2022
1 parent 5a1f43d commit 4f77cba
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 29 deletions.
66 changes: 38 additions & 28 deletions src/back/wgsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -452,11 +452,11 @@ impl<W: Write> Writer<W> {
/// Adds no trailing or leading whitespace
fn write_value_type(&mut self, module: &Module, inner: &TypeInner) -> BackendResult {
match *inner {
TypeInner::Vector { size, kind, .. } => write!(
TypeInner::Vector { size, kind, width } => write!(
self.out,
"vec{}<{}>",
back::vector_size_str(size),
scalar_kind_str(kind),
scalar_kind_str(kind, width),
)?,
TypeInner::Sampler { comparison: false } => {
write!(self.out, "sampler")?;
Expand All @@ -478,7 +478,7 @@ impl<W: Write> Writer<W> {
Ic::Sampled { kind, multi } => (
"",
if multi { "multisampled_" } else { "" },
scalar_kind_str(kind),
scalar_kind_str(kind, 4),
"",
),
Ic::Depth { multi } => {
Expand Down Expand Up @@ -508,11 +508,11 @@ impl<W: Write> Writer<W> {
write!(self.out, "<{}{}>", format_str, storage_str)?;
}
}
TypeInner::Scalar { kind, .. } => {
write!(self.out, "{}", scalar_kind_str(kind))?;
TypeInner::Scalar { kind, width } => {
write!(self.out, "{}", scalar_kind_str(kind, width))?;
}
TypeInner::Atomic { kind, .. } => {
write!(self.out, "atomic<{}>", scalar_kind_str(kind))?;
TypeInner::Atomic { kind, width } => {
write!(self.out, "atomic<{}>", scalar_kind_str(kind, width))?;
}
TypeInner::Array {
base,
Expand Down Expand Up @@ -582,12 +582,12 @@ impl<W: Write> Writer<W> {
TypeInner::ValuePointer {
size: None,
kind,
width: _,
width,
space,
} => {
let (address, maybe_access) = address_space_str(space);
if let Some(space) = address {
write!(self.out, "ptr<{}, {}", space, scalar_kind_str(kind))?;
write!(self.out, "ptr<{}, {}", space, scalar_kind_str(kind, width))?;
if let Some(access) = maybe_access {
write!(self.out, ", {}", access)?;
}
Expand All @@ -602,7 +602,7 @@ impl<W: Write> Writer<W> {
TypeInner::ValuePointer {
size: Some(size),
kind,
width: _,
width,
space,
} => {
let (address, maybe_access) = address_space_str(space);
Expand All @@ -612,7 +612,7 @@ impl<W: Write> Writer<W> {
"ptr<{}, vec{}<{}>",
space,
back::vector_size_str(size),
scalar_kind_str(kind)
scalar_kind_str(kind, width)
)?;
if let Some(access) = maybe_access {
write!(self.out, ", {}", access)?;
Expand Down Expand Up @@ -1424,17 +1424,24 @@ impl<W: Write> Writer<W> {
} => {
let inner = func_ctx.info[expr].ty.inner_with(&module.types);
match *inner {
TypeInner::Matrix { columns, rows, .. } => {
TypeInner::Matrix {
columns,
rows,
width,
..
} => {
let scalar_kind_str = scalar_kind_str(kind, convert.unwrap_or(width));
write!(
self.out,
"mat{}x{}<f32>",
"mat{}x{}<{}>",
back::vector_size_str(columns),
back::vector_size_str(rows)
back::vector_size_str(rows),
scalar_kind_str
)?;
}
TypeInner::Vector { size, .. } => {
TypeInner::Vector { size, width, .. } => {
let vector_size_str = back::vector_size_str(size);
let scalar_kind_str = scalar_kind_str(kind);
let scalar_kind_str = scalar_kind_str(kind, convert.unwrap_or(width));
if convert.is_some() {
write!(self.out, "vec{}<{}>", vector_size_str, scalar_kind_str)?;
} else {
Expand All @@ -1445,11 +1452,12 @@ impl<W: Write> Writer<W> {
)?;
}
}
TypeInner::Scalar { .. } => {
TypeInner::Scalar { width, .. } => {
let scalar_kind_str = scalar_kind_str(kind, convert.unwrap_or(width));
if convert.is_some() {
write!(self.out, "{}", scalar_kind_str(kind))?
write!(self.out, "{}", scalar_kind_str)?
} else {
write!(self.out, "bitcast<{}>", scalar_kind_str(kind))?
write!(self.out, "bitcast<{}>", scalar_kind_str)?
}
}
_ => {
Expand All @@ -1465,16 +1473,16 @@ impl<W: Write> Writer<W> {
}
Expression::Splat { size, value } => {
let inner = func_ctx.info[value].ty.inner_with(&module.types);
let scalar_kind = match *inner {
crate::TypeInner::Scalar { kind, .. } => kind,
let (scalar_kind, scalar_width) = match *inner {
crate::TypeInner::Scalar { kind, width } => (kind, width),
_ => {
return Err(Error::Unimplemented(format!(
"write_expr expression::splat {:?}",
inner
)));
}
};
let scalar = scalar_kind_str(scalar_kind);
let scalar = scalar_kind_str(scalar_kind, scalar_width);
let size = back::vector_size_str(size);

write!(self.out, "vec{}<{}>(", size, scalar)?;
Expand Down Expand Up @@ -1931,14 +1939,16 @@ const fn image_dimension_str(dim: crate::ImageDimension) -> &'static str {
}
}

const fn scalar_kind_str(kind: crate::ScalarKind) -> &'static str {
const fn scalar_kind_str(kind: crate::ScalarKind, width: u8) -> &'static str {
use crate::ScalarKind as Sk;

match kind {
Sk::Float => "f32",
Sk::Sint => "i32",
Sk::Uint => "u32",
Sk::Bool => "bool",
match (kind, width) {
(Sk::Float, 8) => "f64",
(Sk::Float, 4) => "f32",
(Sk::Sint, 4) => "i32",
(Sk::Uint, 4) => "u32",
(Sk::Bool, 1) => "bool",
_ => unreachable!(),
}
}

Expand Down
2 changes: 1 addition & 1 deletion tests/out/wgsl/extra.wgsl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
struct PushConstants {
index: u32,
double: vec2<f32>,
double: vec2<f64>,
}

struct FragmentIn {
Expand Down

0 comments on commit 4f77cba

Please sign in to comment.