Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[Clojure] Helper function for n-dim vector to ndarray #14305

Merged
merged 4 commits into from
Mar 11, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -92,19 +92,24 @@
(arange start stop {})))

(defn ->ndarray
"Creates a new NDArray based on the given n-dimensional
float/double vector.
`nd-vec`: n-dimensional vector with floats or doubles.
"Creates a new NDArray based on the given n-dimenstional vector
of numbers.
`nd-vec`: n-dimensional vector with numbers.
`opts-map` {
`ctx`: Context of the output ndarray, will use default context if unspecified.
}
returns: `ndarray` with the given values and matching the shape of the input vector.
Ex:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Love your docstring!

(->ndarray [5.0 -4.0])
(->ndarray [[1.0 2.0 3.0] [4.0 5.0 6.0]])
(->ndarray [5 -4] {:ctx (context/cpu)})
(->ndarray [[1 2 3] [4 5 6]])
(->ndarray [[[1.0] [2.0]]]"
([nd-vec {:keys [ctx] :as opts}]
(NDArray/toNDArray (util/to-array-nd nd-vec) ctx))
([nd-vec {:keys [ctx]
:or {ctx (mx-context/default-context)}
:as opts}]
(array (vec (clojure.core/flatten nd-vec))
(util/nd-seq-shape nd-vec)
{:ctx ctx}))
([nd-vec] (->ndarray nd-vec {})))

(defn slice
Expand Down
22 changes: 16 additions & 6 deletions contrib/clojure-package/src/org/apache/clojure_mxnet/util.clj
Original file line number Diff line number Diff line change
Expand Up @@ -218,15 +218,25 @@
(throw (ex-info error-msg
(s/explain-data spec value)))))

(s/def ::non-empty-seq sequential?)
(s/def ::non-empty-seq (s/and sequential? not-empty))
(defn to-array-nd
"Converts any N-D sequential structure to an array
with the same dimensions."
[s]
(validate! ::non-empty-seq s "Invalid N-D sequence")
(if (sequential? (first s))
(to-array (mapv to-array-nd s))
(to-array s)))
[nd-seq]
(validate! ::non-empty-seq nd-seq "Invalid N-D sequence")
(if (sequential? (first nd-seq))
(to-array (mapv to-array-nd nd-seq))
(to-array nd-seq)))

(defn nd-seq-shape
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

"Computes the shape of a n-dimensional sequential structure"
[nd-seq]
(validate! ::non-empty-seq nd-seq "Invalid N-D sequence")
(loop [s nd-seq
shape [(count s)]]
(if (sequential? (first s))
(recur (first s) (conj shape (count (first s))))
shape)))

(defn map->scala-tuple-seq
"* Convert a map to a scala-Seq of scala-Tubple.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,8 @@

(deftest test->ndarray
(let [nda1 (ndarray/->ndarray [5.0 -4.0])
nda2 (ndarray/->ndarray [[1.0 2.0 3.0]
[4.0 5.0 6.0]])
nda2 (ndarray/->ndarray [[1 2 3]
[4 5 6]])
nda3 (ndarray/->ndarray [[[7.0] [8.0]]])]
(is (= [5.0 -4.0] (->vec nda1)))
(is (= [2] (mx-shape/->vec (shape nda1))))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,12 +164,12 @@
(util/tuple->vec)))))

(deftest test-to-array-nd
(let [a1 (util/to-array-nd '())
(let [a1 (util/to-array-nd '(1))
a2 (util/to-array-nd [1.0 2.0])
a3 (util/to-array-nd [[3.0] [4.0]])
a4 (util/to-array-nd [[[5 -5]]])]
(is (= 0 (alength a1)))
(is (= [] (->> a1 vec)))
(is (= 1 (alength a1)))
(is (= [1] (->> a1 vec)))
(is (= 2 (alength a2)))
(is (= 2.0 (aget a2 1)))
(is (= [1.0 2.0] (->> a2 vec)))
Expand All @@ -183,6 +183,13 @@
(is (= 5 (aget a4 0 0 0)))
(is (= [[[5 -5]]] (->> a4 vec (mapv vec) (mapv #(mapv vec %)))))))

(deftest test-nd-seq-shape
(is (= [1] (util/nd-seq-shape '(5))))
(is (= [2] (util/nd-seq-shape [1.0 2.0])))
(is (= [3] (util/nd-seq-shape [1 1 1])))
(is (= [2 1] (util/nd-seq-shape [[3.0] [4.0]])))
(is (= [1 3 2] (util/nd-seq-shape [[[5 -5] [5 -5] [5 -5]]]))))

(deftest test-coerce-return
(is (= [] (util/coerce-return (ArrayBuffer.))))
(is (= [1 2 3] (util/coerce-return (util/vec->indexed-seq [1 2 3]))))
Expand Down