Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add partial support for second-order derivatives (for grid.h's input) #69

Merged
merged 9 commits into from
Mar 25, 2022
53 changes: 53 additions & 0 deletions bindings/torch/tinycudann/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,58 @@ class Module {
return { dL_dinput, dL_dparams };
}

std::tuple<torch::Tensor, torch::Tensor> bwd_bwd_input(const tcnn::cpp::Context& ctx, torch::Tensor input, torch::Tensor params, torch::Tensor dL_ddLdinput, torch::Tensor dL_doutput) {
// from: dL_ddLdinput
// to: dL_ddLdoutput, dL_dparams

if (!ctx.ctx) {
throw std::runtime_error{"Module::bwd_bwd_input: called with invalid context. fwd likely (mistakenly) ran in inference mode."};
}

// Types
CHECK_THROW(input.scalar_type() == torch::kFloat32);
CHECK_THROW(dL_ddLdinput.scalar_type() == torch::kFloat32);
CHECK_THROW(params.scalar_type() == c10_param_precision());
CHECK_THROW(dL_doutput.scalar_type() == c10_output_precision());

// Sizes
CHECK_THROW(input.size(1) == n_input_dims());
CHECK_THROW(dL_doutput.size(1) == n_output_dims());
CHECK_THROW(dL_ddLdinput.size(1) == n_input_dims());
CHECK_THROW(params.size(0) == n_params());
CHECK_THROW(dL_doutput.size(0) == input.size(0));

uint32_t batch_size = input.size(0);

cudaStream_t stream = at::cuda::getCurrentCUDAStream();

torch::Tensor dL_ddLdoutput;
if (dL_doutput.requires_grad()) {
dL_ddLdoutput = torch::zeros({ batch_size, n_output_dims() }, torch::TensorOptions().dtype(c10_output_precision()).device(torch::kCUDA));
}

torch::Tensor dL_dparams;
if (params.requires_grad()) {
dL_dparams = torch::zeros({ n_params() }, torch::TensorOptions().dtype(c10_param_precision()).device(torch::kCUDA));
}

if (dL_doutput.requires_grad() || params.requires_grad()) {
m_module->backward_backward_input(
stream,
ctx,
batch_size,
dL_ddLdinput.data_ptr<float>(),
input.data_ptr<float>(),
params.requires_grad() ? void_data_ptr(dL_doutput) : nullptr,
params.requires_grad() ? void_data_ptr(dL_dparams) : nullptr,
dL_doutput.requires_grad() ? void_data_ptr(dL_ddLdoutput) : nullptr,
void_data_ptr(params)
);
}

return {dL_ddLdoutput, dL_dparams};
}

torch::Tensor initial_params(size_t seed) {
torch::Tensor output = torch::zeros({ n_params() }, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA));
m_module->initialize_params(seed, output.data_ptr<float>());
Expand Down Expand Up @@ -230,6 +282,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::class_<Module>(m, "Module")
.def("fwd", &Module::fwd)
.def("bwd", &Module::bwd)
.def("bwd_bwd_input", &Module::bwd_bwd_input)
.def("initial_params", &Module::initial_params)
.def("n_input_dims", &Module::n_input_dims)
.def("n_params", &Module::n_params)
Expand Down
51 changes: 45 additions & 6 deletions bindings/torch/tinycudann/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ def forward(ctx, native_tcnn_module, input, params, loss_scale):
return output

@staticmethod
@once_differentiable
Tom94 marked this conversation as resolved.
Show resolved Hide resolved
def backward(ctx, doutput):
if doutput is None:
return None, None, None, None
Expand All @@ -47,14 +46,54 @@ def backward(ctx, doutput):
doutput = doutput.cuda()

input, params, output = ctx.saved_tensors
with torch.no_grad():
scaled_grad = doutput * ctx.loss_scale
input_grad, weight_grad = ctx.native_tcnn_module.bwd(ctx.native_ctx, input, params, output, scaled_grad)
input_grad = None if input_grad is None else (input_grad / ctx.loss_scale)
weight_grad = None if weight_grad is None else (weight_grad / ctx.loss_scale)
input_grad, weight_grad = _module_function_backward.apply(ctx, doutput, input, params, output)

return None, input_grad, weight_grad, None

class _module_function_backward(torch.autograd.Function):
@staticmethod
def forward(ctx, ctx_fwd, doutput, input, params, output):
ctx.ctx_fwd = ctx_fwd
ctx.save_for_backward(input, params, doutput)
with torch.no_grad():
scaled_grad = doutput * ctx_fwd.loss_scale
input_grad, weight_grad = ctx_fwd.native_tcnn_module.bwd(ctx_fwd.native_ctx, input, params, output, scaled_grad)
input_grad = None if input_grad is None else (input_grad / ctx_fwd.loss_scale)
weight_grad = None if weight_grad is None else (weight_grad / ctx_fwd.loss_scale)
return input_grad, weight_grad

@staticmethod
def backward(ctx, dinput_grad, dweight_grad):
# NOTE: currently support:
# ✓ d(dL_dinput)_d(dL_doutput) doutput_grad
# ✓ d(dL_dinput)_d(params) weight_grad
# x d(dL_dinput)_d(input)
# x d(dL_dparam)_d(...)
input, params, doutput = ctx.saved_tensors
# assert dweight_grad is None, "currently do not support 2nd-order gradients from gradient of grid"
# assert not input.requires_grad, "currently do not support 2nd-order gradients from gradient of dLdx to input"
with torch.enable_grad():
# NOTE: preserves requires_grad info
# NOTE: be cautious when multiplying and dividing loss_scale:
# dinput_grad = dinput_grad * ctx.ctx_fwd.loss_scale
doutput = doutput * ctx.ctx_fwd.loss_scale
Tom94 marked this conversation as resolved.
Show resolved Hide resolved
with torch.no_grad():
doutput_grad, weight_grad = ctx.ctx_fwd.native_tcnn_module.bwd_bwd_input(
ctx.ctx_fwd.native_ctx,
input,
params,
dinput_grad,
doutput
)
# NOTE:
# doutput_grad uses dinput_grad
# weight_grad uses dinput_grad * doutput
# doutput_grad = None if doutput_grad is None else (doutput_grad / ctx.ctx_fwd.loss_scale)
weight_grad = None if weight_grad is None else (weight_grad / ctx.ctx_fwd.loss_scale)

# ctx_fwd, doutput, input, params, output
return None, doutput_grad, None, weight_grad, None

class Module(torch.nn.Module):
def __init__(self, seed=1337):
super(Module, self).__init__()
Expand Down
1 change: 1 addition & 0 deletions include/tiny-cuda-nn/cpp_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ class Module {
virtual void inference(cudaStream_t stream, uint32_t n_elements, const float* input, void* output, void* params) = 0;
virtual Context forward(cudaStream_t stream, uint32_t n_elements, const float* input, void* output, void* params, bool prepare_input_gradients) = 0;
virtual void backward(cudaStream_t stream, const Context& ctx, uint32_t n_elements, float* dL_dinput, const void* dL_doutput, void* dL_dparams, const float* input, const void* output, const void* params) = 0;
virtual void backward_backward_input(cudaStream_t stream, const Context& ctx, uint32_t n_elements, const float* dL_ddLdinput, const float* input, const void* dL_doutput, void* dL_dparams, void* dL_ddLdoutput, const void* params) = 0;

virtual uint32_t n_input_dims() const = 0;

Expand Down
Loading