diff --git a/cuda_rasterizer/auxiliary.h b/cuda_rasterizer/auxiliary.h index 49c7f98..8b31bf5 100644 --- a/cuda_rasterizer/auxiliary.h +++ b/cuda_rasterizer/auxiliary.h @@ -17,8 +17,6 @@ #define BLOCK_SIZE (BLOCK_X * BLOCK_Y) #define NUM_WARPS (BLOCK_SIZE/32) -#define FilterSize 0.7071067811865476 -#define FilterInvSquare 1/(FilterSize*FilterSize) #define TIGHTBBOX 0 #define RENDER_AXUTILITY 1 @@ -27,15 +25,19 @@ #define NORMAL_OFFSET 2 #define MIDDEPTH_OFFSET 5 #define DISTORTION_OFFSET 6 -#define MEDIAN_WEIGHT_OFFSET 7 +// #define MEDIAN_WEIGHT_OFFSET 7 // distortion helper macros #define BACKFACE_CULL 1 #define DUAL_VISIABLE 1 -#define NEAR_PLANE 0.2 -#define FAR_PLANE 100.0 +// #define NEAR_PLANE 0.2 +// #define FAR_PLANE 100.0 #define DETACH_WEIGHT 0 +__device__ const float near_n = 0.2; +__device__ const float far_n = 100.0; +__device__ const float FilterInvSquare = 2.0f; + // Spherical harmonics coefficients __device__ const float SH_C0 = 0.28209479177387814f; __device__ const float SH_C1 = 0.4886025119029199f; @@ -149,13 +151,35 @@ __forceinline__ __device__ float4 dnormvdv(float4 v, float4 dv) return dnormvdv; } -__forceinline__ __device__ float3 crossProduct(float3 a, float3 b) { - float3 result; - result.x = a.y * b.z - a.z * b.y; - result.y = a.z * b.x - a.x * b.z; - result.z = a.x * b.y - a.y * b.x; - return result; -} +__forceinline__ __device__ float3 cross(float3 a, float3 b){return make_float3(a.y*b.z - a.z*b.y, a.z*b.x - a.x*b.z, a.x*b.y - a.y*b.x);} + +__forceinline__ __device__ float3 operator*(float3 a, float3 b){return make_float3(a.x * b.x, a.y * b.y, a.z*b.z);} + +__forceinline__ __device__ float2 operator*(float2 a, float2 b){return make_float2(a.x * b.x, a.y * b.y);} + +__forceinline__ __device__ float3 operator*(float f, float3 a){return make_float3(f * a.x, f * a.y, f * a.z);} + +__forceinline__ __device__ float2 operator*(float f, float2 a){return make_float2(f * a.x, f * a.y);} + +__forceinline__ __device__ float3 operator-(float3 a, float3 b){return make_float3(a.x - b.x, a.y - b.y, a.z - b.z);} + +__forceinline__ __device__ float2 operator-(float2 a, float2 b){return make_float2(a.x - b.x, a.y - b.y);} + +__forceinline__ __device__ float sumf3(float3 a){return a.x + a.y + a.z;} + +__forceinline__ __device__ float sumf2(float2 a){return a.x + a.y;} + +__forceinline__ __device__ float3 sqrtf3(float3 a){return make_float3(sqrtf(a.x), sqrtf(a.y), sqrtf(a.z));} + +__forceinline__ __device__ float2 sqrtf2(float2 a){return make_float2(sqrtf(a.x), sqrtf(a.y));} + +__forceinline__ __device__ float3 minf3(float f, float3 a){return make_float3(min(f, a.x), min(f, a.y), min(f, a.z));} + +__forceinline__ __device__ float2 minf2(float f, float2 a){return make_float2(min(f, a.x), min(f, a.y));} + +__forceinline__ __device__ float3 maxf3(float f, float3 a){return make_float3(max(f, a.x), max(f, a.y), max(f, a.z));} + +__forceinline__ __device__ float2 maxf2(float f, float2 a){return make_float2(max(f, a.x), max(f, a.y));} __forceinline__ __device__ bool in_frustum(int idx, const float* orig_points, @@ -258,11 +282,11 @@ quat_to_rotmat_vjp(const glm::vec4 quat, const glm::mat3 v_R) { inline __device__ glm::mat3 -scale_to_mat(const float3 scale, const float glob_scale) { +scale_to_mat(const glm::vec2 scale, const float glob_scale) { glm::mat3 S = glm::mat3(1.f); S[0][0] = glob_scale * scale.x; S[1][1] = glob_scale * scale.y; - S[2][2] = glob_scale * scale.z; + // S[2][2] = glob_scale * scale.z; return S; } diff --git a/cuda_rasterizer/backward.cu b/cuda_rasterizer/backward.cu index 05f8272..c52ff2d 100644 --- a/cuda_rasterizer/backward.cu +++ b/cuda_rasterizer/backward.cu @@ -219,7 +219,7 @@ renderCUDA( dL_dnormal2D[i] = dL_depths[(NORMAL_OFFSET + i) * H * W + pix_id]; dL_dmedian_depth = dL_depths[MIDDEPTH_OFFSET * H * W + pix_id]; - dL_dmax_dweight = dL_depths[MEDIAN_WEIGHT_OFFSET * H * W + pix_id]; + // dL_dmax_dweight = dL_depths[MEDIAN_WEIGHT_OFFSET * H * W + pix_id]; } // for compute gradient with respect to depth and normal @@ -280,51 +280,42 @@ renderCUDA( continue; // compute ray-splat intersection as before - float3 Tu = collected_Tu[j]; - float3 Tv = collected_Tv[j]; - float3 Tw = collected_Tw[j]; - // compute two planes intersection as the ray intersection - float3 k = {-Tu.x + pixf.x * Tw.x, -Tu.y + pixf.x * Tw.y, -Tu.z + pixf.x * Tw.z}; - float3 l = {-Tv.x + pixf.y * Tw.x, -Tv.y + pixf.y * Tw.y, -Tv.z + pixf.y * Tw.z}; - // cross product of two planes is a line (i.e., homogeneous point), See Eq. (10) - float3 p = crossProduct(k, l); -#if BACKFACE_CULL - // May hanle this by replacing a low pass filter, - // but this case is extremely rare. - if (p.z == 0.0) continue; // there is not intersection -#endif - - float2 s = {p.x / p.z, p.y / p.z}; - // Compute Mahalanobis distance in the canonical splat' space - float rho3d = (s.x * s.x + s.y * s.y); // splat distance - - // Add low pass filter according to Botsch et al. [2005], - // see Eq. (11) from 2DGS paper. - float2 xy = collected_xy[j]; - // 2d screen distance + // Fisrt compute two homogeneous planes, See Eq. (8) + const float2 xy = collected_xy[j]; + const float3 Tu = collected_Tu[j]; + const float3 Tv = collected_Tv[j]; + const float3 Tw = collected_Tw[j]; + float3 k = pix.x * Tw - Tu; + float3 l = pix.y * Tw - Tv; + float3 p = cross(k, l); + if (p.z == 0.0) continue; + float2 s = {p.x / p.z, p.y / p.z}; + float rho3d = (s.x * s.x + s.y * s.y); float2 d = {xy.x - pixf.x, xy.y - pixf.y}; - float rho2d = FilterInvSquare * (d.x * d.x + d.y * d.y); // screen distance - float rho = min(rho3d, rho2d); - - // Compute accurate depth when necessary - float c_d = (rho3d <= rho2d) ? (s.x * Tw.x + s.y * Tw.y) + Tw.z : Tw.z; - if (c_d < NEAR_PLANE) continue; + float rho2d = FilterInvSquare * (d.x * d.x + d.y * d.y); + // compute intersection and depth + float rho = min(rho3d, rho2d); + float c_d = (rho3d <= rho2d) ? (s.x * Tw.x + s.y * Tw.y) + Tw.z : Tw.z; + if (c_d < near_n) continue; float4 nor_o = collected_normal_opacity[j]; float normal[3] = {nor_o.x, nor_o.y, nor_o.z}; + float opa = nor_o.w; + + // accumulations float power = -0.5f * rho; if (power > 0.0f) continue; const float G = exp(power); - const float alpha = min(0.99f, nor_o.w * G); + const float alpha = min(0.99f, opa * G); if (alpha < 1.0f / 255.0f) continue; T = T / (1.f - alpha); const float dchannel_dcolor = alpha * T; - + const float w = alpha * T; // Propagate gradients to per-Gaussian colors and keep // gradients w.r.t. alpha (blending factor for a Gaussian/pixel // pair). @@ -348,11 +339,11 @@ renderCUDA( float dL_dz = 0.0f; float dL_dweight = 0; #if RENDER_AXUTILITY - float m_d = (FAR_PLANE * c_d - FAR_PLANE * NEAR_PLANE) / ((FAR_PLANE - NEAR_PLANE) * c_d); - float dmd_dd = (FAR_PLANE * NEAR_PLANE) / ((FAR_PLANE - NEAR_PLANE) * c_d * c_d); + const float m_d = far_n / (far_n - near_n) * (1 - near_n / c_d); + const float dmd_dd = (far_n * near_n) / ((far_n - near_n) * c_d * c_d); if (contributor == median_contributor-1) { dL_dz += dL_dmedian_depth; - dL_dweight += dL_dmax_dweight; + // dL_dweight += dL_dmax_dweight; } #if DETACH_WEIGHT // if not detached weight, sometimes @@ -364,7 +355,7 @@ renderCUDA( dL_dalpha += dL_dweight - last_dL_dT; // propagate the current weight W_{i} to next weight W_{i-1} last_dL_dT = dL_dweight * alpha + (1 - alpha) * last_dL_dT; - float dL_dmd = 2.0f * (T * alpha) * (m_d * final_A - final_D) * dL_dreg; + const float dL_dmd = 2.0f * (T * alpha) * (m_d * final_A - final_D) * dL_dreg; dL_dz += dL_dmd * dmd_dd; // Propagate gradients w.r.t ray-splat depths @@ -404,20 +395,20 @@ renderCUDA( if (rho3d <= rho2d) { // Update gradients w.r.t. covariance of Gaussian 3x3 (T) - float2 dL_ds = { + const float2 dL_ds = { dL_dG * -G * s.x + dL_dz * Tw.x, dL_dG * -G * s.y + dL_dz * Tw.y }; - float3 dz_dTw = {s.x, s.y, 1.0}; - float dsx_pz = dL_ds.x / p.z; - float dsy_pz = dL_ds.y / p.z; - float3 dL_dp = {dsx_pz, dsy_pz, -(dsx_pz * s.x + dsy_pz * s.y)}; - float3 dL_dk = crossProduct(l, dL_dp); - float3 dL_dl = crossProduct(dL_dp, k); - - float3 dL_dTu = {-dL_dk.x, -dL_dk.y, -dL_dk.z}; - float3 dL_dTv = {-dL_dl.x, -dL_dl.y, -dL_dl.z}; - float3 dL_dTw = { + const float3 dz_dTw = {s.x, s.y, 1.0}; + const float dsx_pz = dL_ds.x / p.z; + const float dsy_pz = dL_ds.y / p.z; + const float3 dL_dp = {dsx_pz, dsy_pz, -(dsx_pz * s.x + dsy_pz * s.y)}; + const float3 dL_dk = cross(l, dL_dp); + const float3 dL_dl = cross(dL_dp, k); + + const float3 dL_dTu = {-dL_dk.x, -dL_dk.y, -dL_dk.z}; + const float3 dL_dTv = {-dL_dl.x, -dL_dl.y, -dL_dl.z}; + const float3 dL_dTw = { pixf.x * dL_dk.x + pixf.y * dL_dl.x + dL_dz * dz_dTw.x, pixf.x * dL_dk.y + pixf.y * dL_dl.y + dL_dz * dz_dTw.y, pixf.x * dL_dk.z + pixf.y * dL_dl.z + dL_dz * dz_dTw.z}; @@ -435,8 +426,8 @@ renderCUDA( atomicAdd(&dL_dtransMat[global_id * 9 + 8], dL_dTw.z); } else { // // Update gradients w.r.t. center of Gaussian 2D mean position - float dG_ddelx = -G * FilterInvSquare * d.x; - float dG_ddely = -G * FilterInvSquare * d.y; + const float dG_ddelx = -G * FilterInvSquare * d.x; + const float dG_ddely = -G * FilterInvSquare * d.y; atomicAdd(&dL_dmean2D[global_id].x, dL_dG * dG_ddelx); // not scaled atomicAdd(&dL_dmean2D[global_id].y, dL_dG * dG_ddely); // not scaled atomicAdd(&dL_dtransMat[global_id * 9 + 8], dL_dz); // propagate depth loss @@ -448,75 +439,148 @@ renderCUDA( } } -inline __device__ void computeTransMat( - const glm::vec3 & p_world, - const glm::vec4 & quat, - const glm::vec2 & scale, - const float* viewmat, - const float* projmat, - const int W, - const int H, - const float* transMat, - const float* dL_dtransMat, - const float* dL_dnormal3D, - glm::vec3 & dL_dmean3D, - glm::vec2 & dL_dscale, - glm::vec4 & dL_drot -) { - glm::mat4 world2ndc = glm::mat4( - projmat[0], projmat[4], projmat[8], projmat[12], - projmat[1], projmat[5], projmat[9], projmat[13], - projmat[2], projmat[6], projmat[10], projmat[14], - projmat[3], projmat[7], projmat[11], projmat[15] - ); - - glm::mat3x4 ndc2pix = glm::mat3x4( - glm::vec4(float(W) / 2.0, 0.0, 0.0, float(W-1) / 2.0), - glm::vec4(0.0, float(H) / 2.0, 0.0, float(H-1) / 2.0), - glm::vec4(0.0, 0.0, 0.0, 1.0) - ); - - glm::mat3x4 P = world2ndc * ndc2pix; - glm::mat3 dL_dT = glm::mat3( - dL_dtransMat[0], dL_dtransMat[1], dL_dtransMat[2], - dL_dtransMat[3], dL_dtransMat[4], dL_dtransMat[5], - dL_dtransMat[6], dL_dtransMat[7], dL_dtransMat[8] - ); +__device__ void compute_transmat_aabb( + int idx, + const float* Ts_precomp, + const float3* p_origs, + const glm::vec2* scales, + const glm::vec4* rots, + const float* projmatrix, + const float* viewmatrix, + const int W, const int H, + const float3* dL_dnormals, + const float3* dL_dmean2Ds, + float* dL_dTs, + glm::vec3* dL_dmeans, + glm::vec2* dL_dscales, + glm::vec4* dL_drots) +{ + glm::mat3 T; + float3 normal; + glm::mat3x4 P; + glm::mat3 R; + glm::mat3 S; + float3 p_orig; + glm::vec4 rot; + glm::vec2 scale; + + // Get transformation matrix of the Gaussian + if (Ts_precomp != nullptr) { + T = glm::mat3( + Ts_precomp[idx * 9 + 0], Ts_precomp[idx * 9 + 1], Ts_precomp[idx * 9 + 2], + Ts_precomp[idx * 9 + 3], Ts_precomp[idx * 9 + 4], Ts_precomp[idx * 9 + 5], + Ts_precomp[idx * 9 + 6], Ts_precomp[idx * 9 + 7], Ts_precomp[idx * 9 + 8] + ); + normal = {0.0, 0.0, 0.0}; + } else { + p_orig = p_origs[idx]; + rot = rots[idx]; + scale = scales[idx]; + R = quat_to_rotmat(rot); + S = scale_to_mat(scale, 1.0f); + + glm::mat3 L = R * S; + glm::mat3x4 M = glm::mat3x4( + glm::vec4(L[0], 0.0), + glm::vec4(L[1], 0.0), + glm::vec4(p_orig.x, p_orig.y, p_orig.z, 1) + ); - glm::mat3x4 dL_dsplat = P * glm::transpose(dL_dT); + glm::mat4 world2ndc = glm::mat4( + projmatrix[0], projmatrix[4], projmatrix[8], projmatrix[12], + projmatrix[1], projmatrix[5], projmatrix[9], projmatrix[13], + projmatrix[2], projmatrix[6], projmatrix[10], projmatrix[14], + projmatrix[3], projmatrix[7], projmatrix[11], projmatrix[15] + ); - const glm::mat3 R = quat_to_rotmat(quat); + glm::mat3x4 ndc2pix = glm::mat3x4( + glm::vec4(float(W) / 2.0, 0.0, 0.0, float(W-1) / 2.0), + glm::vec4(0.0, float(H) / 2.0, 0.0, float(H-1) / 2.0), + glm::vec4(0.0, 0.0, 0.0, 1.0) + ); - float multiplier = 1; + P = world2ndc * ndc2pix; + T = glm::transpose(M) * P; + normal = transformVec4x3({L[2].x, L[2].y, L[2].z}, viewmatrix); + } -#if DUAL_VISIABLE - float3 normal = transformVec4x3({R[2].x, R[2].y, R[2].z}, viewmat); - multiplier = normal.z < 0 ? 1: -1; + // Update gradients w.r.t. transformation matrix of the Gaussian + glm::mat3 dL_dT = glm::mat3( + dL_dTs[idx*9+0], dL_dTs[idx*9+1], dL_dTs[idx*9+2], + dL_dTs[idx*9+3], dL_dTs[idx*9+4], dL_dTs[idx*9+5], + dL_dTs[idx*9+6], dL_dTs[idx*9+7], dL_dTs[idx*9+8] + ); + float3 dL_dmean2D = dL_dmean2Ds[idx]; + if(dL_dmean2D.x != 0 || dL_dmean2D.y != 0) + { + const float distance = T[2].x * T[2].x + T[2].y * T[2].y - T[2].z * T[2].z; + const float f = 1 / (distance); + const float dpx_dT00 = f * T[2].x; + const float dpx_dT01 = f * T[2].y; + const float dpx_dT02 = -f * T[2].z; + const float dpy_dT10 = f * T[2].x; + const float dpy_dT11 = f * T[2].y; + const float dpy_dT12 = -f * T[2].z; + const float dpx_dT30 = T[0].x * (f - 2 * f * f * T[2].x * T[2].x); + const float dpx_dT31 = T[0].y * (f - 2 * f * f * T[2].y * T[2].y); + const float dpx_dT32 = -T[0].z * (f + 2 * f * f * T[2].z * T[2].z); + const float dpy_dT30 = T[1].x * (f - 2 * f * f * T[2].x * T[2].x); + const float dpy_dT31 = T[1].y * (f - 2 * f * f * T[2].y * T[2].y); + const float dpy_dT32 = -T[1].z * (f + 2 * f * f * T[2].z * T[2].z); + + dL_dT[0].x += dL_dmean2D.x * dpx_dT00; + dL_dT[0].y += dL_dmean2D.x * dpx_dT01; + dL_dT[0].z += dL_dmean2D.x * dpx_dT02; + dL_dT[1].x += dL_dmean2D.y * dpy_dT10; + dL_dT[1].y += dL_dmean2D.y * dpy_dT11; + dL_dT[1].z += dL_dmean2D.y * dpy_dT12; + dL_dT[2].x += dL_dmean2D.x * dpx_dT30 + dL_dmean2D.y * dpy_dT30; + dL_dT[2].y += dL_dmean2D.x * dpx_dT31 + dL_dmean2D.y * dpy_dT31; + dL_dT[2].z += dL_dmean2D.x * dpx_dT32 + dL_dmean2D.y * dpy_dT32; + + if (Ts_precomp != nullptr) { + dL_dTs[idx * 9 + 0] = dL_dT[0].x; + dL_dTs[idx * 9 + 1] = dL_dT[0].y; + dL_dTs[idx * 9 + 2] = dL_dT[0].z; + dL_dTs[idx * 9 + 3] = dL_dT[1].x; + dL_dTs[idx * 9 + 4] = dL_dT[1].y; + dL_dTs[idx * 9 + 5] = dL_dT[1].z; + dL_dTs[idx * 9 + 6] = dL_dT[2].x; + dL_dTs[idx * 9 + 7] = dL_dT[2].y; + dL_dTs[idx * 9 + 8] = dL_dT[2].z; + return; + } + } + + if (Ts_precomp != nullptr) return; + + // Update gradients w.r.t. scaling, rotation, position of the Gaussian + glm::mat3x4 dL_dM = P * glm::transpose(dL_dT); + float3 dL_dtn = transformVec4x3Transpose(dL_dnormals[idx], viewmatrix); +#if DUAL_VISIABLE + float multiplier = normal.z < 0 ? 1: -1; + dL_dtn = multiplier * dL_dtn; #endif - - float3 dL_dtn = transformVec4x3Transpose({dL_dnormal3D[0],dL_dnormal3D[1],dL_dnormal3D[2]}, viewmat); glm::mat3 dL_dRS = glm::mat3( - glm::vec3(dL_dsplat[0]), - glm::vec3(dL_dsplat[1]), - multiplier * glm::vec3(dL_dtn.x, dL_dtn.y, dL_dtn.z) + glm::vec3(dL_dM[0]), + glm::vec3(dL_dM[1]), + glm::vec3(dL_dtn.x, dL_dtn.y, dL_dtn.z) ); - // propagate to scale and quat, mean glm::mat3 dL_dR = glm::mat3( dL_dRS[0] * glm::vec3(scale.x), dL_dRS[1] * glm::vec3(scale.y), dL_dRS[2]); - - dL_dmean3D = glm::vec3(dL_dsplat[2]); - dL_drot = quat_to_rotmat_vjp(quat, dL_dR); - dL_dscale = glm::vec2( + + dL_drots[idx] = quat_to_rotmat_vjp(rot, dL_dR); + dL_dscales[idx] = glm::vec2( (float)glm::dot(dL_dRS[0], R[0]), - (float)glm::dot(dL_dRS[1], R[1])); + (float)glm::dot(dL_dRS[1], R[1]) + ); + dL_dmeans[idx] = glm::vec3(dL_dM[2]); } - - template __global__ void preprocessCUDA( int P, int D, int M, @@ -536,11 +600,11 @@ __global__ void preprocessCUDA( const float tan_fovy, const glm::vec3* campos, // grad input - const float* dL_dtransMats, + float* dL_dtransMats, const float* dL_dnormal3Ds, float* dL_dcolors, float* dL_dshs, - // grad output + float3* dL_dmean2Ds, glm::vec3* dL_dmean3Ds, glm::vec2* dL_dscales, glm::vec4* dL_drots) @@ -549,93 +613,29 @@ __global__ void preprocessCUDA( if (idx >= P || !(radii[idx] > 0)) return; - if (scales) { - const float* transMat = &(transMats[9 * idx]); - const float* dL_dtransMat = &(dL_dtransMats[9 * idx]); - const float* dL_dnormal3D = &(dL_dnormal3Ds[3 * idx]); - - glm::vec3 p_world = glm::vec3(means3D[idx].x, means3D[idx].y, means3D[idx].z); - - const int W = int(focal_x * tan_fovx * 2); - const int H = int(focal_y * tan_fovy * 2); - glm::vec3 dL_dmean3D; - glm::vec2 dL_dscale; - glm::vec4 dL_drot; - computeTransMat( - p_world, - rotations[idx], - scales[idx], - viewmatrix, - projmatrix, - W, - H, - transMat, - dL_dtransMat, - dL_dnormal3D, - dL_dmean3D, - dL_dscale, - dL_drot - ); - // update - dL_dmean3Ds[idx] = dL_dmean3D; - dL_dscales[idx] = dL_dscale; - dL_drots[idx] = dL_drot; - } + const int W = int(focal_x * tan_fovx * 2); + const int H = int(focal_y * tan_fovy * 2); + const float * Ts_precomp = (scales) ? nullptr : transMats; + compute_transmat_aabb( + idx, + Ts_precomp, + means3D, scales, rotations, + projmatrix, viewmatrix, W, H, + (float3*)dL_dnormal3Ds, + dL_dmean2Ds, + (dL_dtransMats), + dL_dmean3Ds, + dL_dscales, + dL_drots + ); if (shs) computeColorFromSH(idx, D, M, (glm::vec3*)means3D, *campos, shs, clamped, (glm::vec3*)dL_dcolors, (glm::vec3*)dL_dmean3Ds, (glm::vec3*)dL_dshs); -} - -__global__ void computeAABB(int P, - const int * radii, - const float W, const float H, - const float * transMats, - float3 * dL_dmean2Ds, - float *dL_dtransMats) { - auto idx = cg::this_grid().thread_rank(); - if (idx >= P || !(radii[idx] > 0)) - return; - - const float* transMat = transMats + 9 * idx; - - const float3 dL_dmean2D = dL_dmean2Ds[idx]; - glm::mat4x3 T = glm::mat4x3( - transMat[0], transMat[1], transMat[2], - transMat[3], transMat[4], transMat[5], - transMat[6], transMat[7], transMat[8], - transMat[6], transMat[7], transMat[8] - ); - - float d = glm::dot(glm::vec3(1.0, 1.0, -1.0), T[3] * T[3]); - glm::vec3 f = glm::vec3(1.0, 1.0, -1.0) * (1.0f / d); - - glm::vec3 p = glm::vec3( - glm::dot(f, T[0] * T[3]), - glm::dot(f, T[1] * T[3]), - glm::dot(f, T[2] * T[3])); - - glm::vec3 dL_dT0 = dL_dmean2D.x * f * T[3]; - glm::vec3 dL_dT1 = dL_dmean2D.y * f * T[3]; - glm::vec3 dL_dT3 = dL_dmean2D.x * f * T[0] + dL_dmean2D.y * f * T[1]; - glm::vec3 dL_df = (dL_dmean2D.x * T[0] * T[3]) + (dL_dmean2D.y * T[1] * T[3]); - float dL_dd = glm::dot(dL_df, f) * (-1.0 / d); - glm::vec3 dd_dT3 = glm::vec3(1.0, 1.0, -1.0) * T[3] * 2.0f; - dL_dT3 += dL_dd * dd_dT3; - dL_dtransMats[9 * idx + 0] += dL_dT0.x; - dL_dtransMats[9 * idx + 1] += dL_dT0.y; - dL_dtransMats[9 * idx + 2] += dL_dT0.z; - dL_dtransMats[9 * idx + 3] += dL_dT1.x; - dL_dtransMats[9 * idx + 4] += dL_dT1.y; - dL_dtransMats[9 * idx + 5] += dL_dT1.z; - dL_dtransMats[9 * idx + 6] += dL_dT3.x; - dL_dtransMats[9 * idx + 7] += dL_dT3.y; - dL_dtransMats[9 * idx + 8] += dL_dT3.z; - - // just use to hack the projected 2D gradient here. - float z = transMat[8]; - dL_dmean2Ds[idx].x = dL_dtransMats[9 * idx + 2] * z * W; // to ndc - dL_dmean2Ds[idx].y = dL_dtransMats[9 * idx + 5] * z * H; // to ndc + // hack the gradient here for densitification + float depth = transMats[idx * 9 + 8]; + dL_dmean2Ds[idx].x = dL_dtransMats[idx * 9 + 2] * depth * 0.5 * float(W); // to ndc + dL_dmean2Ds[idx].y = dL_dtransMats[idx * 9 + 5] * depth * 0.5 * float(H); // to ndc } @@ -662,25 +662,7 @@ void BACKWARD::preprocess( glm::vec3* dL_dmean3Ds, glm::vec2* dL_dscales, glm::vec4* dL_drots) -{ - // Propagate gradients for the path of 2D conic matrix computation. - // Somewhat long, thus it is its own kernel rather than being part of - // "preprocess". When done, loss gradient w.r.t. 3D means has been - // modified and gradient w.r.t. 3D covariance matrix has been computed. - // propagate gradients to transMat - - // we do not use the center actually - float W = focal_x * tan_fovx; - float H = focal_y * tan_fovy; - computeAABB << <(P + 255) / 256, 256 >> >( - P, - radii, - W, H, - transMats, - dL_dmean2Ds, - dL_dtransMats); - - // propagate gradients from transMat to mean3d, scale, rot, sh, color +{ preprocessCUDA<< <(P + 255) / 256, 256 >> > ( P, D, M, (float3*)means3D, @@ -702,6 +684,7 @@ void BACKWARD::preprocess( dL_dnormal3Ds, dL_dcolors, dL_dshs, + dL_dmean2Ds, dL_dmean3Ds, dL_dscales, dL_drots diff --git a/cuda_rasterizer/forward.cu b/cuda_rasterizer/forward.cu index 19605f9..59f9c89 100644 --- a/cuda_rasterizer/forward.cu +++ b/cuda_rasterizer/forward.cu @@ -72,14 +72,34 @@ __device__ glm::vec3 computeColorFromSH(int idx, int deg, int max_coeffs, const // Compute a 2D-to-2D mapping matrix from a tangent plane into a image plane // given a 2D gaussian parameters. -__device__ void computeTransMat(const glm::vec3 &p_world, const glm::vec4 &quat, const glm::vec2 &scale, const float *viewmat, const float*projmat, const int W, const int H, float* transMat, float3 &normal) { - // setup camera - // can be fatored out to reduce computations +__device__ void compute_transmat( + const float3& p_orig, + const glm::vec2 scale, + const glm::vec4 rot, + const float* projmatrix, + const float* viewmatrix, + const int W, + const int H, + glm::mat3 &T, + float3 &normal +) { + + glm::mat3 R = quat_to_rotmat(rot); + glm::mat3 S = scale_to_mat(scale, 1.0f); + glm::mat3 L = R * S; + + // center of Gaussians in the camera coordinate + glm::mat3x4 splat2world = glm::mat3x4( + glm::vec4(L[0], 0.0), + glm::vec4(L[1], 0.0), + glm::vec4(p_orig.x, p_orig.y, p_orig.z, 1) + ); + glm::mat4 world2ndc = glm::mat4( - projmat[0], projmat[4], projmat[8], projmat[12], - projmat[1], projmat[5], projmat[9], projmat[13], - projmat[2], projmat[6], projmat[10], projmat[14], - projmat[3], projmat[7], projmat[11], projmat[15] + projmatrix[0], projmatrix[4], projmatrix[8], projmatrix[12], + projmatrix[1], projmatrix[5], projmatrix[9], projmatrix[13], + projmatrix[2], projmatrix[6], projmatrix[10], projmatrix[14], + projmatrix[3], projmatrix[7], projmatrix[11], projmatrix[15] ); glm::mat3x4 ndc2pix = glm::mat3x4( @@ -88,70 +108,43 @@ __device__ void computeTransMat(const glm::vec3 &p_world, const glm::vec4 &quat, glm::vec4(0.0, 0.0, 0.0, 1.0) ); - glm::mat3x4 P = world2ndc * ndc2pix; - // Make the geometry of 2D Gaussian as a Homogeneous transformation matrix - // under the camera view, See Eq. (5) in 2DGS' paper. - glm::mat3 RS = quat_to_rotmat(quat) * scale_to_mat({scale.x, scale.y, 1.0f}, 1.0f); - glm::mat3x4 splat2world = glm::mat3x4( - glm::vec4(RS[0], 0.0), - glm::vec4(RS[1], 0.0), - glm::vec4(p_world, 1.0) - ); - // projection into screen space, see Eq. (7) in 2DGS - glm::mat3 T = glm::transpose(splat2world) * P; - - transMat[0] = T[0].x; - transMat[1] = T[0].y; - transMat[2] = T[0].z; - transMat[3] = T[1].x; - transMat[4] = T[1].y; - transMat[5] = T[1].z; - transMat[6] = T[2].x; - transMat[7] = T[2].y; - transMat[8] = T[2].z; - - normal = transformVec4x3({RS[2].x, RS[2].y, RS[2].z}, viewmat); + T = glm::transpose(splat2world) * world2ndc * ndc2pix; + normal = transformVec4x3({L[2].x, L[2].y, L[2].z}, viewmatrix); #if DUAL_VISIABLE - // This means a 2D Gaussian is dual visiable. - // Experimentally, turning off the dual visiable works eqully. float multiplier = normal.z < 0 ? 1: -1; - normal = {multiplier * normal.x, multiplier * normal.y, multiplier * normal.z}; + normal = multiplier * normal; #endif } -// Computing the bounding box of the 2D Gaussian and its center, -// where the center of the bounding box is used to create a low pass filter -// in the image plane -__device__ bool computeAABB(const float *transMat, float2 & center, float2 & extent) { - glm::mat4x3 T = glm::mat4x3( - transMat[0], transMat[1], transMat[2], - transMat[3], transMat[4], transMat[5], - transMat[6], transMat[7], transMat[8], - transMat[6], transMat[7], transMat[8] - ); - - float d = glm::dot(glm::vec3(1.0, 1.0, -1.0), T[3] * T[3]); - - if (d == 0.0f) return false; - - glm::vec3 f = glm::vec3(1.0, 1.0, -1.0) * (1.0f / d); - - glm::vec3 p = glm::vec3( - glm::dot(f, T[0] * T[3]), - glm::dot(f, T[1] * T[3]), - glm::dot(f, T[2] * T[3])); +// Computing the bounding box of the 2D Gaussian and its center +// The center of the bounding box is used to create a low pass filter +__device__ bool compute_aabb( + glm::mat3 T, + float2& point_image, + float2 & extent +) { + float3 T0 = {T[0][0], T[0][1], T[0][2]}; + float3 T1 = {T[1][0], T[1][1], T[1][2]}; + float3 T3 = {T[2][0], T[2][1], T[2][2]}; + + // Compute AABB + float3 temp_point = {1.0f, 1.0f, -1.0f}; + float distance = sumf3(T3 * T3 * temp_point); + float3 f = (1 / distance) * temp_point; + if (distance == 0.0) return false; + + point_image = { + sumf3(f * T0 * T3), + sumf3(f * T1 * T3) + }; - glm::vec3 h0 = p * p - - glm::vec3( - glm::dot(f, T[0] * T[0]), - glm::dot(f, T[1] * T[1]), - glm::dot(f, T[2] * T[2]) - ); - - glm::vec3 h = sqrt(max(glm::vec3(0.0), h0)) + glm::vec3(0.0, 0.0, 1e-2); - center = {p.x, p.y}; - extent = {h.x, h.y}; + float2 temp = { + sumf3(f * T0 * T0), + sumf3(f * T1 * T1) + }; + float2 half_extend = point_image * point_image - temp; + extent = sqrtf2(maxf2(1e-4, half_extend)); return true; } @@ -192,46 +185,47 @@ __global__ void preprocessCUDA(int P, int D, int M, radii[idx] = 0; tiles_touched[idx] = 0; - glm::vec3 p_world = glm::vec3(orig_points[3 * idx], orig_points[3 * idx + 1], orig_points[3 * idx + 2]); // Perform near culling, quit if outside. float3 p_view; if (!in_frustum(idx, orig_points, viewmatrix, projmatrix, prefiltered, p_view)) return; - const float* transMat; + // Compute transformation matrix + glm::mat3 T; float3 normal; - if (transMat_precomp != nullptr) + if (transMat_precomp == nullptr) { - transMat = transMat_precomp + idx * 9; - normal = {0.f, 0.f, 0.f}; // not support precomp normal + compute_transmat(((float3*)orig_points)[idx], scales[idx], rotations[idx], projmatrix, viewmatrix, W, H, T, normal); + float3 *T_ptr = (float3*)transMats; + T_ptr[idx * 3 + 0] = {T[0][0], T[0][1], T[0][2]}; + T_ptr[idx * 3 + 1] = {T[1][0], T[1][1], T[1][2]}; + T_ptr[idx * 3 + 2] = {T[2][0], T[2][1], T[2][2]}; + } else { + glm::vec3 *T_ptr = (glm::vec3*)transMat_precomp; + T = glm::mat3( + T_ptr[idx * 3 + 0], + T_ptr[idx * 3 + 1], + T_ptr[idx * 3 + 2] + ); + normal = make_float3(0.0, 0.0, 1.0); } - else + + // Compute center and radius + float2 point_image; + float radius; { - computeTransMat(p_world, rotations[idx], scales[idx], viewmatrix, projmatrix, W, H, transMats + idx * 9, normal); - transMat = transMats + idx * 9; + float2 extent; + bool ok = compute_aabb(T, point_image, extent); + if (!ok) return; + radius = 3.0f * ceil(max(extent.x, extent.y)); } - - // compute center and extent - float2 center; - float2 extent; - bool ok = computeAABB(transMat, center, extent); - if (!ok) return; - - // add the bounding of countour -#if TIGHTBBOX // no use in the paper, but it indeed help speeds. - // the effective extent is now depended on the opacity of gaussian. - float truncated_R = sqrtf(max(9.f + 2.f * logf(opacities[idx]), 0.000001)); -#else - float truncated_R = 3.f; -#endif - float radius = ceil(truncated_R * max(max(extent.x, extent.y), FilterSize)); uint2 rect_min, rect_max; - getRect(center, radius, rect_min, rect_max, grid); + getRect(point_image, radius, rect_min, rect_max, grid); if ((rect_max.x - rect_min.x) * (rect_max.y - rect_min.y) == 0) return; - // compute colors + // Compute colors if (colors_precomp == nullptr) { glm::vec3 result = computeColorFromSH(idx, D, M, (glm::vec3*)orig_points, *cam_pos, shs, clamped); rgb[idx * C + 0] = result.x; @@ -241,8 +235,7 @@ __global__ void preprocessCUDA(int P, int D, int M, depths[idx] = p_view.z; radii[idx] = (int)radius; - points_xy_image[idx] = center; - // store them in float4 + points_xy_image[idx] = point_image; normal_opacity[idx] = {normal.x, normal.y, normal.z, opacities[idx]}; tiles_touched[idx] = (rect_max.y - rect_min.y) * (rect_max.x - rect_min.x); } @@ -304,13 +297,13 @@ renderCUDA( #if RENDER_AXUTILITY // render axutility ouput - float D = { 0 }; float N[3] = {0}; - float dist1 = {0}; - float dist2 = {0}; + float D = { 0 }; + float M1 = {0}; + float M2 = {0}; float distortion = {0}; float median_depth = {0}; - float median_weight = {0}; + // float median_weight = {0}; float median_contributor = {-1}; #endif @@ -344,37 +337,28 @@ renderCUDA( contributor++; // Fisrt compute two homogeneous planes, See Eq. (8) - float3 Tu = collected_Tu[j]; - float3 Tv = collected_Tv[j]; - float3 Tw = collected_Tw[j]; - float3 k = {-Tu.x + pixf.x * Tw.x, -Tu.y + pixf.x * Tw.y, -Tu.z + pixf.x * Tw.z}; - float3 l = {-Tv.x + pixf.y * Tw.x, -Tv.y + pixf.y * Tw.y, -Tv.z + pixf.y * Tw.z}; - // cross product of two planes is a line (i.e., homogeneous point), See Eq. (10) - float3 p = crossProduct(k, l); -#if BACKFACE_CULL - // May hanle this by replacing a low pass filter, - // but this case is extremely rare. - if (p.z == 0.0) continue; // there is not intersection -#endif - // 3d homogeneous point to 2d point on the splat + const float2 xy = collected_xy[j]; + const float3 Tu = collected_Tu[j]; + const float3 Tv = collected_Tv[j]; + const float3 Tw = collected_Tw[j]; + float3 k = pix.x * Tw - Tu; + float3 l = pix.y * Tw - Tv; + float3 p = cross(k, l); + if (p.z == 0.0) continue; float2 s = {p.x / p.z, p.y / p.z}; - // 3d distance. Compute Mahalanobis distance in the canonical splat' space float rho3d = (s.x * s.x + s.y * s.y); - - // Add low pass filter according to Botsch et al. [2005], - // see Eq. (11) from 2DGS paper. - float2 xy = collected_xy[j]; float2 d = {xy.x - pixf.x, xy.y - pixf.y}; - // 2d screen distance float rho2d = FilterInvSquare * (d.x * d.x + d.y * d.y); + + // compute intersection and depth float rho = min(rho3d, rho2d); - - float depth = (rho3d <= rho2d) ? (s.x * Tw.x + s.y * Tw.y) + Tw.z : Tw.z; // splat depth - if (depth < NEAR_PLANE) continue; + float depth = (rho3d <= rho2d) ? (s.x * Tw.x + s.y * Tw.y) + Tw.z : Tw.z; + if (depth < near_n) continue; float4 nor_o = collected_normal_opacity[j]; float normal[3] = {nor_o.x, nor_o.y, nor_o.z}; + float opa = nor_o.w; + float power = -0.5f * rho; - // power = -0.5f * 100.f * max(rho - 1, 0.0f); if (power > 0.0f) continue; @@ -382,7 +366,7 @@ renderCUDA( // Obtain alpha by multiplying with Gaussian opacity // and its exponential falloff from mean. // Avoid numerical instabilities (see paper appendix). - float alpha = min(0.99f, nor_o.w * exp(power)); + float alpha = min(0.99f, opa * exp(power)); if (alpha < 1.0f / 255.0f) continue; float test_T = T * (1 - alpha); @@ -392,33 +376,29 @@ renderCUDA( continue; } - + float w = alpha * T; #if RENDER_AXUTILITY // Render depth distortion map // Efficient implementation of distortion loss, see 2DGS' paper appendix. float A = 1-T; - float mapped_depth = (FAR_PLANE * depth - FAR_PLANE * NEAR_PLANE) / ((FAR_PLANE - NEAR_PLANE) * depth); - float error = mapped_depth * mapped_depth * A + dist2 - 2 * mapped_depth * dist1; - distortion += error * alpha * T; + float m = far_n / (far_n - near_n) * (1 - near_n / depth); + distortion += (m * m * A + M2 - 2 * m * M1) * w; + D += depth * w; + M1 += m * w; + M2 += m * m * w; if (T > 0.5) { median_depth = depth; - median_weight = alpha * T; + // median_weight = w; median_contributor = contributor; } // Render normal map - for (int ch=0; ch<3; ch++) N[ch] += normal[ch] * alpha * T; - - // Render depth map - D += depth * alpha * T; - // Efficient implementation of distortion loss, see 2DGS' paper appendix. - dist1 += mapped_depth * alpha * T; - dist2 += mapped_depth * mapped_depth * alpha * T; + for (int ch=0; ch<3; ch++) N[ch] += normal[ch] * w; #endif // Eq. (3) from 3D Gaussian splatting paper. for (int ch = 0; ch < CHANNELS; ch++) - C[ch] += features[collected_id[j] * CHANNELS + ch] * alpha * T; + C[ch] += features[collected_id[j] * CHANNELS + ch] * w; T = test_T; // Keep track of last range entry to update this @@ -438,14 +418,14 @@ renderCUDA( #if RENDER_AXUTILITY n_contrib[pix_id + H * W] = median_contributor; - final_T[pix_id + H * W] = dist1; - final_T[pix_id + 2 * H * W] = dist2; + final_T[pix_id + H * W] = M1; + final_T[pix_id + 2 * H * W] = M2; out_others[pix_id + DEPTH_OFFSET * H * W] = D; out_others[pix_id + ALPHA_OFFSET * H * W] = 1 - T; for (int ch=0; ch<3; ch++) out_others[pix_id + (NORMAL_OFFSET+ch) * H * W] = N[ch]; out_others[pix_id + MIDDEPTH_OFFSET * H * W] = median_depth; out_others[pix_id + DISTORTION_OFFSET * H * W] = distortion; - out_others[pix_id + MEDIAN_WEIGHT_OFFSET * H * W] = median_weight; + // out_others[pix_id + MEDIAN_WEIGHT_OFFSET * H * W] = median_weight; #endif } } diff --git a/rasterize_points.cu b/rasterize_points.cu index c1426e8..321d5dd 100644 --- a/rasterize_points.cu +++ b/rasterize_points.cu @@ -83,7 +83,7 @@ RasterizeGaussiansCUDA( auto float_opts = means3D.options().dtype(torch::kFloat32); torch::Tensor out_color = torch::full({NUM_CHANNELS, H, W}, 0.0, float_opts); - torch::Tensor out_others = torch::full({3+3+2, H, W}, 0.0, float_opts); + torch::Tensor out_others = torch::full({3+3+1, H, W}, 0.0, float_opts); torch::Tensor radii = torch::full({P}, 0, means3D.options().dtype(torch::kInt32)); torch::Device device(torch::kCUDA);