Skip to content

Commit

Permalink
Fix one hot scalar tensor bug (#7975)
Browse files Browse the repository at this point in the history
* fix reduce_sum scalar check bug

* fix one_hot scalar tensor bug

* fix clang tidy error

Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
BBuf and mergify[bot] authored Apr 8, 2022
1 parent 202c80f commit 3932e16
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 3 deletions.
4 changes: 3 additions & 1 deletion oneflow/user/ops/one_hot_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ namespace oneflow {
const int64_t depth = ctx->Attr<int64_t>("depth");
CHECK_GT_OR_RETURN(depth, 0);
const user_op::TensorDesc& indices_desc = ctx->InputTensorDesc("indices", 0);
CHECK_GT_OR_RETURN(indices_desc.shape().NumAxes(), 0);
// For 0-dim Tensor
CHECK_GE_OR_RETURN(indices_desc.shape().NumAxes(), 0)
<< "indices dim must be great or equal than 0";
user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0);
*out_desc->mut_is_dynamic() = indices_desc.is_dynamic();
DimVector dim_vec = indices_desc.shape().dim_vec();
Expand Down
10 changes: 8 additions & 2 deletions python/oneflow/test/modules/test_one_hot.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import oneflow as flow


def _test_onehot(test_case, device, num_classes, size, on_value, off_value):
def _test_one_hot(test_case, device, num_classes, size, on_value, off_value):
x = np.random.randint(9, size=size)
input = flow.tensor(x, device=flow.device(device), dtype=flow.int64)
output = flow.nn.functional.one_hot(input, num_classes, on_value, off_value)
Expand All @@ -41,7 +41,7 @@ class TestOnehot(flow.unittest.TestCase):
def test_onehot(test_case):
arg_dict = OrderedDict()
arg_dict["test_fun"] = [
_test_onehot,
_test_one_hot,
]
arg_dict["device"] = ["cpu", "cuda"]
arg_dict["num_classes"] = [-1, 10, 11]
Expand All @@ -51,6 +51,12 @@ def test_onehot(test_case):
for arg in GenArgList(arg_dict):
arg[0](test_case, *arg[1:])

@autotest(auto_backward=False)
def test_one_hot_scalar(test_case):
x = torch.tensor(2)
y = torch.nn.functional.one_hot(x, num_classes=5)
return y


if __name__ == "__main__":
unittest.main()

0 comments on commit 3932e16

Please sign in to comment.