Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Advanced optimization #656

Closed
18 tasks done
xumingkuan opened this issue Mar 25, 2020 · 64 comments
Closed
18 tasks done

Advanced optimization #656

xumingkuan opened this issue Mar 25, 2020 · 64 comments
Assignees
Labels
feature request Suggest an idea on this project

Comments

@xumingkuan
Copy link
Collaborator

xumingkuan commented Mar 25, 2020

Concisely describe the proposed feature
With new extensions introduced by #581, there are lots of space to optimize the IR. I also found some feasible optimizations that are not directly related to the new extension. For example, in this fragment of IR,

...
<f32 x1> $5 = alloca
if $26 {
  ...
} else {
  ...
}
if $26 {
  ...
} else {
  ...
}
<f32 x1> $83 = local load [ [$5[0]]] (the only statement about $5)
...

we could merge the two if's together, change $83 to const [0], and then delete $5.

A list of optimizations I have done and going to do:

Additional comments
For benchmarking, we may want to introduce a temporary boolean variable as the switch of optimization.

Some nice slides: https://courses.cs.washington.edu/courses/cse401/08wi/lecture/opt-mark.v2.pdf

@xumingkuan xumingkuan added the feature request Suggest an idea on this project label Mar 25, 2020
@xumingkuan
Copy link
Collaborator Author

@yuanming-hu please assign me. It seems that I can't assign myself...

@yuanming-hu
Copy link
Member

Awesome!! This is vitally important for improving run-time performance & reducing compilation time. Thanks for taking charge of this.

@archibate
Copy link
Collaborator

archibate commented Mar 26, 2020

Merge adjacent if's with identical condition

What if these if's contains statements with side-effect like x = x + 1? eg.

if (cond) x++;
if (cond) x++;

We want to obtain:

if (cond) { x++; x++; }

and the duplicated x++ can be dealt in other lower passes.

Merge identical local loads if no statements between them modify the variable even if there are if's

What if the two local load is in different blocks? eg.

if (cond) {
print 'yes';
x = local load 233;
} else {
print 'no';
x = local load 233;
}

What if a statement is shown once in IR, but ran for multiple times, should we optimize it? eg.

while (cond) {
x = local load 233
... (no changes stored to 233)
}

We may move this out the while.

First add a analysis pass to detect if a block stored an address.

@xumingkuan
Copy link
Collaborator Author

Merge adjacent if's with identical condition

What if these if's contains statements with side-effect like x = x + 1? eg.

if (cond) x++;
if (cond) x++;

We want to obtain:

if (cond) { x++; x++; }

and the duplicated x++ can be dealt in other lower passes.

Exactly.

Merge identical local loads if no statements between them modify the variable even if there are if's

What if the two local load is in different blocks? eg.

if (cond) {
print 'yes';
x = local load 233;
} else {
print 'no';
x = local load 233;
}

This is non-trivial. We could analyze the common code fragment of true-branch and the false-branch, and put them outside the if, but I don't know if it would make a great difference.

What if a statement is shown once in IR, but ran for multiple times, should we optimize it? eg.

while (cond) {
x = local load 233
... (no changes stored to 233)
}

We may move this out the while.

If cond is false, does moving it out have side effects?

First add a analysis pass to detect if a block stored an address.

To merge identical local loads if no statements between them modify the variable, this is not necessary: I think directly searching for modifications when we find a local load fits the code frame better. Maybe we can add this pass later if necessary.

@archibate
Copy link
Collaborator

If cond is false, does moving it out have side effects?

No, it's just load and never used, will be opt-out by other lower passes.

@archibate
Copy link
Collaborator

How about first make:

if (cond) {
print 'yes';
x = local load 233;
} else {
print 'no';
x = local load 233;
}

to become:

if (cond) print 'yes'; else print 'no';
if (cond) xxx; else xxx;

since cond is aconstant IR value, and the second can be safely opt-out.

@xumingkuan
Copy link
Collaborator Author

How about first make:

if (cond) {
print 'yes';
x = local load 233;
} else {
print 'no';
x = local load 233;
}

to become:

if (cond) print 'yes'; else print 'no';
if (cond) xxx; else xxx;

since cond is aconstant IR value, and the second can be safely opt-out.

I just thought about a situation:

if (cond) {
  print 'yes';
  x = local load 233;
  print 'yes';
} else {
  print 'no';
  x = local load 233;
  print 'no';
}

I can't tell if the following is more efficient than the above:

if (cond) print 'yes'; else print 'no';
x = local load 233;
if (cond) print 'yes'; else print 'no';

(especially when the common code fragment is relatively short than the others)

We can restrict this optimization to only the first statement and the last statement of the body of if.

@xumingkuan
Copy link
Collaborator Author

xumingkuan commented Mar 27, 2020

@yuanming-hu What do

Stmt *true_mask, *false_mask;
mean?

May I just ignore them when merging two adjacent if's?

@yuanming-hu
Copy link
Member

Quick answer for now: yes. I'll document this in greater detail later. You don't have to worry about that until we start doing vectorization.

@xumingkuan
Copy link
Collaborator Author

I just found a piece of IR:

<i32 x1> $8 = const [0]
...
if $19 {
  ...
  <i32 x1> $25 = const [0]
  ...
} else {
  ...
  <i32 x1> $40 = const [0]
  ...
}

I think we could optimize them all to $8. Currently void visit(ConstStmt*) searches statements before the current statement, and so $25 cannot find $8 as they are not in a basic block.

There are two ways to do this optimization:

  1. Search statements after the current statement (say $8) instead, and dive into container statements to replace them with $8.
  2. Search statements before the current statement (say $25), and do this recursively for parent blocks.

Which do you think is better?

@yuanming-hu
Copy link
Member

yuanming-hu commented Mar 27, 2020

I think 2 is better. At compile time it's hard to judge whether $25 or $40 will be after $8, but it's sure that $8 is before $25 and $40.

@xumingkuan
Copy link
Collaborator Author

Shall this pass (identical ConstStmt elimination) be still in BasicBlockSimplify? It won't be in one basic block, so maybe I should implement it in Simplify?

@yuanming-hu
Copy link
Member

Let's add a WholeKernelCSE (common subexpression elimination) pass then.

@xumingkuan
Copy link
Collaborator Author

For checking if the first statements (which can be container statements) in both branches of if are exactly the same, shall we add a function like bool same_statements(IRNode *root1, IRNode *root2) in ir.h and implement it using visitors in taichi/analysis/?

@yuanming-hu
Copy link
Member

Very good question. I need to think about this a little bit. One very important IR functionality is to test if two IRNodes are equivalent. IRNode can be not only one statement but also a hierarchy. We might need to use some hashing here.

@yuanming-hu
Copy link
Member

yuanming-hu commented Mar 29, 2020

A few things to think about here

  • We have to support not only a single statement but also a container with multiple statements.
  • There are many statements to support, each with special fields. We do have a common std::vector<Stmt **> Stmt::operands that keeps tracks of all operands of a statement in a unified manner, but the special fields (e.g. BinaryOpType BinaryOpStmt::op_type)
  • We don't have to worry about Expressions since they only live in the frontend.
  • Binary DNA
  • (Advanced) Reject fast.

@xumingkuan
Copy link
Collaborator Author

There are 3 kinds of solutions I thought about. Denote the number of statements in the container IRNode we want to test by $n$ (if it's not a container, then n=1).

  1. Do nothing more when modifying statements. Then it takes O(n) time to find two IRNode's are the same, and O(n) time in the worst case to find two IRNode's are different. I think in most cases, we can find two IRNode's are different in O(1).
  2. Spend O(depth) more time when modifying statements, where "depth" means the number of container statements directly or indirectly containing the modified statement. We can update Binary DNA's and the hash of it in O(1) for each container statement. (Note that if we only set a boolean variable to tell if the container statement is modified, it still takes O(1) for each container statement!) So we can find two IRNode's are different in O(1) in expectation, but we still need O(n) time to find two IRNode's are the same ---- Binary DNAs' length is Ω(n).
  3. Spend O(depth * log(n)) more time when modifying statements. Then we can find two IRNode's are the same in O(log(n)) with some fancy data structures.

To me, I prefer the 1st solution. I think it unacceptable to spend O(depth) more time whenever modifying statements, just to avoid the worst-case O(n) time finding if two IRNode's are different: we modify statements far more often than checking if two IRNode's are equivalent.

If there is a stage that statements don't change anymore, we can build data structures for comparing IRNode's then.

@yuanming-hu
Copy link
Member

Thanks for the detailed analysis. I agree with your decision and we should probably go with the 1st solution.

Meanwhile, a very easy-to-implement (and slightly hacky) way to test if two statements are equivalent:

  • First to a re_id pass to minimize the statement indices
  • Then use print_ir to convert the statements to an std::string
  • Then compare if the two strings are equal

This should work for most cases (assuming the print_ir pass is doing a correct job) and can probably be implemented within 20 LoC.

@xumingkuan
Copy link
Collaborator Author

Thanks for the hacky way, but I want to implement a reject-fast solution. I think most of the queries will be of different IRNode's.

@xumingkuan
Copy link
Collaborator Author

Maybe I can implement a visitor to visit one of the IRNode's, while storing the corresponding IRNode in the visitor class?

@yuanming-hu
Copy link
Member

Sounds good. I champion your decision :-)

Maybe I can implement a visitor to visit one of the IRNode's, while storing the corresponding IRNode in the visitor class?

Right, you have to use one IRNode to guide the other.

@xumingkuan
Copy link
Collaborator Author

I wonder if this IR is valid:

<f32 x1> $238 = alloca
<f32 x1> $197 = alloca
<f32 x1> $239 : local store [$238 <- $197]
<f32 x1> $199 = ...
<f32 x1> $200 : local store [$197 <- $199]
<f32 x1> $242 = local load [ [$238[0]]]
<f32 x1> $218 = local load [ [$242[0]]]

It causes simplify.cpp to crash because the alloca here

auto alloca = stmt->ptr[0].var;

is not an AllocaStmt when we are visiting $218.

@yuanming-hu
Copy link
Member

Good question. LocalLoad must take Allocas as inputs. $218 is invalid.

@xumingkuan
Copy link
Collaborator Author

@yuanming-hu I found an issue when doing CSE for global pointers:
Case: test_ad_if_parallel_complex
Before (good):

[I 06/30/20 20:38:44.108] [compile_to_offloads.cpp:taichi::lang::irpass::co
mpile_to_offloads::<lambda_a4464fe7c75e1f42a3a490ee54c7ec3e>::operator ()@2
3] Simplified I:
kernel {
  <f32 x1> $0 = const [1.0]
  <i32 x1> $1 = const [1]
  <i32 x1> $2 = const [0]
  <f32 x1> $3 = const [0.0]
  <i32 x1> $4 = const [2]
  $5 : for in range($2, $4) (vectorize 1) block_dim=adaptive {
    <i32 x1> $6 = loop $5 index 0
    <f32 x1> $7 = alloca
    <f32 x1> $8 : local store [$7 <- $3]
    <f32*x1> $9 = global ptr [S2place_f32], index [$6] activate=true
    <f32 x1> $10 = global load $9
    <i32 x1> $11 = cmp_gt $10 $3
    <i32 x1> $12 = bit_and $11 $1
    $13 : if $12 {
      <f32*x1> $14 = global ptr [S2place_f32], index [$6] activate=true
      <f32 x1> $15 = global load $14
      <f32 x1> $16 = div $0 $15
      <f32 x1> $17 : local store [$7 <- $16]
    }
    <f32 x1> $18 = local load [ [$7[0]]]
    <f32*x1> $19 = global ptr [S4place_f32], index [$6] activate=true
    <f32*x1> $20 : global store [$19 <- $18]
  }
}
[I 06/30/20 20:38:44.110] [compile_to_offloads.cpp:taichi::lang::irpass::co
mpile_to_offloads::<lambda_a4464fe7c75e1f42a3a490ee54c7ec3e>::operator ()@2
3] Gradient:
kernel {
  <f32 x1> $0 = const [1.0]
  <i32 x1> $1 = const [1]
  <i32 x1> $2 = const [0]
  <f32 x1> $3 = const [0.0]
  <i32 x1> $4 = const [2]
  $5 : for in range($2, $4) (vectorize 1) block_dim=adaptive {
    <f32 x1> $6 = alloca
    <f32 x1> $7 = alloca
    <f32 x1> $8 = alloca
    <f32 x1> $9 = alloca
    <i32 x1> $10 = loop $5 index 0
    <f32 x1> $11 = stack alloc (max_size=16)
    <f32 x1> $12 : stack push $11, val = $3
    <f32*x1> $13 = global ptr [S2place_f32], index [$10] activate=true
    <f32 x1> $14 = global load $13
    <i32 x1> $15 = cmp_gt $14 $3
    <i32 x1> $16 = bit_and $15 $1
    $17 : if $16 {
      <f32*x1> $18 = global ptr [S2place_f32], index [$10] activate=true
      <f32 x1> $19 = global load $18
      <f32 x1> $20 : local store [$6 <- $19]
      <f32 x1> $21 = div $0 $19
      <f32 x1> $22 : stack push $11, val = $21
    }
    <f32*x1> $23 = global ptr [S4place_f32], index [$10] activate=true
    <f32*x1> $24 = global ptr [S6place_f32], index [$10] activate=true
    <f32 x1> $25 = global load $24
    <f32 x1> $26 : stack acc adj $11, val = $25
    $27 : if $16 {
      <f32 x1> $28 = stack load top adj $11
      <f32 x1> $29 = local load [ [$9[0]]]
      <f32 x1> $30 = add $29 $28
      <f32 x1> $31 : local store [$9 <- $30]
      <f32 x1> $32 : stack pop $11
      <f32 x1> $33 = local load [ [$6[0]]]
      <f32 x1> $34 = div $30 $33
      <f32 x1> $35 = local load [ [$8[0]]]
      <f32 x1> $36 = add $35 $34
      <f32 x1> $37 : local store [$8 <- $36]
      <f32 x1> $38 = mul $33 $33
      <f32 x1> $39 = div $30 $38
      <f32 x1> $40 = neg $39
      <f32 x1> $41 = local load [ [$7[0]]]
      <f32 x1> $42 = add $41 $40
      <f32 x1> $43 : local store [$7 <- $42]
      <f32*x1> $44 = global ptr [S5place_f32], index [$10] activate=true
      <f32 x1> $45 = atomic add($44, $42)
    }
    <f32*x1> $46 = global ptr [S5place_f32], index [$10] activate=true
    <f32 x1> $47 = atomic add($46, $3)
    <f32 x1> $48 : stack pop $11
  }
}

After(bad, with some debug output in full_simplify()):

[I 06/30/20 20:43:33.360] [compile_to_offloads.cpp:taichi::lang::irpass::co
mpile_to_offloads::<lambda_a4464fe7c75e1f42a3a490ee54c7ec3e>::operator ()@2
3] Simplified I:
kernel {
  <f32 x1> $0 = const [1.0]
  <i32 x1> $1 = const [1]
  <i32 x1> $2 = const [0]
  <f32 x1> $3 = const [0.0]
  <i32 x1> $4 = const [2]
  $5 : for in range($2, $4) (vectorize 1) block_dim=adaptive {
    <i32 x1> $6 = loop $5 index 0
    <f32*x1> $7 = global ptr [S2place_f32], index [$6] activate=true
    <f32 x1> $8 = global load $7
    <i32 x1> $9 = cmp_gt $8 $3
    <i32 x1> $10 = bit_and $9 $1
    <f32 x1> $11 = div $0 $8
    <f32 x1> $12 = select($10, $11, $3)
    <f32*x1> $13 = global ptr [S4place_f32], index [$6] activate=true
    <f32*x1> $14 : global store [$13 <- $12]
  }
}
before simplify
kernel {
  <f32 x1> $0 = const [1.0]
  <i32 x1> $1 = const [1]
  <i32 x1> $2 = const [0]
  <f32 x1> $3 = const [0.0]
  <i32 x1> $4 = const [2]
  $5 : for in range($2, $4) (vectorize 1) block_dim=adaptive {
    <i32 x1> $6 = loop $5 index 0
    <f32*x1> $7 = global ptr [S2place_f32], index [$6] activate=true
    <f32 x1> $8 = global load $7
    <i32 x1> $9 = cmp_gt $8 $3
    <i32 x1> $10 = bit_and $9 $1
    <f32 x1> $11 = div $0 $8
    <f32 x1> $12 = select($10, $11, $3)
    <f32*x1> $13 = global ptr [S4place_f32], index [$6] activate=true
    <f32*x1> $14 : global store [$13 <- $12]
  }
}
after simplify
kernel {
  <f32 x1> $0 = const [1.0]
  <i32 x1> $1 = const [1]
  <i32 x1> $2 = const [0]
  <f32 x1> $3 = const [0.0]
  <i32 x1> $4 = const [2]
  $5 : for in range($2, $4) (vectorize 1) block_dim=adaptive {
    <i32 x1> $6 = loop $5 index 0
    <f32*x1> $7 = global ptr [S2place_f32], index [$6] activate=true
    <f32 x1> $8 = global load $7
    <i32 x1> $9 = cmp_gt $8 $3
    <i32 x1> $10 = bit_and $9 $1
    <f32 x1> $11 = div $0 $8
    <f32 x1> $12 = select($10, $11, $3)
    <f32*x1> $13 = global ptr [S4place_f32], index [$6] activate=true
    <f32*x1> $14 : global store [$13 <- $12]
  }
}
after cse
kernel {
  <f32 x1> $0 = const [1.0]
  <i32 x1> $1 = const [1]
  <i32 x1> $2 = const [0]
  <f32 x1> $3 = const [0.0]
  <i32 x1> $4 = const [2]
  $5 : for in range($2, $4) (vectorize 1) block_dim=adaptive {
    <i32 x1> $6 = loop $5 index 0
    <f32*x1> $7 = global ptr [S2place_f32], index [$6] activate=true
    <f32 x1> $8 = global load $7
    <i32 x1> $9 = cmp_gt $8 $3
    <i32 x1> $10 = bit_and $9 $1
    <f32 x1> $11 = div $0 $8
    <f32 x1> $12 = select($10, $11, $3)
    <f32*x1> $13 = global ptr [S4place_f32], index [$6] activate=true
    <f32*x1> $14 : global store [$13 <- $12]
  }
}
before simplify
kernel {
  <f32 x1> $205 = const [0.0]
  <i32 x1> $1 = const [1]
  <i32 x1> $2 = const [0]
  <f32 x1> $3 = const [0.0]
  <i32 x1> $4 = const [2]
  $5 : for in range($2, $4) (vectorize 1) block_dim=adaptive {
    <f32 x1> $229 = alloca
    <f32 x1> $220 = alloca
    <f32 x1> $214 = alloca
    <f32 x1> $208 = alloca
    <f32 x1> $201 = alloca
    <i32 x1> $6 = loop $5 index 0
    <f32*x1> $7 = global ptr [S2place_f32], index [$6] activate=true
    <f32 x1> $8 = global load $7
    <i32 x1> $9 = cmp_gt $8 $3
    <i32 x1> $10 = bit_and $9 $1
    <f32*x1> $13 = global ptr [S4place_f32], index [$6] activate=true
    <f32*x1> $199 = global ptr [S6place_f32], index [$6] activate=true
    <f32 x1> $200 = global load $199
    <f32 x1> $202 = local load [ [$201[0]]]
    <f32 x1> $203 = add $202 $200
    <f32 x1> $204 : local store [$201 <- $203]
    <f32 x1> $206 = local load [ [$201[0]]]
    <f32 x1> $207 = select($10, $206, $205)
    <f32 x1> $209 = local load [ [$208[0]]]
    <f32 x1> $210 = add $209 $207
    <f32 x1> $211 : local store [$208 <- $210]
    <f32 x1> $212 = local load [ [$201[0]]]
    <f32 x1> $213 = select($10, $205, $212)
    <f32 x1> $215 = local load [ [$214[0]]]
    <f32 x1> $216 = add $215 $213
    <f32 x1> $217 : local store [$214 <- $216]
    <f32 x1> $218 = local load [ [$208[0]]]
    <f32 x1> $219 = div $218 $8
    <f32 x1> $221 = local load [ [$220[0]]]
    <f32 x1> $222 = add $221 $219
    <f32 x1> $223 : local store [$220 <- $222]
    <f32 x1> $224 = mul $8 $8
    <f32 x1> $225 = local load [ [$208[0]]]
    <f32 x1> $227 = div $225 $224
    <f32 x1> $228 = neg $227
    <f32 x1> $230 = local load [ [$229[0]]]
    <f32 x1> $231 = add $230 $228
    <f32 x1> $232 : local store [$229 <- $231]
    <f32*x1> $233 = global ptr [S5place_f32], index [$6] activate=true
    <f32 x1> $234 = local load [ [$229[0]]]
    <f32 x1> $235 = atomic add($233, $234)
  }
}
after simplify
kernel {
  <f32 x1> $205 = const [0.0]
  <i32 x1> $1 = const [1]
  <i32 x1> $2 = const [0]
  <f32 x1> $3 = const [0.0]
  <i32 x1> $4 = const [2]
  $5 : for in range($2, $4) (vectorize 1) block_dim=adaptive {
    <i32 x1> $6 = loop $5 index 0
    <f32*x1> $7 = global ptr [S2place_f32], index [$6] activate=true
    <f32 x1> $8 = global load $7
    <i32 x1> $9 = cmp_gt $8 $3
    <i32 x1> $10 = bit_and $9 $1
    <f32*x1> $13 = global ptr [S4place_f32], index [$6] activate=true
    <f32*x1> $199 = global ptr [S6place_f32], index [$6] activate=true
    <f32 x1> $200 = global load $199
    <f32 x1> $236 = const [0.0]
    <f32 x1> $203 = add $236 $200
    <f32 x1> $207 = select($10, $203, $205)
    <f32 x1> $237 = const [0.0]
    <f32 x1> $210 = add $237 $207
    <f32 x1> $224 = mul $8 $8
    <f32 x1> $227 = div $210 $224
    <f32 x1> $228 = neg $227
    <f32 x1> $240 = const [0.0]
    <f32 x1> $231 = add $240 $228
    <f32*x1> $233 = global ptr [S5place_f32], index [$6] activate=true
    <f32 x1> $235 = atomic add($233, $231)
  }
}
after cse
kernel {
  <f32 x1> $205 = const [0.0]
  <i32 x1> $1 = const [1]
  <i32 x1> $2 = const [0]
  <i32 x1> $4 = const [2]
  $5 : for in range($2, $4) (vectorize 1) block_dim=adaptive {
    <i32 x1> $6 = loop $5 index 0
    <f32*x1> $7 = global ptr [S2place_f32], index [$6] activate=true
    <f32 x1> $8 = global load $7
    <i32 x1> $9 = cmp_gt $8 $205
    <i32 x1> $10 = bit_and $9 $1
    <f32*x1> $13 = global ptr [S4place_f32], index [$6] activate=true
    <f32*x1> $199 = global ptr [S6place_f32], index [$6] activate=true
    <f32 x1> $200 = global load $199
    <f32 x1> $203 = add $205 $200
    <f32 x1> $207 = select($10, $203, $205)
    <f32 x1> $210 = add $205 $207
    <f32 x1> $224 = mul $8 $8
    <f32 x1> $227 = div $210 $224
    <f32 x1> $228 = neg $227
    <f32 x1> $231 = add $205 $228
    <f32*x1> $233 = global ptr [S5place_f32], index [$6] activate=true
    <f32 x1> $235 = atomic add($233, $231)
  }
}
before simplify
kernel {
  <f32 x1> $205 = const [0.0]
  <i32 x1> $1 = const [1]
  <i32 x1> $2 = const [0]
  <i32 x1> $4 = const [2]
  $5 : for in range($2, $4) (vectorize 1) block_dim=adaptive {
    <i32 x1> $6 = loop $5 index 0
    <f32*x1> $7 = global ptr [S2place_f32], index [$6] activate=true
    <f32 x1> $8 = global load $7
    <i32 x1> $9 = cmp_gt $8 $205
    <i32 x1> $10 = bit_and $9 $1
    <f32*x1> $13 = global ptr [S4place_f32], index [$6] activate=true
    <f32*x1> $199 = global ptr [S6place_f32], index [$6] activate=true
    <f32 x1> $200 = global load $199
    <f32 x1> $207 = select($10, $200, $205)
    <f32 x1> $224 = mul $8 $8
    <f32 x1> $227 = div $207 $224
    <f32 x1> $228 = neg $227
    <f32*x1> $233 = global ptr [S5place_f32], index [$6] activate=true
    <f32 x1> $235 = atomic add($233, $228)
  }
}
after simplify
kernel {
  <f32 x1> $205 = const [0.0]
  <i32 x1> $1 = const [1]
  <i32 x1> $2 = const [0]
  <i32 x1> $4 = const [2]
  $5 : for in range($2, $4) (vectorize 1) block_dim=adaptive {
    <i32 x1> $6 = loop $5 index 0
    <f32*x1> $7 = global ptr [S2place_f32], index [$6] activate=true
    <f32 x1> $8 = global load $7
    <i32 x1> $9 = cmp_gt $8 $205
    <i32 x1> $10 = bit_and $9 $1
    <f32*x1> $13 = global ptr [S4place_f32], index [$6] activate=true
    <f32*x1> $199 = global ptr [S6place_f32], index [$6] activate=true
    <f32 x1> $200 = global load $199
    <f32 x1> $207 = select($10, $200, $205)
    <f32 x1> $224 = mul $8 $8
    <f32 x1> $227 = div $207 $224
    <f32 x1> $228 = neg $227
    <f32*x1> $233 = global ptr [S5place_f32], index [$6] activate=true
    <f32 x1> $235 = atomic add($233, $228)
  }
}
after cse
kernel {
  <f32 x1> $205 = const [0.0]
  <i32 x1> $1 = const [1]
  <i32 x1> $2 = const [0]
  <i32 x1> $4 = const [2]
  $5 : for in range($2, $4) (vectorize 1) block_dim=adaptive {
    <i32 x1> $6 = loop $5 index 0
    <f32*x1> $7 = global ptr [S2place_f32], index [$6] activate=true
    <f32 x1> $8 = global load $7
    <i32 x1> $9 = cmp_gt $8 $205
    <i32 x1> $10 = bit_and $9 $1
    <f32*x1> $13 = global ptr [S4place_f32], index [$6] activate=true
    <f32*x1> $199 = global ptr [S6place_f32], index [$6] activate=true
    <f32 x1> $200 = global load $199
    <f32 x1> $207 = select($10, $200, $205)
    <f32 x1> $224 = mul $8 $8
    <f32 x1> $227 = div $207 $224
    <f32 x1> $228 = neg $227
    <f32*x1> $233 = global ptr [S5place_f32], index [$6] activate=true
    <f32 x1> $235 = atomic add($233, $228)
  }
}
before simplify
kernel {
  <f32 x1> $205 = const [0.0]
  <i32 x1> $1 = const [1]
  <i32 x1> $2 = const [0]
  <i32 x1> $4 = const [2]
  $5 : for in range($2, $4) (vectorize 1) block_dim=adaptive {
    <i32 x1> $6 = loop $5 index 0
    <f32*x1> $7 = global ptr [S2place_f32], index [$6] activate=true
    <f32 x1> $8 = global load $7
    <i32 x1> $9 = cmp_gt $8 $205
    <i32 x1> $10 = bit_and $9 $1
    <f32*x1> $13 = global ptr [S4place_f32], index [$6] activate=true
    <f32*x1> $199 = global ptr [S6place_f32], index [$6] activate=true
    <f32 x1> $200 = global load $199
    <f32 x1> $207 = select($10, $200, $205)
    <f32 x1> $224 = mul $8 $8
    <f32 x1> $227 = div $207 $224
    <f32 x1> $228 = neg $227
    <f32*x1> $233 = global ptr [S5place_f32], index [$6] activate=true
    <f32 x1> $235 = atomic add($233, $228)
  }
}
after simplify
kernel {
  <f32 x1> $205 = const [0.0]
  <i32 x1> $1 = const [1]
  <i32 x1> $2 = const [0]
  <i32 x1> $4 = const [2]
  $5 : for in range($2, $4) (vectorize 1) block_dim=adaptive {
    <i32 x1> $6 = loop $5 index 0
    <f32*x1> $7 = global ptr [S2place_f32], index [$6] activate=true
    <f32 x1> $8 = global load $7
    <i32 x1> $9 = cmp_gt $8 $205
    <i32 x1> $10 = bit_and $9 $1
    <f32*x1> $13 = global ptr [S4place_f32], index [$6] activate=true
    <f32*x1> $199 = global ptr [S6place_f32], index [$6] activate=true
    <f32 x1> $200 = global load $199
    <f32 x1> $207 = select($10, $200, $205)
    <f32 x1> $224 = mul $8 $8
    <f32 x1> $227 = div $207 $224
    <f32 x1> $228 = neg $227
    <f32*x1> $233 = global ptr [S5place_f32], index [$6] activate=true
    <f32 x1> $235 = atomic add($233, $228)
  }
}
after cse
kernel {
  <f32 x1> $205 = const [0.0]
  <i32 x1> $1 = const [1]
  <i32 x1> $2 = const [0]
  <i32 x1> $4 = const [2]
  $5 : for in range($2, $4) (vectorize 1) block_dim=adaptive {
    <i32 x1> $6 = loop $5 index 0
    <f32*x1> $7 = global ptr [S2place_f32], index [$6] activate=true
    <f32 x1> $8 = global load $7
    <i32 x1> $9 = cmp_gt $8 $205
    <i32 x1> $10 = bit_and $9 $1
    <f32*x1> $13 = global ptr [S4place_f32], index [$6] activate=true
    <f32*x1> $199 = global ptr [S6place_f32], index [$6] activate=true
    <f32 x1> $200 = global load $199
    <f32 x1> $207 = select($10, $200, $205)
    <f32 x1> $224 = mul $8 $8
    <f32 x1> $227 = div $207 $224
    <f32 x1> $228 = neg $227
    <f32*x1> $233 = global ptr [S5place_f32], index [$6] activate=true
    <f32 x1> $235 = atomic add($233, $228)
  }
}
[I 06/30/20 20:43:33.386] [compile_to_offloads.cpp:taichi::lang::irpass::co
mpile_to_offloads::<lambda_a4464fe7c75e1f42a3a490ee54c7ec3e>::operator ()@2
3] Gradient:
kernel {
  <f32 x1> $0 = const [0.0]
  <i32 x1> $1 = const [1]
  <i32 x1> $2 = const [0]
  <i32 x1> $3 = const [2]
  $4 : for in range($2, $3) (vectorize 1) block_dim=adaptive {
    <i32 x1> $5 = loop $4 index 0
    <f32*x1> $6 = global ptr [S2place_f32], index [$5] activate=true
    <f32 x1> $7 = global load $6
    <i32 x1> $8 = cmp_gt $7 $0
    <i32 x1> $9 = bit_and $8 $1
    <f32*x1> $10 = global ptr [S4place_f32], index [$5] activate=true
    <f32*x1> $11 = global ptr [S6place_f32], index [$5] activate=true
    <f32 x1> $12 = global load $11
    <f32 x1> $13 = select($9, $12, $0)
    <f32 x1> $14 = mul $7 $7
    <f32 x1> $15 = div $13 $14
    <f32 x1> $16 = neg $15
    <f32*x1> $17 = global ptr [S5place_f32], index [$5] activate=true
    <f32 x1> $18 = atomic add($17, $16)
  }
}

I think although the IRs in Simplified I are different, both look pretty good... But after Gradient, the latter becomes wrong.

@xumingkuan
Copy link
Collaborator Author

xumingkuan commented Jul 1, 2020

Final IR:
Good:

[I 06/30/20 20:38:44.166] [compile_to_offloads.cpp:taichi::lang::irpass::co
mpile_to_offloads::<lambda_a4464fe7c75e1f42a3a490ee54c7ec3e>::operator ()@2
3] Simplified III:
kernel {
  $0 = offloaded range_for(0, 2) block_dim=adaptive
  body {
    <f32 x1> $1 = const [1.0]
    <f32 x1> $2 = alloca
    <i32 x1> $3 = loop $0 index 0
    <f32 x1> $4 = stack alloc (max_size=16)
    <f32 x1> $5 = const [0.0]
    <f32 x1> $6 : stack push $4, val = $5
    <gen*x1> $7 = get root
    <i32 x1> $8 = const [0]
    <gen*x1> $9 = [S0root][root]::lookup($7, $8) activate = false
    <gen*x1> $10 = get child [S0root->S1dense] $9
    <i32 x1> $11 = const [1]
    <gen*x1> $12 = [S1dense][dense]::lookup($10, $3) activate = false
    <f32*x1> $13 = get child [S1dense->S2place_f32] $12
    <f32 x1> $14 = global load $13
    <i32 x1> $15 = cmp_gt $14 $5
    <i32 x1> $16 = bit_and $15 $11
    $17 : if $16 {
      <f32 x1> $18 = global load $13
      <f32 x1> $19 : local store [$2 <- $18]
      <f32 x1> $20 = div $1 $18
      <f32 x1> $21 : stack push $4, val = $20
    }
    <gen*x1> $22 = get child [S0root->S3dense] $9
    <gen*x1> $23 = [S3dense][dense]::lookup($22, $3) activate = false
    <f32*x1> $24 = get child [S3dense->S6place_f32] $23
    <f32 x1> $25 = global load $24
    <f32 x1> $26 : stack acc adj $4, val = $25
    <f32 x1> $27 = stack load top adj $4
    <f32 x1> $28 = local load [ [$2[0]]]
    <f32 x1> $29 = mul $28 $28   <--- probably 0*0
    <f32 x1> $30 = div $27 $29   <--- nan
    <f32 x1> $31 = neg $30   <--- nan
    <f32*x1> $32 = get child [S1dense->S5place_f32] $12
    <f32 x1> $33 = global load $32
    <f32 x1> $34 = add $33 $31   <--- nan
    $35 : if $16 {   <--- good!
      <f32*x1> $36 : global store [$32 <- $34]
    }
    <f32 x1> $37 = global load $32
    <f32 x1> $38 : global store [$32 <- $37]
  }
}

Bad(nan):

[I 06/30/20 20:43:33.481] [compile_to_offloads.cpp:taichi::lang::irpass::co
mpile_to_offloads::<lambda_a4464fe7c75e1f42a3a490ee54c7ec3e>::operator ()@2
3] Simplified III:
kernel {
  $0 = offloaded range_for(0, 2) block_dim=adaptive
  body {
    <i32 x1> $1 = loop $0 index 0
    <gen*x1> $2 = get root
    <i32 x1> $3 = const [0]
    <gen*x1> $4 = [S0root][root]::lookup($2, $3) activate = false
    <gen*x1> $5 = get child [S0root->S1dense] $4
    <i32 x1> $6 = const [1]
    <gen*x1> $7 = [S1dense][dense]::lookup($5, $1) activate = false
    <f32*x1> $8 = get child [S1dense->S2place_f32] $7
    <f32 x1> $9 = global load $8
    <f32 x1> $10 = const [0.0]
    <i32 x1> $11 = cmp_gt $9 $10
    <i32 x1> $12 = bit_and $11 $6
    <gen*x1> $13 = get child [S0root->S3dense] $4
    <gen*x1> $14 = [S3dense][dense]::lookup($13, $1) activate = false
    <f32*x1> $15 = get child [S3dense->S6place_f32] $14
    <f32 x1> $16 = global load $15
    <f32 x1> $17 = select($12, $16, $10)
    <f32 x1> $18 = mul $9 $9   <--- probably 0*0
    <f32 x1> $19 = div $17 $18  <--- nan
    <f32 x1> $20 = neg $19   <--- nan
    <f32*x1> $21 = get child [S1dense->S5place_f32] $7
    <f32 x1> $22 = global load $21
    <f32 x1> $23 = add $22 $20   <--- nan
    <f32 x1> $24 : global store [$21 <- $23]   <--- bad
  }
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature request Suggest an idea on this project
Projects
None yet
Development

No branches or pull requests

4 participants