Skip to content

Commit

Permalink
glsl-in: Allow nested accesses in lhs positions
Browse files Browse the repository at this point in the history
Also fixes access to runtime sized arrays behind named blocks
  • Loading branch information
JCapucho authored and kvark committed Mar 28, 2022
1 parent 21f89b6 commit 4146cb2
Show file tree
Hide file tree
Showing 8 changed files with 168 additions and 65 deletions.
2 changes: 2 additions & 0 deletions src/front/glsl/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,9 @@ pub struct EntryArg {
#[derive(Debug, Clone)]
pub struct VariableReference {
pub expr: Handle<Expression>,
/// Wether the variable is of a pointer type (and needs loading) or not
pub load: bool,
/// Wether the value of the variable can be changed or not
pub mutable: bool,
pub constant: Option<(Handle<Constant>, Handle<Type>)>,
pub entry_arg: Option<usize>,
Expand Down
78 changes: 52 additions & 26 deletions src/front/glsl/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,26 +16,26 @@ use crate::{
use std::{convert::TryFrom, ops::Index};

/// The position at which an expression is, used while lowering
#[derive(Clone, Copy, PartialEq, Eq)]
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
pub enum ExprPos {
/// The expression is in the left hand side of an assignment
Lhs,
/// The expression is in the right hand side of an assignment
Rhs,
/// The expression is an array being indexed, needed to allow constant
/// arrays to be dinamically indexed
ArrayBase {
AccessBase {
/// The index is a constant
constant_index: bool,
},
}

impl ExprPos {
/// Returns an lhs position if the current position is lhs otherwise ArrayBase
fn maybe_array_base(&self, constant_index: bool) -> Self {
/// Returns an lhs position if the current position is lhs otherwise AccessBase
fn maybe_access_base(&self, constant_index: bool) -> Self {
match *self {
ExprPos::Lhs => *self,
_ => ExprPos::ArrayBase { constant_index },
_ => ExprPos::AccessBase { constant_index },
}
}
}
Expand Down Expand Up @@ -492,6 +492,8 @@ impl Context {
) -> Result<(Option<Handle<Expression>>, Span)> {
let HirExpr { ref kind, meta } = stmt.hir_exprs[expr];

log::debug!("Lowering {:?}", expr);

let handle = match *kind {
HirExprKind::Access { base, index } => {
let (index, index_meta) =
Expand All @@ -508,7 +510,7 @@ impl Context {
stmt,
parser,
base,
pos.maybe_array_base(maybe_constant_index.is_some()),
pos.maybe_access_base(maybe_constant_index.is_some()),
body,
)?
.0;
Expand Down Expand Up @@ -551,18 +553,20 @@ impl Context {
pointer
}
HirExprKind::Select { base, ref field } => {
let base = self.lower_expect_inner(stmt, parser, base, pos, body)?.0;
let base = self
.lower_expect_inner(stmt, parser, base, pos.maybe_access_base(true), body)?
.0;

parser.field_selection(self, ExprPos::Lhs == pos, body, base, field, meta)?
parser.field_selection(self, pos, body, base, field, meta)?
}
HirExprKind::Constant(constant) if pos != ExprPos::Lhs => {
self.add_expression(Expression::Constant(constant), meta, body)
}
HirExprKind::Binary { left, op, right } if pos != ExprPos::Lhs => {
let (mut left, left_meta) =
self.lower_expect_inner(stmt, parser, left, pos, body)?;
self.lower_expect_inner(stmt, parser, left, ExprPos::Rhs, body)?;
let (mut right, right_meta) =
self.lower_expect_inner(stmt, parser, right, pos, body)?;
self.lower_expect_inner(stmt, parser, right, ExprPos::Rhs, body)?;

match op {
BinaryOperator::ShiftLeft | BinaryOperator::ShiftRight => self
Expand Down Expand Up @@ -1003,7 +1007,9 @@ impl Context {
}
}
HirExprKind::Unary { op, expr } if pos != ExprPos::Lhs => {
let expr = self.lower_expect_inner(stmt, parser, expr, pos, body)?.0;
let expr = self
.lower_expect_inner(stmt, parser, expr, ExprPos::Rhs, body)?
.0;

self.add_expression(Expression::Unary { op, expr }, meta, body)
}
Expand All @@ -1020,20 +1026,29 @@ impl Context {

var.expr
}
ExprPos::ArrayBase {
constant_index: false,
} => {
if let Some((constant, ty)) = var.constant {
let local = self.locals.append(
LocalVariable {
name: None,
ty,
init: Some(constant),
},
Span::default(),
);
ExprPos::AccessBase { constant_index } => {
// If the index isn't constant all accesses backed by a constant base need
// to be done trough a proxy local variable, since constants have a non
// pointer type which is required for dynamic indexing
if !constant_index {
if let Some((constant, ty)) = var.constant {
let local = self.locals.append(
LocalVariable {
name: None,
ty,
init: Some(constant),
},
Span::default(),
);

self.add_expression(Expression::LocalVariable(local), Span::default(), body)
self.add_expression(
Expression::LocalVariable(local),
Span::default(),
body,
)
} else {
var.expr
}
} else {
var.expr
}
Expand Down Expand Up @@ -1091,9 +1106,12 @@ impl Context {
let (mut value, value_meta) =
self.lower_expect_inner(stmt, parser, value, ExprPos::Rhs, body)?;

let scalar_components = self.expr_scalar_components(parser, pointer, ptr_meta)?;
let ty = match *parser.resolve_type(self, pointer, ptr_meta)? {
TypeInner::Pointer { base, .. } => &parser.module.types[base].inner,
ref ty => ty,
};

if let Some((kind, width)) = scalar_components {
if let Some((kind, width)) = scalar_components(ty) {
self.implicit_conversion(parser, &mut value, value_meta, kind, width)?;
}

Expand Down Expand Up @@ -1216,6 +1234,14 @@ impl Context {
}
};

log::trace!(
"Lowered {:?}\n\tKind = {:?}\n\tPos = {:?}\n\tResult = {:?}",
expr,
kind,
pos,
handle
);

Ok((Some(handle), meta))
}

Expand Down
38 changes: 27 additions & 11 deletions src/front/glsl/variables.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use super::{
ast::*,
context::Context,
context::{Context, ExprPos},
error::{Error, ErrorKind},
Parser, Result, Span,
};
Expand Down Expand Up @@ -231,7 +231,7 @@ impl Parser {
pub(crate) fn field_selection(
&mut self,
ctx: &mut Context,
lhs: bool,
pos: ExprPos,
body: &mut Block,
expression: Handle<Expression>,
name: &str,
Expand All @@ -250,14 +250,21 @@ impl Parser {
kind: ErrorKind::UnknownField(name.into()),
meta,
})?;
Ok(ctx.add_expression(
let pointer = ctx.add_expression(
Expression::AccessIndex {
base: expression,
index: index as u32,
},
meta,
body,
))
);

Ok(match pos {
ExprPos::Rhs if is_pointer => {
ctx.add_expression(Expression::Load { pointer }, meta, body)
}
_ => pointer,
})
}
// swizzles (xyzw, rgba, stpq)
TypeInner::Vector { size, .. } => {
Expand All @@ -277,7 +284,7 @@ impl Parser {
.or_else(|| check_swizzle_components("stpq"));

if let Some(components) = components {
if lhs {
if let ExprPos::Lhs = pos {
let not_unique = (1..components.len())
.any(|i| components[i..].contains(&components[i - 1]));
if not_unique {
Expand Down Expand Up @@ -315,14 +322,23 @@ impl Parser {
}

let size = match components.len() {
// Swizzles with just one component are accesses and not swizzles
1 => {
// only single element swizzle, like pos.y, just return that component.
if lhs {
// Because of possible nested swizzles, like pos.xy.x, we have to unwrap the potential load expr.
if let Expression::Load { ref pointer } = ctx[expression] {
expression = *pointer;
match pos {
// If the position is in the right hand side and the base
// vector is a pointer, load it, otherwise the swizzle would
// produce a pointer
ExprPos::Rhs if is_pointer => {
expression = ctx.add_expression(
Expression::Load {
pointer: expression,
},
meta,
body,
);
}
}
_ => {}
};
return Ok(ctx.add_expression(
Expression::AccessIndex {
base: expression,
Expand Down
22 changes: 22 additions & 0 deletions tests/in/glsl/buffer.frag
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#version 450

layout(set = 0, binding = 0) buffer testBufferBlock {
uint[] data;
} testBuffer;

layout(set = 0, binding = 1) writeonly buffer testBufferWriteOnlyBlock {
uint[] data;
} testBufferWriteOnly;

layout(set = 0, binding = 2) readonly buffer testBufferReadOnlyBlock {
uint[] data;
} testBufferReadOnly;

void main() {
uint a = testBuffer.data[0];
testBuffer.data[1] = 2;

testBufferWriteOnly.data[1] = 2;

uint b = testBufferReadOnly.data[0];
}
26 changes: 13 additions & 13 deletions tests/out/wgsl/bevy-pbr-frag.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -498,21 +498,21 @@ fn point_light(light: PointLight, roughness_8: f32, NdotV: f32, N: vec3<f32>, V_
R_1 = R;
F0_1 = F0_;
diffuseColor_1 = diffuseColor;
let _e56 = light_1;
let _e57 = light_1.pos;
let _e59 = v_WorldPosition_1;
light_to_frag = (_e56.pos.xyz - _e59.xyz);
light_to_frag = (_e57.xyz - _e59.xyz);
let _e65 = light_to_frag;
let _e66 = light_to_frag;
distance_square = dot(_e65, _e66);
let _e70 = light_1;
let _e71 = light_1.lightParams;
let _e73 = distance_square;
let _e74 = light_1;
let _e77 = getDistanceAttenuation(_e73, _e74.lightParams.x);
let _e75 = light_1.lightParams;
let _e77 = getDistanceAttenuation(_e73, _e75.x);
rangeAttenuation = _e77;
let _e79 = roughness_9;
a_1 = _e79;
let _e81 = light_1;
radius = _e81.lightParams.y;
let _e82 = light_1.lightParams;
radius = _e82.y;
let _e87 = light_to_frag;
let _e88 = R_1;
let _e90 = R_1;
Expand Down Expand Up @@ -611,10 +611,10 @@ fn point_light(light: PointLight, roughness_8: f32, NdotV: f32, N: vec3<f32>, V_
diffuse = (_e302 * _e311);
let _e314 = diffuse;
let _e315 = specular_1;
let _e317 = light_1;
let _e318 = light_1.color;
let _e321 = rangeAttenuation;
let _e322 = NoL_6;
return (((_e314 + _e315) * _e317.color.xyz) * (_e321 * _e322));
return (((_e314 + _e315) * _e318.xyz) * (_e321 * _e322));
}

fn dir_light(light_2: DirectionalLight, roughness_10: f32, NdotV_2: f32, normal: vec3<f32>, view: vec3<f32>, R_2: vec3<f32>, F0_2: vec3<f32>, diffuseColor_2: vec3<f32>) -> vec3<f32> {
Expand Down Expand Up @@ -643,8 +643,8 @@ fn dir_light(light_2: DirectionalLight, roughness_10: f32, NdotV_2: f32, normal:
R_3 = R_2;
F0_3 = F0_2;
diffuseColor_3 = diffuseColor_2;
let _e56 = light_3;
incident_light = _e56.direction.xyz;
let _e57 = light_3.direction;
incident_light = _e57.xyz;
let _e60 = incident_light;
let _e61 = view_1;
let _e63 = incident_light;
Expand Down Expand Up @@ -684,9 +684,9 @@ fn dir_light(light_2: DirectionalLight, roughness_10: f32, NdotV_2: f32, normal:
specular_2 = _e146;
let _e148 = specular_2;
let _e149 = diffuse_1;
let _e151 = light_3;
let _e152 = light_3.color;
let _e155 = NoL_7;
return (((_e148 + _e149) * _e151.color.xyz) * _e155);
return (((_e148 + _e149) * _e152.xyz) * _e155);
}

fn main_1() {
Expand Down
37 changes: 37 additions & 0 deletions tests/out/wgsl/buffer-frag.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
struct testBufferBlock {
data: array<u32>,
}

struct testBufferWriteOnlyBlock {
data: array<u32>,
}

struct testBufferReadOnlyBlock {
data: array<u32>,
}

@group(0) @binding(0)
var<storage, read_write> testBuffer: testBufferBlock;
@group(0) @binding(1)
var<storage, read_write> testBufferWriteOnly: testBufferWriteOnlyBlock;
@group(0) @binding(2)
var<storage> testBufferReadOnly: testBufferReadOnlyBlock;

fn main_1() {
var a: u32;
var b: u32;

let _e12 = testBuffer.data[0];
a = _e12;
testBuffer.data[1] = u32(2);
testBufferWriteOnly.data[1] = u32(2);
let _e27 = testBufferReadOnly.data[0];
b = _e27;
return;
}

@stage(fragment)
fn main() {
main_1();
return;
}
12 changes: 6 additions & 6 deletions tests/out/wgsl/declarations-vert.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@ fn main_1() {
var a_1: f32;
var b: f32;

let _e34 = in_array_2;
from_input_array = _e34[1];
let _e39 = array_2d;
a_1 = _e39[0][0];
let _e50 = array_toomanyd;
b = _e50[0][0][0][0][0][0][0];
let _e35 = in_array_2[1];
from_input_array = _e35;
let _e41 = array_2d[0][0];
a_1 = _e41;
let _e57 = array_toomanyd[0][0][0][0][0][0][0];
b = _e57;
out_array[0] = vec4<f32>(2.0);
return;
}
Expand Down
Loading

0 comments on commit 4146cb2

Please sign in to comment.