Skip to content

Commit

Permalink
添加pir:value情况
Browse files Browse the repository at this point in the history
  • Loading branch information
ZHOU05030 committed Dec 2, 2024
1 parent ca1a7f6 commit 46edbd0
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 18 deletions.
3 changes: 1 addition & 2 deletions python/paddle/tensorrt/impls/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,7 @@ def reshape_converter(network, paddle_op, inputs):

@converter_registry.register("pd_op.gather", trt_version="8.x")
def gather_converter(network, paddle_op, inputs):
input_tensor, index_tensor, *_ = inputs
axis = paddle_op.attrs()["axis"]
input_tensor, index_tensor, axis = inputs
reshape_layer = network.add_shuffle(index_tensor)
reshape_layer.reshape_dims = (-1,)
gather_layer = network.add_gather(
Expand Down
24 changes: 8 additions & 16 deletions test/tensorrt/test_converter_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,15 +573,11 @@ def setUp(self):
self.api_args = {
"x": np.random.random([3, 4, 10]).astype("float32"),
"index": np.array([0, 2]).astype("int64"),
"axis": 1,
}
self.program_config = {"feed_list": ["x", "index"]}
self.min_shape = {"x": [1, 4, 10], "index": [1]}
self.max_shape = {"x": [5, 4, 10], "index": [5]}
self.dynamic_shape = {
"x": {"min": [1, 4, 10], "max": [5, 4, 10], "opt": [3, 4, 10]},
"index": {"min": [1], "max": [5], "opt": [2]},
"axis": np.array([1]).astype("int32"),
}
self.program_config = {"feed_list": ["x", "index", "axis"]}
self.min_shape = {"x": [1, 4, 10], "index": [1], "axis": [1]}
self.max_shape = {"x": [5, 4, 10], "index": [5], "axis": [1]}

def test_trt_result(self):
self.check_trt_result()
Expand All @@ -593,15 +589,11 @@ def setUp(self):
self.api_args = {
"x": np.random.random([3, 4, 10]).astype("int64"),
"index": np.array([0, 2]).astype("int64"),
"axis": 1,
}
self.program_config = {"feed_list": ["x", "index"]}
self.min_shape = {"x": [1, 4, 10], "index": [1]}
self.max_shape = {"x": [5, 4, 10], "index": [5]}
self.dynamic_shape = {
"x": {"min": [1, 4, 10], "max": [5, 4, 10], "opt": [3, 4, 10]},
"index": {"min": [1], "max": [5], "opt": [2]},
"axis": np.array([1]).astype("int32"),
}
self.program_config = {"feed_list": ["x", "index", "axis"]}
self.min_shape = {"x": [1, 4, 10], "index": [1], "axis": [1]}
self.max_shape = {"x": [5, 4, 10], "index": [5], "axis": [1]}

def test_trt_result(self):
self.check_trt_result()
Expand Down

0 comments on commit 46edbd0

Please sign in to comment.