Skip to content

Commit

Permalink
+ python unitest
Browse files Browse the repository at this point in the history
  • Loading branch information
Meiyim committed Mar 16, 2021
1 parent 0b44afc commit 3cf6c2f
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 1 deletion.
2 changes: 1 addition & 1 deletion paddle/fluid/operators/expand_op_npu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,12 @@ class ExpandNPUKernel : public framework::OpKernel<T> {
"of dimensions (%d) of the input.",
expand_times.size(), static_cast<size_t>(in_dims.size())));
auto* out0 = context.Output<framework::LoDTensor>("Out");
out0->mutable_data<T>(context.device_context().GetPlace());
framework::DDim out_dims(in_dims);
for (size_t i = 0; i < expand_times.size(); ++i) {
out_dims[i] *= expand_times[i];
}
out0->Resize(out_dims);
out0->mutable_data<T>(context.device_context().GetPlace());
auto runner = NpuOpRunner("TileD", {*in0}, {*out0}, {{"multiples", expand_times}});
auto stream =
context.template device_context<paddle::platform::NPUDeviceContext>()
Expand Down
19 changes: 19 additions & 0 deletions python/paddle/fluid/tests/unittests/npu/test_expand_op_npu.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,25 @@ def test_check_output(self):
# self.check_grad(['X'], 'Out')
#

@unittest.skipIf(not paddle.is_compiled_with_npu(),
"core is not compiled with NPU")
class TestExpandV2(TestExpand):
def setUp(self):
self.set_npu()
self.op_type = "expand"
self.place = paddle.NPUPlace(0)

self.init_dtype()
np.random.seed(SEED)
x = np.random.randn(3,1,7).astype(self.dtype)
out = np.tile(x, [1,10,1])
expand_times = np.array([1,10,1]).astype(np.int32)

self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x), 'ExpandTimes': OpTest.np_dtype_to_fluid_dtype(expand_times)}
self.attrs = {}
self.outputs = {'Out': out}


@unittest.skipIf(not paddle.is_compiled_with_npu(),
"core is not compiled with NPU")
class TestExpandFp16(TestExpand):
Expand Down

0 comments on commit 3cf6c2f

Please sign in to comment.