Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
8b5a81d
trying to understand why batchnorm returns all zeros
hugolatendresse Mar 10, 2025
99373ae
debugging training vs non-training batch norm
hugolatendresse Mar 10, 2025
4f93317
merge main
hugolatendresse Mar 16, 2025
b0e1154
added training in attrs
hugolatendresse Mar 16, 2025
dde7872
training False
hugolatendresse Mar 16, 2025
dff60db
training argument in nn.py
hugolatendresse Mar 16, 2025
f1986d9
little cleanup before building
hugolatendresse Mar 16, 2025
1545b99
fix copy-paste errors
hugolatendresse Mar 16, 2025
77cc1d8
builds, but should probably just update nn.h instead
hugolatendresse Mar 16, 2025
0dbf8fe
batch_norm build
hugolatendresse Mar 16, 2025
1164d21
first batchnorm test passes with .eval(), but not without, and copy f…
hugolatendresse Mar 16, 2025
a72ce6e
copy failing
hugolatendresse Mar 16, 2025
42728f7
todo
hugolatendresse Mar 16, 2025
9ee0672
cleanup
hugolatendresse Mar 16, 2025
e3f0236
training failing
hugolatendresse Mar 16, 2025
3f68087
no need to pass center and scale since default ok
hugolatendresse Mar 16, 2025
5cd314d
cleanup
hugolatendresse Mar 16, 2025
d5d30b7
cleanup
hugolatendresse Mar 16, 2025
125a9a6
reformat
hugolatendresse Mar 16, 2025
281fb53
Merge branch 'main' into batch_norm
hugolatendresse Mar 16, 2025
3f0eaea
batch norm default and print torch version
hugolatendresse Mar 16, 2025
79e3ec6
whitespace
hugolatendresse Mar 16, 2025
79c4a0e
remove dummy test
hugolatendresse Mar 16, 2025
2dc643e
Merge branch 'main' of https://github.com/apache/tvm into batch_norm
hugolatendresse Mar 21, 2025
b9697f3
getting a tuple as output of batchnorm
hugolatendresse Mar 21, 2025
bc18182
output now of the right dimension, and close! but is not exactly equal
hugolatendresse Mar 21, 2025
e2e7263
still not the same with 2 1 2 2
hugolatendresse Mar 21, 2025
4cdb05a
missing eps
hugolatendresse Mar 21, 2025
b256163
last small test passes, but most tests still fail
hugolatendresse Mar 21, 2025
ab8d75c
passes
hugolatendresse Mar 21, 2025
7cb5a56
passes
hugolatendresse Mar 21, 2025
536310a
need to fix test_batch_norm7
hugolatendresse Mar 21, 2025
4c55f20
commented out tests that pass
hugolatendresse Mar 21, 2025
e99d659
legalize tests
hugolatendresse Mar 23, 2025
56b3999
correct calc of data for everyone
hugolatendresse Mar 23, 2025
fc6b03a
track running stats is equivalent to training! passes all
hugolatendresse Mar 23, 2025
7139590
all tests pass except for cache size
hugolatendresse Mar 23, 2025
267f011
all batch norm only pass!
hugolatendresse Mar 23, 2025
7a5cadd
all exported tests work, moved to main script
hugolatendresse Mar 23, 2025
e76ab8c
need to fix legalize tests
hugolatendresse Mar 23, 2025
7c99174
Merge branch 'main' into batch_norm
hugolatendresse Mar 24, 2025
54f00f1
first legalize test passes
hugolatendresse Mar 24, 2025
d74cfbf
all legalize pass
hugolatendresse Mar 24, 2025
a9de5ef
all tests pass
hugolatendresse Mar 24, 2025
6458a64
linting
hugolatendresse Mar 24, 2025
cd8fa7b
cleanup
hugolatendresse Mar 24, 2025
73ef53e
cleanup batchnorm
hugolatendresse Mar 24, 2025
1578ae2
linting
hugolatendresse Mar 24, 2025
95254bc
smaller third test
hugolatendresse Mar 24, 2025
6984609
formatting
hugolatendresse Mar 24, 2025
8233013
renaming
hugolatendresse Mar 24, 2025
8c5cfc7
resolve conflicts with main
hugolatendresse Mar 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/tvm/relax/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,7 @@ struct BatchNormAttrs : public tvm::AttrsNode<BatchNormAttrs> {
bool center;
bool scale;
double momentum;
bool training;

TVM_DECLARE_ATTRS(BatchNormAttrs, "relax.attrs.BatchNormAttrs") {
TVM_ATTR_FIELD(axis).describe("The axis along which the normalization is applied.");
Expand All @@ -470,6 +471,7 @@ struct BatchNormAttrs : public tvm::AttrsNode<BatchNormAttrs> {
"Indicating if the beta offset will be added to the normalized tensor.");
TVM_ATTR_FIELD(scale).describe("Indicating if the gamma scale will be multiplied.");
TVM_ATTR_FIELD(momentum).describe("The value used for the moving_mean and moving_var update.");
TVM_ATTR_FIELD(training).describe("Whether we are training (i.e., not in eval mode).");
}
}; // struct BatchNormAttrs

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1106,7 +1106,7 @@ def _detach(self, node: fx.Node) -> relax.Var:
return self.env[node.args[0]]

def _copy_(self, node: fx.Node) -> relax.Var:
# Copies the source tensor's to the destination tensor
# Copies the source tensor's into the destination tensor
# In TVM, that means simply returning the source tensor
return self.env[node.args[1]]

Expand Down
43 changes: 33 additions & 10 deletions python/tvm/relax/frontend/torch/exported_program_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def _hardtanh(self, node: fx.Node) -> relax.Expr:

########## Neural Network ##########

def _batch_norm_legit_no_training(self, node: fx.Node) -> relax.Var:
def _batch_norm(self, node: fx.Node, training) -> relax.Var:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good to add a type annotation in any of followup PRs.

Suggested change
def _batch_norm(self, node: fx.Node, training) -> relax.Var:
def _batch_norm(self, node: fx.Node, training: bool) -> relax.Var:

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, will do, thanks

import numpy as np

x = self.env[node.args[0]]
Expand All @@ -55,22 +55,43 @@ def _batch_norm_legit_no_training(self, node: fx.Node) -> relax.Var:
bias = self.env.get(node.args[2], relax.const(np.zeros(channel), dtype=dtype))
running_mean = self.env.get(node.args[3], relax.const(np.zeros(channel), dtype=dtype))
running_var = self.env.get(node.args[4], relax.const(np.ones(channel), dtype=dtype))
momentum = node.args[5] if len(node.args) > 5 else node.kwargs.get("momentum", 0.1)
eps = node.args[6] if len(node.args) > 6 else node.kwargs.get("eps", 1e-05)
ignore_running_stats = (
node.args[5] if len(node.args) > 5 else node.kwargs.get("track_running_stats", True)
)
track_running_stats = not ignore_running_stats
momentum = node.args[6] if len(node.args) > 6 else node.kwargs.get("momentum", 0.1)
eps = node.args[7] if len(node.args) > 7 else node.kwargs.get("eps", 1e-05)

if track_running_stats:
training = True

return self.block_builder.emit(
relax.op.nn.batch_norm(
x,
weight,
bias,
running_mean,
running_var,
axis=1,
data=x,
gamma=weight,
beta=bias,
moving_mean=running_mean,
moving_var=running_var,
axis=1, # Always over channel
epsilon=eps,
momentum=momentum,
)
training=training,
)[0]
)

def _batch_norm_legit_functional(self, node: fx.Node) -> relax.Var:
# This method is called for batch_norm in training mode
# TODO does not have correctness!
# TODO we need to store the running mean and variance returned by the
# previous call to batch_norm and pass it again
training = True
return self._batch_norm(node, training)

def _batch_norm_legit_no_training(self, node: fx.Node) -> relax.Var:
# This method is called for batch_norm in eval mode
training = False
return self._batch_norm(node, training)

def _group_norm(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
num_groups = node.args[1]
Expand Down Expand Up @@ -283,7 +304,9 @@ def create_convert_map(
# linear algebra
"linalg_vector_norm.default": self._linalg_vector_norm,
# neural network
"_native_batch_norm_legit_functional.default": self._batch_norm_legit_functional,
"_native_batch_norm_legit_no_training.default": self._batch_norm_legit_no_training,
"batch_norm.default": self._batch_norm_legit_no_training,
"adaptive_avg_pool2d.default": self._adaptive_avg_pool2d,
"addmm.default": self._addmm,
"avg_pool2d.default": self._avg_pool2d,
Expand Down
8 changes: 7 additions & 1 deletion python/tvm/relax/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1393,6 +1393,7 @@ def batch_norm(
center: bool = True,
scale: bool = True,
momentum: float = 0.1,
training: bool = True,
) -> Expr:
r"""
Batch normalization layer (Ioffe and Szegedy, 2014).
Expand Down Expand Up @@ -1481,13 +1482,18 @@ def batch_norm(
momentum : float
The value used for the moving_mean and moving_var update.

training : bool
A boolean value to indicate whether training or in eval mode. By default.
relax batch_norm is training mode. To transform it to inference mode,
can use DecomposeOpsForInference.

Returns
-------
result : relax.Expr
The computed result.
"""
return _ffi_api.batch_norm( # type: ignore
data, gamma, beta, moving_mean, moving_var, axis, epsilon, center, scale, momentum
data, gamma, beta, moving_mean, moving_var, axis, epsilon, center, scale, momentum, training
)


Expand Down
4 changes: 1 addition & 3 deletions python/tvm/relax/transform/legalize_ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,9 +551,7 @@ def _nn_batch_norm(bb: BlockBuilder, call: Call) -> Expr:
epsilon=call.attrs.epsilon,
center=call.attrs.center,
scale=call.attrs.scale,
# By default relax batch_norm is training mode.
# To transform it to inference mode, use DecomposeOpsForInference.
training=True,
training=call.attrs.training,
momentum=call.attrs.momentum,
)

Expand Down
26 changes: 15 additions & 11 deletions python/tvm/topi/nn/batch_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,22 +111,26 @@ def batch_norm(
shape = [1] * len(data.shape)
shape[axis] = data.shape[axis]

reduce_axes = list(range(len(data.shape)))
reduce_axes.remove(axis)
shape_prod = reduce(lambda x, y: x * y, [data.shape[ax] for ax in reduce_axes], 1)

data_mean = topi.sum(data, axis=reduce_axes) / shape_prod
data_mean_rs = topi.reshape(data_mean, shape)
data_var = (
topi.sum((data - data_mean_rs) * (data - data_mean_rs), axis=reduce_axes) / shape_prod
)
data_var_rs = topi.reshape(data_var, shape)

if training:
reduce_axes = list(range(len(data.shape)))
reduce_axes.remove(axis)
shape_prod = reduce(lambda x, y: x * y, [data.shape[ax] for ax in reduce_axes], 1)
data_mean = topi.sum(data, axis=reduce_axes) / shape_prod
data_mean_rs = topi.reshape(data_mean, shape)
data_var = (
topi.sum((data - data_mean_rs) * (data - data_mean_rs), axis=reduce_axes) / shape_prod
)
data_var_rs = topi.reshape(data_var, shape)
out = (data - data_mean_rs) / topi.math.sqrt(data_var_rs + epsilon)
else:
moving_mean_rs = topi.reshape(moving_mean, shape)
moving_var_rs = topi.reshape(moving_var, shape)

out = (data - moving_mean_rs) / topi.math.sqrt(moving_var_rs + epsilon)

else:
out = (data - data_mean_rs) / topi.math.sqrt(data_var_rs + epsilon)

if scale:
out = out * topi.reshape(gamma, shape)
if center:
Expand Down
8 changes: 4 additions & 4 deletions src/relax/op/nn/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -252,21 +252,21 @@ bool NormCheckDtypeAndShape(const Call& call, const BlockBuilder& ctx,
TVM_REGISTER_NODE_TYPE(BatchNormAttrs);

Expr batch_norm(Expr data, Expr gamma, Expr beta, Expr moving_mean, Expr moving_var, //
int axis, double epsilon, bool center, bool scale, double momentum) {
int axis, double epsilon, bool center, bool scale, double momentum, bool training) {
ObjectPtr<BatchNormAttrs> attrs = make_object<BatchNormAttrs>();
attrs->axis = axis;
attrs->epsilon = epsilon;
attrs->center = center;
attrs->scale = scale;
attrs->momentum = momentum;
attrs->training = training;

static const Op& op = Op::Get("relax.nn.batch_norm");
return Call(op,
{std::move(data), std::move(gamma), std::move(beta), std::move(moving_mean),
std::move(moving_var)},
Attrs{attrs}, {});
}

TVM_REGISTER_GLOBAL("relax.op.nn.batch_norm").set_body_typed(batch_norm);

StructInfo InferStructInfoBatchNorm(const Call& call, const BlockBuilder& ctx) {
Expand Down Expand Up @@ -388,7 +388,7 @@ InferLayoutOutput InferLayoutLayerNorm(const Call& call,
TVM_REGISTER_OP("relax.nn.layer_norm")
.set_attrs_type<LayerNormAttrs>()
.set_num_inputs(3)
.add_argument("data", "Tensor", "Input to which batch_norm will be applied.")
.add_argument("data", "Tensor", "Input to which layer_norm will be applied.")
.add_argument("gamma", "Tensor", "The gamma scale factor.")
.add_argument("beta", "Tensor", "The beta offset factor.")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoLayerNorm)
Expand Down Expand Up @@ -500,7 +500,7 @@ InferLayoutOutput InferLayoutGroupNorm(const Call& call,
TVM_REGISTER_OP("relax.nn.group_norm")
.set_attrs_type<GroupNormAttrs>()
.set_num_inputs(3)
.add_argument("data", "Tensor", "Input to which batch_norm will be applied.")
.add_argument("data", "Tensor", "Input to which group_norm will be applied.")
.add_argument("gamma", "Tensor", "The gamma scale factor.")
.add_argument("beta", "Tensor", "The beta offset factor.")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoGroupNorm)
Expand Down
2 changes: 1 addition & 1 deletion src/relax/op/nn/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ Expr log_softmax(Expr data, int axis);

/*! \brief Compute batch normalization. */
Expr batch_norm(Expr data, Expr gamma, Expr beta, Expr moving_mean, Expr moving_var, //
int axis, double epsilon, bool center, bool scale, double momentum);
int axis, double epsilon, bool center, bool scale, double momentum, bool training);

/*! \brief Compute layer normalization. */
Expr layer_norm(Expr data, Expr gamma, Expr beta, Array<Integer> axes, double epsilon, bool center,
Expand Down
59 changes: 58 additions & 1 deletion tests/python/relax/test_from_exported_to_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,25 @@ def forward(self, x):
assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module3, target, dev)


@tvm.testing.parametrize_targets("cuda")
def test_batch_norm_prog(target, dev):
# Default args, in a pytorch program (to ensure output is in proper type and format)
raw_data = np.random.randn(2, 3, 2, 2).astype(np.float32)

class BatchNormWrapper(nn.Module):
def __init__(self):
super(BatchNormWrapper, self).__init__()
self.bn = nn.BatchNorm2d(3)

def forward(self, x):
x = self.bn(x)
x = x + 1
return x

torch_module = BatchNormWrapper().eval()
assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev)


@tvm.testing.parametrize_targets("cuda")
def test_split_size(target, dev):
# Test split using the split_size argument such that it is not a divisor
Expand All @@ -310,7 +329,46 @@ def forward(self, x):
return torch.split(x, split_size_or_sections=self.split_size, dim=self.dim)

torch_module = SplitModelSplitSize(split_size=split_size, dim=dim).eval()
assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev)


@tvm.testing.parametrize_targets("cuda")
def test_batch_norm0(target, dev):
# Eval, no momentum, no affine, no running stats
raw_data = np.random.randn(8, 3, 4, 4).astype(np.float32)
torch_module = nn.BatchNorm2d(
3, eps=1e-02, momentum=0.0, affine=False, track_running_stats=False, device=None, dtype=None
).eval()
assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev)


@tvm.testing.parametrize_targets("cuda")
def test_batch_norm1(target, dev):
# Eval, with momentum, no affine, with running stats
raw_data = np.random.randn(1, 4, 2, 2).astype(np.float32)
torch_module = nn.BatchNorm2d(
4, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True, device=None, dtype=None
).eval()
assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev)


@tvm.testing.parametrize_targets("cuda")
def test_batch_norm2(target, dev):
# Eval, with momentum, affine, no running stats
raw_data = np.random.randn(3, 4, 2, 2).astype(np.float32)
torch_module = nn.BatchNorm2d(
4, eps=1e-05, momentum=0.2, affine=True, track_running_stats=False
).eval()
assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev)


@tvm.testing.parametrize_targets("cuda")
def test_batch_norm3(target, dev):
# Eval, no momentum, affine, with running stats
raw_data = np.random.randn(1, 2, 2, 2).astype(np.float32)
torch_module = nn.BatchNorm2d(
2, eps=1e-05, momentum=0.0, affine=True, track_running_stats=True
).eval()
assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev)


Expand All @@ -335,7 +393,6 @@ def forward(self, x):
return torch.split(x, split_size_or_sections=self.split_size, dim=self.dim)

torch_module = SplitModelSectionsList(split_size=sections, dim=dim).eval()

assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev)


Expand Down
Loading