From 9daaf1ea8edeb21910baf73805b79904c1d18bae Mon Sep 17 00:00:00 2001 From: Haichen Shen Date: Sun, 22 Mar 2020 09:22:37 -0700 Subject: [PATCH] Adjust strategy plevel to achieve expected performance by default (#5118) --- python/tvm/relay/op/strategy/arm_cpu.py | 6 +++--- python/tvm/relay/op/strategy/bifrost.py | 2 +- python/tvm/relay/op/strategy/cuda.py | 6 +++--- python/tvm/relay/op/strategy/mali.py | 2 +- python/tvm/relay/op/strategy/rocm.py | 18 ++++++++---------- python/tvm/relay/op/strategy/x86.py | 6 +++--- 6 files changed, 19 insertions(+), 21 deletions(-) diff --git a/python/tvm/relay/op/strategy/arm_cpu.py b/python/tvm/relay/op/strategy/arm_cpu.py index 79976eb439cb..87e48dc7fa04 100644 --- a/python/tvm/relay/op/strategy/arm_cpu.py +++ b/python/tvm/relay/op/strategy/arm_cpu.py @@ -67,13 +67,13 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target): wrap_compute_conv2d(topi.arm_cpu.conv2d_nchw_winograd), wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_nchw_winograd), name="conv2d_nchw_winograd.arm_cpu", - plevel=15) + plevel=5) if "nnpack" in target.libs and pt == 1 and pb == 1 and pl == 1 and pr == 1: strategy.add_implementation( wrap_compute_conv2d(topi.arm_cpu.conv2d_nchw_winograd_nnpack), wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_nchw_winograd_nnpack), name="conv2d_nchw_winograd_nnpack.arm_cpu", - plevel=13) + plevel=15) elif re.match(r"OIHW\d*o", kernel_layout): strategy.add_implementation( wrap_compute_conv2d(topi.arm_cpu.conv2d_nchw_spatial_pack), @@ -177,7 +177,7 @@ def conv2d_winograd_without_weight_transfrom_strategy_arm_cpu(attrs, inputs, out wrap_topi_schedule( topi.arm_cpu.schedule_conv2d_nchw_winograd_nnpack_without_weight_transform), name="conv2d_nchw_winograd_nnpack_withou_weight_transform.arm_cpu", - plevel=5) + plevel=15) else: raise RuntimeError("Unsupported kernel shape: {}".format(kernel.shape)) else: diff --git a/python/tvm/relay/op/strategy/bifrost.py b/python/tvm/relay/op/strategy/bifrost.py index e8f62980a621..a96463fa6ffa 100644 --- a/python/tvm/relay/op/strategy/bifrost.py +++ b/python/tvm/relay/op/strategy/bifrost.py @@ -50,7 +50,7 @@ def conv2d_strategy_bifrost(attrs, inputs, out_type, target): wrap_compute_conv2d(topi.bifrost.conv2d_nchw_winograd), wrap_topi_schedule(topi.bifrost.schedule_conv2d_nchw_winograd), name="conv2d_nchw_winograd.bifrost", - plevel=15) + plevel=5) elif re.match(r"OIHW\d*o", kernel_layout): strategy.add_implementation( wrap_compute_conv2d(topi.bifrost.conv2d_nchw_spatial_pack), diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index 8ccd6bf51508..f52a7d5f2dd1 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -135,7 +135,7 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target): wrap_compute_conv2d(topi.cuda.conv2d_cudnn, True), wrap_topi_schedule(topi.cuda.schedule_conv2d_cudnn), name="conv2d_cudnn.cuda", - plevel=5) + plevel=15) elif is_depthwise_conv2d(data.shape, layout, kernel.shape, kernel_layout, groups): if layout == "NCHW": assert kernel_layout == "OIHW" @@ -295,13 +295,13 @@ def dense_strategy_cuda(attrs, inputs, out_type, target): wrap_compute_dense(topi.cuda.dense_large_batch), wrap_topi_schedule(topi.cuda.schedule_dense_large_batch), name="dense_large_batch.cuda", - plevel=15) + plevel=5) if target.target_name == "cuda" and "cublas" in target.libs: strategy.add_implementation( wrap_compute_dense(topi.cuda.dense_cublas), wrap_topi_schedule(topi.cuda.schedule_dense_cublas), name="dense_cublas.cuda", - plevel=20) + plevel=15) return strategy @batch_matmul_strategy.register(["cuda", "gpu"]) diff --git a/python/tvm/relay/op/strategy/mali.py b/python/tvm/relay/op/strategy/mali.py index 8f1fa291d236..5e4a7e5669d2 100644 --- a/python/tvm/relay/op/strategy/mali.py +++ b/python/tvm/relay/op/strategy/mali.py @@ -49,7 +49,7 @@ def conv2d_strategy_mali(attrs, inputs, out_type, target): wrap_compute_conv2d(topi.mali.conv2d_nchw_winograd), wrap_topi_schedule(topi.mali.schedule_conv2d_nchw_winograd), name="conv2d_nchw_winograd.mali", - plevel=15) + plevel=5) elif re.match(r"OIHW\d*o", kernel_layout): strategy.add_implementation( wrap_compute_conv2d(topi.mali.conv2d_nchw_spatial_pack), diff --git a/python/tvm/relay/op/strategy/rocm.py b/python/tvm/relay/op/strategy/rocm.py index 63bfe5e4a6e9..0486f71b526c 100644 --- a/python/tvm/relay/op/strategy/rocm.py +++ b/python/tvm/relay/op/strategy/rocm.py @@ -77,13 +77,12 @@ def conv2d_strategy_rocm(attrs, inputs, out_type, target): else: raise RuntimeError("Unsupported conv2d layout {} for CUDA".format(layout)) # add miopen implementation - if "miopen" in target.libs: - if layout == "NCHW": - strategy.add_implementation( - wrap_compute_conv2d(topi.rocm.conv2d_nchw_miopen, True), - wrap_topi_schedule(topi.rocm.schedule_conv2d_nchw_miopen), - name="conv2d_nchw_miopen.rocm", - plevel=15) + if "miopen" in target.libs and layout == "NCHW": + strategy.add_implementation( + wrap_compute_conv2d(topi.rocm.conv2d_nchw_miopen, True), + wrap_topi_schedule(topi.rocm.schedule_conv2d_nchw_miopen), + name="conv2d_nchw_miopen.rocm", + plevel=15) elif is_depthwise_conv2d(data.shape, layout, kernel.shape, kernel_layout, groups): if layout == "NCHW": assert kernel_layout == "OIHW" @@ -120,9 +119,8 @@ def conv2d_strategy_rocm(attrs, inputs, out_type, target): @dense_strategy.register("rocm") def dense_strategy_rocm(attrs, inputs, out_type, target): """Dense strategy for ROCM""" - strategy = _op.OpStrategy() assert len(inputs[0].shape) == 2 and len(inputs[1].shape) == 2, "Only support 2-dim dense" - + strategy = _op.OpStrategy() strategy.add_implementation( wrap_compute_dense(topi.rocm.dense), wrap_topi_schedule(topi.rocm.schedule_dense), @@ -133,5 +131,5 @@ def dense_strategy_rocm(attrs, inputs, out_type, target): wrap_compute_dense(topi.rocm.dense_rocblas), wrap_topi_schedule(topi.rocm.dense_rocblas), name="dense_rocblas.rocm", - plevel=5) + plevel=15) return strategy diff --git a/python/tvm/relay/op/strategy/x86.py b/python/tvm/relay/op/strategy/x86.py index e35838c1c5e8..6606b5c49184 100644 --- a/python/tvm/relay/op/strategy/x86.py +++ b/python/tvm/relay/op/strategy/x86.py @@ -232,13 +232,13 @@ def dense_strategy_cpu(attrs, inputs, out_type, target): strategy.add_implementation(wrap_compute_dense(topi.x86.dense_cblas), wrap_topi_schedule(topi.x86.schedule_dense_cblas), name="dense_cblas.x86", - plevel=5) + plevel=15) with SpecializedCondition(m >= 16): # this implementation may not be well-optimized, so use plevel=8 for now. strategy.add_implementation(wrap_compute_dense(topi.x86.dense_pack), wrap_topi_schedule(topi.x86.schedule_dense_pack), name="dense_pack.x86", - plevel=8) + plevel=5) return strategy @batch_matmul_strategy.register("cpu") @@ -253,7 +253,7 @@ def batch_matmul_strategy_cpu(attrs, inputs, out_type, target): strategy.add_implementation(wrap_compute_batch_matmul(topi.x86.batch_matmul_cblas), wrap_topi_schedule(topi.x86.schedule_batch_matmul_cblas), name="batch_matmul_cblas.x86", - plevel=5) + plevel=15) return strategy @schedule_sparse_dense.register("cpu")