Skip to content

Commit

Permalink
rewrote the concat test to avoid flaky failures (apache#14049)
Browse files Browse the repository at this point in the history
ran 10000 times with no failures
  • Loading branch information
gigasquid authored and vdantu committed Mar 31, 2019
1 parent d9f36a5 commit f10d9b2
Showing 1 changed file with 9 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,33 +51,16 @@
(is (= out-grad grad))))))

(deftest test-concat
(let [shape-vecs [[2 2] [3 2]]
x (sym/variable "x")
y (sym/variable "y")
out (sym/concat "conc" nil [x y] {:dim 0})
arr (mapv #(ndarray/empty %) shape-vecs)
arr-np (mapv #(ndarray/copy %) arr)
arr-grad (map #(ndarray/empty %) shape-vecs)
arg-names (sym/list-arguments out)
grad-map (zipmap arg-names arr-grad)
args (sym/list-arguments out)
[arg-shapes out-shapes aux-shapes] (sym/infer-shape out (zipmap args shape-vecs))
out-shape-vec (first out-shapes)
out-grad (ndarray/empty out-shape-vec)
exec1 (sym/bind out (context/default-context) arr grad-map)
out1 (-> (executor/forward exec1)
(let [a (sym/variable "a")
b (sym/variable "b")
c (sym/concat "conc" nil [a b] {:dim 0})
exec (sym/bind c (context/default-context) {"a" (ndarray/array [1 2] [2 1])
"b" (ndarray/array [3 4] [2 1])})
output (-> (executor/forward exec)
(executor/outputs)
(first))
ret (ndarray/concatenate arr)]
(is (= out1 ret))

;;backward
(ndarray/copy-to out1 out-grad)
(ndarray/+= out-grad 1)
(executor/backward exec1 out-grad)
(let [grads arr-grad
np-grads arr-np]
(is (= grads (mapv #(ndarray/+ % 1) np-grads))))))
(first))]
(is (= [1.0 2.0 3.0 4.0] (ndarray/->vec output)))
(is (= [4 1] (ndarray/shape-vec output)))))

(defn check-regression [model forward-fn backward-fn]
(let [shape-vec [3 1]
Expand Down

0 comments on commit f10d9b2

Please sign in to comment.