Skip to content

Commit

Permalink
Pure rewrite improvements (#353)
Browse files Browse the repository at this point in the history
  • Loading branch information
eldritchconundrum authored Apr 21, 2024
1 parent 08902cc commit 0828467
Show file tree
Hide file tree
Showing 9 changed files with 95 additions and 32 deletions.
49 changes: 34 additions & 15 deletions src/rewriter.fs
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,21 @@ let renameField field =
field |> String.map (fun c -> options.canonicalFieldNames.[swizzleIndex c])
else field

let rec private isPure = function
| Var v when v.Name = "true" || v.Name = "false" -> true
let private commaSeparatedExprs = List.reduce (fun a b -> FunCall(Op ",", [a; b]))

let rec private sideEffects = function
| Var _ -> []
| Int _
| Float _ -> true
| FunCall(Var fct, args) ->
Builtin.pureBuiltinFunctions.Contains fct.Name && List.forall isPure args
| FunCall(Op op, args) -> not (Builtin.assignOps.Contains op) && List.forall isPure args
| _ -> false
| Float _ -> []
| Dot(v, _) -> sideEffects v
| Subscript(e1, e2) -> (e1 :: (Option.toList e2)) |> List.collect sideEffects
| FunCall(Var fct, args) when Builtin.pureBuiltinFunctions.Contains(fct.Name) -> args |> List.collect sideEffects
| FunCall(Op op, args) when not(Builtin.assignOps.Contains(op)) -> args |> List.collect sideEffects
| FunCall(Dot(d, field) as e, args) when field = "length" -> (e :: args) |> List.collect sideEffects
| FunCall(Subscript _ as e, args) -> (e :: args) |> List.collect sideEffects
| e -> [e]

let rec private isPure e = sideEffects e = []

module private RewriterImpl =

Expand Down Expand Up @@ -465,15 +472,20 @@ module private RewriterImpl =
// Compact a pure declaration immediately followed by re-assignment: float m=14.;m=58.; -> float m=58.;
| Decl (ty, [declElt]), (Expr (FunCall (Op "=", [Var name2; init2])) as assign2)
when declElt.name.Name = name2.Name
&& not (exprUsesIdentName init2 declElt.name.Name)
&& declElt.init |> Option.map isPure |> Option.defaultValue true ->
Some [Decl (ty, [{declElt with init = Some init2}])]
&& not (exprUsesIdentName init2 declElt.name.Name) ->
match declElt.init |> Option.map sideEffects |> Option.defaultValue [] with
| [] -> Some [Decl (ty, [{declElt with init = Some init2}])]
| es -> Some [Decl (ty, [{declElt with init = Some (commaSeparatedExprs (es @ [init2]))}])]
| _ -> None)

// Remove pure expression statements.
let b = b |> List.filter (function
| Expr e when isPure e -> false
| _ -> true)
// Reduces impure expression statements to their side effects.
let b = b |> List.collect (function
| Expr e ->
match sideEffects e with
| [] -> [] // Remove pure statements.
| [e] -> [Expr e]
| sideEffects -> [Expr (commaSeparatedExprs sideEffects)]
| s -> [s])

// Inline inner decl-less blocks. (Presence of decl could lead to redefinitions.) a();{b();}c(); -> a();b();c();
let b = b |> List.collect (function
Expand Down Expand Up @@ -620,7 +632,14 @@ let reorderFunctions code =
// Inline the argument of a function call into the function body.
module private ArgumentInlining =

let isInlinableExpr e = isPure e
let rec isInlinableExpr = function
// This is different that purity: reading a variable is pure, but non-inlinable in general.
| Var v when v.Name = "true" || v.Name = "false" -> true
| Int _
| Float _ -> true
| FunCall(Var fct, args) -> Builtin.pureBuiltinFunctions.Contains fct.Name && List.forall isInlinableExpr args
| FunCall(Op op, args) -> not (Builtin.assignOps.Contains op) && List.forall isInlinableExpr args
| _ -> false

type [<NoComparison>] Inlining = {
func: TopLevel
Expand Down
3 changes: 0 additions & 3 deletions tests/real/from-the-seas-to-the-stars.frag.expected
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,6 @@ bool jellyfish()
w.xy*=5.,w.z=pow(w.z,5.),p.xz=(disc(fract(w.xy))*.01/(1.01-w.z)+floor(w.xy)/2.-1.)*.25,p.y=w.z*2.-1.5+(3.-length(fract(w.xy)-.5))*.3-.8-(.5+.5*cos(floor(w.xy).x))*.2,p.y*=2,p.xz*=1.5,p.xyz+=of,p.x+=(fbm(p.yz/4.*vec2(1,6)-time2/10)-1.2)*pow(1.-w.z,.75),p.z+=(fbm(p.yx/4.*vec2(1,6)+time2/10)-1.2)*pow(1.-w.z,.75);
p-=of;
p.yz*=rotmat(cos(of.x+p.y/2)*.2);
p*.8+.4*sin(of.x*2+of.z*10);
p+=of;
p*=.2;
return false;
Expand Down Expand Up @@ -424,7 +423,6 @@ void main()
w.y=R();
w.z=R();
w.w=R();
w.w<.5;
}
}
p.xyz+=(w.xyz*2-1)*pow(length(p),2)*.001;
Expand Down Expand Up @@ -476,7 +474,6 @@ void main()
w.y=R();
w.z=R();
w.w=R();
w.w<.5;
}
}
p.xyz+=(w.xyz*2-1)*pow(length(p),2)*.001;
Expand Down
2 changes: 1 addition & 1 deletion tests/real/terrarium.frag.expected
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ void main()
else if(r<1.07)
col=mix(col,0.,.2),p.xy=p.xy+vec2(-.5,-.866025),p=S(p);
else
mix(col,0.,.2),col=1.,p.xy=p.xy+vec2(-.5,.866025),p=S(p),p.z+=.1;
col=1.,p.xy=p.xy+vec2(-.5,.866025),p=S(p),p.z+=.1;
if(time>48.)
p.z=mix(p.z,mod(p.z+time/2.+2,4.)-2,smoothstep(48.,49.,time));
}
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/array.frag.expected
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ const char *array_frag =
"void main()"
"{"
"float[2] test=float[](5.,7.);"
"func_bank(test)[0];"
"func_bank(test)[1];"
"func_bank(test);"
"func_bank(test);"
"}";

#endif // ARRAY_FRAG_EXPECTED_
3 changes: 0 additions & 3 deletions tests/unit/blocks.expected
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ const char *blocks_frag =
"}"
"int test_block()"
"{"
"k==0;"
"for(int i=0;i<2;i++)"
"{"
"if(k==1)"
Expand Down Expand Up @@ -69,8 +68,6 @@ const char *blocks_frag =
"removeUselessElseAfterReturn2(0.);"
"replaceIfReturnsByReturnTernary1(0.);"
"gl_FragColor=vec4(.2,a,b,0);"
"a<b;"
"a<b;"
"}";

#endif // BLOCKS_EXPECTED_
4 changes: 2 additions & 2 deletions tests/unit/forward_declaration.frag
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,6 @@ void e() {
}

int x;
void c() {x;}
void c() {x++;}

void d() {x;}
void d() {x++;}
4 changes: 2 additions & 2 deletions tests/unit/forward_declaration.frag.expected
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
int x;
void c()
{
x;
x++;
}
void d()
{
x;
x++;
}
void b()
{
Expand Down
25 changes: 23 additions & 2 deletions tests/unit/simplify.expected
Original file line number Diff line number Diff line change
@@ -1,7 +1,28 @@
#version 330

float bar(float x)
{
float a=6.;
a+=x;
float b=x;
b+=x;
b=34.;
b+=x;
float arr[2]=float[2](7.,8.);
arr=float[2](5.,float(arr.length()));
float m=a*=10,a*=20,a*=30,b++,i++,58.;
return a+b+m;
}
float baz(float a)
{
float b=a+4.;
b+=sin(a);
float c=b+5.;
c+=sin(b);
return-c+-c+c;
}
out vec3 output;
void main()
void notMain(float x)
{
output.xyz=vec3(92);
output.xyz=vec3(92)+vec3(bar(x)+baz(x));
}
33 changes: 31 additions & 2 deletions tests/unit/simplify.frag
Original file line number Diff line number Diff line change
@@ -1,13 +1,42 @@
#version 330

float bar(float x)
{
float a = x;
a = 6.;
a = a + x;

float b = x;
b = b + x;
b = 34.;
b = b + x;

float arr[2] = float[2](7.,8.);
arr = float[2](5.,float(arr.length()));

float m = 14.+length(vec3(sin(a *= 10), sin(a *= 20), sin(a *= 30)))+b++;
m = arr[1] + arr[i++];
m = 58.;

return a + b + m;
}
float baz(float a)
{
float b = a + 4.;
b += sin(a);
float c = b + 5.;
c += sin(b);
return c + (-(c - -c));
}

int n = 2;
float y = 47.;
out vec3 output;

float foo(float a) { if (n == 1) return 0.; if (n == 2) return a; return 1.; }
float foo(float a, float b) { if (y > a) return 0.; if (y < b) return a; return 1.; }

void main() {
void notMain(float x) {
vec3 v = vec3(foo(42.) + foo(50., 70.));
output.rgb = v;
output.rgb = v + vec3(bar(x) + baz(x));
}

0 comments on commit 0828467

Please sign in to comment.