Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Commit

Permalink
Add gather tests (#1399)
Browse files Browse the repository at this point in the history
* add gather tests

* add more gather tests

* add paddle.disable_static()

* fix bug

* fix bug

* fix bug

* use TestCaseHelper
  • Loading branch information
zrr1999 authored Jun 1, 2023
1 parent 55381a1 commit e38b000
Showing 1 changed file with 108 additions and 76 deletions.
184 changes: 108 additions & 76 deletions python/tests/ops/test_gather_op.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,4 @@
# Copyright (c) 2022 CINN Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#!/usr/bin/env python3

# Copyright (c) 2022 CINN Authors. All Rights Reserved.
#
Expand All @@ -29,12 +17,14 @@
import unittest
import numpy as np
from op_test import OpTest, OpTestTool
from op_test_helper import TestCaseHelper
import paddle
import cinn
from cinn.frontend import *
from cinn.common import *
import logging
import os
from itertools import product

logging.basicConfig(level=os.environ.get('LOG_LEVEL', 'INFO').upper())
logger = logging.getLogger(name="gather")
Expand All @@ -44,82 +34,124 @@
"x86 test will be skipped due to timeout.")
class TestGatherOp(OpTest):
def setUp(self):
self.init_case()

def init_case(self):
self.inputs = {
"x":
np.array([[[1.1, 1.2, 1.3], [2.1, 2.2, 2.3], [3.1, 3.2, 3.3],
[4.1, 4.2, 4.3]],
[[5.1, 5.2, 5.3], [6.1, 6.2, 6.3], [7.1, 7.2, 7.3],
[8.1, 8.2, 8.3]],
[[9.1, 9.2, 9.3], [10.1, 10.2, 10.3], [11.1, 11.2, 11.3],
[12.1, 12.2, 12.3]]]).astype("float32"),
"index":
np.array([0, 0, 2, 2]).astype("int32")
}
self.axis = 0
print(f"\nRunning {self.__class__.__name__}: {self.case}")
self.data = None

def build_paddle_program(self, target):
x = paddle.to_tensor(self.inputs["x"], stop_gradient=True)
index = paddle.to_tensor(self.inputs["index"], stop_gradient=True)
out = paddle.gather(x, index, self.axis)
inputs = self.case
dtype = self.case["x_dtype"]
axis = inputs["axis"]
x_shape = inputs["x"]
index_shape = inputs["index"]
# Paddle does not support negative axis values.
axis = axis if axis >= 0 else len(x_shape) + axis
x = np.random.randn(*x_shape).astype(dtype)
index = np.random.randint(0, x_shape[axis],
index_shape).astype("int32")
self.data = [x, index]
x = paddle.to_tensor(x, stop_gradient=False)
index = paddle.to_tensor(index, stop_gradient=False)
out = paddle.gather(x, index, axis)
logger.debug(" -- The output of Paddle:\n{}".format(out))
self.paddle_outputs = [out]
self.paddle_outputs.append(out)

def build_cinn_program(self, target):
inputs = self.case
dtype = self.case["x_dtype"]
axis = inputs["axis"]
builder = NetBuilder("gather")
x = builder.create_input(Float(32), self.inputs["x"].shape, "x")
index = builder.create_input(
Int(32), self.inputs["index"].shape, "index")
out = builder.gather(x, index, axis=self.axis)

x = builder.create_input(self.nptype2cinntype(dtype), inputs["x"], "x")
index = builder.create_input(Int(32), inputs["index"], "index")
out = builder.gather(x, index, axis=axis)
prog = builder.build()
res = self.get_cinn_output(prog, target, [x, index],
[self.inputs["x"], self.inputs["index"]],
[out])
res = self.get_cinn_output(prog, target, [x, index], self.data, [out])
logger.debug(" -- The output of CINN:\n{}".format(res))
self.cinn_outputs = res
self.cinn_outputs.extend(res)

def test_check_results(self):
self.check_outputs_and_grads(all_equal=True)


class TestGatherOpCase1(TestGatherOp):
def init_case(self):
self.inputs = {
"x": np.random.random([16, 32, 32]).astype("float32"),
"index": np.random.randint(0, 16, 64).astype("int32")
}
self.axis = 0


class TestGatherOpCase2(TestGatherOp):
def init_case(self):
self.inputs = {
"x": np.random.random([16, 32, 32]).astype("float32"),
"index": np.random.randint(0, 32, 15).astype("int32")
}
self.axis = 1


class TestGatherOpCase3(TestGatherOp):
def init_case(self):
self.inputs = {
"x": np.random.random([16, 16, 32, 32]).astype("float32"),
"index": np.random.randint(0, 32, 8).astype("int32")
}
self.axis = 2


class TestGatherOpCase4(TestGatherOp):
def init_case(self):
self.inputs = {
"x": np.random.random([17, 29, 31, 13]).astype("float32"),
"index": np.random.randint(0, 13, 11).astype("int32")
}
self.axis = 3
class TestGatherOpAll(TestCaseHelper):
def init_attrs(self):
self.class_name = "TestGatherOpAll"
self.cls = TestGatherOp
# note: The possible values of axis are related to x, so axis is added in self.inputs
self.inputs = [
{
"x": [128],
"index": [64],
"axis": 0
},
{
"x": [16, 32],
"index": [32],
"axis": 0
},
{
"x": [16, 32],
"index": [32],
"axis": 1
},
{
"x": [8, 16, 32],
"index": [16],
"axis": -3
},
{
"x": [8, 16, 32],
"index": [8],
"axis": -2
},
{
"x": [8, 16, 32],
"index": [8],
"axis": -1
},
{
"x": [8, 16, 32],
"index": [4],
"axis": 2
},
{
"x": [16, 8, 4, 64],
"index": [4],
"axis": 2
},
{
"x": [16, 8, 4, 1024],
"index": [4],
"axis": 2
},
{
"x": [16, 8, 4, 1],
"index": [4],
"axis": 2
},
{
"x": [1, 1, 1, 1],
"index": [4],
"axis": 2
},
]
self.dtypes = [{
"x_dtype": "int16",
"y_dtype": "int64"
}, {
"x_dtype": "int32",
"y_dtype": "int64"
}, {
"x_dtype": "int64",
"y_dtype": "int64"
}, {
"x_dtype": "float32",
"y_dtype": "int64"
}, {
"x_dtype": "float64",
"y_dtype": "int64"
}]
self.attrs = []


if __name__ == "__main__":
unittest.main()
TestGatherOpAll().run()

0 comments on commit e38b000

Please sign in to comment.