Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[Large Tensor] Fixed Embedding op #17599

Merged
merged 5 commits into from
Feb 25, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/operator/tensor/indexing_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ enum QuantizedEmbeddingOpResource {kTempSpace};


struct SparseEmbeddingParam: public dmlc::Parameter<SparseEmbeddingParam> {
int input_dim;
int output_dim;
index_t input_dim;
index_t output_dim;
int dtype;
bool deterministic;
DMLC_DECLARE_PARAMETER(SparseEmbeddingParam) {
Expand All @@ -89,8 +89,8 @@ struct SparseEmbeddingParam: public dmlc::Parameter<SparseEmbeddingParam> {
};

struct EmbeddingParam: public dmlc::Parameter<EmbeddingParam> {
int input_dim;
int output_dim;
index_t input_dim;
index_t output_dim;
int dtype;
bool sparse_grad;
DMLC_DECLARE_PARAMETER(EmbeddingParam) {
Expand Down
13 changes: 13 additions & 0 deletions tests/nightly/test_large_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
SMALL_X = 100
SMALL_Y = 50
LARGE_SIZE = LARGE_X * SMALL_Y
LARGE_TENSOR_SHAPE = 2**32


def test_nn():
Expand Down Expand Up @@ -467,6 +468,17 @@ def check_col2im():
assert res.shape[2] == 2
assert res.shape[3] == 2
assert res.shape[4] == 1
def check_embedding():
data = nd.random_normal(shape=(LARGE_TENSOR_SHAPE, 1))
weight = nd.random_normal(shape=(LARGE_TENSOR_SHAPE, 1))
input_dim = LARGE_TENSOR_SHAPE
output_dim = 1

out = nd.Embedding(data=data, weight=weight, input_dim=input_dim, output_dim=output_dim)

assert out.shape[0] == LARGE_TENSOR_SHAPE
assert out.shape[1] == 1
assert out.shape[2] == 1

check_gluon_embedding()
check_fully_connected()
Expand All @@ -488,6 +500,7 @@ def check_col2im():
check_l2_normalization()
check_instance_norm()
check_col2im()
check_embedding()


def test_tensor():
Expand Down