|
18 | 18 | // but each shadow argument is `width` times larger (thus 16 and 20 elements here). |
19 | 19 | // `d_square3` instead takes `width` (4) shadow arguments, which are all the same size as the |
20 | 20 | // original function arguments. |
21 | | -// |
22 | | -// FIXME(autodiff): We currently can't test `d_square1` and `d_square3` in the same file, since they |
23 | | -// generate the same dummy functions which get merged by LLVM, breaking pieces of our pipeline which |
24 | | -// try to rewrite the dummy functions later. We should consider to change to pure declarations both |
25 | | -// in our frontend and in the llvm backend to avoid these issues. |
26 | 21 |
|
27 | 22 | #![feature(autodiff)] |
28 | 23 |
|
29 | 24 | use std::autodiff::autodiff_forward; |
30 | 25 |
|
31 | 26 | // CHECK: ; |
32 | 27 | #[no_mangle] |
33 | | -//#[autodiff(d_square1, Forward, Dual, Dual)] |
| 28 | +#[autodiff_forward(d_square1, Dual, Dual)] |
34 | 29 | #[autodiff_forward(d_square2, 4, Dualv, Dualv)] |
35 | 30 | #[autodiff_forward(d_square3, 4, Dual, Dual)] |
36 | 31 | fn square(x: &[f32], y: &mut [f32]) { |
@@ -79,25 +74,25 @@ fn main() { |
79 | 74 | let mut dy3_4 = std::hint::black_box(vec![0.0; 5]); |
80 | 75 |
|
81 | 76 | // scalar. |
82 | | - //d_square1(&x1, &z1, &mut y1, &mut dy1_1); |
83 | | - //d_square1(&x1, &z2, &mut y2, &mut dy1_2); |
84 | | - //d_square1(&x1, &z3, &mut y3, &mut dy1_3); |
85 | | - //d_square1(&x1, &z4, &mut y4, &mut dy1_4); |
| 77 | + d_square1(&x1, &z1, &mut y1, &mut dy1_1); |
| 78 | + d_square1(&x1, &z2, &mut y2, &mut dy1_2); |
| 79 | + d_square1(&x1, &z3, &mut y3, &mut dy1_3); |
| 80 | + d_square1(&x1, &z4, &mut y4, &mut dy1_4); |
86 | 81 |
|
87 | 82 | // assert y1 == y2 == y3 == y4 |
88 | | - //for i in 0..5 { |
89 | | - // assert_eq!(y1[i], y2[i]); |
90 | | - // assert_eq!(y1[i], y3[i]); |
91 | | - // assert_eq!(y1[i], y4[i]); |
92 | | - //} |
| 83 | + for i in 0..5 { |
| 84 | + assert_eq!(y1[i], y2[i]); |
| 85 | + assert_eq!(y1[i], y3[i]); |
| 86 | + assert_eq!(y1[i], y4[i]); |
| 87 | + } |
93 | 88 |
|
94 | 89 | // batch mode A) |
95 | 90 | d_square2(&x1, &z5, &mut y5, &mut dy2); |
96 | 91 |
|
97 | 92 | // assert y1 == y2 == y3 == y4 == y5 |
98 | | - //for i in 0..5 { |
99 | | - // assert_eq!(y1[i], y5[i]); |
100 | | - //} |
| 93 | + for i in 0..5 { |
| 94 | + assert_eq!(y1[i], y5[i]); |
| 95 | + } |
101 | 96 |
|
102 | 97 | // batch mode B) |
103 | 98 | d_square3(&x1, &z1, &z2, &z3, &z4, &mut y6, &mut dy3_1, &mut dy3_2, &mut dy3_3, &mut dy3_4); |
|
0 commit comments