From 9e9838b8482b70054a58f4bd842d74357ea4b966 Mon Sep 17 00:00:00 2001 From: Martin Valigursky <59932779+mvaligursky@users.noreply.github.com> Date: Fri, 8 Mar 2024 14:20:16 +0000 Subject: [PATCH] Use more compact texture formats storing GSplat data (#6129) Co-authored-by: Martin Valigursky <mvaligursky@snapchat.com> --- src/framework/parsers/gsplat-resource.js | 8 +- src/scene/gsplat/gsplat.js | 290 ++++++++++---------- src/scene/gsplat/shader-generator-gsplat.js | 57 ++-- 3 files changed, 180 insertions(+), 175 deletions(-) diff --git a/src/framework/parsers/gsplat-resource.js b/src/framework/parsers/gsplat-resource.js index 9351b056209..841665c336c 100644 --- a/src/framework/parsers/gsplat-resource.js +++ b/src/framework/parsers/gsplat-resource.js @@ -57,9 +57,11 @@ class GSplatResource { // texture data splat.updateColorData(splatData.getProp('f_dc_0'), splatData.getProp('f_dc_1'), splatData.getProp('f_dc_2'), splatData.getProp('opacity')); - splat.updateCenterData(splatData.getProp('x'), splatData.getProp('y'), splatData.getProp('z')); - splat.updateCovData(splatData.getProp('rot_0'), splatData.getProp('rot_1'), splatData.getProp('rot_2'), splatData.getProp('rot_3'), - splatData.getProp('scale_0'), splatData.getProp('scale_1'), splatData.getProp('scale_2')); + splat.updateTransformData( + splatData.getProp('x'), splatData.getProp('y'), splatData.getProp('z'), + splatData.getProp('rot_0'), splatData.getProp('rot_1'), splatData.getProp('rot_2'), splatData.getProp('rot_3'), + splatData.getProp('scale_0'), splatData.getProp('scale_1'), splatData.getProp('scale_2') + ); // centers - constant buffer that is sent to the worker const x = splatData.getProp('x'); diff --git a/src/scene/gsplat/gsplat.js b/src/scene/gsplat/gsplat.js index ee134b38500..4ad4a46ee9b 100644 --- a/src/scene/gsplat/gsplat.js +++ b/src/scene/gsplat/gsplat.js @@ -4,7 +4,7 @@ import { Quat } from '../../core/math/quat.js'; import { Vec2 } from '../../core/math/vec2.js'; import { Mat3 } from '../../core/math/mat3.js'; import { - ADDRESS_CLAMP_TO_EDGE, FILTER_NEAREST, PIXELFORMAT_RGB32F, PIXELFORMAT_RGBA16F, PIXELFORMAT_RGBA32F, + ADDRESS_CLAMP_TO_EDGE, FILTER_NEAREST, PIXELFORMAT_R16F, PIXELFORMAT_R32F, PIXELFORMAT_RGBA16F, PIXELFORMAT_RGBA32F, PIXELFORMAT_RGBA8, SEMANTIC_ATTR13, TYPE_FLOAT32, TYPE_UINT32 } from '../../platform/graphics/constants.js'; import { Texture } from '../../platform/graphics/texture.js'; @@ -20,35 +20,34 @@ const _m2 = new Vec3(); const _s = new Vec3(); const _r = new Vec3(); -/** - * @typedef {object} SplatTextureFormat - * @property {number} format - The pixel format of the texture. - * @property {number} numComponents - The number of components in the texture format. - * @property {boolean} isHalf - Indicates if the format uses half-precision floats. - */ - /** @ignore */ class GSplat { device; numSplats; + /** @type {VertexFormat} */ vertexFormat; - /** @type {SplatTextureFormat} */ - format; + /** + * True if half format should be used, false is float format should be used or undefined if none + * are available. + * + * @type {boolean|undefined} + */ + halfFormat; /** @type {Texture} */ colorTexture; /** @type {Texture} */ - covATexture; + transformATexture; /** @type {Texture} */ - covBTexture; + transformBTexture; /** @type {Texture} */ - centerTexture; + transformCTexture; /** @type {Float32Array} */ centers; @@ -70,20 +69,23 @@ class GSplat { { semantic: SEMANTIC_ATTR13, components: 1, type: device.isWebGL1 ? TYPE_FLOAT32 : TYPE_UINT32, asInt: !device.isWebGL1 } ]); - // create data textures - const size = this.evalTextureSize(numSplats); - this.format = this.getTextureFormat(device, false); - this.colorTexture = this.createTexture(device, 'splatColor', PIXELFORMAT_RGBA8, size); - this.centerTexture = this.createTexture(device, 'splatCenter', this.format.format, size); - this.covATexture = this.createTexture(device, 'splatCovA', this.format.format, size); - this.covBTexture = this.createTexture(device, 'splatCovB', this.format.format, size); + // create data textures if any format is available + this.halfFormat = this.getTextureFormat(device, true); + + if (this.halfFormat !== undefined) { + const size = this.evalTextureSize(numSplats); + this.colorTexture = this.createTexture(device, 'splatColor', PIXELFORMAT_RGBA8, size); + this.transformATexture = this.createTexture(device, 'transformA', this.halfFormat ? PIXELFORMAT_RGBA16F : PIXELFORMAT_RGBA32F, size); + this.transformBTexture = this.createTexture(device, 'transformB', this.halfFormat ? PIXELFORMAT_RGBA16F : PIXELFORMAT_RGBA32F, size); + this.transformCTexture = this.createTexture(device, 'transformC', this.halfFormat ? PIXELFORMAT_R16F : PIXELFORMAT_R32F, size); + } } destroy() { - this.colorTexture.destroy(); - this.centerTexture.destroy(); - this.covATexture.destroy(); - this.covBTexture.destroy(); + this.colorTexture?.destroy(); + this.transformATexture?.destroy(); + this.transformBTexture?.destroy(); + this.transformCTexture?.destroy(); } /** @@ -92,13 +94,15 @@ class GSplat { */ setupMaterial(material) { - material.setParameter('splatColor', this.colorTexture); - material.setParameter('splatCenter', this.centerTexture); - material.setParameter('splatCovA', this.covATexture); - material.setParameter('splatCovB', this.covBTexture); + if (this.colorTexture) { + material.setParameter('splatColor', this.colorTexture); + material.setParameter('transformA', this.transformATexture); + material.setParameter('transformB', this.transformBTexture); + material.setParameter('transformC', this.transformCTexture); - const { width, height } = this.colorTexture; - material.setParameter('tex_params', new Float32Array([width, height, 1 / width, 1 / height])); + const { width, height } = this.colorTexture; + material.setParameter('tex_params', new Float32Array([width, height, 1 / width, 1 / height])); + } } /** @@ -144,24 +148,35 @@ class GSplat { * * @param {import('../../platform/graphics/graphics-device.js').GraphicsDevice} device - The graphics device. * @param {boolean} preferHighPrecision - True to prefer high precision when available. - * @returns {SplatTextureFormat} The texture format info or undefined if not available. + * @returns {boolean|undefined} True if half format should be used, false is float format should + * be used or undefined if none are available. */ getTextureFormat(device, preferHighPrecision) { - const halfFormat = (device.extTextureHalfFloat && device.textureHalfFloatUpdatable) ? PIXELFORMAT_RGBA16F : undefined; - const half = halfFormat ? { - format: halfFormat, - numComponents: 4, - isHalf: true - } : undefined; - - const floatFormat = device.isWebGPU ? PIXELFORMAT_RGBA32F : (device.extTextureFloat ? PIXELFORMAT_RGB32F : undefined); - const float = floatFormat ? { - format: floatFormat, - numComponents: floatFormat === PIXELFORMAT_RGBA32F ? 4 : 3, - isHalf: false - } : undefined; - - return preferHighPrecision ? (float ?? half) : (half ?? float); + + // on WebGL1 R32F is not supported, always use half precision + if (device.isWebGL1) + preferHighPrecision = false; + + const halfSupported = device.extTextureHalfFloat && device.textureHalfFloatUpdatable; + const floatSupported = device.extTextureFloat; + + // true if half format should be used, false is float format should be used or undefined if none are available. + let halfFormat; + if (preferHighPrecision) { + if (floatSupported) { + halfFormat = false; + } else if (halfSupported) { + halfFormat = true; + } + } else { + if (halfSupported) { + halfFormat = true; + } else if (floatSupported) { + halfFormat = false; + } + } + + return halfFormat; } /** @@ -177,6 +192,8 @@ class GSplat { updateColorData(c0, c1, c2, opacity) { const SH_C0 = 0.28209479177387814; const texture = this.colorTexture; + if (!texture) + return; const data = texture.lock(); /** @@ -210,6 +227,89 @@ class GSplat { texture.unlock(); } + /** + * @param {Float32Array} x - The array containing the 'x' component of the center points. + * @param {Float32Array} y - The array containing the 'y' component of the center points. + * @param {Float32Array} z - The array containing the 'z' component of the center points. + * @param {Float32Array} rot0 - The array containing the 'x' component of quaternion rotations. + * @param {Float32Array} rot1 - The array containing the 'y' component of quaternion rotations. + * @param {Float32Array} rot2 - The array containing the 'z' component of quaternion rotations. + * @param {Float32Array} rot3 - The array containing the 'w' component of quaternion rotations. + * @param {Float32Array} scale0 - The first scale component associated with the x-dimension. + * @param {Float32Array} scale1 - The second scale component associated with the y-dimension. + * @param {Float32Array} scale2 - The third scale component associated with the z-dimension. + */ + updateTransformData(x, y, z, rot0, rot1, rot2, rot3, scale0, scale1, scale2) { + + const { halfFormat } = this; + const float2Half = FloatPacking.float2Half; + + if (!this.transformATexture) + return; + + const dataA = this.transformATexture.lock(); + const dataB = this.transformBTexture.lock(); + const dataC = this.transformCTexture.lock(); + + const quat = new Quat(); + const mat = new Mat3(); + const cA = new Vec3(); + const cB = new Vec3(); + + for (let i = 0; i < this.numSplats; i++) { + + // rotation + quat.set(rot0[i], rot1[i], rot2[i], rot3[i]).normalize(); + if (quat.w < 0) { + quat.conjugate(); + } + _r.set(quat.x, quat.y, quat.z); + this.quatToMat3(_r, mat); + + // scale + _s.set( + Math.exp(scale0[i]), + Math.exp(scale1[i]), + Math.exp(scale2[i]) + ); + + this.computeCov3d(mat, _s, cA, cB); + + if (halfFormat) { + + dataA[i * 4 + 0] = float2Half(x[i]); + dataA[i * 4 + 1] = float2Half(y[i]); + dataA[i * 4 + 2] = float2Half(z[i]); + dataA[i * 4 + 3] = float2Half(cB.x); + + dataB[i * 4 + 0] = float2Half(cA.x); + dataB[i * 4 + 1] = float2Half(cA.y); + dataB[i * 4 + 2] = float2Half(cA.z); + dataB[i * 4 + 3] = float2Half(cB.y); + + dataC[i] = float2Half(cB.z); + + } else { + + dataA[i * 4 + 0] = x[i]; + dataA[i * 4 + 1] = y[i]; + dataA[i * 4 + 2] = z[i]; + dataA[i * 4 + 3] = cB.x; + + dataB[i * 4 + 0] = cA.x; + dataB[i * 4 + 1] = cA.y; + dataB[i * 4 + 2] = cA.z; + dataB[i * 4 + 3] = cB.y; + + dataC[i] = cB.z; + } + } + + this.transformATexture.unlock(); + this.transformBTexture.unlock(); + this.transformCTexture.unlock(); + } + /** * Convert quaternion rotation stored in Vec3 to a rotation matrix. * @@ -268,104 +368,6 @@ class GSplat { _m2.dot(_m2) ); } - - /** - * Updates data of covATexture and covBTexture based on the supplied rotation and scale - * components. - * - * @param {Float32Array} rot0 - The array containing the 'x' component of quaternion rotations. - * @param {Float32Array} rot1 - The array containing the 'y' component of quaternion rotations. - * @param {Float32Array} rot2 - The array containing the 'z' component of quaternion rotations. - * @param {Float32Array} rot3 - The array containing the 'w' component of quaternion rotations. - * @param {Float32Array} scale0 - The first scale component associated with the x-dimension. - * @param {Float32Array} scale1 - The second scale component associated with the y-dimension. - * @param {Float32Array} scale2 - The third scale component associated with the z-dimension. - */ - updateCovData(rot0, rot1, rot2, rot3, scale0, scale1, scale2) { - - const { numComponents, isHalf } = this.format; - const float2Half = FloatPacking.float2Half; - const quat = new Quat(); - const mat = new Mat3(); - const cA = new Vec3(); - const cB = new Vec3(); - - const covA = this.covATexture.lock(); - const covB = this.covBTexture.lock(); - - for (let i = 0; i < this.numSplats; ++i) { - - // rotation - quat.set(rot0[i], rot1[i], rot2[i], rot3[i]).normalize(); - if (quat.w < 0) { - quat.conjugate(); - } - _r.set(quat.x, quat.y, quat.z); - this.quatToMat3(_r, mat); - - // scale - _s.set( - Math.exp(scale0[i]), - Math.exp(scale1[i]), - Math.exp(scale2[i]) - ); - - this.computeCov3d(mat, _s, cA, cB); - - if (isHalf) { - covA[i * numComponents + 0] = float2Half(cA.x); - covA[i * numComponents + 1] = float2Half(cA.y); - covA[i * numComponents + 2] = float2Half(cA.z); - - covB[i * numComponents + 0] = float2Half(cB.x); - covB[i * numComponents + 1] = float2Half(cB.y); - covB[i * numComponents + 2] = float2Half(cB.z); - } else { - covA[i * numComponents + 0] = cA.x; - covA[i * numComponents + 1] = cA.y; - covA[i * numComponents + 2] = cA.z; - - covB[i * numComponents + 0] = cB.x; - covB[i * numComponents + 1] = cB.y; - covB[i * numComponents + 2] = cB.z; - } - } - - this.covATexture.unlock(); - this.covBTexture.unlock(); - } - - /** - * Updates pixel data of this.centerTexture based on the supplied center coordinates. - * The center coordinates are stored as either half or full precision floats depending on the - * texture format. - * - * @param {Float32Array} x - The array containing the 'x' component of the center points. - * @param {Float32Array} y - The array containing the 'y' component of the center points. - * @param {Float32Array} z - The array containing the 'z' component of the center points. - */ - updateCenterData(x, y, z) { - const { numComponents, isHalf } = this.format; - - const texture = this.centerTexture; - const data = texture.lock(); - const float2Half = FloatPacking.float2Half; - - for (let i = 0; i < this.numSplats; i++) { - - if (isHalf) { - data[i * numComponents + 0] = float2Half(x[i]); - data[i * numComponents + 1] = float2Half(y[i]); - data[i * numComponents + 2] = float2Half(z[i]); - } else { - data[i * numComponents + 0] = x[i]; - data[i * numComponents + 1] = y[i]; - data[i * numComponents + 2] = z[i]; - } - } - - texture.unlock(); - } } export { GSplat }; diff --git a/src/scene/gsplat/shader-generator-gsplat.js b/src/scene/gsplat/shader-generator-gsplat.js index 0a9393daa89..7ad07014cab 100644 --- a/src/scene/gsplat/shader-generator-gsplat.js +++ b/src/scene/gsplat/shader-generator-gsplat.js @@ -34,10 +34,14 @@ const splatCoreVS = ` uniform vec4 tex_params; uniform sampler2D splatColor; - uniform highp sampler2D splatCenter; - uniform highp sampler2D splatCovA; - uniform highp sampler2D splatCovB; + uniform highp sampler2D transformA; + uniform highp sampler2D transformB; + uniform highp sampler2D transformC; + + vec3 center; + vec3 covA; + vec3 covB; #ifdef INT_INDICES @@ -58,16 +62,14 @@ const splatCoreVS = ` return texelFetch(splatColor, dataUV, 0); } - vec3 getCenter() { - return texelFetch(splatCenter, dataUV, 0).xyz; - } + void getTransform() { + vec4 tA = texelFetch(transformA, dataUV, 0); + vec4 tB = texelFetch(transformB, dataUV, 0); + vec4 tC = texelFetch(transformC, dataUV, 0); - vec3 getCovA() { - return texelFetch(splatCovA, dataUV, 0).xyz; - } - - vec3 getCovB() { - return texelFetch(splatCovB, dataUV, 0).xyz; + center = tA.xyz; + covA = tB.xyz; + covB = vec3(tA.w, tB.w, tC.x); } #else @@ -92,23 +94,25 @@ const splatCoreVS = ` return texture2D(splatColor, dataUV); } - vec3 getCenter() { - return texture2D(splatCenter, dataUV).xyz; - } - - vec3 getCovA() { - return texture2D(splatCovA, dataUV).xyz; - } + void getTransform() { + vec4 tA = texture2D(transformA, dataUV); + vec4 tB = texture2D(transformB, dataUV); + vec4 tC = texture2D(transformC, dataUV); - vec3 getCovB() { - return texture2D(splatCovB, dataUV).xyz; + center = tA.xyz; + covA = tB.xyz; + covB = vec3(tA.w, tB.w, tC.x); } #endif vec3 evalCenter() { evalDataUV(); - return getCenter(); + + // get data + getTransform(); + + return center; } vec4 evalSplat(vec4 centerWorld) @@ -123,13 +127,10 @@ const splatCoreVS = ` color = getColor(); - vec3 splat_cova = getCovA(); - vec3 splat_covb = getCovB(); - mat3 Vrk = mat3( - splat_cova.x, splat_cova.y, splat_cova.z, - splat_cova.y, splat_covb.x, splat_covb.y, - splat_cova.z, splat_covb.y, splat_covb.z + covA.x, covA.y, covA.z, + covA.y, covB.x, covB.y, + covA.z, covB.y, covB.z ); float focal = viewport.x * matrix_projection[0][0];