Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[glsl-in] Compute shader fixes and additions #955

Merged
merged 12 commits into from
Jun 9, 2021
195 changes: 174 additions & 21 deletions src/front/glsl/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ pub enum GlobalLookupKind {
pub struct GlobalLookup {
pub kind: GlobalLookupKind,
pub entry_arg: Option<usize>,
pub mutable: bool,
}

#[derive(Debug, PartialEq, Eq, Hash)]
Expand Down Expand Up @@ -77,11 +78,13 @@ pub struct Program<'a> {
pub profile: Profile,
pub entry_points: &'a FastHashMap<String, ShaderStage>,

pub workgroup_size: [u32; 3],
pub early_fragment_tests: bool,

pub lookup_function: FastHashMap<FunctionSignature, FunctionDeclaration>,
pub lookup_type: FastHashMap<String, Handle<Type>>,

pub global_variables: Vec<(String, GlobalLookup)>,
pub constants: Vec<(String, Handle<Constant>)>,

pub entry_args: Vec<EntryArg>,
pub entries: Vec<(String, ShaderStage, Handle<Function>)>,
Expand All @@ -98,10 +101,12 @@ impl<'a> Program<'a> {
profile: Profile::Core,
entry_points,

workgroup_size: [1; 3],
early_fragment_tests: false,

lookup_function: FastHashMap::default(),
lookup_type: FastHashMap::default(),
global_variables: Vec::new(),
constants: Vec::new(),

entry_args: Vec::new(),
entries: Vec::new(),
Expand Down Expand Up @@ -219,7 +224,7 @@ impl<'function> Context<'function> {

scopes: vec![FastHashMap::default()],
lookup_global_var_exps: FastHashMap::with_capacity_and_hasher(
program.constants.len() + program.global_variables.len(),
program.global_variables.len(),
Default::default(),
),
typifier: Typifier::new(),
Expand All @@ -229,23 +234,15 @@ impl<'function> Context<'function> {
emitter: Emitter::default(),
};

for &(ref name, handle) in program.constants.iter() {
let expr = this.expressions.append(Expression::Constant(handle));
let var = VariableReference {
expr,
load: None,
mutable: false,
entry_arg: None,
};

this.lookup_global_var_exps.insert(name.into(), var);
}

this.emit_start();

for &(ref name, lookup) in program.global_variables.iter() {
this.emit_flush(body);
let GlobalLookup { kind, entry_arg } = lookup;
let GlobalLookup {
kind,
entry_arg,
mutable,
} = lookup;
let (expr, load) = match kind {
GlobalLookupKind::Variable(v) => {
let res = (
Expand All @@ -263,7 +260,24 @@ impl<'function> Context<'function> {
.expressions
.append(Expression::AccessIndex { base, index });

(expr, true)
(expr, {
let ty = program.module.global_variables[handle].ty;

match program.module.types[ty].inner {
TypeInner::Struct { ref members, .. } => {
if let TypeInner::Array {
size: crate::ArraySize::Dynamic,
..
} = program.module.types[members[index as usize].ty].inner
{
false
} else {
true
}
}
_ => true,
}
})
}
};

Expand All @@ -274,8 +288,7 @@ impl<'function> Context<'function> {
} else {
None
},
// TODO: respect constant qualifier
mutable: true,
mutable,
entry_arg,
};

Expand Down Expand Up @@ -382,7 +395,41 @@ impl<'function> Context<'function> {
None
};

if let Some(current) = self.scopes.last_mut() {
if mutable && load.is_none() {
let handle = self.locals.append(LocalVariable {
name: Some(name.clone()),
ty,
init: None,
});
let local_expr = self.add_expression(Expression::LocalVariable(handle), body);

self.emit_flush(body);
self.emit_start();

body.push(Statement::Store {
pointer: local_expr,
value: expr,
});

let local_load = self.add_expression(
Expression::Load {
pointer: local_expr,
},
body,
);

if let Some(current) = self.scopes.last_mut() {
(*current).insert(
name,
VariableReference {
expr: local_expr,
load: Some(local_load),
mutable,
entry_arg: None,
},
);
}
} else if let Some(current) = self.scopes.last_mut() {
(*current).insert(
name,
VariableReference {
Expand Down Expand Up @@ -441,7 +488,18 @@ impl<'function> Context<'function> {
let base = self.lower_expect(program, base, lhs, body)?.0;
let index = self.lower_expect(program, index, false, body)?.0;

self.add_expression(Expression::Access { base, index }, body)
let pointer = self.add_expression(Expression::Access { base, index }, body);

if let TypeInner::Pointer { .. } = *program.resolve_type(self, pointer, meta)? {
if !lhs {
return Ok((
Some(self.add_expression(Expression::Load { pointer }, body)),
meta,
));
}
}

pointer
}
HirExprKind::Select { base, field } => {
let base = self.lower_expect(program, base, lhs, body)?.0;
Expand Down Expand Up @@ -574,6 +632,95 @@ impl<'function> Context<'function> {

value
}
HirExprKind::IncDec {
increment,
postfix,
expr,
} => {
let op = match increment {
true => BinaryOperator::Add,
false => BinaryOperator::Subtract,
};

let pointer = self.lower_expect(program, expr, true, body)?.0;
let left = self.add_expression(Expression::Load { pointer }, body);

let uint = if let Some(kind) = program.resolve_type(self, left, meta)?.scalar_kind()
{
match kind {
ScalarKind::Sint => false,
ScalarKind::Uint => true,
_ => {
return Err(ErrorKind::SemanticError(
meta,
"Increment/decrement operations must operate in integers".into(),
))
}
}
} else {
return Err(ErrorKind::SemanticError(
meta,
"Increment/decrement operations must operate in integers".into(),
));
};

let one = program.module.constants.append(Constant {
name: None,
specialization: None,
inner: crate::ConstantInner::Scalar {
width: 4,
value: match uint {
true => crate::ScalarValue::Uint(1),
false => crate::ScalarValue::Sint(1),
},
},
});
let right = self.add_expression(Expression::Constant(one), body);

let value = self.add_expression(Expression::Binary { op, left, right }, body);

if postfix {
let local = self.locals.append(LocalVariable {
name: None,
ty: program.module.types.fetch_or_append(Type {
name: None,
inner: TypeInner::Scalar {
kind: match uint {
true => ScalarKind::Uint,
false => ScalarKind::Sint,
},
width: 4,
},
}),
init: None,
});

let expr = self.add_expression(Expression::LocalVariable(local), body);
let load = self.add_expression(Expression::Load { pointer: expr }, body);

self.emit_flush(body);
self.emit_start();

body.push(Statement::Store {
pointer: expr,
value: left,
});

self.emit_flush(body);
self.emit_start();

body.push(Statement::Store { pointer, value });

load
} else {
self.emit_flush(body);
self.emit_start();

body.push(Statement::Store { pointer, value });

left
}
}
_ => {
return Err(ErrorKind::SemanticError(
meta,
Expand Down Expand Up @@ -719,6 +866,11 @@ pub enum HirExprKind {
tgt: Handle<HirExpr>,
value: Handle<HirExpr>,
},
IncDec {
increment: bool,
postfix: bool,
expr: Handle<HirExpr>,
},
}

#[derive(Debug)]
Expand All @@ -727,6 +879,7 @@ pub enum TypeQualifier {
Interpolation(Interpolation),
ResourceBinding(ResourceBinding),
Location(u32),
WorkGroupSize(usize, u32),
Sampling(Sampling),
Layout(StructLayout),
EarlyFragmentTests,
Expand Down
43 changes: 39 additions & 4 deletions src/front/glsl/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,37 @@ impl Program<'_> {
body,
)))
}
"mod" => {
if args.len() != 2 {
return Err(ErrorKind::wrong_function_args(name, 2, args.len(), meta));
}

let (mut left, left_meta) = args[0];
let (mut right, right_meta) = args[1];

ctx.binary_implicit_conversion(
self, &mut left, left_meta, &mut right, right_meta,
)?;

let expr = if let Some(ScalarKind::Float) =
self.resolve_type(ctx, args[0].0, args[1].1)?.scalar_kind()
{
Expression::Math {
fun: MathFunction::Modf,
arg: left,
arg1: Some(right),
arg2: None,
}
} else {
Expression::Binary {
op: BinaryOperator::Modulo,
left,
right,
}
};

Ok(Some(ctx.add_expression(expr, body)))
}
"pow" | "dot" | "max" => {
if args.len() != 2 {
return Err(ErrorKind::wrong_function_args(name, 2, args.len(), meta));
Expand Down Expand Up @@ -515,7 +546,7 @@ impl Program<'_> {
.take(callee_len.saturating_sub(caller_len)),
);

for i in 0..callee_len.max(caller_len) {
for i in 0..callee_len.min(caller_len) {
let callee_use = function_arg_use[function.index()][i];
function_arg_use[caller.index()][i] |= callee_use
}
Expand Down Expand Up @@ -623,9 +654,13 @@ impl Program<'_> {
self.module.entry_points.push(EntryPoint {
name,
stage,
// TODO
early_depth_test: None,
workgroup_size: [0; 3],
early_depth_test: Some(crate::EarlyDepthTest { conservative: None })
.filter(|_| self.early_fragment_tests && stage == crate::ShaderStage::Fragment),
workgroup_size: if let crate::ShaderStage::Compute = stage {
self.workgroup_size
} else {
[0; 3]
},
function: Function {
arguments,
expressions,
Expand Down
1 change: 1 addition & 0 deletions src/front/glsl/lex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ impl<'a> Iterator for Lexer<'a> {
"in" => TokenValue::In,
"out" => TokenValue::Out,
"uniform" => TokenValue::Uniform,
"buffer" => TokenValue::Buffer,
"flat" => TokenValue::Interpolation(crate::Interpolation::Flat),
"noperspective" => TokenValue::Interpolation(crate::Interpolation::Linear),
"smooth" => TokenValue::Interpolation(crate::Interpolation::Perspective),
Expand Down
Loading