-
Notifications
You must be signed in to change notification settings - Fork 457
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
merge 2nd order derivative of CutlassMLP #370
base: master
Are you sure you want to change the base?
Conversation
…F and MSE + Eikonal loss
…test_sdf_derivation.py and reconstructing a sphere, test_sphere_bwdbwd.py both with an Eikonal loss leveraged
…9/tiny-cuda-nn into 2nd_order_cutlassmlp
Hi Susan, thank you (and @yanglixiao1994) very much for this contribution as well as its thorough derivation and testing! I'd like to take the time to properly review it, but am already occupied in the next week or two. I'll get back to this PR afterwards -- thanks again. |
Thanks for your wonderful job. so this implementation is almost same as neus2 ? @SusanLiu0709 |
i have try your PR, But i can not find the file Armadillo.ply, how can i generate it. @SusanLiu0709 |
Hi zebin-dm, we were implementing the feature before neuS2 released. And currently we are preparing to learn the details of neuS2 and check if it's available to merge neuS2 into our implementation. |
Thanks for testing :) @yanglixiao1994 may help to upload the test data and params soon. |
Hi, zebin. This is the testing armadillo data. (https://drive.google.com/file/d/1KfIkGcLkQOopnXoBLmkT55bBBNQu6nBm/view?usp=sharing). Actually, you can generate your own data(3D grid and corresponding sdf) according to https://github.com/SusanLiu0709/tiny-cuda-nn/blob/8c66716e59b94f73f918c058797e17368528c748/scripts/test_armadillo_numeric_align.py#L129 |
Thank you very much. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, thanks again for this PR! I requested a few changes in the C++ code that are required before I can merge. Please feel free to discuss if anything is unclear or if I missed something.
Once the changes are in, I'll go through the testing code (which I appreciate a lot, by the way) and give another round of feedback.
@@ -60,7 +60,7 @@ __global__ void identity( | |||
const uint32_t j = encoded_index - i * fan_out; | |||
|
|||
if (j >= num_to_encode) { | |||
data_out(j, i) = 1; | |||
data_out(j, i) = 0; // data_out(j, i) = 0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be reverted to previous behavior
// element-wise convert float* to T* | ||
template <typename T> | ||
__global__ void element_wise_convert(uint32_t n_elements, float* in, T* out) { | ||
uint32_t i = blockIdx.x * blockDim.x + threadIdx.x; | ||
if (i >= n_elements) return; | ||
|
||
out[i] = (T)in[i]; | ||
} | ||
|
||
// element-wise convert T* to float* and then add back to *out | ||
template <typename T> | ||
__global__ void element_wise_convert_float(uint32_t n_elements, T* in, float* out) { | ||
uint32_t i = blockIdx.x * blockDim.x + threadIdx.x; | ||
if (i >= n_elements) return; | ||
|
||
out[i] += (float)in[i]; | ||
} | ||
|
||
// element-wise add | ||
template <typename T> | ||
__global__ void element_wise_add(uint32_t n_elements, T* in, T* out) { | ||
uint32_t i = blockIdx.x * blockDim.x + threadIdx.x; | ||
if (i >= n_elements) return; | ||
|
||
out[i] += in[i]; | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use common_device.h:cast()
and add()
kernels that implement the same functionality instead.
@@ -181,6 +291,7 @@ class NetworkWithInputEncoding : public Network<float, T> { | |||
private: | |||
std::shared_ptr<Encoding<T>> m_encoding; | |||
std::shared_ptr<Network<T>> m_network; | |||
GPUMatrixDynamic<T> dL_dnetwork_input; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
GPU memory (or pointers to GPU memory) can not be class members.
They should be part of either the ForwardContext
or, if necessary, an additional context for the Fwd+Bwd pass (in analogy to how double bwd contexts are implemented here https://github.com/NVlabs/tiny-cuda-nn/blob/212104156403bd87616c1a4f73a1c5f2c2e172a9/bindings/torch/tinycudann/modules.py#L120C6-L120C6) to enable multiple parallel passes through the model, and to support multi-GPU execution.
template <typename CutlassLayer, typename T> | ||
bool compute_fc_layer( | ||
cudaStream_t stream, | ||
const GPUMatrix<T, RM>& weights, | ||
const GPUMatrixDynamic<T>& input, | ||
GPUMatrixDynamic<T>& p | ||
) { | ||
// compute for forward values before activation | ||
fc_multiply<CutlassLayer>(stream, weights, input, p); | ||
|
||
return true; | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function looks like it does not serve a purpose and should be inlined.
template <typename T> | ||
__global__ void compute_activation_backward_backward(uint32_t n_elements, Activation activation, T* p, T* res) { | ||
uint32_t i = blockIdx.x * blockDim.x + threadIdx.x; | ||
if (i >= n_elements) return; | ||
|
||
switch (activation) { | ||
case Activation::Softplus: | ||
float K_ACT = 10.0; | ||
float tmp = (float)p[i] * K_ACT;// | ||
if (tmp > 10.0) { | ||
tmp = 10.0; | ||
} else if (tmp < -15.0) { | ||
tmp = -15.0; | ||
} | ||
|
||
float exp_tmp = expf(tmp); | ||
float pow_tmp = (exp_tmp + 1.0) * (exp_tmp + 1.0); | ||
float ddoutputdp_dp = exp_tmp / pow_tmp * K_ACT; | ||
res[i] = (T)ddoutputdp_dp; | ||
return; | ||
|
||
case Activation::ReLU: | ||
res[i] = 0.0; | ||
return; | ||
|
||
default: | ||
// ERROR: this activation currently is not supported | ||
res[i] = 0.0; | ||
return; | ||
} | ||
|
||
return; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Second order gradients should be implemented for all activation functions that admit them, not just for Softplus
. They should also be placed in common_device.h
next to the other activation function related implementations (and following the same warp-level API).
param_gradients_mode | ||
); | ||
|
||
if (dL_ddLdoutput) { // if dL_ddLdoutput is not nullptr |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Redundant comments like this one should be removed from the PR. (Basically, use comments only to explain behavior that is not clear from the programming language itself. Like algorithmic details, mathematical reasoning, hidden details, etc.)
template <typename CutlassLayer, typename T> | ||
bool compute_dL2dinput( | ||
cudaStream_t stream, | ||
bool is_inference, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unused parameter?
const GPUMatrixDynamic<T>& dL_ddLdinput, // dL2_ddL1dinput | ||
const GPUMatrixDynamic<T>& ddoutputdp_dp, // ddoutputdp_dp | ||
GPUMatrixDynamic<T>& dL_dinput, | ||
GradientMode param_gradients_mode |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also unused parameter?
GPUMatrixDynamic<T>& dL_dinput, | ||
GradientMode param_gradients_mode | ||
) { | ||
// no dL2dinput when activation is None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Redundant comment
fc_multiply_split_k<FullLayerK>(stream, dL_dp, dL_ddLdinput.transposed(), weight_gradient, split_k_factor, param_gradient_beta); | ||
} | ||
|
||
// when activation is None, don't have to compute dL2dw_2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Redundant comment. If you'd like to have a comment here, it should explain why dL2dw_2 does not need to be computed if the activation is none. (Though I think most readers familiar with derivatives will accept it in any case.)
Hi Thomas @Tom94, Sorry for replying late. I was busy last 3 months. Best, |
Thanks you very much, looking forward your nice job. |
Hi, Using the network config below, I still get an error when calculating the second derivative.
Should these changes also be available in pytorch at the moment? |
For now, the 2nd order derivative of the activation function "Sine" is not supported. You can try with "ReLU" and "Softplus". |
Hi @Tom94,
As second order derivative is a common operation in 3D reconstruction (neural SDF related), there sometimes requires an Eikonal loss during training, in which a first order derivative is leveraged to better learn the normal of surface so as to compute a second order derivative when backpropagation. We (@SusanLiu0709 and @yanglixiao1994) implemented the feature of 2nd order derivative of class CutlassMLP and NetworkWithInputEncoding.
Overall
There are major changes in files below:
network_with_input_encoding.h
: add upbackward_backward_input_impl
function ofclass NetworkWithInputEncoding
, of which the dataflow includes 2 modules interaction, encoding and network, shown as Figure 1.cutlass_mlp.h
: add up definitions involved in 2nd order derivative of the network class CutlassMLP, including backward_backward_input_impl, prepare_backward_variables.cutlass_mlp.cu
: add up implementation ofbackward_backward_input_impl
and necessary kernels and functions, includingcompute_dL2dw
,compute_dL2_ddL1doutput
,comput_dL2dinput
, etc. Its detailed dataflow is shown in Figure 2.identity.h
: add up 2nd order derivative support of identity encoding, in caseclass Network
is solely used and encoding would be default as identity.Theoretical Derivation
Class NetworkWithInputEncoding
Class NetworkWithInputEncoding contains 2 modules, an encoder and a network. An encoder can be set as identity, grid (hash encoding), etc. And the network is a multi-layer perceptron including several hidden layers and activation. And the dataflow of its forward, backward and backward backward pass can be visualized as Figure 1.
Previously, tiny-cuda-nn already supports forward and backward pass of class NetworkWithInputEncoding. For backward backward pass, it consists of 3 stages and involves with 2 modules, the encoder and network. Those 3 stages are 2nd order derivative of the encoder, 2nd order derivative of the network, and 1st order derivative of dL2dmlp_input back to dL2dinput.
For each 2nd order derivative (i.e. backward backward pass), there require 3 input params of current module (encoder/network), input, dL2_ddL1dinput i.e. 2nd order derivative term, and dL1doutput 1st order derivative term, and meanwhile there exist 2 output variables of current module, dL2_ddL1doutput 2nd order derivative term of dL1doutput, and dL2_dinput 2nd order derivative of input.
For the 1st order derivative involved in the 2nd order derivative (backward backward pass), there also require 3 input params of current module, input, output, dL2doutput and 1 output variable, dL2dinput.
And we have implemented backward backward pass of class NetworkWithInputEncoding in the function
backward_backward_input_impl()
innetwork_with_input_encoding.h
.Figure 1. The overall forward, backward and backward backward dataflow of encoding, linear layer (including activation). It's worth noting that there also includes a 1st order derivative when the 2nd order term dL2dinput finished computing, marked as dashed amaranth line.
Class CutlassMLP
Previously, tiny-cuda-nn already supports 2nd order derivative of hash encoding !69. And this pull request mainly focuses on implementing 2nd order derivative of the network module, class CutlassMLP (FullyFusedMLP not supported yet). And the simplified dataflow of the network module is visualized as Figure 2.
Figure 2. The forward, backward and backward backward dataflow of a single linear layer (including activation). It's worth noting that there also includes a 1st order derivative when the 2nd order term dL2dinput finished computing, while the 1st order derivative is not marked in this figure and it should be similar to the dashed amaranth line Figure 1.
The detailed numeric derivation of CutlassMLP 2nd order derivative are shown as below:
Figure 3. The numeric derivation of 2nd order derivative of a network with multi-layer perceptron.
Implementation Details
With the analysis numeric derivation of 2nd order derivative, we designed and implemented the computing workflow and it can be shown as Figure 4.
Figure 4. The workflow of the implemented 2nd order derivative of CutlassMLP.
Visual Comparison
In order to further verify the correctness of implemented 2nd order derivative of the network module, we conduct an experiment comparing the visual quality trained by pytorch and tiny-cuda-nn (TCNN) respectively. The training is based on the open-source method, NeuS and the only difference between pytorch and TCNN version is the definition of SDF Network. And the training results are shown in Figure 4 and Figure 5, in which there is no obvious difference between the pytorch and TCNN results.
Figure 5. Visual comparison of trained results of pytorch and tiny-cuda-nn (TCNN). And the encoder of the SDF Network is set as positional encoding and the network is set as 3 hidden layers with ReLU activation and None activation of the output layer.
Figure 6. Visual comparison of trained results of pytorch and tiny-cuda-nn (TCNN). And the encoder of the SDF Network is set as hash encoding and the network is set as 3 hidden layers with ReLU activation and None activation of the output layer.
Numeric Alignment with Pytorch
To verify the correctness of implemented 2nd order derivative of CutlassMLP and NetworkWithInputEncoding, we implemented a toy test script defining a simple neural SDF with an Eikonal loss supervised. The sample codes are as below:
Figure 7. Comparison between sampled gradients from Pytorch and TCNN defined NeuS. All the numbers are sampled from 1st and 2nd hidden layer of the first training iteration.
Figure 8. Weights distribution comparison between Pytorch and TCNN of the 1st hidden layer.
Figure 9. Weights distribution comparison between Pytorch and TCNN of the 2nd hidden layer.
Figure 10. Weights distribution comparison between Pytorch and TCNN of the 3rd hidden layer.
TODO
More details would be complemented soon: