diff --git a/include/tvm/relax/attrs/nn.h b/include/tvm/relax/attrs/nn.h index 832934417484..8f63012e095a 100644 --- a/include/tvm/relax/attrs/nn.h +++ b/include/tvm/relax/attrs/nn.h @@ -462,6 +462,7 @@ struct BatchNormAttrs : public tvm::AttrsNode { bool center; bool scale; double momentum; + bool training; TVM_DECLARE_ATTRS(BatchNormAttrs, "relax.attrs.BatchNormAttrs") { TVM_ATTR_FIELD(axis).describe("The axis along which the normalization is applied."); @@ -470,6 +471,7 @@ struct BatchNormAttrs : public tvm::AttrsNode { "Indicating if the beta offset will be added to the normalized tensor."); TVM_ATTR_FIELD(scale).describe("Indicating if the gamma scale will be multiplied."); TVM_ATTR_FIELD(momentum).describe("The value used for the moving_mean and moving_var update."); + TVM_ATTR_FIELD(training).describe("Whether we are training (i.e., not in eval mode)."); } }; // struct BatchNormAttrs diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 4bfdb8c1bc36..ecd8665b4353 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1106,7 +1106,7 @@ def _detach(self, node: fx.Node) -> relax.Var: return self.env[node.args[0]] def _copy_(self, node: fx.Node) -> relax.Var: - # Copies the source tensor's to the destination tensor + # Copies the source tensor's into the destination tensor # In TVM, that means simply returning the source tensor return self.env[node.args[1]] diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index bf01bd653130..84821d27b5e8 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -45,7 +45,7 @@ def _hardtanh(self, node: fx.Node) -> relax.Expr: ########## Neural Network ########## - def _batch_norm_legit_no_training(self, node: fx.Node) -> relax.Var: + def _batch_norm(self, node: fx.Node, training) -> relax.Var: import numpy as np x = self.env[node.args[0]] @@ -55,22 +55,43 @@ def _batch_norm_legit_no_training(self, node: fx.Node) -> relax.Var: bias = self.env.get(node.args[2], relax.const(np.zeros(channel), dtype=dtype)) running_mean = self.env.get(node.args[3], relax.const(np.zeros(channel), dtype=dtype)) running_var = self.env.get(node.args[4], relax.const(np.ones(channel), dtype=dtype)) - momentum = node.args[5] if len(node.args) > 5 else node.kwargs.get("momentum", 0.1) - eps = node.args[6] if len(node.args) > 6 else node.kwargs.get("eps", 1e-05) + ignore_running_stats = ( + node.args[5] if len(node.args) > 5 else node.kwargs.get("track_running_stats", True) + ) + track_running_stats = not ignore_running_stats + momentum = node.args[6] if len(node.args) > 6 else node.kwargs.get("momentum", 0.1) + eps = node.args[7] if len(node.args) > 7 else node.kwargs.get("eps", 1e-05) + + if track_running_stats: + training = True return self.block_builder.emit( relax.op.nn.batch_norm( - x, - weight, - bias, - running_mean, - running_var, - axis=1, + data=x, + gamma=weight, + beta=bias, + moving_mean=running_mean, + moving_var=running_var, + axis=1, # Always over channel epsilon=eps, momentum=momentum, - ) + training=training, + )[0] ) + def _batch_norm_legit_functional(self, node: fx.Node) -> relax.Var: + # This method is called for batch_norm in training mode + # TODO does not have correctness! + # TODO we need to store the running mean and variance returned by the + # previous call to batch_norm and pass it again + training = True + return self._batch_norm(node, training) + + def _batch_norm_legit_no_training(self, node: fx.Node) -> relax.Var: + # This method is called for batch_norm in eval mode + training = False + return self._batch_norm(node, training) + def _group_norm(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] num_groups = node.args[1] @@ -283,7 +304,9 @@ def create_convert_map( # linear algebra "linalg_vector_norm.default": self._linalg_vector_norm, # neural network + "_native_batch_norm_legit_functional.default": self._batch_norm_legit_functional, "_native_batch_norm_legit_no_training.default": self._batch_norm_legit_no_training, + "batch_norm.default": self._batch_norm_legit_no_training, "adaptive_avg_pool2d.default": self._adaptive_avg_pool2d, "addmm.default": self._addmm, "avg_pool2d.default": self._avg_pool2d, diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py index 5a1895cbc14f..09a7df5149f9 100644 --- a/python/tvm/relax/op/nn/nn.py +++ b/python/tvm/relax/op/nn/nn.py @@ -1393,6 +1393,7 @@ def batch_norm( center: bool = True, scale: bool = True, momentum: float = 0.1, + training: bool = True, ) -> Expr: r""" Batch normalization layer (Ioffe and Szegedy, 2014). @@ -1481,13 +1482,18 @@ def batch_norm( momentum : float The value used for the moving_mean and moving_var update. + training : bool + A boolean value to indicate whether training or in eval mode. By default. + relax batch_norm is training mode. To transform it to inference mode, + can use DecomposeOpsForInference. + Returns ------- result : relax.Expr The computed result. """ return _ffi_api.batch_norm( # type: ignore - data, gamma, beta, moving_mean, moving_var, axis, epsilon, center, scale, momentum + data, gamma, beta, moving_mean, moving_var, axis, epsilon, center, scale, momentum, training ) diff --git a/python/tvm/relax/transform/legalize_ops/nn.py b/python/tvm/relax/transform/legalize_ops/nn.py index d9fb4701f7e9..4c8bdbc6615c 100644 --- a/python/tvm/relax/transform/legalize_ops/nn.py +++ b/python/tvm/relax/transform/legalize_ops/nn.py @@ -551,9 +551,7 @@ def _nn_batch_norm(bb: BlockBuilder, call: Call) -> Expr: epsilon=call.attrs.epsilon, center=call.attrs.center, scale=call.attrs.scale, - # By default relax batch_norm is training mode. - # To transform it to inference mode, use DecomposeOpsForInference. - training=True, + training=call.attrs.training, momentum=call.attrs.momentum, ) diff --git a/python/tvm/topi/nn/batch_norm.py b/python/tvm/topi/nn/batch_norm.py index 3181efd7daa6..8308c93eae4f 100644 --- a/python/tvm/topi/nn/batch_norm.py +++ b/python/tvm/topi/nn/batch_norm.py @@ -111,22 +111,26 @@ def batch_norm( shape = [1] * len(data.shape) shape[axis] = data.shape[axis] + reduce_axes = list(range(len(data.shape))) + reduce_axes.remove(axis) + shape_prod = reduce(lambda x, y: x * y, [data.shape[ax] for ax in reduce_axes], 1) + + data_mean = topi.sum(data, axis=reduce_axes) / shape_prod + data_mean_rs = topi.reshape(data_mean, shape) + data_var = ( + topi.sum((data - data_mean_rs) * (data - data_mean_rs), axis=reduce_axes) / shape_prod + ) + data_var_rs = topi.reshape(data_var, shape) + if training: - reduce_axes = list(range(len(data.shape))) - reduce_axes.remove(axis) - shape_prod = reduce(lambda x, y: x * y, [data.shape[ax] for ax in reduce_axes], 1) - data_mean = topi.sum(data, axis=reduce_axes) / shape_prod - data_mean_rs = topi.reshape(data_mean, shape) - data_var = ( - topi.sum((data - data_mean_rs) * (data - data_mean_rs), axis=reduce_axes) / shape_prod - ) - data_var_rs = topi.reshape(data_var, shape) - out = (data - data_mean_rs) / topi.math.sqrt(data_var_rs + epsilon) - else: moving_mean_rs = topi.reshape(moving_mean, shape) moving_var_rs = topi.reshape(moving_var, shape) + out = (data - moving_mean_rs) / topi.math.sqrt(moving_var_rs + epsilon) + else: + out = (data - data_mean_rs) / topi.math.sqrt(data_var_rs + epsilon) + if scale: out = out * topi.reshape(gamma, shape) if center: diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc index b4668d65d399..826711538c68 100644 --- a/src/relax/op/nn/nn.cc +++ b/src/relax/op/nn/nn.cc @@ -252,13 +252,14 @@ bool NormCheckDtypeAndShape(const Call& call, const BlockBuilder& ctx, TVM_REGISTER_NODE_TYPE(BatchNormAttrs); Expr batch_norm(Expr data, Expr gamma, Expr beta, Expr moving_mean, Expr moving_var, // - int axis, double epsilon, bool center, bool scale, double momentum) { + int axis, double epsilon, bool center, bool scale, double momentum, bool training) { ObjectPtr attrs = make_object(); attrs->axis = axis; attrs->epsilon = epsilon; attrs->center = center; attrs->scale = scale; attrs->momentum = momentum; + attrs->training = training; static const Op& op = Op::Get("relax.nn.batch_norm"); return Call(op, @@ -266,7 +267,6 @@ Expr batch_norm(Expr data, Expr gamma, Expr beta, Expr moving_mean, Expr moving_ std::move(moving_var)}, Attrs{attrs}, {}); } - TVM_REGISTER_GLOBAL("relax.op.nn.batch_norm").set_body_typed(batch_norm); StructInfo InferStructInfoBatchNorm(const Call& call, const BlockBuilder& ctx) { @@ -388,7 +388,7 @@ InferLayoutOutput InferLayoutLayerNorm(const Call& call, TVM_REGISTER_OP("relax.nn.layer_norm") .set_attrs_type() .set_num_inputs(3) - .add_argument("data", "Tensor", "Input to which batch_norm will be applied.") + .add_argument("data", "Tensor", "Input to which layer_norm will be applied.") .add_argument("gamma", "Tensor", "The gamma scale factor.") .add_argument("beta", "Tensor", "The beta offset factor.") .set_attr("FInferStructInfo", InferStructInfoLayerNorm) @@ -500,7 +500,7 @@ InferLayoutOutput InferLayoutGroupNorm(const Call& call, TVM_REGISTER_OP("relax.nn.group_norm") .set_attrs_type() .set_num_inputs(3) - .add_argument("data", "Tensor", "Input to which batch_norm will be applied.") + .add_argument("data", "Tensor", "Input to which group_norm will be applied.") .add_argument("gamma", "Tensor", "The gamma scale factor.") .add_argument("beta", "Tensor", "The beta offset factor.") .set_attr("FInferStructInfo", InferStructInfoGroupNorm) diff --git a/src/relax/op/nn/nn.h b/src/relax/op/nn/nn.h index a3658fed5430..28c14139b97b 100644 --- a/src/relax/op/nn/nn.h +++ b/src/relax/op/nn/nn.h @@ -68,7 +68,7 @@ Expr log_softmax(Expr data, int axis); /*! \brief Compute batch normalization. */ Expr batch_norm(Expr data, Expr gamma, Expr beta, Expr moving_mean, Expr moving_var, // - int axis, double epsilon, bool center, bool scale, double momentum); + int axis, double epsilon, bool center, bool scale, double momentum, bool training); /*! \brief Compute layer normalization. */ Expr layer_norm(Expr data, Expr gamma, Expr beta, Array axes, double epsilon, bool center, diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index c120eb89811c..f7501dd3b5b3 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -289,6 +289,25 @@ def forward(self, x): assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module3, target, dev) +@tvm.testing.parametrize_targets("cuda") +def test_batch_norm_prog(target, dev): + # Default args, in a pytorch program (to ensure output is in proper type and format) + raw_data = np.random.randn(2, 3, 2, 2).astype(np.float32) + + class BatchNormWrapper(nn.Module): + def __init__(self): + super(BatchNormWrapper, self).__init__() + self.bn = nn.BatchNorm2d(3) + + def forward(self, x): + x = self.bn(x) + x = x + 1 + return x + + torch_module = BatchNormWrapper().eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + @tvm.testing.parametrize_targets("cuda") def test_split_size(target, dev): # Test split using the split_size argument such that it is not a divisor @@ -310,7 +329,46 @@ def forward(self, x): return torch.split(x, split_size_or_sections=self.split_size, dim=self.dim) torch_module = SplitModelSplitSize(split_size=split_size, dim=dim).eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + +@tvm.testing.parametrize_targets("cuda") +def test_batch_norm0(target, dev): + # Eval, no momentum, no affine, no running stats + raw_data = np.random.randn(8, 3, 4, 4).astype(np.float32) + torch_module = nn.BatchNorm2d( + 3, eps=1e-02, momentum=0.0, affine=False, track_running_stats=False, device=None, dtype=None + ).eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + +@tvm.testing.parametrize_targets("cuda") +def test_batch_norm1(target, dev): + # Eval, with momentum, no affine, with running stats + raw_data = np.random.randn(1, 4, 2, 2).astype(np.float32) + torch_module = nn.BatchNorm2d( + 4, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True, device=None, dtype=None + ).eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + +@tvm.testing.parametrize_targets("cuda") +def test_batch_norm2(target, dev): + # Eval, with momentum, affine, no running stats + raw_data = np.random.randn(3, 4, 2, 2).astype(np.float32) + torch_module = nn.BatchNorm2d( + 4, eps=1e-05, momentum=0.2, affine=True, track_running_stats=False + ).eval() + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + +@tvm.testing.parametrize_targets("cuda") +def test_batch_norm3(target, dev): + # Eval, no momentum, affine, with running stats + raw_data = np.random.randn(1, 2, 2, 2).astype(np.float32) + torch_module = nn.BatchNorm2d( + 2, eps=1e-05, momentum=0.0, affine=True, track_running_stats=True + ).eval() assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) @@ -335,7 +393,6 @@ def forward(self, x): return torch.split(x, split_size_or_sections=self.split_size, dim=self.dim) torch_module = SplitModelSectionsList(split_size=sections, dim=dim).eval() - assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py b/tests/python/relax/test_transform_legalize_ops_nn.py index d83d0567e482..4ac4b57b91d4 100644 --- a/tests/python/relax/test_transform_legalize_ops_nn.py +++ b/tests/python/relax/test_transform_legalize_ops_nn.py @@ -1955,212 +1955,289 @@ def main(x: R.Tensor((2, 3, 28, 28), "float32"), gamma: R.Tensor((3,), "float32" @tvm.script.ir_module class Expected: @T.prim_func(private=True) - def batch_norm(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(28), T.int64(28)), "float32"), rxplaceholder_1: T.Buffer((T.int64(3),), "float32"), rxplaceholder_2: T.Buffer((T.int64(3),), "float32"), rxplaceholder_3: T.Buffer((T.int64(3),), "float32"), rxplaceholder_4: T.Buffer((T.int64(3),), "float32"), T_add: T.Buffer((T.int64(2), T.int64(3), T.int64(28), T.int64(28)), "float32"), T_add_1: T.Buffer((T.int64(3),), "float32"), T_add_2: T.Buffer((T.int64(3),), "float32")): - T.func_attr({"tir.noalias": True}) - # with T.block("root"): - rxplaceholder_red = T.alloc_buffer((T.int64(3),)) - T_divide = T.alloc_buffer((T.int64(3),)) - T_reshape = T.alloc_buffer((T.int64(1), T.int64(3), T.int64(1), T.int64(1))) - T_subtract = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(28), T.int64(28))) - T_subtract_1 = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(28), T.int64(28))) - T_subtract_2 = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(28), T.int64(28))) - T_multiply = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(28), T.int64(28))) - T_multiply_red = T.alloc_buffer((T.int64(3),)) - T_divide_1 = T.alloc_buffer((T.int64(3),)) - T_reshape_1 = T.alloc_buffer((T.int64(1), T.int64(3), T.int64(1), T.int64(1))) - T_add_3 = T.alloc_buffer((T.int64(1), T.int64(3), T.int64(1), T.int64(1))) - compute = T.alloc_buffer((T.int64(1), T.int64(3), T.int64(1), T.int64(1))) - T_divide_2 = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(28), T.int64(28))) - T_reshape_2 = T.alloc_buffer((T.int64(1), T.int64(3), T.int64(1), T.int64(1))) - T_multiply_1 = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(28), T.int64(28))) - T_reshape_3 = T.alloc_buffer((T.int64(1), T.int64(3), T.int64(1), T.int64(1))) - T_multiply_2 = T.alloc_buffer((T.int64(3),)) - T_multiply_3 = T.alloc_buffer((T.int64(3),)) - T_multiply_4 = T.alloc_buffer((T.int64(3),)) - T_subtract_3 = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(28), T.int64(28))) - T_subtract_4 = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(28), T.int64(28))) - T_multiply_5 = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(28), T.int64(28))) - T_multiply_red_1 = T.alloc_buffer((T.int64(3),)) - T_divide_3 = T.alloc_buffer((T.int64(3),)) - T_multiply_6 = T.alloc_buffer((T.int64(3),)) - for ax0, k0, k2, k3 in T.grid(T.int64(3), T.int64(2), T.int64(28), T.int64(28)): - with T.block("rxplaceholder_red"): - v_ax0, v_k0, v_k2, v_k3 = T.axis.remap("SRRR", [ax0, k0, k2, k3]) - T.reads(rxplaceholder[v_k0, v_ax0, v_k2, v_k3]) - T.writes(rxplaceholder_red[v_ax0]) - with T.init(): - rxplaceholder_red[v_ax0] = T.float32(0) - rxplaceholder_red[v_ax0] = rxplaceholder_red[v_ax0] + rxplaceholder[v_k0, v_ax0, v_k2, v_k3] - for ax0 in range(T.int64(3)): - with T.block("T_divide"): - v_ax0 = T.axis.spatial(T.int64(3), ax0) - T.reads(rxplaceholder_red[v_ax0]) - T.writes(T_divide[v_ax0]) - T_divide[v_ax0] = rxplaceholder_red[v_ax0] * T.float32(0.00063775510204081628) - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(3), T.int64(1), T.int64(1)): - with T.block("T_reshape"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(T_divide[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)]) - T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3]) - T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = T_divide[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)] - for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(28), T.int64(28)): - with T.block("T_subtract"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) - T.writes(T_subtract[v_ax0, v_ax1, v_ax2, v_ax3]) - T_subtract[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)] - for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(28), T.int64(28)): - with T.block("T_subtract_1"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) - T.writes(T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3]) - T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)] - for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(28), T.int64(28)): - with T.block("T_subtract_2"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) - T.writes(T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3]) - T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)] - for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(28), T.int64(28)): - with T.block("T_multiply"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3], T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3]) - T.writes(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3]) - T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] = T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3] * T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3] - for ax0, k0, k2, k3 in T.grid(T.int64(3), T.int64(2), T.int64(28), T.int64(28)): - with T.block("T_multiply_red"): - v_ax0, v_k0, v_k2, v_k3 = T.axis.remap("SRRR", [ax0, k0, k2, k3]) - T.reads(T_multiply[v_k0, v_ax0, v_k2, v_k3]) - T.writes(T_multiply_red[v_ax0]) - with T.init(): - T_multiply_red[v_ax0] = T.float32(0) - T_multiply_red[v_ax0] = T_multiply_red[v_ax0] + T_multiply[v_k0, v_ax0, v_k2, v_k3] - for ax0 in range(T.int64(3)): - with T.block("T_divide_1"): - v_ax0 = T.axis.spatial(T.int64(3), ax0) - T.reads(T_multiply_red[v_ax0]) - T.writes(T_divide_1[v_ax0]) - T_divide_1[v_ax0] = T_multiply_red[v_ax0] * T.float32(0.00063775510204081628) - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(3), T.int64(1), T.int64(1)): - with T.block("T_reshape_1"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(T_divide_1[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)]) - T.writes(T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3]) - T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3] = T_divide_1[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)] - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(3), T.int64(1), T.int64(1)): - with T.block("T_add"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3]) - T.writes(T_add_3[v_ax0, v_ax1, v_ax2, v_ax3]) - T_add_3[v_ax0, v_ax1, v_ax2, v_ax3] = T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3] + T.float32(1.0000000000000001e-05) - for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(3), T.int64(1), T.int64(1)): - with T.block("compute"): - v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(T_add_3[v_i0, v_i1, v_i2, v_i3]) - T.writes(compute[v_i0, v_i1, v_i2, v_i3]) - compute[v_i0, v_i1, v_i2, v_i3] = T.sqrt(T_add_3[v_i0, v_i1, v_i2, v_i3]) - for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(28), T.int64(28)): - with T.block("T_divide_2"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(T_subtract[v_ax0, v_ax1, v_ax2, v_ax3], compute[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) - T.writes(T_divide_2[v_ax0, v_ax1, v_ax2, v_ax3]) - T_divide_2[v_ax0, v_ax1, v_ax2, v_ax3] = T_subtract[v_ax0, v_ax1, v_ax2, v_ax3] / compute[T.int64(0), v_ax1, T.int64(0), T.int64(0)] - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(3), T.int64(1), T.int64(1)): - with T.block("T_reshape_2"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(rxplaceholder_1[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)]) - T.writes(T_reshape_2[v_ax0, v_ax1, v_ax2, v_ax3]) - T_reshape_2[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder_1[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)] - for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(28), T.int64(28)): - with T.block("T_multiply_1"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(T_divide_2[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape_2[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) - T.writes(T_multiply_1[v_ax0, v_ax1, v_ax2, v_ax3]) - T_multiply_1[v_ax0, v_ax1, v_ax2, v_ax3] = T_divide_2[v_ax0, v_ax1, v_ax2, v_ax3] * T_reshape_2[T.int64(0), v_ax1, T.int64(0), T.int64(0)] - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(3), T.int64(1), T.int64(1)): - with T.block("T_reshape_3"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(rxplaceholder_2[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)]) - T.writes(T_reshape_3[v_ax0, v_ax1, v_ax2, v_ax3]) - T_reshape_3[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder_2[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)] - for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(28), T.int64(28)): - with T.block("T_add_1"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(T_multiply_1[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape_3[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) - T.writes(T_add[v_ax0, v_ax1, v_ax2, v_ax3]) - T_add[v_ax0, v_ax1, v_ax2, v_ax3] = T_multiply_1[v_ax0, v_ax1, v_ax2, v_ax3] + T_reshape_3[T.int64(0), v_ax1, T.int64(0), T.int64(0)] - for ax0 in range(T.int64(3)): - with T.block("T_multiply_2"): - v_ax0 = T.axis.spatial(T.int64(3), ax0) - T.reads(rxplaceholder_3[v_ax0]) - T.writes(T_multiply_2[v_ax0]) - T_multiply_2[v_ax0] = T.float32(0.90000000000000002) * rxplaceholder_3[v_ax0] - for ax0 in range(T.int64(3)): - with T.block("T_multiply_3"): - v_ax0 = T.axis.spatial(T.int64(3), ax0) - T.reads(T_divide[v_ax0]) - T.writes(T_multiply_3[v_ax0]) - T_multiply_3[v_ax0] = T.float32(0.10000000000000001) * T_divide[v_ax0] - for ax0 in range(T.int64(3)): - with T.block("T_add_2"): - v_ax0 = T.axis.spatial(T.int64(3), ax0) - T.reads(T_multiply_2[v_ax0], T_multiply_3[v_ax0]) - T.writes(T_add_1[v_ax0]) - T_add_1[v_ax0] = T_multiply_2[v_ax0] + T_multiply_3[v_ax0] - for ax0 in range(T.int64(3)): - with T.block("T_multiply_4"): - v_ax0 = T.axis.spatial(T.int64(3), ax0) - T.reads(rxplaceholder_4[v_ax0]) - T.writes(T_multiply_4[v_ax0]) - T_multiply_4[v_ax0] = T.float32(0.90000000000000002) * rxplaceholder_4[v_ax0] - for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(28), T.int64(28)): - with T.block("T_subtract_3"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) - T.writes(T_subtract_3[v_ax0, v_ax1, v_ax2, v_ax3]) - T_subtract_3[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)] - for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(28), T.int64(28)): - with T.block("T_subtract_4"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) - T.writes(T_subtract_4[v_ax0, v_ax1, v_ax2, v_ax3]) - T_subtract_4[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)] - for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(28), T.int64(28)): - with T.block("T_multiply_5"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(T_subtract_3[v_ax0, v_ax1, v_ax2, v_ax3], T_subtract_4[v_ax0, v_ax1, v_ax2, v_ax3]) - T.writes(T_multiply_5[v_ax0, v_ax1, v_ax2, v_ax3]) - T_multiply_5[v_ax0, v_ax1, v_ax2, v_ax3] = T_subtract_3[v_ax0, v_ax1, v_ax2, v_ax3] * T_subtract_4[v_ax0, v_ax1, v_ax2, v_ax3] - for ax0, k0, k2, k3 in T.grid(T.int64(3), T.int64(2), T.int64(28), T.int64(28)): - with T.block("T_multiply_red_1"): - v_ax0, v_k0, v_k2, v_k3 = T.axis.remap("SRRR", [ax0, k0, k2, k3]) - T.reads(T_multiply_5[v_k0, v_ax0, v_k2, v_k3]) - T.writes(T_multiply_red_1[v_ax0]) - with T.init(): - T_multiply_red_1[v_ax0] = T.float32(0) - T_multiply_red_1[v_ax0] = T_multiply_red_1[v_ax0] + T_multiply_5[v_k0, v_ax0, v_k2, v_k3] - for ax0 in range(T.int64(3)): - with T.block("T_divide_3"): - v_ax0 = T.axis.spatial(T.int64(3), ax0) - T.reads(T_multiply_red_1[v_ax0]) - T.writes(T_divide_3[v_ax0]) - T_divide_3[v_ax0] = T_multiply_red_1[v_ax0] * T.float32(0.00063775510204081628) - for ax0 in range(T.int64(3)): - with T.block("T_multiply_6"): - v_ax0 = T.axis.spatial(T.int64(3), ax0) - T.reads(T_divide_3[v_ax0]) - T.writes(T_multiply_6[v_ax0]) - T_multiply_6[v_ax0] = T.float32(0.10000000000000001) * T_divide_3[v_ax0] - for ax0 in range(T.int64(3)): - with T.block("T_add_3"): - v_ax0 = T.axis.spatial(T.int64(3), ax0) - T.reads(T_multiply_4[v_ax0], T_multiply_6[v_ax0]) - T.writes(T_add_2[v_ax0]) - T_add_2[v_ax0] = T_multiply_4[v_ax0] + T_multiply_6[v_ax0] + def batch_norm(var_x: T.handle, var_gamma: T.handle, var_beta: T.handle, var_moving_mean: T.handle, var_moving_var: T.handle, var_T_add: T.handle, var_T_add_1: T.handle, var_T_add_2: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + x = T.match_buffer(var_x, (T.int64(2), T.int64(3), T.int64(28), T.int64(28))) + gamma = T.match_buffer(var_gamma, (T.int64(3),)) + beta = T.match_buffer(var_beta, (T.int64(3),)) + moving_mean = T.match_buffer(var_moving_mean, (T.int64(3),)) + moving_var = T.match_buffer(var_moving_var, (T.int64(3),)) + T_add = T.match_buffer(var_T_add, (T.int64(2), T.int64(3), T.int64(28), T.int64(28))) + T_add_1 = T.match_buffer(var_T_add_1, (T.int64(3),)) + T_add_2 = T.match_buffer(var_T_add_2, (T.int64(3),)) + with T.block("root"): + T.reads() + T.writes() + T_reshape = T.alloc_buffer((T.int64(1), T.int64(3), T.int64(1), T.int64(1))) + T_subtract = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(28), T.int64(28))) + T_reshape_1 = T.alloc_buffer((T.int64(1), T.int64(3), T.int64(1), T.int64(1))) + T_add_3 = T.alloc_buffer((T.int64(1), T.int64(3), T.int64(1), T.int64(1))) + compute = T.alloc_buffer((T.int64(1), T.int64(3), T.int64(1), T.int64(1))) + T_divide = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(28), T.int64(28))) + T_reshape_2 = T.alloc_buffer((T.int64(1), T.int64(3), T.int64(1), T.int64(1))) + T_multiply = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(28), T.int64(28))) + T_reshape_3 = T.alloc_buffer((T.int64(1), T.int64(3), T.int64(1), T.int64(1))) + T_multiply_1 = T.alloc_buffer((T.int64(3),)) + x_red = T.alloc_buffer((T.int64(3),)) + T_divide_1 = T.alloc_buffer((T.int64(3),)) + T_multiply_2 = T.alloc_buffer((T.int64(3),)) + T_multiply_3 = T.alloc_buffer((T.int64(3),)) + T_reshape_4 = T.alloc_buffer((T.int64(1), T.int64(3), T.int64(1), T.int64(1))) + T_subtract_1 = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(28), T.int64(28))) + T_subtract_2 = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(28), T.int64(28))) + T_multiply_4 = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(28), T.int64(28))) + T_multiply_red = T.alloc_buffer((T.int64(3),)) + T_divide_2 = T.alloc_buffer((T.int64(3),)) + T_multiply_5 = T.alloc_buffer((T.int64(3),)) + for ax0 in range(T.int64(1)): + for ax1 in range(T.int64(3)): + for ax2 in range(T.int64(1)): + for ax3 in range(T.int64(1)): + with T.block("T_reshape"): + v_ax0 = T.axis.spatial(T.int64(1), ax0) + v_ax1 = T.axis.spatial(T.int64(3), ax1) + v_ax2 = T.axis.spatial(T.int64(1), ax2) + v_ax3 = T.axis.spatial(T.int64(1), ax3) + T.reads(moving_mean[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)]) + T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3]) + T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = moving_mean[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)] + for ax0 in range(T.int64(2)): + for ax1 in range(T.int64(3)): + for ax2 in range(T.int64(28)): + for ax3 in range(T.int64(28)): + with T.block("T_subtract"): + v_ax0 = T.axis.spatial(T.int64(2), ax0) + v_ax1 = T.axis.spatial(T.int64(3), ax1) + v_ax2 = T.axis.spatial(T.int64(28), ax2) + v_ax3 = T.axis.spatial(T.int64(28), ax3) + T.reads(x[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) + T.writes(T_subtract[v_ax0, v_ax1, v_ax2, v_ax3]) + T_subtract[v_ax0, v_ax1, v_ax2, v_ax3] = x[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)] + for ax0 in range(T.int64(1)): + for ax1 in range(T.int64(3)): + for ax2 in range(T.int64(1)): + for ax3 in range(T.int64(1)): + with T.block("T_reshape_1"): + v_ax0 = T.axis.spatial(T.int64(1), ax0) + v_ax1 = T.axis.spatial(T.int64(3), ax1) + v_ax2 = T.axis.spatial(T.int64(1), ax2) + v_ax3 = T.axis.spatial(T.int64(1), ax3) + T.reads(moving_var[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)]) + T.writes(T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3]) + T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3] = moving_var[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)] + for ax0 in range(T.int64(1)): + for ax1 in range(T.int64(3)): + for ax2 in range(T.int64(1)): + for ax3 in range(T.int64(1)): + with T.block("T_add"): + v_ax0 = T.axis.spatial(T.int64(1), ax0) + v_ax1 = T.axis.spatial(T.int64(3), ax1) + v_ax2 = T.axis.spatial(T.int64(1), ax2) + v_ax3 = T.axis.spatial(T.int64(1), ax3) + T.reads(T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3]) + T.writes(T_add_3[v_ax0, v_ax1, v_ax2, v_ax3]) + T_add_3[v_ax0, v_ax1, v_ax2, v_ax3] = T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3] + T.float32(1.0000000000000001e-05) + for i0 in range(T.int64(1)): + for i1 in range(T.int64(3)): + for i2 in range(T.int64(1)): + for i3 in range(T.int64(1)): + with T.block("compute"): + v_i0 = T.axis.spatial(T.int64(1), i0) + v_i1 = T.axis.spatial(T.int64(3), i1) + v_i2 = T.axis.spatial(T.int64(1), i2) + v_i3 = T.axis.spatial(T.int64(1), i3) + T.reads(T_add_3[v_i0, v_i1, v_i2, v_i3]) + T.writes(compute[v_i0, v_i1, v_i2, v_i3]) + compute[v_i0, v_i1, v_i2, v_i3] = T.sqrt(T_add_3[v_i0, v_i1, v_i2, v_i3]) + for ax0 in range(T.int64(2)): + for ax1 in range(T.int64(3)): + for ax2 in range(T.int64(28)): + for ax3 in range(T.int64(28)): + with T.block("T_divide"): + v_ax0 = T.axis.spatial(T.int64(2), ax0) + v_ax1 = T.axis.spatial(T.int64(3), ax1) + v_ax2 = T.axis.spatial(T.int64(28), ax2) + v_ax3 = T.axis.spatial(T.int64(28), ax3) + T.reads(T_subtract[v_ax0, v_ax1, v_ax2, v_ax3], compute[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) + T.writes(T_divide[v_ax0, v_ax1, v_ax2, v_ax3]) + T_divide[v_ax0, v_ax1, v_ax2, v_ax3] = T_subtract[v_ax0, v_ax1, v_ax2, v_ax3] / compute[T.int64(0), v_ax1, T.int64(0), T.int64(0)] + for ax0 in range(T.int64(1)): + for ax1 in range(T.int64(3)): + for ax2 in range(T.int64(1)): + for ax3 in range(T.int64(1)): + with T.block("T_reshape_2"): + v_ax0 = T.axis.spatial(T.int64(1), ax0) + v_ax1 = T.axis.spatial(T.int64(3), ax1) + v_ax2 = T.axis.spatial(T.int64(1), ax2) + v_ax3 = T.axis.spatial(T.int64(1), ax3) + T.reads(gamma[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)]) + T.writes(T_reshape_2[v_ax0, v_ax1, v_ax2, v_ax3]) + T_reshape_2[v_ax0, v_ax1, v_ax2, v_ax3] = gamma[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)] + for ax0 in range(T.int64(2)): + for ax1 in range(T.int64(3)): + for ax2 in range(T.int64(28)): + for ax3 in range(T.int64(28)): + with T.block("T_multiply"): + v_ax0 = T.axis.spatial(T.int64(2), ax0) + v_ax1 = T.axis.spatial(T.int64(3), ax1) + v_ax2 = T.axis.spatial(T.int64(28), ax2) + v_ax3 = T.axis.spatial(T.int64(28), ax3) + T.reads(T_divide[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape_2[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) + T.writes(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3]) + T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] = T_divide[v_ax0, v_ax1, v_ax2, v_ax3] * T_reshape_2[T.int64(0), v_ax1, T.int64(0), T.int64(0)] + for ax0 in range(T.int64(1)): + for ax1 in range(T.int64(3)): + for ax2 in range(T.int64(1)): + for ax3 in range(T.int64(1)): + with T.block("T_reshape_3"): + v_ax0 = T.axis.spatial(T.int64(1), ax0) + v_ax1 = T.axis.spatial(T.int64(3), ax1) + v_ax2 = T.axis.spatial(T.int64(1), ax2) + v_ax3 = T.axis.spatial(T.int64(1), ax3) + T.reads(beta[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)]) + T.writes(T_reshape_3[v_ax0, v_ax1, v_ax2, v_ax3]) + T_reshape_3[v_ax0, v_ax1, v_ax2, v_ax3] = beta[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)] + for ax0 in range(T.int64(2)): + for ax1 in range(T.int64(3)): + for ax2 in range(T.int64(28)): + for ax3 in range(T.int64(28)): + with T.block("T_add_1"): + v_ax0 = T.axis.spatial(T.int64(2), ax0) + v_ax1 = T.axis.spatial(T.int64(3), ax1) + v_ax2 = T.axis.spatial(T.int64(28), ax2) + v_ax3 = T.axis.spatial(T.int64(28), ax3) + T.reads(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape_3[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) + T.writes(T_add[v_ax0, v_ax1, v_ax2, v_ax3]) + T_add[v_ax0, v_ax1, v_ax2, v_ax3] = T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] + T_reshape_3[T.int64(0), v_ax1, T.int64(0), T.int64(0)] + for ax0 in range(T.int64(3)): + with T.block("T_multiply_1"): + v_ax0 = T.axis.spatial(T.int64(3), ax0) + T.reads(moving_mean[v_ax0]) + T.writes(T_multiply_1[v_ax0]) + T_multiply_1[v_ax0] = T.float32(0.90000000000000002) * moving_mean[v_ax0] + for ax0 in range(T.int64(3)): + for k0 in range(T.int64(2)): + for k2 in range(T.int64(28)): + for k3 in range(T.int64(28)): + with T.block("x_red"): + v_ax0 = T.axis.spatial(T.int64(3), ax0) + v_k0 = T.axis.reduce(T.int64(2), k0) + v_k2 = T.axis.reduce(T.int64(28), k2) + v_k3 = T.axis.reduce(T.int64(28), k3) + T.reads(x[v_k0, v_ax0, v_k2, v_k3]) + T.writes(x_red[v_ax0]) + with T.init(): + x_red[v_ax0] = T.float32(0.0) + x_red[v_ax0] = x_red[v_ax0] + x[v_k0, v_ax0, v_k2, v_k3] + for ax0 in range(T.int64(3)): + with T.block("T_divide_1"): + v_ax0 = T.axis.spatial(T.int64(3), ax0) + T.reads(x_red[v_ax0]) + T.writes(T_divide_1[v_ax0]) + T_divide_1[v_ax0] = x_red[v_ax0] * T.float32(0.00063775510204081628) + for ax0 in range(T.int64(3)): + with T.block("T_multiply_2"): + v_ax0 = T.axis.spatial(T.int64(3), ax0) + T.reads(T_divide_1[v_ax0]) + T.writes(T_multiply_2[v_ax0]) + T_multiply_2[v_ax0] = T.float32(0.10000000000000001) * T_divide_1[v_ax0] + for ax0 in range(T.int64(3)): + with T.block("T_add_2"): + v_ax0 = T.axis.spatial(T.int64(3), ax0) + T.reads(T_multiply_1[v_ax0], T_multiply_2[v_ax0]) + T.writes(T_add_1[v_ax0]) + T_add_1[v_ax0] = T_multiply_1[v_ax0] + T_multiply_2[v_ax0] + for ax0 in range(T.int64(3)): + with T.block("T_multiply_3"): + v_ax0 = T.axis.spatial(T.int64(3), ax0) + T.reads(moving_var[v_ax0]) + T.writes(T_multiply_3[v_ax0]) + T_multiply_3[v_ax0] = T.float32(0.90000000000000002) * moving_var[v_ax0] + for ax0 in range(T.int64(1)): + for ax1 in range(T.int64(3)): + for ax2 in range(T.int64(1)): + for ax3 in range(T.int64(1)): + with T.block("T_reshape_4"): + v_ax0 = T.axis.spatial(T.int64(1), ax0) + v_ax1 = T.axis.spatial(T.int64(3), ax1) + v_ax2 = T.axis.spatial(T.int64(1), ax2) + v_ax3 = T.axis.spatial(T.int64(1), ax3) + T.reads(T_divide_1[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)]) + T.writes(T_reshape_4[v_ax0, v_ax1, v_ax2, v_ax3]) + T_reshape_4[v_ax0, v_ax1, v_ax2, v_ax3] = T_divide_1[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)] + for ax0 in range(T.int64(2)): + for ax1 in range(T.int64(3)): + for ax2 in range(T.int64(28)): + for ax3 in range(T.int64(28)): + with T.block("T_subtract_1"): + v_ax0 = T.axis.spatial(T.int64(2), ax0) + v_ax1 = T.axis.spatial(T.int64(3), ax1) + v_ax2 = T.axis.spatial(T.int64(28), ax2) + v_ax3 = T.axis.spatial(T.int64(28), ax3) + T.reads(x[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape_4[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) + T.writes(T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3]) + T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3] = x[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape_4[T.int64(0), v_ax1, T.int64(0), T.int64(0)] + for ax0 in range(T.int64(2)): + for ax1 in range(T.int64(3)): + for ax2 in range(T.int64(28)): + for ax3 in range(T.int64(28)): + with T.block("T_subtract_2"): + v_ax0 = T.axis.spatial(T.int64(2), ax0) + v_ax1 = T.axis.spatial(T.int64(3), ax1) + v_ax2 = T.axis.spatial(T.int64(28), ax2) + v_ax3 = T.axis.spatial(T.int64(28), ax3) + T.reads(x[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape_4[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) + T.writes(T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3]) + T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3] = x[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape_4[T.int64(0), v_ax1, T.int64(0), T.int64(0)] + for ax0 in range(T.int64(2)): + for ax1 in range(T.int64(3)): + for ax2 in range(T.int64(28)): + for ax3 in range(T.int64(28)): + with T.block("T_multiply_4"): + v_ax0 = T.axis.spatial(T.int64(2), ax0) + v_ax1 = T.axis.spatial(T.int64(3), ax1) + v_ax2 = T.axis.spatial(T.int64(28), ax2) + v_ax3 = T.axis.spatial(T.int64(28), ax3) + T.reads(T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3], T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3]) + T.writes(T_multiply_4[v_ax0, v_ax1, v_ax2, v_ax3]) + T_multiply_4[v_ax0, v_ax1, v_ax2, v_ax3] = T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3] * T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3] + for ax0 in range(T.int64(3)): + for k0 in range(T.int64(2)): + for k2 in range(T.int64(28)): + for k3 in range(T.int64(28)): + with T.block("T_multiply_red"): + v_ax0 = T.axis.spatial(T.int64(3), ax0) + v_k0 = T.axis.reduce(T.int64(2), k0) + v_k2 = T.axis.reduce(T.int64(28), k2) + v_k3 = T.axis.reduce(T.int64(28), k3) + T.reads(T_multiply_4[v_k0, v_ax0, v_k2, v_k3]) + T.writes(T_multiply_red[v_ax0]) + with T.init(): + T_multiply_red[v_ax0] = T.float32(0.0) + T_multiply_red[v_ax0] = T_multiply_red[v_ax0] + T_multiply_4[v_k0, v_ax0, v_k2, v_k3] + for ax0 in range(T.int64(3)): + with T.block("T_divide_2"): + v_ax0 = T.axis.spatial(T.int64(3), ax0) + T.reads(T_multiply_red[v_ax0]) + T.writes(T_divide_2[v_ax0]) + T_divide_2[v_ax0] = T_multiply_red[v_ax0] * T.float32(0.00063775510204081628) + for ax0 in range(T.int64(3)): + with T.block("T_multiply_5"): + v_ax0 = T.axis.spatial(T.int64(3), ax0) + T.reads(T_divide_2[v_ax0]) + T.writes(T_multiply_5[v_ax0]) + T_multiply_5[v_ax0] = T.float32(0.10000000000000001) * T_divide_2[v_ax0] + for ax0 in range(T.int64(3)): + with T.block("T_add_3"): + v_ax0 = T.axis.spatial(T.int64(3), ax0) + T.reads(T_multiply_3[v_ax0], T_multiply_5[v_ax0]) + T.writes(T_add_2[v_ax0]) + T_add_2[v_ax0] = T_multiply_3[v_ax0] + T_multiply_5[v_ax0] @R.function def main(x: R.Tensor((2, 3, 28, 28), dtype="float32"), gamma: R.Tensor((3,), dtype="float32"), beta: R.Tensor((3,), dtype="float32"), moving_mean: R.Tensor((3,), dtype="float32"), moving_var: R.Tensor((3,), dtype="float32")) -> R.Tuple(R.Tensor((2, 3, 28, 28), dtype="float32"), R.Tensor((3,), dtype="float32"), R.Tensor((3,), dtype="float32")): - gv = R.call_tir(Expected.batch_norm, (x, gamma, beta, moving_mean, moving_var), out_sinfo=[R.Tensor((2, 3, 28, 28), dtype="float32"), R.Tensor((3,), dtype="float32"), R.Tensor((3,), dtype="float32")]) + cls = Expected + gv = R.call_tir(cls.batch_norm, (x, gamma, beta, moving_mean, moving_var), out_sinfo=[R.Tensor((2, 3, 28, 28), dtype="float32"), R.Tensor((3,), dtype="float32"), R.Tensor((3,), dtype="float32")]) return gv # fmt: on @@ -2184,230 +2261,295 @@ def main(x: R.Tensor(("n", "h", "w", "c"), "float32"), gamma: R.Tensor(("c",), " @tvm.script.ir_module class Expected: @T.prim_func(private=True) - def batch_norm(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_rxplaceholder_2: T.handle, var_rxplaceholder_3: T.handle, var_rxplaceholder_4: T.handle, var_T_add: T.handle, var_T_add_1: T.handle, var_T_add_2: T.handle): - T.func_attr({"tir.noalias": True}) - n = T.int64() - h = T.int64() - w = T.int64() - c = T.int64() - rxplaceholder = T.match_buffer(var_rxplaceholder, (n, h, w, c)) - rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, (c,)) - rxplaceholder_2 = T.match_buffer(var_rxplaceholder_2, (c,)) - rxplaceholder_3 = T.match_buffer(var_rxplaceholder_3, (c,)) - rxplaceholder_4 = T.match_buffer(var_rxplaceholder_4, (c,)) + def batch_norm(var_x: T.handle, var_gamma: T.handle, var_beta: T.handle, var_moving_mean: T.handle, var_moving_var: T.handle, var_T_add: T.handle, var_T_add_1: T.handle, var_T_add_2: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n, h, w, c = T.int64(), T.int64(), T.int64(), T.int64() + x = T.match_buffer(var_x, (n, h, w, c)) + gamma = T.match_buffer(var_gamma, (c,)) + beta = T.match_buffer(var_beta, (c,)) + moving_mean = T.match_buffer(var_moving_mean, (c,)) + moving_var = T.match_buffer(var_moving_var, (c,)) T_add = T.match_buffer(var_T_add, (n, h, w, c)) T_add_1 = T.match_buffer(var_T_add_1, (T.max(c, h),)) T_add_2 = T.match_buffer(var_T_add_2, (T.max(c, h),)) - # with T.block("root"): - rxplaceholder_red = T.alloc_buffer((h,)) - T_divide = T.alloc_buffer((h,)) - T_reshape = T.alloc_buffer((T.int64(1), h, T.int64(1), T.int64(1))) - T_subtract = T.alloc_buffer((n, h, w, c)) - T_subtract_1 = T.alloc_buffer((n, h, w, c)) - T_subtract_2 = T.alloc_buffer((n, h, w, c)) - T_multiply = T.alloc_buffer((n, h, w, c)) - T_multiply_red = T.alloc_buffer((h,)) - T_divide_1 = T.alloc_buffer((h,)) - T_reshape_1 = T.alloc_buffer((T.int64(1), h, T.int64(1), T.int64(1))) - T_add_3 = T.alloc_buffer((T.int64(1), h, T.int64(1), T.int64(1))) - compute = T.alloc_buffer((T.int64(1), h, T.int64(1), T.int64(1))) - T_divide_2 = T.alloc_buffer((n, h, w, c)) - T_reshape_2 = T.alloc_buffer((T.int64(1), h, T.int64(1), T.int64(1))) - T_multiply_1 = T.alloc_buffer((n, h, w, c)) - T_reshape_3 = T.alloc_buffer((T.int64(1), h, T.int64(1), T.int64(1))) - T_multiply_2 = T.alloc_buffer((c,)) - T_multiply_3 = T.alloc_buffer((h,)) - T_multiply_4 = T.alloc_buffer((c,)) - T_subtract_3 = T.alloc_buffer((n, h, w, c)) - T_subtract_4 = T.alloc_buffer((n, h, w, c)) - T_multiply_5 = T.alloc_buffer((n, h, w, c)) - T_multiply_red_1 = T.alloc_buffer((h,)) - T_divide_3 = T.alloc_buffer((h,)) - T_multiply_6 = T.alloc_buffer((h,)) - for ax0, k0, k2, k3 in T.grid(h, n, w, c): - with T.block("rxplaceholder_red"): - v_ax0, v_k0, v_k2, v_k3 = T.axis.remap("SRRR", [ax0, k0, k2, k3]) - T.reads(rxplaceholder[v_k0, v_ax0, v_k2, v_k3]) - T.writes(rxplaceholder_red[v_ax0]) - with T.init(): - rxplaceholder_red[v_ax0] = T.float32(0) - rxplaceholder_red[v_ax0] = rxplaceholder_red[v_ax0] + rxplaceholder[v_k0, v_ax0, v_k2, v_k3] - for ax0 in range(h): - with T.block("T_divide"): - v_ax0 = T.axis.spatial(h, ax0) - T.reads(rxplaceholder_red[v_ax0]) - T.writes(T_divide[v_ax0]) - T_divide[v_ax0] = rxplaceholder_red[v_ax0] / T.Cast("float32", n * w * c) - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), h, T.int64(1), T.int64(1)): - with T.block("T_reshape"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(T_divide[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % h]) - T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3]) - T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = T_divide[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % h] - for ax0, ax1, ax2, ax3 in T.grid(n, h, w, c): - with T.block("T_subtract"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) - T.writes(T_subtract[v_ax0, v_ax1, v_ax2, v_ax3]) - T_subtract[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)] - for ax0, ax1, ax2, ax3 in T.grid(n, h, w, c): - with T.block("T_subtract_1"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) - T.writes(T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3]) - T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)] - for ax0, ax1, ax2, ax3 in T.grid(n, h, w, c): - with T.block("T_subtract_2"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) - T.writes(T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3]) - T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)] - for ax0, ax1, ax2, ax3 in T.grid(n, h, w, c): - with T.block("T_multiply"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3], T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3]) - T.writes(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3]) - T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] = T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3] * T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3] - for ax0, k0, k2, k3 in T.grid(h, n, w, c): - with T.block("T_multiply_red"): - v_ax0, v_k0, v_k2, v_k3 = T.axis.remap("SRRR", [ax0, k0, k2, k3]) - T.reads(T_multiply[v_k0, v_ax0, v_k2, v_k3]) - T.writes(T_multiply_red[v_ax0]) - with T.init(): - T_multiply_red[v_ax0] = T.float32(0) - T_multiply_red[v_ax0] = T_multiply_red[v_ax0] + T_multiply[v_k0, v_ax0, v_k2, v_k3] - for ax0 in range(h): - with T.block("T_divide_1"): - v_ax0 = T.axis.spatial(h, ax0) - T.reads(T_multiply_red[v_ax0]) - T.writes(T_divide_1[v_ax0]) - T_divide_1[v_ax0] = T_multiply_red[v_ax0] / T.Cast("float32", n * w * c) - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), h, T.int64(1), T.int64(1)): - with T.block("T_reshape_1"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(T_divide_1[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % h]) - T.writes(T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3]) - T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3] = T_divide_1[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % h] - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), h, T.int64(1), T.int64(1)): - with T.block("T_add"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3]) - T.writes(T_add_3[v_ax0, v_ax1, v_ax2, v_ax3]) - T_add_3[v_ax0, v_ax1, v_ax2, v_ax3] = T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3] + T.float32(1.0000000000000001e-05) - for i0, i1, i2, i3 in T.grid(T.int64(1), h, T.int64(1), T.int64(1)): - with T.block("compute"): - v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(T_add_3[v_i0, v_i1, v_i2, v_i3]) - T.writes(compute[v_i0, v_i1, v_i2, v_i3]) - compute[v_i0, v_i1, v_i2, v_i3] = T.sqrt(T_add_3[v_i0, v_i1, v_i2, v_i3]) - for ax0, ax1, ax2, ax3 in T.grid(n, h, w, c): - with T.block("T_divide_2"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(T_subtract[v_ax0, v_ax1, v_ax2, v_ax3], compute[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) - T.writes(T_divide_2[v_ax0, v_ax1, v_ax2, v_ax3]) - T_divide_2[v_ax0, v_ax1, v_ax2, v_ax3] = T_subtract[v_ax0, v_ax1, v_ax2, v_ax3] / compute[T.int64(0), v_ax1, T.int64(0), T.int64(0)] - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), h, T.int64(1), T.int64(1)): - with T.block("T_reshape_2"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(rxplaceholder_1[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % c]) - T.writes(T_reshape_2[v_ax0, v_ax1, v_ax2, v_ax3]) - T_reshape_2[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder_1[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % c] - for ax0, ax1, ax2, ax3 in T.grid(n, h, w, c): - with T.block("T_multiply_1"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(T_divide_2[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape_2[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) - T.writes(T_multiply_1[v_ax0, v_ax1, v_ax2, v_ax3]) - T_multiply_1[v_ax0, v_ax1, v_ax2, v_ax3] = T_divide_2[v_ax0, v_ax1, v_ax2, v_ax3] * T_reshape_2[T.int64(0), v_ax1, T.int64(0), T.int64(0)] - for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), h, T.int64(1), T.int64(1)): - with T.block("T_reshape_3"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(rxplaceholder_2[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % c]) - T.writes(T_reshape_3[v_ax0, v_ax1, v_ax2, v_ax3]) - T_reshape_3[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder_2[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % c] - for ax0, ax1, ax2, ax3 in T.grid(n, h, w, c): - with T.block("T_add_1"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(T_multiply_1[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape_3[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) - T.writes(T_add[v_ax0, v_ax1, v_ax2, v_ax3]) - T_add[v_ax0, v_ax1, v_ax2, v_ax3] = T_multiply_1[v_ax0, v_ax1, v_ax2, v_ax3] + T_reshape_3[T.int64(0), v_ax1, T.int64(0), T.int64(0)] - for ax0 in range(c): - with T.block("T_multiply_2"): - v_ax0 = T.axis.spatial(c, ax0) - T.reads(rxplaceholder_3[v_ax0]) - T.writes(T_multiply_2[v_ax0]) - T_multiply_2[v_ax0] = T.float32(0.90000000000000002) * rxplaceholder_3[v_ax0] - for ax0 in range(h): - with T.block("T_multiply_3"): - v_ax0 = T.axis.spatial(h, ax0) - T.reads(T_divide[v_ax0]) - T.writes(T_multiply_3[v_ax0]) - T_multiply_3[v_ax0] = T.float32(0.10000000000000001) * T_divide[v_ax0] - for ax0 in range(T.max(c, h)): - with T.block("T_add_2"): - v_ax0 = T.axis.spatial(T.max(c, h), ax0) - T.reads(T_multiply_2[v_ax0], T_multiply_3[v_ax0]) - T.writes(T_add_1[v_ax0]) - T_add_1[v_ax0] = T_multiply_2[v_ax0] + T_multiply_3[v_ax0] - for ax0 in range(c): - with T.block("T_multiply_4"): - v_ax0 = T.axis.spatial(c, ax0) - T.reads(rxplaceholder_4[v_ax0]) - T.writes(T_multiply_4[v_ax0]) - T_multiply_4[v_ax0] = T.float32(0.90000000000000002) * rxplaceholder_4[v_ax0] - for ax0, ax1, ax2, ax3 in T.grid(n, h, w, c): - with T.block("T_subtract_3"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) - T.writes(T_subtract_3[v_ax0, v_ax1, v_ax2, v_ax3]) - T_subtract_3[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)] - for ax0, ax1, ax2, ax3 in T.grid(n, h, w, c): - with T.block("T_subtract_4"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) - T.writes(T_subtract_4[v_ax0, v_ax1, v_ax2, v_ax3]) - T_subtract_4[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)] - for ax0, ax1, ax2, ax3 in T.grid(n, h, w, c): - with T.block("T_multiply_5"): - v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - T.reads(T_subtract_3[v_ax0, v_ax1, v_ax2, v_ax3], T_subtract_4[v_ax0, v_ax1, v_ax2, v_ax3]) - T.writes(T_multiply_5[v_ax0, v_ax1, v_ax2, v_ax3]) - T_multiply_5[v_ax0, v_ax1, v_ax2, v_ax3] = T_subtract_3[v_ax0, v_ax1, v_ax2, v_ax3] * T_subtract_4[v_ax0, v_ax1, v_ax2, v_ax3] - for ax0, k0, k2, k3 in T.grid(h, n, w, c): - with T.block("T_multiply_red_1"): - v_ax0, v_k0, v_k2, v_k3 = T.axis.remap("SRRR", [ax0, k0, k2, k3]) - T.reads(T_multiply_5[v_k0, v_ax0, v_k2, v_k3]) - T.writes(T_multiply_red_1[v_ax0]) - with T.init(): - T_multiply_red_1[v_ax0] = T.float32(0) - T_multiply_red_1[v_ax0] = T_multiply_red_1[v_ax0] + T_multiply_5[v_k0, v_ax0, v_k2, v_k3] - for ax0 in range(h): - with T.block("T_divide_3"): - v_ax0 = T.axis.spatial(h, ax0) - T.reads(T_multiply_red_1[v_ax0]) - T.writes(T_divide_3[v_ax0]) - T_divide_3[v_ax0] = T_multiply_red_1[v_ax0] / T.Cast("float32", n * w * c) - for ax0 in range(h): - with T.block("T_multiply_6"): - v_ax0 = T.axis.spatial(h, ax0) - T.reads(T_divide_3[v_ax0]) - T.writes(T_multiply_6[v_ax0]) - T_multiply_6[v_ax0] = T.float32(0.10000000000000001) * T_divide_3[v_ax0] - for ax0 in range(T.max(c, h)): - with T.block("T_add_3"): - v_ax0 = T.axis.spatial(T.max(c, h), ax0) - T.reads(T_multiply_4[v_ax0], T_multiply_6[v_ax0]) - T.writes(T_add_2[v_ax0]) - T_add_2[v_ax0] = T_multiply_4[v_ax0] + T_multiply_6[v_ax0] - - @R.function - def main(x: R.Tensor(("n", "h", "w", "c"), dtype="float32"), gamma: R.Tensor(("c",), dtype="float32"), beta: R.Tensor(("c",), dtype="float32"), moving_mean: R.Tensor(("c",), dtype="float32"), moving_var: R.Tensor(("c",), dtype="float32")) -> R.Tuple(R.Tensor(("n", "h", "w", "c"), dtype="float32"), R.Tensor(("T.max(c,h)",), dtype="float32"), R.Tensor(("T.max(c,h)",), dtype="float32")): + with T.block("root"): + T.reads() + T.writes() + T_reshape = T.alloc_buffer((T.int64(1), h, T.int64(1), T.int64(1))) + T_subtract = T.alloc_buffer((n, h, w, c)) + T_reshape_1 = T.alloc_buffer((T.int64(1), h, T.int64(1), T.int64(1))) + T_add_3 = T.alloc_buffer((T.int64(1), h, T.int64(1), T.int64(1))) + compute = T.alloc_buffer((T.int64(1), h, T.int64(1), T.int64(1))) + T_divide = T.alloc_buffer((n, h, w, c)) + T_reshape_2 = T.alloc_buffer((T.int64(1), h, T.int64(1), T.int64(1))) + T_multiply = T.alloc_buffer((n, h, w, c)) + T_reshape_3 = T.alloc_buffer((T.int64(1), h, T.int64(1), T.int64(1))) + T_multiply_1 = T.alloc_buffer((c,)) + x_red = T.alloc_buffer((h,)) + T_divide_1 = T.alloc_buffer((h,)) + T_multiply_2 = T.alloc_buffer((h,)) + T_multiply_3 = T.alloc_buffer((c,)) + T_reshape_4 = T.alloc_buffer((T.int64(1), h, T.int64(1), T.int64(1))) + T_subtract_1 = T.alloc_buffer((n, h, w, c)) + T_subtract_2 = T.alloc_buffer((n, h, w, c)) + T_multiply_4 = T.alloc_buffer((n, h, w, c)) + T_multiply_red = T.alloc_buffer((h,)) + T_divide_2 = T.alloc_buffer((h,)) + T_multiply_5 = T.alloc_buffer((h,)) + for ax0 in range(T.int64(1)): + for ax1 in range(h): + for ax2 in range(T.int64(1)): + for ax3 in range(T.int64(1)): + with T.block("T_reshape"): + v_ax0 = T.axis.spatial(T.int64(1), ax0) + v_ax1 = T.axis.spatial(h, ax1) + v_ax2 = T.axis.spatial(T.int64(1), ax2) + v_ax3 = T.axis.spatial(T.int64(1), ax3) + T.reads(moving_mean[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % c]) + T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3]) + T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = moving_mean[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % c] + for ax0 in range(n): + for ax1 in range(h): + for ax2 in range(w): + for ax3 in range(c): + with T.block("T_subtract"): + v_ax0 = T.axis.spatial(n, ax0) + v_ax1 = T.axis.spatial(h, ax1) + v_ax2 = T.axis.spatial(w, ax2) + v_ax3 = T.axis.spatial(c, ax3) + T.reads(x[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) + T.writes(T_subtract[v_ax0, v_ax1, v_ax2, v_ax3]) + T_subtract[v_ax0, v_ax1, v_ax2, v_ax3] = x[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)] + for ax0 in range(T.int64(1)): + for ax1 in range(h): + for ax2 in range(T.int64(1)): + for ax3 in range(T.int64(1)): + with T.block("T_reshape_1"): + v_ax0 = T.axis.spatial(T.int64(1), ax0) + v_ax1 = T.axis.spatial(h, ax1) + v_ax2 = T.axis.spatial(T.int64(1), ax2) + v_ax3 = T.axis.spatial(T.int64(1), ax3) + T.reads(moving_var[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % c]) + T.writes(T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3]) + T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3] = moving_var[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % c] + for ax0 in range(T.int64(1)): + for ax1 in range(h): + for ax2 in range(T.int64(1)): + for ax3 in range(T.int64(1)): + with T.block("T_add"): + v_ax0 = T.axis.spatial(T.int64(1), ax0) + v_ax1 = T.axis.spatial(h, ax1) + v_ax2 = T.axis.spatial(T.int64(1), ax2) + v_ax3 = T.axis.spatial(T.int64(1), ax3) + T.reads(T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3]) + T.writes(T_add_3[v_ax0, v_ax1, v_ax2, v_ax3]) + T_add_3[v_ax0, v_ax1, v_ax2, v_ax3] = T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3] + T.float32(1.0000000000000001e-05) + for i0 in range(T.int64(1)): + for i1 in range(h): + for i2 in range(T.int64(1)): + for i3 in range(T.int64(1)): + with T.block("compute"): + v_i0 = T.axis.spatial(T.int64(1), i0) + v_i1 = T.axis.spatial(h, i1) + v_i2 = T.axis.spatial(T.int64(1), i2) + v_i3 = T.axis.spatial(T.int64(1), i3) + T.reads(T_add_3[v_i0, v_i1, v_i2, v_i3]) + T.writes(compute[v_i0, v_i1, v_i2, v_i3]) + compute[v_i0, v_i1, v_i2, v_i3] = T.sqrt(T_add_3[v_i0, v_i1, v_i2, v_i3]) + for ax0 in range(n): + for ax1 in range(h): + for ax2 in range(w): + for ax3 in range(c): + with T.block("T_divide"): + v_ax0 = T.axis.spatial(n, ax0) + v_ax1 = T.axis.spatial(h, ax1) + v_ax2 = T.axis.spatial(w, ax2) + v_ax3 = T.axis.spatial(c, ax3) + T.reads(T_subtract[v_ax0, v_ax1, v_ax2, v_ax3], compute[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) + T.writes(T_divide[v_ax0, v_ax1, v_ax2, v_ax3]) + T_divide[v_ax0, v_ax1, v_ax2, v_ax3] = T_subtract[v_ax0, v_ax1, v_ax2, v_ax3] / compute[T.int64(0), v_ax1, T.int64(0), T.int64(0)] + for ax0 in range(T.int64(1)): + for ax1 in range(h): + for ax2 in range(T.int64(1)): + for ax3 in range(T.int64(1)): + with T.block("T_reshape_2"): + v_ax0 = T.axis.spatial(T.int64(1), ax0) + v_ax1 = T.axis.spatial(h, ax1) + v_ax2 = T.axis.spatial(T.int64(1), ax2) + v_ax3 = T.axis.spatial(T.int64(1), ax3) + T.reads(gamma[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % c]) + T.writes(T_reshape_2[v_ax0, v_ax1, v_ax2, v_ax3]) + T_reshape_2[v_ax0, v_ax1, v_ax2, v_ax3] = gamma[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % c] + for ax0 in range(n): + for ax1 in range(h): + for ax2 in range(w): + for ax3 in range(c): + with T.block("T_multiply"): + v_ax0 = T.axis.spatial(n, ax0) + v_ax1 = T.axis.spatial(h, ax1) + v_ax2 = T.axis.spatial(w, ax2) + v_ax3 = T.axis.spatial(c, ax3) + T.reads(T_divide[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape_2[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) + T.writes(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3]) + T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] = T_divide[v_ax0, v_ax1, v_ax2, v_ax3] * T_reshape_2[T.int64(0), v_ax1, T.int64(0), T.int64(0)] + for ax0 in range(T.int64(1)): + for ax1 in range(h): + for ax2 in range(T.int64(1)): + for ax3 in range(T.int64(1)): + with T.block("T_reshape_3"): + v_ax0 = T.axis.spatial(T.int64(1), ax0) + v_ax1 = T.axis.spatial(h, ax1) + v_ax2 = T.axis.spatial(T.int64(1), ax2) + v_ax3 = T.axis.spatial(T.int64(1), ax3) + T.reads(beta[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % c]) + T.writes(T_reshape_3[v_ax0, v_ax1, v_ax2, v_ax3]) + T_reshape_3[v_ax0, v_ax1, v_ax2, v_ax3] = beta[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % c] + for ax0 in range(n): + for ax1 in range(h): + for ax2 in range(w): + for ax3 in range(c): + with T.block("T_add_1"): + v_ax0 = T.axis.spatial(n, ax0) + v_ax1 = T.axis.spatial(h, ax1) + v_ax2 = T.axis.spatial(w, ax2) + v_ax3 = T.axis.spatial(c, ax3) + T.reads(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape_3[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) + T.writes(T_add[v_ax0, v_ax1, v_ax2, v_ax3]) + T_add[v_ax0, v_ax1, v_ax2, v_ax3] = T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] + T_reshape_3[T.int64(0), v_ax1, T.int64(0), T.int64(0)] + for ax0 in range(c): + with T.block("T_multiply_1"): + v_ax0 = T.axis.spatial(c, ax0) + T.reads(moving_mean[v_ax0]) + T.writes(T_multiply_1[v_ax0]) + T_multiply_1[v_ax0] = T.float32(0.90000000000000002) * moving_mean[v_ax0] + for ax0 in range(h): + for k0 in range(n): + for k2 in range(w): + for k3 in range(c): + with T.block("x_red"): + v_ax0 = T.axis.spatial(h, ax0) + v_k0 = T.axis.reduce(n, k0) + v_k2 = T.axis.reduce(w, k2) + v_k3 = T.axis.reduce(c, k3) + T.reads(x[v_k0, v_ax0, v_k2, v_k3]) + T.writes(x_red[v_ax0]) + with T.init(): + x_red[v_ax0] = T.float32(0.0) + x_red[v_ax0] = x_red[v_ax0] + x[v_k0, v_ax0, v_k2, v_k3] + for ax0 in range(h): + with T.block("T_divide_1"): + v_ax0 = T.axis.spatial(h, ax0) + T.reads(x_red[v_ax0]) + T.writes(T_divide_1[v_ax0]) + T_divide_1[v_ax0] = x_red[v_ax0] / T.Cast("float32", n * w * c) + for ax0 in range(h): + with T.block("T_multiply_2"): + v_ax0 = T.axis.spatial(h, ax0) + T.reads(T_divide_1[v_ax0]) + T.writes(T_multiply_2[v_ax0]) + T_multiply_2[v_ax0] = T.float32(0.10000000000000001) * T_divide_1[v_ax0] + for ax0 in range(T.max(c, h)): + with T.block("T_add_2"): + v_ax0 = T.axis.spatial(T.max(c, h), ax0) + T.reads(T_multiply_1[v_ax0], T_multiply_2[v_ax0]) + T.writes(T_add_1[v_ax0]) + T_add_1[v_ax0] = T_multiply_1[v_ax0] + T_multiply_2[v_ax0] + for ax0 in range(c): + with T.block("T_multiply_3"): + v_ax0 = T.axis.spatial(c, ax0) + T.reads(moving_var[v_ax0]) + T.writes(T_multiply_3[v_ax0]) + T_multiply_3[v_ax0] = T.float32(0.90000000000000002) * moving_var[v_ax0] + for ax0 in range(T.int64(1)): + for ax1 in range(h): + for ax2 in range(T.int64(1)): + for ax3 in range(T.int64(1)): + with T.block("T_reshape_4"): + v_ax0 = T.axis.spatial(T.int64(1), ax0) + v_ax1 = T.axis.spatial(h, ax1) + v_ax2 = T.axis.spatial(T.int64(1), ax2) + v_ax3 = T.axis.spatial(T.int64(1), ax3) + T.reads(T_divide_1[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % h]) + T.writes(T_reshape_4[v_ax0, v_ax1, v_ax2, v_ax3]) + T_reshape_4[v_ax0, v_ax1, v_ax2, v_ax3] = T_divide_1[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % h] + for ax0 in range(n): + for ax1 in range(h): + for ax2 in range(w): + for ax3 in range(c): + with T.block("T_subtract_1"): + v_ax0 = T.axis.spatial(n, ax0) + v_ax1 = T.axis.spatial(h, ax1) + v_ax2 = T.axis.spatial(w, ax2) + v_ax3 = T.axis.spatial(c, ax3) + T.reads(x[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape_4[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) + T.writes(T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3]) + T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3] = x[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape_4[T.int64(0), v_ax1, T.int64(0), T.int64(0)] + for ax0 in range(n): + for ax1 in range(h): + for ax2 in range(w): + for ax3 in range(c): + with T.block("T_subtract_2"): + v_ax0 = T.axis.spatial(n, ax0) + v_ax1 = T.axis.spatial(h, ax1) + v_ax2 = T.axis.spatial(w, ax2) + v_ax3 = T.axis.spatial(c, ax3) + T.reads(x[v_ax0, v_ax1, v_ax2, v_ax3], T_reshape_4[T.int64(0), v_ax1, T.int64(0), T.int64(0)]) + T.writes(T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3]) + T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3] = x[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape_4[T.int64(0), v_ax1, T.int64(0), T.int64(0)] + for ax0 in range(n): + for ax1 in range(h): + for ax2 in range(w): + for ax3 in range(c): + with T.block("T_multiply_4"): + v_ax0 = T.axis.spatial(n, ax0) + v_ax1 = T.axis.spatial(h, ax1) + v_ax2 = T.axis.spatial(w, ax2) + v_ax3 = T.axis.spatial(c, ax3) + T.reads(T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3], T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3]) + T.writes(T_multiply_4[v_ax0, v_ax1, v_ax2, v_ax3]) + T_multiply_4[v_ax0, v_ax1, v_ax2, v_ax3] = T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3] * T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3] + for ax0 in range(h): + for k0 in range(n): + for k2 in range(w): + for k3 in range(c): + with T.block("T_multiply_red"): + v_ax0 = T.axis.spatial(h, ax0) + v_k0 = T.axis.reduce(n, k0) + v_k2 = T.axis.reduce(w, k2) + v_k3 = T.axis.reduce(c, k3) + T.reads(T_multiply_4[v_k0, v_ax0, v_k2, v_k3]) + T.writes(T_multiply_red[v_ax0]) + with T.init(): + T_multiply_red[v_ax0] = T.float32(0.0) + T_multiply_red[v_ax0] = T_multiply_red[v_ax0] + T_multiply_4[v_k0, v_ax0, v_k2, v_k3] + for ax0 in range(h): + with T.block("T_divide_2"): + v_ax0 = T.axis.spatial(h, ax0) + T.reads(T_multiply_red[v_ax0]) + T.writes(T_divide_2[v_ax0]) + T_divide_2[v_ax0] = T_multiply_red[v_ax0] / T.Cast("float32", n * w * c) + for ax0 in range(h): + with T.block("T_multiply_5"): + v_ax0 = T.axis.spatial(h, ax0) + T.reads(T_divide_2[v_ax0]) + T.writes(T_multiply_5[v_ax0]) + T_multiply_5[v_ax0] = T.float32(0.10000000000000001) * T_divide_2[v_ax0] + for ax0 in range(T.max(c, h)): + with T.block("T_add_3"): + v_ax0 = T.axis.spatial(T.max(c, h), ax0) + T.reads(T_multiply_3[v_ax0], T_multiply_5[v_ax0]) + T.writes(T_add_2[v_ax0]) + T_add_2[v_ax0] = T_multiply_3[v_ax0] + T_multiply_5[v_ax0] + + @R.function + def main(x: R.Tensor(("n", "h", "w", "c"), dtype="float32"), gamma: R.Tensor(("c",), dtype="float32"), beta: R.Tensor(("c",), dtype="float32"), moving_mean: R.Tensor(("c",), dtype="float32"), moving_var: R.Tensor(("c",), dtype="float32")) -> R.Tuple(R.Tensor(("n", "h", "w", "c"), dtype="float32"), R.Tensor(("T.max(c, h)",), dtype="float32"), R.Tensor(("T.max(c, h)",), dtype="float32")): n = T.int64() h = T.int64() w = T.int64() c = T.int64() - gv = R.call_tir(Expected.batch_norm, (x, gamma, beta, moving_mean, moving_var), out_sinfo=[R.Tensor((n, h, w, c), dtype="float32"), R.Tensor((T.max(c, h),), dtype="float32"), R.Tensor((T.max(c, h),), dtype="float32")]) + cls = Expected + gv = R.call_tir(cls.batch_norm, (x, gamma, beta, moving_mean, moving_var), out_sinfo=[R.Tensor((n, h, w, c), dtype="float32"), R.Tensor((T.max(c, h),), dtype="float32"), R.Tensor((T.max(c, h),), dtype="float32")]) return gv - # fmt: on mod = LegalizeOps()(BatchNorm) tvm.ir.assert_structural_equal(mod, Expected)