Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PROTOTYPE] generated batching rules for custom dispatcher ops #578

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions functorch/_src/custom_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ def custom_vjp(name, filter_fn, fwd_fn, bwd_fn):
m.def_(f"{name}(Tensor[] args) -> Tensor[]")
m.impl(f"{name}", "CompositeImplicitAutograd", fwd_fn)

m.gen_vmap_binding(f"{name}")

m.def_(f"{name}_vjp(Tensor[] args) -> Tensor[]")
m.impl(f"{name}_vjp", "CompositeImplicitAutograd", bwd_fn)

Expand Down
107 changes: 99 additions & 8 deletions functorch/csrc/CustomFunction.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include <functorch/csrc/CustomFunction.h>
#include <functorch/csrc/BatchedTensorImpl.h>
#include <ATen/ATen.h>
#include <torch/csrc/autograd/function.h>
#include <torch/csrc/autograd/variable.h>
Expand Down Expand Up @@ -207,7 +208,7 @@ using torch::autograd::collect_next_edges;
using torch::autograd::deleteNode;
using torch::autograd::flatten_tensor_args;

void customFunctionBoxed(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
void customFunctionBoxed(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool get_output_by_running_forward_pass) {
auto tensors = torch::jit::pop(stack).toTensorList().vec();
auto tensors_ = unpack(tensors, "tensors", 0);
auto _any_requires_grad = compute_requires_grad(tensors);
Expand All @@ -226,12 +227,27 @@ void customFunctionBoxed(const c10::OperatorHandle& op, torch::jit::Stack* stack
grad_fn->num_inputs_ = tensors_.size();
}

auto typed_handle = op.typed<custom_function_t>();
std::vector<Tensor> _tmp = ([&]() {
at::AutoDispatchBelowADInplaceOrView guard;
return typed_handle.call(tensors_);
})();
auto result = std::move(_tmp);
std::vector<at::Tensor> result;
// When this is true, we:
// - run the forward pass
// - construct the autograd graph
// - return the result
// When this is false, we:
// - DONT run the forward pass (and instead, assume that the output from the forward pass
// was already pushed on the stack)
// - construct the autograd graph, using the (unwrapped) inputs and outputs from the fwd pass
// - DONT return the result
if (get_output_by_running_forward_pass) {
auto typed_handle = op.typed<custom_function_t>();
std::vector<Tensor> _tmp = ([&]() {
at::AutoDispatchBelowADInplaceOrView guard;
return typed_handle.call(tensors_);
})();
result = std::move(_tmp);
} else {
result = torch::jit::pop(stack).toTensorList().vec();
}

if (grad_fn) {
for (auto& tensor : result) {
// TODO: is this right?
Expand All @@ -248,9 +264,77 @@ void customFunctionBoxed(const c10::OperatorHandle& op, torch::jit::Stack* stack
grad_fn->saved_tensors_.push_back(torch::autograd::SavedVariable(tensor, !is_input));
}
}
torch::jit::push(stack, result);
// if we computed the output ourselves, return it.
if (get_output_by_running_forward_pass) {
torch::jit::push(stack, result);
}
}

void customFunctionBoxed(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
customFunctionBoxed(op, stack, /*get_output_by_running_forward_pass=*/true);
}


void generatedCustomBatchingRule(const c10::OperatorHandle& op, c10::DispatchKeySet ks, torch::jit::Stack* stack) {
// We basically simulate running the user's op in inference mode WITH the decomposition
// And then separately we create the autograd graph WITHOUT the decomposition.
// This allows us to decompose and "get batching rules for free",
// while still being able to run a user's custom backward function
// (which might be necessary for numeric stability).

auto tensors = torch::jit::pop(stack).toTensorList().vec();
auto typed_handle = op.typed<custom_function_t>();

// Step (1) = run the forward using the decomposition
std::vector<Tensor> _tmp = ([&]() {
// NOTE: I don't think this composes with DynamicLayer very well.
// I want to fully turn off autograd when I call the custom python operator,
// but when vmap() is active, DynamicLayer will overwrite TLS and (potentially) run autograd anyway.
// TODO: think more about this.
at::AutoDispatchBelowADInplaceOrView guard;
// The tensor arguments should all be batched tensors at this point,
// so what will happen is we:
// (a) Skip the autograd key and go straight to the backend
// (potentially running other stuff like AMP along the way)
// (b) Enter the user's python kernel, which runs a bunch of "prim" aten ops
// (c) Those prim ops each enter the dispatcher, and we'll hit each of their
// batching rule kernels (because our inputs are *still* BatchedTensors)
// TODO better idiom for this - I just want to go straight to the python impl
return typed_handle.redispatch(c10::DispatchKeySet(c10::DispatchKey::CPU), tensors);
})();
auto forward_result = std::move(_tmp);

// Step (2) = Create the autograd graph without the decomposition.
// Taking special care to "re-use" the same inputs/outputs in the autograd kernel
// that we got from the forward pass.
// This is really hacky - I'm hardcoding the boxed autograd kernel
// to know that when it's running in "don't run the forward pass" mode,
// it can assume that the arguments on the stack are <unwrapped_output, unwrapped_inputs...>
// from the forward pass.
auto unwrapped_args = std::vector<Tensor>();
for (const auto& a : tensors) {
TORCH_INTERNAL_ASSERT(at::functorch::isBatchedTensor(a));
unwrapped_args.push_back(at::functorch::unsafeGetBatchedImpl(a)->value());
}
auto unwrapped_outs = std::vector<Tensor>();
for (const auto& a : forward_result) {
TORCH_INTERNAL_ASSERT(at::functorch::isBatchedTensor(a));
unwrapped_outs.push_back(at::functorch::unsafeGetBatchedImpl(a)->value());
}
// relying on customFunctionBoxed will push these off the stack.
torch::jit::push(stack, unwrapped_outs);
torch::jit::push(stack, unwrapped_args);
{
// When get_output_by_running_forward_pass is false, the autograd boxed fallback knows to:
// (a) add the vjp to the autograd graph
// (b) NOT run the forward pass
customFunctionBoxed(op, stack, /*get_output_by_running_forward_pass=*/false);
}

torch::jit::push(stack, forward_result);
}


void initDispatchBindings(PyObject* module) {
auto m = py::handle(module).cast<py::module>();

Expand All @@ -272,6 +356,13 @@ void initDispatchBindings(PyObject* module) {
torch::CppFunction::makeFromBoxedFunction<&customFunctionBoxed>())
);
}, "", py::arg("name"), py::arg("dispatch"))
.def("gen_vmap_binding", [](py::object self, const char* name) {
self.cast<torch::Library&>().impl(
name,
dispatch_str("FuncTorchBatched",
torch::CppFunction::makeFromBoxedFunction<&generatedCustomBatchingRule>())
);
}, "", py::arg("name"))
.def("fallback_fallthrough", [](py::object self, const char* dispatch) {
self.cast<torch::Library&>().fallback(
dispatch_str(dispatch, torch::CppFunction::makeFallthrough())
Expand Down
51 changes: 51 additions & 0 deletions test/test_eager_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1905,6 +1905,57 @@ def filter_fn(args):

assert torch.allclose(x.grad, 3 * x.cos())

@onlyCPU
def test_generated_batching_rule_for_custom_op(self, device):
called_impl = False
called_vjp = False
called_impl_with_batched_args = None
called_vjp_with_batched_args = None

def my_sin_impl(args):
x, = args
nonlocal called_impl
nonlocal called_impl_with_batched_args
called_impl = True
called_impl_with_batched_args = functorch._C.is_batchedtensor(x)
return x.sin(), x

def my_sin_vjp(args):
grad_y, result, x = args
nonlocal called_vjp
nonlocal called_vjp_with_batched_args
called_vjp = True
called_vjp_with_batched_args = all(functorch._C.is_batchedtensor(a) for a in [grad_y, result, x])
return (grad_y * 3 * x.cos(),)

def filter_fn(args):
return args[0]

my_sin = custom_vjp('my_sin2', filter_fn, my_sin_impl, my_sin_vjp)

x = torch.tensor([[[1., 2.], [3., 4.]], [[1., 2.], [3., 4.]]], requires_grad=True, device=device)
x_copy = torch.tensor([[[1., 2.], [3., 4.]], [[1., 2.], [3., 4.]]], requires_grad=True, device=device)

vmap_my_sin = vmap(my_sin)
y = vmap_my_sin(x)
self.assertTrue(called_impl)
# We expect to run the custom forward with batched tensors, so when
# it decomposes into base ops we run the batching rule on each base op.
self.assertTrue(called_impl_with_batched_args)

y.sum().backward()
self.assertTrue(called_vjp)
# We expect to run the custom forward with non-batched tensors,
# because we didn't explictly vmap over the backward() call.
self.assertFalse(called_vjp_with_batched_args)

assert torch.allclose(x.grad, 3 * x.cos())

y_copy = my_sin(x_copy)
y_copy.sum().backward()
assert torch.allclose(y_copy, y)
assert torch.allclose(x_copy.grad, x.grad)


class TestComposability(TestCase):
def test_grad_grad(self, device):
Expand Down