Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[Large Tensor] Fixed Embedding op (#17599)
Browse files Browse the repository at this point in the history
* Switched from int to index_t for input_dim

* Implemented fix for output_dim

* Added nightly test for Embedding

* Set const value for output dim

* More standardization via const param
  • Loading branch information
connorgoggins committed Feb 25, 2020
1 parent 31144c7 commit d51753b
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/operator/tensor/indexing_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ enum QuantizedEmbeddingOpResource {kTempSpace};


struct SparseEmbeddingParam: public dmlc::Parameter<SparseEmbeddingParam> {
int input_dim;
int output_dim;
index_t input_dim;
index_t output_dim;
int dtype;
bool deterministic;
DMLC_DECLARE_PARAMETER(SparseEmbeddingParam) {
Expand All @@ -89,8 +89,8 @@ struct SparseEmbeddingParam: public dmlc::Parameter<SparseEmbeddingParam> {
};

struct EmbeddingParam: public dmlc::Parameter<EmbeddingParam> {
int input_dim;
int output_dim;
index_t input_dim;
index_t output_dim;
int dtype;
bool sparse_grad;
DMLC_DECLARE_PARAMETER(EmbeddingParam) {
Expand Down
13 changes: 13 additions & 0 deletions tests/nightly/test_large_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
SMALL_X = 100
SMALL_Y = 50
LARGE_SIZE = LARGE_X * SMALL_Y
LARGE_TENSOR_SHAPE = 2**32


def test_nn():
Expand Down Expand Up @@ -467,6 +468,17 @@ def check_col2im():
assert res.shape[2] == 2
assert res.shape[3] == 2
assert res.shape[4] == 1
def check_embedding():
data = nd.random_normal(shape=(LARGE_TENSOR_SHAPE, 1))
weight = nd.random_normal(shape=(LARGE_TENSOR_SHAPE, 1))
input_dim = LARGE_TENSOR_SHAPE
output_dim = 1

out = nd.Embedding(data=data, weight=weight, input_dim=input_dim, output_dim=output_dim)

assert out.shape[0] == LARGE_TENSOR_SHAPE
assert out.shape[1] == 1
assert out.shape[2] == 1

check_gluon_embedding()
check_fully_connected()
Expand All @@ -488,6 +500,7 @@ def check_col2im():
check_l2_normalization()
check_instance_norm()
check_col2im()
check_embedding()


def test_tensor():
Expand Down

0 comments on commit d51753b

Please sign in to comment.