Skip to content

Commit

Permalink
fix group norm, add scaffolding for autograd.grad tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Sam Andow committed Apr 6, 2022
1 parent b504e6d commit 1d8759d
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 52 deletions.
68 changes: 41 additions & 27 deletions functorch/csrc/BatchRulesNorm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,39 @@ std::tuple<Tensor,Tensor,Tensor> native_group_norm_plumbing(
return std::make_tuple(result0, mean, rstd);
}

std::tuple<at::Tensor,optional<int64_t>> group_norm_backward_no_weight_bias_batch_rule(
const at::Tensor & grad_out, optional<int64_t> grad_out_bdim,
const at::Tensor & input, optional<int64_t> input_bdim,
const at::Tensor & mean, optional<int64_t> mean_bdim,
const at::Tensor & rstd, optional<int64_t> rstd_bdim,
int64_t N, int64_t C, int64_t HxW, int64_t group) {
auto grad_out_ = moveBatchDimToFront(grad_out, grad_out_bdim);
auto input_ = moveBatchDimToFront(input, input_bdim);
auto mean_ = moveBatchDimToFront(mean, mean_bdim);
auto rstd_ = moveBatchDimToFront(rstd, rstd_bdim);

const auto bdim_size = get_bdim_size2(grad_out, grad_out_bdim, input, input_bdim);
grad_out_ = ensure_has_bdim(grad_out, grad_out_bdim.has_value(), bdim_size);
input_ = ensure_has_bdim(input_, input_bdim.has_value(), bdim_size);
mean_ = ensure_has_bdim(mean_, mean_bdim.has_value(), bdim_size);
rstd_ = ensure_has_bdim(rstd_, rstd_bdim.has_value(), bdim_size);

grad_out_ = reshape_dim_into(0, 0, grad_out_); // [B0 * N, C, *]
input_ = reshape_dim_into(0, 0, input_); // [B0 * N, C, *]
mean_ = reshape_dim_into(0, 0, mean_); // [B0 * N, G]
rstd_ = reshape_dim_into(0, 0, rstd_); // [B0 * N, G]

const auto result = native_group_norm_backward(
grad_out_.contiguous(),
input_.contiguous(),
mean_.contiguous(),
rstd_.contiguous(),
nullopt, N * bdim_size, C, HxW, group, {true, false, false});
auto result0 = std::get<0>(result);
result0 = reshape_dim_outof(0, bdim_size, result0);
return std::make_tuple(result0, 0);
}

std::tuple<Tensor,Tensor,Tensor> native_group_norm_backward_plumbing(
const Tensor & grad_out, const Tensor & input, const Tensor & mean,
const Tensor & rstd, const c10::optional<Tensor> & weight_opt,
Expand All @@ -368,9 +401,6 @@ std::tuple<Tensor,Tensor,Tensor> native_group_norm_backward_plumbing(
return at::native_group_norm_backward(grad_out, input, mean, rstd, weight_opt, N, C, HxW, group, output_mask);
}

Tensor grad_out_value;
optional<int64_t> grad_out_bdim;
std::tie(grad_out_value, grad_out_bdim) = unwrapTensorAtLevel(grad_out, cur_level);
Tensor input_value;
optional<int64_t> input_bdim;
std::tie(input_value, input_bdim) = unwrapTensorAtLevel(input, cur_level);
Expand Down Expand Up @@ -410,32 +440,16 @@ std::tuple<Tensor,Tensor,Tensor> native_group_norm_backward_plumbing(
optional<int64_t> grad_normalized_input_bdim;
std::tie(grad_normalized_input_value, grad_normalized_input_bdim) =
unwrapTensorAtLevel(grad_normalized_input, cur_level);
auto grad_out_ = moveBatchDimToFront(grad_normalized_input_value, grad_normalized_input_bdim);
auto input_ = moveBatchDimToFront(input_value, input_bdim);
auto mean_ = moveBatchDimToFront(mean_value, mean_bdim);
auto rstd_ = moveBatchDimToFront(rstd_value, rstd_bdim);

const auto bdim_size = get_bdim_size3(grad_out_, grad_out_bdim, input_, input_bdim, weight, weight_bdim);
grad_out_ = ensure_has_bdim(grad_out_, grad_out_bdim.has_value(), bdim_size);
input_ = ensure_has_bdim(input_, input_bdim.has_value(), bdim_size);
mean_ = ensure_has_bdim(mean_, mean_bdim.has_value(), bdim_size);
rstd_ = ensure_has_bdim(rstd_, rstd_bdim.has_value(), bdim_size);

grad_out_ = reshape_dim_into(0, 0, grad_out_); // [B0 * N, C, *]
input_ = reshape_dim_into(0, 0, input_); // [B0 * N, C, *]
mean_ = reshape_dim_into(0, 0, mean_); // [B0 * N, G]
rstd_ = reshape_dim_into(0, 0, rstd_); // [B0 * N, G]

c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
const auto result = native_group_norm_backward(
grad_out_,
input_,
mean_,
rstd_,
nullopt, N * bdim_size, C, HxW, group, {true, false, false});
auto result0 = std::get<0>(result);
result0 = reshape_dim_outof(0, bdim_size, result0);
grad_input = makeBatched(result0, 0, cur_level);
const auto res = group_norm_backward_no_weight_bias_batch_rule(
grad_normalized_input_value, grad_normalized_input_bdim,
input_value, input_bdim,
mean_value, mean_bdim,
rstd_value, rstd_bdim,
N, C, HxW, group
);
grad_input = makeBatched(std::get<0>(res), std::get<1>(res), cur_level);
}
return std::make_tuple(grad_input, grad_weight, grad_bias);
}
Expand Down
70 changes: 45 additions & 25 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,8 @@
from torch.testing._internal.common_utils import TestCase, run_tests, is_iterable_of_tensors
import torch
from torch import Tensor
import torch.nn.functional as F
import functools
import unittest
import itertools
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_device_type import ops
from torch.testing._internal.common_dtype import integral_types
Expand All @@ -28,7 +26,6 @@
# tol2,
opsToleranceOverride,
check_vmap_fallback,
loop,
IS_FBCODE,
)
from torch.utils._pytree import tree_flatten, tree_unflatten, tree_map
Expand Down Expand Up @@ -197,6 +194,35 @@ def wrapped(*args):
return wrapped, tuple(flat_args + flat_cotangents)


# returns a new function g(*args, *cotangents)
# that computes vjps and (*args, cotangents) using torch.autograd.grad
def get_autograd_fn_and_args_with_cotangents(f, sample, cotangents):
args = tuple([sample.input] + list(sample.args))
kwargs = sample.kwargs
flat_args, args_spec = tree_flatten(args)
flat_cotangents, cotangents_spec = tree_flatten(cotangents)

@functools.wraps(f)
def wrapped(*args):
assert len(args) == len(flat_args) + len(flat_cotangents)
actual_args = args[:len(flat_args)]
cotangents = args[len(flat_args):]
actual_args = tree_unflatten(actual_args, args_spec)
cotangents = tree_unflatten(cotangents, cotangents_spec)

fn, primals = normalize_op_input_output3(f, actual_args, kwargs,
flat_args,
sample.output_process_fn_grad)
out = fn(*primals)
diff_wrt = tuple(primal for primal in primals if (primal.requires_grad or primal.grad_fn is not None))
if diff_wrt:
return torch.autograd.grad(out, diff_wrt, grad_outputs=cotangents)
else:
return (torch.ones(()),) # uuugh hack...this will need to be more generic

return wrapped, tuple(flat_args + flat_cotangents)


# Returns a new function g(*args, *cotangents) that computes vjps and
# sample (*args, *cotangents)
def get_vjpfull_variant(f, sample):
Expand Down Expand Up @@ -1215,28 +1241,22 @@ def get_names(inpt):
for op in get_names(run_decompositions):
f.write(f'{op}\n')

def test_group_norm_backward(self, device):
# group norm will hit the decomposable ``infinitely_differentiable_group_norm_backward`` when
# GradMode is on, which happens by default in the grad transform. This avoids that
def f(x, weight, bias, grad_out):
output = F.group_norm(x, 6, weight, bias)
inputs = []
for input in (x, weight, bias):
if input.requires_grad:
inputs.append(input)
return torch.autograd.grad(outputs=output, inputs=inputs, grad_outputs=grad_out)

B, N, C, H, W = 2, 3, 24, 5, 7
for (input_grad, weight_grad, bias_grad) in itertools.product((True, False), (True, False), (True, False)):
if not input_grad and not weight_grad and not bias_grad:
continue
x = torch.randn(N, C, H, W, device=device, requires_grad=input_grad)
weight = torch.randn(C, device=device, requires_grad=weight_grad)
bias = torch.randn(C, device=device, requires_grad=bias_grad)
grad_out = torch.randn(B, N, C, H, W, device=device)
loop_out = loop(f, (None, None, None, 0), 0, 2, x, weight, bias, grad_out)
batched_out = vmap(f, (None, None, None, 0), 0)(x, weight, bias, grad_out)
self.assertEqual(loop_out, batched_out)
@ops(filter(lambda op: op.name == "nn.functional.group_norm", functorch_lagging_op_db + additional_op_db),
allowed_dtypes=(torch.float32, torch.double)) # TODO: generalize
def test_group_norm_backward(self, device, dtype, op):
# hacky, only works since no group norm inputs can be scalars
def was_skipped_from_batched_tensors(batched_out, batch_size):
return batched_out.shape == (batch_size,) and all(tuple(e == 1 for e in batched_out))

sample_inputs = op.sample_inputs(device, dtype, requires_grad=True)

for sample_input in sample_inputs:
cotangents = get_sample_cotangents(op, sample_input)
f, args = get_autograd_fn_and_args_with_cotangents(op, sample_input, cotangents)
for loop_out, batched_out in get_fallback_and_vmap_exhaustive(f, args, {}, opinfo=op):
if all(was_skipped_from_batched_tensors(bo, lo.shape[0]) for (bo, lo) in zip(batched_out, loop_out)):
continue # we weren't able to use the batched tensor in autograd.grad
self.assertEqual(loop_out, batched_out)


only_for = ("cpu", "cuda")
Expand Down

0 comments on commit 1d8759d

Please sign in to comment.