From 764913821995e02c71b70cca576240cf26b1badf Mon Sep 17 00:00:00 2001 From: masa Date: Thu, 5 Nov 2020 15:48:56 +0900 Subject: [PATCH 1/5] enable scatter gpu test on cuda --- tests/python/frontend/pytorch/test_forward.py | 318 +++++++++--------- 1 file changed, 159 insertions(+), 159 deletions(-) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 4dec5f7e5916..b96470da3561 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -3149,17 +3149,17 @@ def test_fn_scatter_add(dim): in_data = torch.zeros(3, 5) in_index = torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]) in_src = torch.rand(2, 5) - # TODO: add scatter gpu schedule to enable gpu test. - verify_trace_model(test_fn_scatter(0), [in_data, in_index, in_src], ["llvm"]) - verify_trace_model(test_fn_scatter_add(0), [in_data, in_index, in_src], ["llvm"]) + + targets = ["llvm", "cuda"] + verify_trace_model(test_fn_scatter(0), [in_data, in_index, in_src], targets) + # verify_trace_model(test_fn_scatter_add(0), [in_data, in_index, in_src], targets) in_data = torch.zeros(2, 4) in_index = torch.tensor([[2], [3]]) in_src = torch.rand(2, 1) - # # TODO: add scatter gpu schedule to enable gpu test. - verify_trace_model(test_fn_scatter(1), [in_data, in_index, in_src], ["llvm"]) - verify_trace_model(test_fn_scatter_add(1), [in_data, in_index, in_src], ["llvm"]) + verify_trace_model(test_fn_scatter(1), [in_data, in_index, in_src], targets) + # verify_trace_model(test_fn_scatter_add(1), [in_data, in_index, in_src], ["llvm"]) def test_numel(): @@ -3364,167 +3364,167 @@ def test_fn(x, weights=None): if __name__ == "__main__": - # some structural tests - test_forward_traced_function() - test_forward_dtypes() - test_weight_names() - test_duplicate_weight_use() - - # Single operator tests - test_forward_pixel_shuffle() - test_forward_add() - test_forward_subtract() - test_forward_multiply() - test_forward_matmul() - test_forward_rsub() - test_forward_onehot() - test_forward_embedding() - test_forward_reshape() - test_forward_reciprocal() - test_forward_repeat() - test_forward_repeat_interleave() - test_forward_squeeze() - test_forward_unsqueeze() - test_forward_concatenate() - test_forward_reduce_sum() - test_forward_reduce_prod() - test_forward_argmin() - test_forward_argmax() - test_forward_norm() - test_forward_frobenius_norm() - test_forward_std() - test_forward_variance() - test_forward_relu() - test_forward_prelu() - test_forward_leakyrelu() - test_forward_elu() - test_forward_celu() - test_forward_gelu() - test_forward_selu() - test_forward_log_sigmoid() - test_forward_adaptiveavgpool() - test_forward_maxpool2d() - test_forward_maxpool1d() - test_forward_maxpool3d() - test_forward_hardtanh() - test_forward_conv() - test_forward_conv_transpose() - test_forward_threshold() - test_forward_contiguous() - test_forward_batchnorm() - test_forward_instancenorm() - test_forward_layernorm() - test_forward_groupnorm() - test_forward_transpose() - test_forward_size() - test_forward_view() - test_forward_select() - test_forward_take() - test_forward_topk() - test_forward_where() - test_forward_addcdiv() - test_forward_addcmul() - test_forward_true_divide() - test_forward_clone() - test_forward_softplus() - test_forward_softsign() - test_forward_logsoftmax() - test_forward_sigmoid() - test_forward_dense() - test_forward_avgpool() - test_forward_avgpool3d() - test_forward_dropout() - test_forward_slice() - test_forward_mean() - test_forward_expand() - test_forward_pow() - test_forward_unary() - test_forward_clamp() - test_forward_clamp_() - test_forward_logical_not() - test_forward_bitwise_not() - test_forward_bitwise_xor() - test_forward_logical_xor() - test_forward_isfinite() - test_forward_isnan() - test_forward_isinf() - test_forward_ones() - test_forward_ones_like() - test_forward_zeros() - test_forward_zeros_like() - test_forward_full() - test_forward_full_like() - test_forward_linspace() - test_forward_arange() - test_forward_mesh_grid() - test_forward_chunk() - test_forward_split() - test_forward_gather() - test_upsample() - test_forward_upsample3d() - test_forward_nms() - test_forward_roi_align() - test_to() - test_flatten() - test_type_as() - test_forward_functional_pad() - test_forward_zero_pad2d() - test_forward_constant_pad1d() - test_forward_constant_pad2d() - test_forward_constant_pad3d() - test_forward_reflection_pad1d() - test_forward_reflection_pad2d() - test_forward_replication_pad1d() - test_forward_replication_pad2d() - test_forward_replication_pad3d() - test_adaptive_pool3d() - test_conv3d() - test_conv3d_transpose() - test_forward_index() - test_min_max() - test_logsumexp() - test_stack() - test_stack_dynamic() - test_forward_unbind() - test_forward_nonzero() + # # some structural tests + # test_forward_traced_function() + # test_forward_dtypes() + # test_weight_names() + # test_duplicate_weight_use() + + # # Single operator tests + # test_forward_pixel_shuffle() + # test_forward_add() + # test_forward_subtract() + # test_forward_multiply() + # test_forward_matmul() + # test_forward_rsub() + # test_forward_onehot() + # test_forward_embedding() + # test_forward_reshape() + # test_forward_reciprocal() + # test_forward_repeat() + # test_forward_repeat_interleave() + # test_forward_squeeze() + # test_forward_unsqueeze() + # test_forward_concatenate() + # test_forward_reduce_sum() + # test_forward_reduce_prod() + # test_forward_argmin() + # test_forward_argmax() + # test_forward_norm() + # test_forward_frobenius_norm() + # test_forward_std() + # test_forward_variance() + # test_forward_relu() + # test_forward_prelu() + # test_forward_leakyrelu() + # test_forward_elu() + # test_forward_celu() + # test_forward_gelu() + # test_forward_selu() + # test_forward_log_sigmoid() + # test_forward_adaptiveavgpool() + # test_forward_maxpool2d() + # test_forward_maxpool1d() + # test_forward_maxpool3d() + # test_forward_hardtanh() + # test_forward_conv() + # test_forward_conv_transpose() + # test_forward_threshold() + # test_forward_contiguous() + # test_forward_batchnorm() + # test_forward_instancenorm() + # test_forward_layernorm() + # test_forward_groupnorm() + # test_forward_transpose() + # test_forward_size() + # test_forward_view() + # test_forward_select() + # test_forward_take() + # test_forward_topk() + # test_forward_where() + # test_forward_addcdiv() + # test_forward_addcmul() + # test_forward_true_divide() + # test_forward_clone() + # test_forward_softplus() + # test_forward_softsign() + # test_forward_logsoftmax() + # test_forward_sigmoid() + # test_forward_dense() + # test_forward_avgpool() + # test_forward_avgpool3d() + # test_forward_dropout() + # test_forward_slice() + # test_forward_mean() + # test_forward_expand() + # test_forward_pow() + # test_forward_unary() + # test_forward_clamp() + # test_forward_clamp_() + # test_forward_logical_not() + # test_forward_bitwise_not() + # test_forward_bitwise_xor() + # test_forward_logical_xor() + # test_forward_isfinite() + # test_forward_isnan() + # test_forward_isinf() + # test_forward_ones() + # test_forward_ones_like() + # test_forward_zeros() + # test_forward_zeros_like() + # test_forward_full() + # test_forward_full_like() + # test_forward_linspace() + # test_forward_arange() + # test_forward_mesh_grid() + # test_forward_chunk() + # test_forward_split() + # test_forward_gather() + # test_upsample() + # test_forward_upsample3d() + # test_forward_nms() + # test_forward_roi_align() + # test_to() + # test_flatten() + # test_type_as() + # test_forward_functional_pad() + # test_forward_zero_pad2d() + # test_forward_constant_pad1d() + # test_forward_constant_pad2d() + # test_forward_constant_pad3d() + # test_forward_reflection_pad1d() + # test_forward_reflection_pad2d() + # test_forward_replication_pad1d() + # test_forward_replication_pad2d() + # test_forward_replication_pad3d() + # test_adaptive_pool3d() + # test_conv3d() + # test_conv3d_transpose() + # test_forward_index() + # test_min_max() + # test_logsumexp() + # test_stack() + # test_stack_dynamic() + # test_forward_unbind() + # test_forward_nonzero() test_forward_scatter() - test_numel() - test_bincount() + # test_numel() + # test_bincount() - # Model tests - test_resnet18() - test_squeezenet1_0() - test_squeezenet1_1() - test_densenet121() - # disable inception test for now, since loading it takes ~5min on torchvision-0.5 due to scipy bug - # See https://discuss.pytorch.org/t/torchvisions-inception-v3-takes-much-longer-to-load-than-other-models/68756 - # test_inception_v3() - test_googlenet() - test_mnasnet0_5() - test_mobilenet_v2() + # # Model tests + # test_resnet18() + # test_squeezenet1_0() + # test_squeezenet1_1() + # test_densenet121() + # # disable inception test for now, since loading it takes ~5min on torchvision-0.5 due to scipy bug + # # See https://discuss.pytorch.org/t/torchvisions-inception-v3-takes-much-longer-to-load-than-other-models/68756 + # # test_inception_v3() + # test_googlenet() + # test_mnasnet0_5() + # test_mobilenet_v2() - test_custom_conversion_map() + # test_custom_conversion_map() - test_segmentaton_models() - test_3d_models() + # test_segmentaton_models() + # test_3d_models() - # Quantization test - from qnn_test import test_quantized_imagenet, test_quantized_modules + # # Quantization test + # from qnn_test import test_quantized_imagenet, test_quantized_modules - test_quantized_modules() - test_quantized_imagenet() + # test_quantized_modules() + # test_quantized_imagenet() - # Test simple conditionals and loop - test_control_flow() - test_simple_rnn() + # # Test simple conditionals and loop + # test_control_flow() + # test_simple_rnn() - # More complex recurrent models - from test_lstm import test_custom_lstm + # # More complex recurrent models + # from test_lstm import test_custom_lstm - test_custom_lstm() + # test_custom_lstm() - # Test bert model - test_forward_pretrained_bert_base_uncased() + # # Test bert model + # test_forward_pretrained_bert_base_uncased() - # Test convert torch script(jit) with specific inputs' types - test_convert_torch_script_with_input_types() + # # Test convert torch script(jit) with specific inputs' types + # test_convert_torch_script_with_input_types() From 9cb26d4644a3e8bb96ef12f6a1ec39d5f3f63a14 Mon Sep 17 00:00:00 2001 From: masa Date: Thu, 5 Nov 2020 16:06:59 +0900 Subject: [PATCH 2/5] adding update_func arg --- python/tvm/topi/cuda/scatter.py | 92 +++++++++++++++++++++++++++++---- 1 file changed, 81 insertions(+), 11 deletions(-) diff --git a/python/tvm/topi/cuda/scatter.py b/python/tvm/topi/cuda/scatter.py index 6522d74d8bef..0b9463d1b9fe 100644 --- a/python/tvm/topi/cuda/scatter.py +++ b/python/tvm/topi/cuda/scatter.py @@ -24,7 +24,7 @@ def ceil_div(a, b): return (a + b - 1) // b -def gen_ir_1d(data, indices, updates, axis, out): +def gen_ir_1d(data, indices, updates, axis, out, update_func): """Generate scatter ir for 1d inputs Parameters @@ -44,6 +44,9 @@ def gen_ir_1d(data, indices, updates, axis, out): out : tir.Tensor The output tensor. + update_func: function + The function to be applied to a destination and the corresponding update. + Returns ------- ret : tir @@ -73,14 +76,14 @@ def gen_ir_1d(data, indices, updates, axis, out): with ib.for_range(0, ni, name="i") as i: index = indices_ptr[i] with ib.if_scope(index < 0): - out_ptr[index + n] = updates_ptr[i] + update_func(out_ptr, index + n, updates_ptr[i]) with ib.else_scope(): - out_ptr[index] = updates_ptr[i] + update_func(out_ptr, index, updates_ptr[i]) return ib.get() -def gen_ir_2d(data, indices, updates, axis, out): +def gen_ir_2d(data, indices, updates, axis, out, update_func): """Generate scatter ir for 2d inputs Parameters @@ -100,6 +103,9 @@ def gen_ir_2d(data, indices, updates, axis, out): out : tir.Tensor The output tensor. + update_func: function + The function to be applied to a destination and the corresponding update + Returns ------- ret : tir @@ -140,9 +146,9 @@ def gen_ir_2d(data, indices, updates, axis, out): idx = i * ci + j index = indices_ptr[idx] with ib.if_scope(index < 0): - out_ptr[(index + n) * c + j] = updates_ptr[idx] + update_func(out_ptr, (index + n) * c + j, updates_ptr[idx]) with ib.else_scope(): - out_ptr[index * c + j] = updates_ptr[idx] + update_func(out_ptr, index * c + j, updates_ptr[idx]) else: with ib.new_scope(): i = te.thread_axis("blockIdx.x") @@ -151,13 +157,13 @@ def gen_ir_2d(data, indices, updates, axis, out): idx = i * ci + j index = indices_ptr[idx] with ib.if_scope(index < 0): - out_ptr[i * c + (index + c)] = updates_ptr[idx] + update_func(out_ptr, i * c + (index + c), updates_ptr[idx]) with ib.else_scope(): - out_ptr[i * c + index] = updates_ptr[idx] + update_func(out_ptr, i * c + index, updates_ptr[idx]) return ib.get() -def gen_ir_3d(data, indices, updates, axis, out): +def gen_ir_3d(data, indices, updates, axis, out, update_func): """Generate scatter ir for 3d inputs Parameters @@ -177,6 +183,9 @@ def gen_ir_3d(data, indices, updates, axis, out): out : tir.Tensor The output tensor. + update_func: function + The function to be applied to a destination and the corresponding update + Returns ------- ret : tir @@ -260,7 +269,7 @@ def gen_ir_3d(data, indices, updates, axis, out): return ib.get() -def gen_ir_4d(data, indices, updates, axis, out): +def gen_ir_4d(data, indices, updates, axis, out, update_func): """Generate scatter ir for 4d inputs Parameters @@ -280,6 +289,9 @@ def gen_ir_4d(data, indices, updates, axis, out): out : tir.Tensor The output tensor. + update_func: function + The function to be applied to a destination and the corresponding update + Returns ------- ret : tir @@ -428,12 +440,15 @@ def scatter(data, indices, updates, axis=0): 4: gen_ir_4d, } + def update_func(dst_ptr, dst_index, update): + dst_ptr[dst_index] = update + out_shape = data.shape out_buf = tvm.tir.decl_buffer(out_shape, data.dtype, "out_buf") out = te.extern( [out_shape], [data, indices, updates], - lambda ins, outs: ir_funcs[rank](ins[0], ins[1], ins[2], axis, outs[0]), + lambda ins, outs: ir_funcs[rank](ins[0], ins[1], ins[2], axis, outs[0], update_func), dtype=data.dtype, out_buffers=[out_buf], name="scatter_gpu", @@ -441,3 +456,58 @@ def scatter(data, indices, updates, axis=0): ) return out + + +def scatter_add(data, indices, updates, axis=0): + """Update data by adding values in updates at positions defined by indices + + Parameters + ---------- + data : relay.Expr + The input data to the operator. + + indices : relay.Expr + The index locations to update. + + updates : relay.Expr + The values to be added. + + axis : int + The axis to scatter on + + Returns + ------- + ret : relay.Expr + The computed result. + """ + if axis < 0: + axis += len(data.shape) + assert axis >= 0 + assert axis < len(data.shape) + + rank = len(data.shape) + assert 1 <= rank <= 4, "scatter_add only supports 1-4 dimensions" + + ir_funcs = { + 1: gen_ir_1d, + 2: gen_ir_2d, + 3: gen_ir_3d, + 4: gen_ir_4d, + } + + def update_func(dst_ptr, dst_index, update): + dst_ptr[dst_index] += update + + out_shape = data.shape + out_buf = tvm.tir.decl_buffer(out_shape, data.dtype, "out_buf") + out = te.extern( + [out_shape], + [data, indices, updates], + lambda ins, outs: ir_funcs[rank](ins[0], ins[1], ins[2], axis, outs[0], update_func), + dtype=data.dtype, + out_buffers=[out_buf], + name="scatter_add_gpu", + tag="scatter_add_gpu", + ) + + return out From bb8418130b11d19c709fa07aedd7841749e01458 Mon Sep 17 00:00:00 2001 From: masa Date: Thu, 5 Nov 2020 16:15:35 +0900 Subject: [PATCH 3/5] pytorch scatter_add gpu tests working --- python/tvm/relay/op/_transform.py | 2 +- python/tvm/relay/op/strategy/cuda.py | 15 +- python/tvm/relay/op/strategy/generic.py | 15 +- tests/python/frontend/pytorch/test_forward.py | 310 +++++++++--------- 4 files changed, 179 insertions(+), 163 deletions(-) diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index edcd5b9ce161..e42b8bbae814 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -113,7 +113,7 @@ def compute_scatter_add(attrs, inputs, output_type): return [topi.scatter_add(inputs[0], inputs[1], inputs[2], attrs.axis)] -_reg.register_schedule("scatter_add", strategy.schedule_scatter_add) +_reg.register_strategy("scatter_add", strategy.scatter_add_strategy) ##################### # Shape functions # diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index b7ceda304639..26e9a0060b66 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -664,7 +664,7 @@ def sparse_dense_padded_strategy_cuda(attrs, inputs, out_type, target): @scatter_strategy.register(["cuda", "gpu"]) def scatter_cuda(attrs, inputs, out_type, target): - """sparse dense cuda strategy""" + """scatter cuda strategy""" strategy = _op.OpStrategy() strategy.add_implementation( wrap_compute_scatter(topi.cuda.scatter), @@ -675,6 +675,19 @@ def scatter_cuda(attrs, inputs, out_type, target): return strategy +@scatter_add_strategy.register(["cuda", "gpu"]) +def scatter_add_cuda(attrs, inputs, out_type, target): + """scatter_add cuda strategy""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_scatter(topi.cuda.scatter_add), + wrap_topi_schedule(topi.generic.schedule_extern), + name="scatter_add.cuda", + plevel=10, + ) + return strategy + + @argsort_strategy.register(["cuda", "gpu"]) def argsort_strategy_cuda(attrs, inputs, out_type, target): """argsort cuda strategy""" diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index bdefbcb79009..e49135c4d1bf 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -1052,12 +1052,15 @@ def _compute_scatter(attrs, inputs, _): return _compute_scatter -# scatter_add -@generic_func -def schedule_scatter_add(attrs, outs, target): - """schedule scatter_add""" - with target: - return topi.generic.schedule_scatter_add(outs) +@override_native_generic_func("scatter_add_strategy") +def scatter_add_strategy(attrs, outs, out_type, target): + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_scatter(topi.scatter_add), + wrap_topi_schedule(topi.generic.schedule_scatter), + name="scatter_add.generic", + ) + return strategy # bitserial_conv2d diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index b96470da3561..6250dfff811a 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -3152,14 +3152,14 @@ def test_fn_scatter_add(dim): targets = ["llvm", "cuda"] verify_trace_model(test_fn_scatter(0), [in_data, in_index, in_src], targets) - # verify_trace_model(test_fn_scatter_add(0), [in_data, in_index, in_src], targets) + verify_trace_model(test_fn_scatter_add(0), [in_data, in_index, in_src], targets) in_data = torch.zeros(2, 4) in_index = torch.tensor([[2], [3]]) in_src = torch.rand(2, 1) verify_trace_model(test_fn_scatter(1), [in_data, in_index, in_src], targets) - # verify_trace_model(test_fn_scatter_add(1), [in_data, in_index, in_src], ["llvm"]) + verify_trace_model(test_fn_scatter_add(1), [in_data, in_index, in_src], targets) def test_numel(): @@ -3364,167 +3364,167 @@ def test_fn(x, weights=None): if __name__ == "__main__": - # # some structural tests - # test_forward_traced_function() - # test_forward_dtypes() - # test_weight_names() - # test_duplicate_weight_use() - - # # Single operator tests - # test_forward_pixel_shuffle() - # test_forward_add() - # test_forward_subtract() - # test_forward_multiply() - # test_forward_matmul() - # test_forward_rsub() - # test_forward_onehot() - # test_forward_embedding() - # test_forward_reshape() - # test_forward_reciprocal() - # test_forward_repeat() - # test_forward_repeat_interleave() - # test_forward_squeeze() - # test_forward_unsqueeze() - # test_forward_concatenate() - # test_forward_reduce_sum() - # test_forward_reduce_prod() - # test_forward_argmin() - # test_forward_argmax() - # test_forward_norm() - # test_forward_frobenius_norm() - # test_forward_std() - # test_forward_variance() - # test_forward_relu() - # test_forward_prelu() - # test_forward_leakyrelu() - # test_forward_elu() - # test_forward_celu() - # test_forward_gelu() - # test_forward_selu() - # test_forward_log_sigmoid() - # test_forward_adaptiveavgpool() - # test_forward_maxpool2d() - # test_forward_maxpool1d() - # test_forward_maxpool3d() - # test_forward_hardtanh() - # test_forward_conv() - # test_forward_conv_transpose() - # test_forward_threshold() - # test_forward_contiguous() - # test_forward_batchnorm() - # test_forward_instancenorm() - # test_forward_layernorm() - # test_forward_groupnorm() - # test_forward_transpose() - # test_forward_size() - # test_forward_view() - # test_forward_select() - # test_forward_take() - # test_forward_topk() - # test_forward_where() - # test_forward_addcdiv() - # test_forward_addcmul() - # test_forward_true_divide() - # test_forward_clone() - # test_forward_softplus() - # test_forward_softsign() - # test_forward_logsoftmax() - # test_forward_sigmoid() - # test_forward_dense() - # test_forward_avgpool() - # test_forward_avgpool3d() - # test_forward_dropout() - # test_forward_slice() - # test_forward_mean() - # test_forward_expand() - # test_forward_pow() - # test_forward_unary() - # test_forward_clamp() - # test_forward_clamp_() - # test_forward_logical_not() - # test_forward_bitwise_not() - # test_forward_bitwise_xor() - # test_forward_logical_xor() - # test_forward_isfinite() - # test_forward_isnan() - # test_forward_isinf() - # test_forward_ones() - # test_forward_ones_like() - # test_forward_zeros() - # test_forward_zeros_like() - # test_forward_full() - # test_forward_full_like() - # test_forward_linspace() - # test_forward_arange() - # test_forward_mesh_grid() - # test_forward_chunk() - # test_forward_split() - # test_forward_gather() - # test_upsample() - # test_forward_upsample3d() - # test_forward_nms() - # test_forward_roi_align() - # test_to() - # test_flatten() - # test_type_as() - # test_forward_functional_pad() - # test_forward_zero_pad2d() - # test_forward_constant_pad1d() - # test_forward_constant_pad2d() - # test_forward_constant_pad3d() - # test_forward_reflection_pad1d() - # test_forward_reflection_pad2d() - # test_forward_replication_pad1d() - # test_forward_replication_pad2d() - # test_forward_replication_pad3d() - # test_adaptive_pool3d() - # test_conv3d() - # test_conv3d_transpose() - # test_forward_index() - # test_min_max() - # test_logsumexp() - # test_stack() - # test_stack_dynamic() - # test_forward_unbind() - # test_forward_nonzero() + # some structural tests + test_forward_traced_function() + test_forward_dtypes() + test_weight_names() + test_duplicate_weight_use() + + # Single operator tests + test_forward_pixel_shuffle() + test_forward_add() + test_forward_subtract() + test_forward_multiply() + test_forward_matmul() + test_forward_rsub() + test_forward_onehot() + test_forward_embedding() + test_forward_reshape() + test_forward_reciprocal() + test_forward_repeat() + test_forward_repeat_interleave() + test_forward_squeeze() + test_forward_unsqueeze() + test_forward_concatenate() + test_forward_reduce_sum() + test_forward_reduce_prod() + test_forward_argmin() + test_forward_argmax() + test_forward_norm() + test_forward_frobenius_norm() + test_forward_std() + test_forward_variance() + test_forward_relu() + test_forward_prelu() + test_forward_leakyrelu() + test_forward_elu() + test_forward_celu() + test_forward_gelu() + test_forward_selu() + test_forward_log_sigmoid() + test_forward_adaptiveavgpool() + test_forward_maxpool2d() + test_forward_maxpool1d() + test_forward_maxpool3d() + test_forward_hardtanh() + test_forward_conv() + test_forward_conv_transpose() + test_forward_threshold() + test_forward_contiguous() + test_forward_batchnorm() + test_forward_instancenorm() + test_forward_layernorm() + test_forward_groupnorm() + test_forward_transpose() + test_forward_size() + test_forward_view() + test_forward_select() + test_forward_take() + test_forward_topk() + test_forward_where() + test_forward_addcdiv() + test_forward_addcmul() + test_forward_true_divide() + test_forward_clone() + test_forward_softplus() + test_forward_softsign() + test_forward_logsoftmax() + test_forward_sigmoid() + test_forward_dense() + test_forward_avgpool() + test_forward_avgpool3d() + test_forward_dropout() + test_forward_slice() + test_forward_mean() + test_forward_expand() + test_forward_pow() + test_forward_unary() + test_forward_clamp() + test_forward_clamp_() + test_forward_logical_not() + test_forward_bitwise_not() + test_forward_bitwise_xor() + test_forward_logical_xor() + test_forward_isfinite() + test_forward_isnan() + test_forward_isinf() + test_forward_ones() + test_forward_ones_like() + test_forward_zeros() + test_forward_zeros_like() + test_forward_full() + test_forward_full_like() + test_forward_linspace() + test_forward_arange() + test_forward_mesh_grid() + test_forward_chunk() + test_forward_split() + test_forward_gather() + test_upsample() + test_forward_upsample3d() + test_forward_nms() + test_forward_roi_align() + test_to() + test_flatten() + test_type_as() + test_forward_functional_pad() + test_forward_zero_pad2d() + test_forward_constant_pad1d() + test_forward_constant_pad2d() + test_forward_constant_pad3d() + test_forward_reflection_pad1d() + test_forward_reflection_pad2d() + test_forward_replication_pad1d() + test_forward_replication_pad2d() + test_forward_replication_pad3d() + test_adaptive_pool3d() + test_conv3d() + test_conv3d_transpose() + test_forward_index() + test_min_max() + test_logsumexp() + test_stack() + test_stack_dynamic() + test_forward_unbind() + test_forward_nonzero() test_forward_scatter() - # test_numel() - # test_bincount() + test_numel() + test_bincount() - # # Model tests - # test_resnet18() - # test_squeezenet1_0() - # test_squeezenet1_1() - # test_densenet121() - # # disable inception test for now, since loading it takes ~5min on torchvision-0.5 due to scipy bug - # # See https://discuss.pytorch.org/t/torchvisions-inception-v3-takes-much-longer-to-load-than-other-models/68756 - # # test_inception_v3() - # test_googlenet() - # test_mnasnet0_5() - # test_mobilenet_v2() + # Model tests + test_resnet18() + test_squeezenet1_0() + test_squeezenet1_1() + test_densenet121() + # disable inception test for now, since loading it takes ~5min on torchvision-0.5 due to scipy bug + # See https://discuss.pytorch.org/t/torchvisions-inception-v3-takes-much-longer-to-load-than-other-models/68756 + # test_inception_v3() + test_googlenet() + test_mnasnet0_5() + test_mobilenet_v2() - # test_custom_conversion_map() + test_custom_conversion_map() - # test_segmentaton_models() - # test_3d_models() + test_segmentaton_models() + test_3d_models() - # # Quantization test - # from qnn_test import test_quantized_imagenet, test_quantized_modules + # Quantization test + from qnn_test import test_quantized_imagenet, test_quantized_modules - # test_quantized_modules() - # test_quantized_imagenet() + test_quantized_modules() + test_quantized_imagenet() - # # Test simple conditionals and loop - # test_control_flow() - # test_simple_rnn() + # Test simple conditionals and loop + test_control_flow() + test_simple_rnn() - # # More complex recurrent models - # from test_lstm import test_custom_lstm + # More complex recurrent models + from test_lstm import test_custom_lstm - # test_custom_lstm() + test_custom_lstm() - # # Test bert model - # test_forward_pretrained_bert_base_uncased() + # Test bert model + test_forward_pretrained_bert_base_uncased() - # # Test convert torch script(jit) with specific inputs' types - # test_convert_torch_script_with_input_types() + # Test convert torch script(jit) with specific inputs' types + test_convert_torch_script_with_input_types() From 5df65e13e0751b070f0a9db50b1516f620474547 Mon Sep 17 00:00:00 2001 From: masa Date: Thu, 5 Nov 2020 16:23:06 +0900 Subject: [PATCH 4/5] update 3d and 4d scatter --- python/tvm/topi/cuda/scatter.py | 41 +++++++++++++++++++++------------ 1 file changed, 26 insertions(+), 15 deletions(-) diff --git a/python/tvm/topi/cuda/scatter.py b/python/tvm/topi/cuda/scatter.py index 0b9463d1b9fe..0a3e96f4be30 100644 --- a/python/tvm/topi/cuda/scatter.py +++ b/python/tvm/topi/cuda/scatter.py @@ -234,9 +234,9 @@ def gen_ir_3d(data, indices, updates, axis, out, update_func): idx = (i * ci + j) * hi + k index = indices_ptr[idx] with ib.if_scope(index < 0): - out_ptr[((index + n) * c + j) * h + k] = updates_ptr[idx] + update_func(out_ptr, ((index + n) * c + j) * h + k, updates_ptr[idx]) with ib.else_scope(): - out_ptr[(index * c + j) * h + k] = updates_ptr[idx] + update_func(out_ptr, (index * c + j) * h + k, updates_ptr[idx]) elif axis == 1: with ib.new_scope(): i = te.thread_axis("blockIdx.x") @@ -250,9 +250,9 @@ def gen_ir_3d(data, indices, updates, axis, out, update_func): idx = (i * ci + j) * hi + k index = indices_ptr[idx] with ib.if_scope(index < 0): - out_ptr[(i * c + (index + c)) * h + k] = updates_ptr[idx] + update_func(out_ptr, (i * c + (index + c)) * h + k, updates_ptr[idx]) with ib.else_scope(): - out_ptr[(i * c + index) * h + k] = updates_ptr[idx] + update_func(out_ptr, (i * c + index) * h + k, updates_ptr[idx]) else: with ib.new_scope(): i = te.thread_axis("blockIdx.x") @@ -263,9 +263,9 @@ def gen_ir_3d(data, indices, updates, axis, out, update_func): idx = (i * ci + j) * hi + k index = indices_ptr[idx] with ib.if_scope(index < 0): - out_ptr[(i * c + j) * h + (index + h)] = updates_ptr[idx] + update_func(out_ptr, (i * c + j) * h + (index + h), updates_ptr[idx]) with ib.else_scope(): - out_ptr[(i * c + j) * h + index] = updates_ptr[idx] + update_func(out_ptr, (i * c + j) * h + index, updates_ptr[idx]) return ib.get() @@ -345,9 +345,13 @@ def gen_ir_4d(data, indices, updates, axis, out, update_func): idx = ((i * ci + j) * hi + k) * wi + l index = indices_ptr[idx] with ib.if_scope(index < 0): - out_ptr[(((index + n) * c + j) * h + k) * w + l] = updates_ptr[idx] + update_func( + out_ptr, (((index + n) * c + j) * h + k) * w + l, updates_ptr[idx] + ) with ib.else_scope(): - out_ptr[((index * c + j) * h + k) * w + l] = updates_ptr[idx] + update_func( + out_ptr, ((index * c + j) * h + k) * w + l, updates_ptr[idx] + ) elif axis == 1: with ib.new_scope(): i = te.thread_axis("blockIdx.x") @@ -363,9 +367,13 @@ def gen_ir_4d(data, indices, updates, axis, out, update_func): idx = ((i * ci + j) * hi + k) * wi + l index = indices_ptr[idx] with ib.if_scope(index < 0): - out_ptr[((i * c + (index + c)) * h + k) * w + l] = updates_ptr[idx] + update_func( + out_ptr, ((i * c + (index + c)) * h + k) * w + l, updates_ptr[idx] + ) with ib.else_scope(): - out_ptr[((i * c + index) * h + k) * w + l] = updates_ptr[idx] + update_func( + out_ptr, ((i * c + index) * h + k) * w + l, updates_ptr[idx] + ) elif axis == 2: with ib.new_scope(): i = te.thread_axis("blockIdx.x") @@ -381,9 +389,13 @@ def gen_ir_4d(data, indices, updates, axis, out, update_func): idx = ((i * ci + j) * hi + k) * wi + l index = indices_ptr[idx] with ib.if_scope(index < 0): - out_ptr[((i * c + j) * h + (index + h)) * w + l] = updates_ptr[idx] + update_func( + out_ptr, ((i * c + j) * h + (index + h)) * w + l, updates_ptr[idx] + ) with ib.else_scope(): - out_ptr[((i * c + j) * h + index) * w + l] = updates_ptr[idx] + update_func( + out_ptr, ((i * c + j) * h + index) * w + l, updates_ptr[idx] + ) else: with ib.new_scope(): i = te.thread_axis("blockIdx.x") @@ -396,10 +408,9 @@ def gen_ir_4d(data, indices, updates, axis, out, update_func): idx = ((i * ci + j) * hi + k) * wi + l index = indices_ptr[idx] with ib.if_scope(index < 0): - out_ptr[((i * c + j) * h + k) * w + (index + w)] = updates_ptr[idx] + update_func(out_ptr, ((i * c + j) * h + k) * w + (index + w), updates_ptr[idx]) with ib.else_scope(): - out_ptr[((i * c + j) * h + k) * w + index] = updates_ptr[idx] - + update_func(out_ptr, ((i * c + j) * h + k) * w + index, updates_ptr[idx]) return ib.get() From 00e08bce0d442718ff5005cd743d4bfb5c666140 Mon Sep 17 00:00:00 2001 From: masa Date: Thu, 5 Nov 2020 16:25:56 +0900 Subject: [PATCH 5/5] enable scatter_add gpu test --- tests/python/relay/test_op_level3.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index 3ea0777df8ca..01de6d265cb2 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -961,6 +961,7 @@ def verify_dynamic_scatter(dshape, ishape, axis=0): verify_dynamic_scatter((16, 16, 4, 5), (16, 16, 4, 5), 3) +@tvm.testing.uses_gpu def test_scatter_add(): def ref_scatter_add(data, indices, updates, axis=0): output = np.copy(data) @@ -983,8 +984,7 @@ def verify_scatter_add(dshape, ishape, axis=0): indices_np = np.random.randint(-dshape[axis], dshape[axis] - 1, ishape).astype("int64") ref_res = ref_scatter_add(data_np, indices_np, updates_np, axis) - # TODO(mbrookhart): expand testing when adding more backend schedules - for target, ctx in [("llvm", tvm.cpu())]: + for target, ctx in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, ctx=ctx, target=target) op_res = intrp.evaluate(func)(data_np, indices_np, updates_np)