Skip to content

Commit 6c2fd00

Browse files
committed
fix bug
1 parent 162498a commit 6c2fd00

File tree

3 files changed

+64
-115
lines changed

3 files changed

+64
-115
lines changed

test/custom_op/test_inference_inplace_pir.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,5 +126,67 @@ def test_output(self):
126126
pir_output = self.get_outputs(pir_predictor)
127127

128128

129+
class TestPredictorRunWithConfig(unittest.TestCase):
130+
def setUp(self):
131+
self.temp_dir = tempfile.TemporaryDirectory()
132+
net = TestInplaceNet()
133+
model = paddle.jit.to_static(
134+
net,
135+
input_spec=[
136+
paddle.static.InputSpec(
137+
shape=[None, 4], dtype='float32', name='x'
138+
),
139+
],
140+
full_graph=True,
141+
)
142+
paddle.jit.save(
143+
model,
144+
os.path.join(
145+
self.temp_dir.name, 'test_predictor_run_model/inference'
146+
),
147+
)
148+
149+
def tearDown(self):
150+
self.temp_dir.cleanup()
151+
152+
def init_predictor(self):
153+
config = Config(
154+
os.path.join(
155+
self.temp_dir.name,
156+
'test_predictor_run_model',
157+
),
158+
'inference',
159+
)
160+
config.enable_use_gpu(256, 0)
161+
config.switch_ir_optim(False)
162+
config.enable_new_executor()
163+
config.enable_new_ir()
164+
predictor = create_predictor(config)
165+
return predictor
166+
167+
def get_inputs(self):
168+
x = np.array([[1, 2, 3, 4], [2, 3, 4, 5]]).astype(np.float32)
169+
170+
x_tensor = paddle.to_tensor(x)
171+
172+
return [x_tensor]
173+
174+
def get_outputs(self, predictor):
175+
[x_tensor] = self.get_inputs()
176+
177+
input_names = predictor.get_input_names()
178+
x_tensor.name = input_names[0]
179+
180+
# disorder
181+
inputs = [x_tensor]
182+
outputs = predictor.run(inputs)
183+
184+
return outputs[0]
185+
186+
def test_output(self):
187+
pir_predictor = self.init_predictor()
188+
pir_output = self.get_outputs(pir_predictor)
189+
190+
129191
if __name__ == "__main__":
130192
unittest.main()

test/ir/pir/CMakeLists.txt

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,8 @@ list(REMOVE_ITEM TEST_INTERP_CASES ${TEST_IR_SYSTEM_CASES})
2121
list(REMOVE_ITEM TEST_INTERP_CASES test_subgraph_exporter)
2222

2323
foreach(target ${TEST_INTERP_CASES})
24-
if(NOT ${target} STREQUAL "test_config")
25-
py_test_modules(${target} MODULES ${target} ENVS GLOG_v=1
26-
FLAGS_enable_pir_in_executor=true)
24+
py_test_modules(${target} MODULES ${target} ENVS GLOG_v=1
25+
FLAGS_enable_pir_in_executor=true)
2726
endif()
2827
endforeach()
2928

@@ -43,7 +42,5 @@ py_test_modules(
4342
FLAGS_enable_pir_in_executor=1
4443
FLAGS_pir_subgraph_saving_dir=${CMAKE_CURRENT_SOURCE_DIR})
4544

46-
py_test_modules(test_config MODULES test_config ENVS FLAGS_enable_pir_api=true)
47-
4845
add_subdirectory(fused_pass)
4946
add_subdirectory(translator)

test/ir/pir/test_config.py

Lines changed: 0 additions & 110 deletions
This file was deleted.

0 commit comments

Comments
 (0)