Skip to content

Commit

Permalink
fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
wangye707 committed Aug 9, 2021
1 parent 04ca2e1 commit d9cd374
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 11 deletions.
4 changes: 1 addition & 3 deletions inference/python_api_test/test_case/infer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,7 @@ def disable_gpu_test(self, input_data_dict: dict, repeat=20):
gpu_max_mem = max([float(i["used(MB)"]) for i in _gpu_mem_lists])
assert abs(gpu_max_mem - ori_gpu_mem) < 1, "set disable_gpu(), but gpu activity found"

def mkldnn_test(
self, input_data_dict: dict, output_data_dict: dict, mkldnn_cache_capacity=1, repeat=2, delta=1e-5, gpu_mem=1000
):
def mkldnn_test(self, input_data_dict: dict, output_data_dict: dict, mkldnn_cache_capacity=1, repeat=2, delta=1e-5):
"""
test enable_mkldnn()
Args:
Expand Down
8 changes: 4 additions & 4 deletions inference/python_api_test/test_class_model/test_resnet50.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ def test_trtfp32_more_bz():

file_path = "./resnet50"
images_size = 224
batch_size_pool = 10
for batch_size in range(1, batch_size_pool + 1, 3):
batch_size_pool = [1, 5, 10]
for batch_size in batch_size_pool:
test_suite = InferenceTest()
test_suite.load_config(model_file="./resnet50/inference.pdmodel", params_file="./resnet50/inference.pdiparams")
images_list, npy_list = test_suite.get_images_npy(file_path, images_size)
Expand Down Expand Up @@ -178,8 +178,8 @@ def test_trtfp16_more_bz():

file_path = "./resnet50"
images_size = 224
batch_size_pool = 10
for batch_size in range(1, batch_size_pool + 1, 3):
batch_size_pool = [1, 5, 10]
for batch_size in batch_size_pool:
test_suite = InferenceTest()
test_suite.load_config(model_file="./resnet50/inference.pdmodel", params_file="./resnet50/inference.pdiparams")
images_list, npy_list = test_suite.get_images_npy(file_path, images_size)
Expand Down
8 changes: 4 additions & 4 deletions inference/python_api_test/test_class_model/test_vgg11.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ def test_trtfp32_more_bz():

file_path = "./vgg11"
images_size = 224
batch_size_pool = 10
for batch_size in range(1, batch_size_pool + 1, 3):
batch_size_pool = [1, 10]
for batch_size in batch_size_pool:
test_suite = InferenceTest()
test_suite.load_config(model_file="./vgg11/inference.pdmodel", params_file="./vgg11/inference.pdiparams")
images_list, npy_list = test_suite.get_images_npy(file_path, images_size)
Expand Down Expand Up @@ -125,8 +125,8 @@ def test_trtfp16_more_bz():

file_path = "./vgg11"
images_size = 224
batch_size_pool = 10
for batch_size in range(1, batch_size_pool + 1, 4):
batch_size_pool = [1, 10]
for batch_size in batch_size_pool:
test_suite = InferenceTest()
test_suite.load_config(model_file="./vgg11/inference.pdmodel", params_file="./vgg11/inference.pdiparams")
images_list, npy_list = test_suite.get_images_npy(file_path, images_size)
Expand Down

0 comments on commit d9cd374

Please sign in to comment.