Skip to content

Commit 0af9e71

Browse files
connorgogginsUbuntu
authored and
Ubuntu
committed
Implemented final two binary ops, added default params for functionality (apache#17407)
* Implemented final two binary ops, added default params for functionality * Removed extraneous param checks * Added support for ElementWiseSum * Moved ElementWiseSum to binary_element_wise_operators
1 parent 8ac452c commit 0af9e71

File tree

4 files changed

+62
-4
lines changed

4 files changed

+62
-4
lines changed

benchmark/opperf/nd_operations/binary_operators.py

+28-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,34 @@
3535

3636
from benchmark.opperf.utils.benchmark_utils import run_op_benchmarks
3737
from benchmark.opperf.utils.op_registry_utils import get_all_broadcast_binary_operators, \
38-
get_all_elemen_wise_binary_operators
38+
get_all_elemen_wise_binary_operators, get_all_misc_binary_operators
39+
40+
41+
def run_mx_binary_misc_operators_benchmarks(ctx=mx.cpu(), dtype='float32', profiler='native', warmup=25, runs=100):
42+
"""Runs benchmarks with the given context and precision (dtype) for all the miscellaneous
43+
binary operators in MXNet.
44+
45+
Parameters
46+
----------
47+
ctx: mx.ctx
48+
Context to run benchmarks
49+
dtype: str, default 'float32'
50+
Precision to use for benchmarks
51+
warmup: int, default 25
52+
Number of times to run for warmup
53+
runs: int, default 100
54+
Number of runs to capture benchmark results
55+
56+
Returns
57+
-------
58+
Dictionary of results. Key -> Name of the operator, Value -> Benchmark results.
59+
60+
"""
61+
# Fetch all Miscellaneous Binary Operators
62+
mx_binary_misc_ops = get_all_misc_binary_operators()
63+
# Run benchmarks
64+
mx_binary_op_results = run_op_benchmarks(mx_binary_misc_ops, dtype, ctx, profiler, warmup, runs)
65+
return mx_binary_op_results
3966

4067

4168
def run_mx_binary_broadcast_operators_benchmarks(ctx=mx.cpu(), dtype='float32', profiler='native', warmup=25, runs=100):

benchmark/opperf/opperf.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030

3131
from benchmark.opperf.nd_operations.unary_operators import run_mx_unary_operators_benchmarks
3232
from benchmark.opperf.nd_operations.binary_operators import run_mx_binary_broadcast_operators_benchmarks, \
33-
run_mx_binary_element_wise_operators_benchmarks
33+
run_mx_binary_element_wise_operators_benchmarks, run_mx_binary_misc_operators_benchmarks
3434
from benchmark.opperf.nd_operations.gemm_operators import run_gemm_operators_benchmarks
3535
from benchmark.opperf.nd_operations.random_sampling_operators import run_mx_random_sampling_operators_benchmarks
3636
from benchmark.opperf.nd_operations.reduction_operators import run_mx_reduction_operators_benchmarks
@@ -63,12 +63,15 @@ def run_all_mxnet_operator_benchmarks(ctx=mx.cpu(), dtype='float32', profiler='n
6363
# Run all Unary operations benchmarks with default input values
6464
mxnet_operator_benchmark_results.append(run_mx_unary_operators_benchmarks(ctx=ctx, dtype=dtype, profiler=profiler))
6565

66-
# Run all Binary Broadcast, element_wise operations benchmarks with default input values
66+
# Run all Binary Broadcast, element_wise, and miscellaneous operations benchmarks with default input values
6767
mxnet_operator_benchmark_results.append(run_mx_binary_broadcast_operators_benchmarks(ctx=ctx,
6868
dtype=dtype, profiler=profiler))
6969
mxnet_operator_benchmark_results.append(run_mx_binary_element_wise_operators_benchmarks(ctx=ctx,
7070
dtype=dtype, profiler=profiler))
7171

72+
mxnet_operator_benchmark_results.append(run_mx_binary_misc_operators_benchmarks(ctx=ctx,
73+
dtype=dtype, profiler=profiler))
74+
7275
# Run all GEMM operations benchmarks with default input values
7376
mxnet_operator_benchmark_results.append(run_gemm_operators_benchmarks(ctx=ctx,
7477
dtype=dtype, profiler=profiler))

benchmark/opperf/rules/default_params.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,11 @@
3232
# For Unary operators like abs, arccos, arcsin etc..
3333
DEFAULT_DATA = [(1024, 1024), (10000, 1), (10000, 100)]
3434

35+
# For Binary miscellaneous operators like choose_element0_index
36+
# argument data must be indexed via an NDArray.
37+
# NOTE: Data used is DEFAULT_DATA
38+
DEFAULT_INDEX = [(1, 1024), (1, 1), (1, 100)]
39+
3540
# For Binary broadcast operators like - broadcast_add/sub/mod/logical_and etc..
3641
DEFAULT_LHS = [(1024, 1024), (10000, 10), (10000, 1)]
3742
DEFAULT_RHS = [(1024, 1024), (10000, 10), (10000, 1)]
@@ -188,7 +193,8 @@
188193
"data_smce": DEFAULT_DATA_SMCE,
189194
"data_3d": DEFAULT_DATA_3d,
190195
"label_smce": DEFAULT_LABEL_SMCE,
191-
"label": DEFAULT_LABEL}
196+
"label": DEFAULT_LABEL,
197+
"index": DEFAULT_INDEX}
192198

193199

194200
# These are names of MXNet operator parameters that is of type NDArray.

benchmark/opperf/utils/op_registry_utils.py

+22
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,26 @@ def get_all_broadcast_binary_operators():
205205
return binary_broadcast_mx_operators
206206

207207

208+
def get_all_misc_binary_operators():
209+
"""Gets all miscellaneous binary operators registered with MXNet.
210+
211+
Returns
212+
-------
213+
{"operator_name": {"has_backward", "nd_op_handle", "params"}}
214+
"""
215+
# Get all mxnet operators
216+
mx_operators = _get_all_mxnet_operators()
217+
218+
# Filter for miscellaneous binary operators
219+
binary_misc_mx_operators = {}
220+
for op_name, op_params in mx_operators.items():
221+
if "choose_element_0index" == op_name:
222+
binary_misc_mx_operators[op_name] = mx_operators[op_name]
223+
elif "reshape_like" == op_name:
224+
binary_misc_mx_operators[op_name] = mx_operators[op_name]
225+
return binary_misc_mx_operators
226+
227+
208228
def get_all_elemen_wise_binary_operators():
209229
"""Gets all binary elemen_wise operators registered with MXNet.
210230
@@ -222,6 +242,8 @@ def get_all_elemen_wise_binary_operators():
222242
"lhs" in op_params["params"]["arg_names"] and \
223243
"rhs" in op_params["params"]["arg_names"]:
224244
binary_elemen_wise_mx_operators[op_name] = mx_operators[op_name]
245+
elif "ElementWiseSum" == op_name:
246+
binary_elemen_wise_mx_operators[op_name] = mx_operators[op_name]
225247
return binary_elemen_wise_mx_operators
226248

227249

0 commit comments

Comments
 (0)