From 9e9838b8482b70054a58f4bd842d74357ea4b966 Mon Sep 17 00:00:00 2001
From: Martin Valigursky <59932779+mvaligursky@users.noreply.github.com>
Date: Fri, 8 Mar 2024 14:20:16 +0000
Subject: [PATCH] Use more compact texture formats storing GSplat data (#6129)

Co-authored-by: Martin Valigursky <mvaligursky@snapchat.com>
---
 src/framework/parsers/gsplat-resource.js    |   8 +-
 src/scene/gsplat/gsplat.js                  | 290 ++++++++++----------
 src/scene/gsplat/shader-generator-gsplat.js |  57 ++--
 3 files changed, 180 insertions(+), 175 deletions(-)

diff --git a/src/framework/parsers/gsplat-resource.js b/src/framework/parsers/gsplat-resource.js
index 9351b056209..841665c336c 100644
--- a/src/framework/parsers/gsplat-resource.js
+++ b/src/framework/parsers/gsplat-resource.js
@@ -57,9 +57,11 @@ class GSplatResource {
 
             // texture data
             splat.updateColorData(splatData.getProp('f_dc_0'), splatData.getProp('f_dc_1'), splatData.getProp('f_dc_2'), splatData.getProp('opacity'));
-            splat.updateCenterData(splatData.getProp('x'), splatData.getProp('y'), splatData.getProp('z'));
-            splat.updateCovData(splatData.getProp('rot_0'), splatData.getProp('rot_1'), splatData.getProp('rot_2'), splatData.getProp('rot_3'),
-                                splatData.getProp('scale_0'), splatData.getProp('scale_1'), splatData.getProp('scale_2'));
+            splat.updateTransformData(
+                splatData.getProp('x'), splatData.getProp('y'), splatData.getProp('z'),
+                splatData.getProp('rot_0'), splatData.getProp('rot_1'), splatData.getProp('rot_2'), splatData.getProp('rot_3'),
+                splatData.getProp('scale_0'), splatData.getProp('scale_1'), splatData.getProp('scale_2')
+            );
 
             // centers - constant buffer that is sent to the worker
             const x = splatData.getProp('x');
diff --git a/src/scene/gsplat/gsplat.js b/src/scene/gsplat/gsplat.js
index ee134b38500..4ad4a46ee9b 100644
--- a/src/scene/gsplat/gsplat.js
+++ b/src/scene/gsplat/gsplat.js
@@ -4,7 +4,7 @@ import { Quat } from '../../core/math/quat.js';
 import { Vec2 } from '../../core/math/vec2.js';
 import { Mat3 } from '../../core/math/mat3.js';
 import {
-    ADDRESS_CLAMP_TO_EDGE, FILTER_NEAREST, PIXELFORMAT_RGB32F, PIXELFORMAT_RGBA16F, PIXELFORMAT_RGBA32F,
+    ADDRESS_CLAMP_TO_EDGE, FILTER_NEAREST, PIXELFORMAT_R16F, PIXELFORMAT_R32F, PIXELFORMAT_RGBA16F, PIXELFORMAT_RGBA32F,
     PIXELFORMAT_RGBA8, SEMANTIC_ATTR13, TYPE_FLOAT32, TYPE_UINT32
 } from '../../platform/graphics/constants.js';
 import { Texture } from '../../platform/graphics/texture.js';
@@ -20,35 +20,34 @@ const _m2 = new Vec3();
 const _s = new Vec3();
 const _r = new Vec3();
 
-/**
- * @typedef {object} SplatTextureFormat
- * @property {number} format - The pixel format of the texture.
- * @property {number} numComponents - The number of components in the texture format.
- * @property {boolean} isHalf - Indicates if the format uses half-precision floats.
- */
-
 /** @ignore */
 class GSplat {
     device;
 
     numSplats;
 
+    /** @type {VertexFormat} */
     vertexFormat;
 
-    /** @type {SplatTextureFormat} */
-    format;
+    /**
+     * True if half format should be used, false is float format should be used or undefined if none
+     * are available.
+     *
+     * @type {boolean|undefined}
+     */
+    halfFormat;
 
     /** @type {Texture} */
     colorTexture;
 
     /** @type {Texture} */
-    covATexture;
+    transformATexture;
 
     /** @type {Texture} */
-    covBTexture;
+    transformBTexture;
 
     /** @type {Texture} */
-    centerTexture;
+    transformCTexture;
 
     /** @type {Float32Array} */
     centers;
@@ -70,20 +69,23 @@ class GSplat {
             { semantic: SEMANTIC_ATTR13, components: 1, type: device.isWebGL1 ? TYPE_FLOAT32 : TYPE_UINT32, asInt: !device.isWebGL1 }
         ]);
 
-        // create data textures
-        const size = this.evalTextureSize(numSplats);
-        this.format = this.getTextureFormat(device, false);
-        this.colorTexture = this.createTexture(device, 'splatColor', PIXELFORMAT_RGBA8, size);
-        this.centerTexture = this.createTexture(device, 'splatCenter', this.format.format, size);
-        this.covATexture = this.createTexture(device, 'splatCovA', this.format.format, size);
-        this.covBTexture = this.createTexture(device, 'splatCovB', this.format.format, size);
+        // create data textures if any format is available
+        this.halfFormat = this.getTextureFormat(device, true);
+
+        if (this.halfFormat !== undefined) {
+            const size = this.evalTextureSize(numSplats);
+            this.colorTexture = this.createTexture(device, 'splatColor', PIXELFORMAT_RGBA8, size);
+            this.transformATexture = this.createTexture(device, 'transformA', this.halfFormat ? PIXELFORMAT_RGBA16F : PIXELFORMAT_RGBA32F, size);
+            this.transformBTexture = this.createTexture(device, 'transformB', this.halfFormat ? PIXELFORMAT_RGBA16F : PIXELFORMAT_RGBA32F, size);
+            this.transformCTexture = this.createTexture(device, 'transformC', this.halfFormat ? PIXELFORMAT_R16F : PIXELFORMAT_R32F, size);
+        }
     }
 
     destroy() {
-        this.colorTexture.destroy();
-        this.centerTexture.destroy();
-        this.covATexture.destroy();
-        this.covBTexture.destroy();
+        this.colorTexture?.destroy();
+        this.transformATexture?.destroy();
+        this.transformBTexture?.destroy();
+        this.transformCTexture?.destroy();
     }
 
     /**
@@ -92,13 +94,15 @@ class GSplat {
      */
     setupMaterial(material) {
 
-        material.setParameter('splatColor', this.colorTexture);
-        material.setParameter('splatCenter', this.centerTexture);
-        material.setParameter('splatCovA', this.covATexture);
-        material.setParameter('splatCovB', this.covBTexture);
+        if (this.colorTexture) {
+            material.setParameter('splatColor', this.colorTexture);
+            material.setParameter('transformA', this.transformATexture);
+            material.setParameter('transformB', this.transformBTexture);
+            material.setParameter('transformC', this.transformCTexture);
 
-        const { width, height } = this.colorTexture;
-        material.setParameter('tex_params', new Float32Array([width, height, 1 / width, 1 / height]));
+            const { width, height } = this.colorTexture;
+            material.setParameter('tex_params', new Float32Array([width, height, 1 / width, 1 / height]));
+        }
     }
 
     /**
@@ -144,24 +148,35 @@ class GSplat {
      *
      * @param {import('../../platform/graphics/graphics-device.js').GraphicsDevice} device - The graphics device.
      * @param {boolean} preferHighPrecision - True to prefer high precision when available.
-     * @returns {SplatTextureFormat} The texture format info or undefined if not available.
+     * @returns {boolean|undefined} True if half format should be used, false is float format should
+     * be used or undefined if none are available.
      */
     getTextureFormat(device, preferHighPrecision) {
-        const halfFormat = (device.extTextureHalfFloat && device.textureHalfFloatUpdatable) ? PIXELFORMAT_RGBA16F : undefined;
-        const half = halfFormat ? {
-            format: halfFormat,
-            numComponents: 4,
-            isHalf: true
-        } : undefined;
-
-        const floatFormat = device.isWebGPU ? PIXELFORMAT_RGBA32F : (device.extTextureFloat ? PIXELFORMAT_RGB32F : undefined);
-        const float = floatFormat ? {
-            format: floatFormat,
-            numComponents: floatFormat === PIXELFORMAT_RGBA32F ? 4 : 3,
-            isHalf: false
-        } : undefined;
-
-        return preferHighPrecision ? (float ?? half) : (half ?? float);
+
+        // on WebGL1 R32F is not supported, always use half precision
+        if (device.isWebGL1)
+            preferHighPrecision = false;
+
+        const halfSupported = device.extTextureHalfFloat && device.textureHalfFloatUpdatable;
+        const floatSupported = device.extTextureFloat;
+
+        // true if half format should be used, false is float format should be used or undefined if none are available.
+        let halfFormat;
+        if (preferHighPrecision) {
+            if (floatSupported) {
+                halfFormat = false;
+            } else if (halfSupported) {
+                halfFormat = true;
+            }
+        } else {
+            if (halfSupported) {
+                halfFormat = true;
+            } else if (floatSupported) {
+                halfFormat = false;
+            }
+        }
+
+        return halfFormat;
     }
 
     /**
@@ -177,6 +192,8 @@ class GSplat {
     updateColorData(c0, c1, c2, opacity) {
         const SH_C0 = 0.28209479177387814;
         const texture = this.colorTexture;
+        if (!texture)
+            return;
         const data = texture.lock();
 
         /**
@@ -210,6 +227,89 @@ class GSplat {
         texture.unlock();
     }
 
+    /**
+     * @param {Float32Array} x - The array containing the 'x' component of the center points.
+     * @param {Float32Array} y - The array containing the 'y' component of the center points.
+     * @param {Float32Array} z - The array containing the 'z' component of the center points.
+     * @param {Float32Array} rot0 - The array containing the 'x' component of quaternion rotations.
+     * @param {Float32Array} rot1 - The array containing the 'y' component of quaternion rotations.
+     * @param {Float32Array} rot2 - The array containing the 'z' component of quaternion rotations.
+     * @param {Float32Array} rot3 - The array containing the 'w' component of quaternion rotations.
+     * @param {Float32Array} scale0 - The first scale component associated with the x-dimension.
+     * @param {Float32Array} scale1 - The second scale component associated with the y-dimension.
+     * @param {Float32Array} scale2 - The third scale component associated with the z-dimension.
+     */
+    updateTransformData(x, y, z, rot0, rot1, rot2, rot3, scale0, scale1, scale2) {
+
+        const { halfFormat } = this;
+        const float2Half = FloatPacking.float2Half;
+
+        if (!this.transformATexture)
+            return;
+
+        const dataA = this.transformATexture.lock();
+        const dataB = this.transformBTexture.lock();
+        const dataC = this.transformCTexture.lock();
+
+        const quat = new Quat();
+        const mat = new Mat3();
+        const cA = new Vec3();
+        const cB = new Vec3();
+
+        for (let i = 0; i < this.numSplats; i++) {
+
+            // rotation
+            quat.set(rot0[i], rot1[i], rot2[i], rot3[i]).normalize();
+            if (quat.w < 0) {
+                quat.conjugate();
+            }
+            _r.set(quat.x, quat.y, quat.z);
+            this.quatToMat3(_r, mat);
+
+            // scale
+            _s.set(
+                Math.exp(scale0[i]),
+                Math.exp(scale1[i]),
+                Math.exp(scale2[i])
+            );
+
+            this.computeCov3d(mat, _s, cA, cB);
+
+            if (halfFormat) {
+
+                dataA[i * 4 + 0] = float2Half(x[i]);
+                dataA[i * 4 + 1] = float2Half(y[i]);
+                dataA[i * 4 + 2] = float2Half(z[i]);
+                dataA[i * 4 + 3] = float2Half(cB.x);
+
+                dataB[i * 4 + 0] = float2Half(cA.x);
+                dataB[i * 4 + 1] = float2Half(cA.y);
+                dataB[i * 4 + 2] = float2Half(cA.z);
+                dataB[i * 4 + 3] = float2Half(cB.y);
+
+                dataC[i] = float2Half(cB.z);
+
+            } else {
+
+                dataA[i * 4 + 0] = x[i];
+                dataA[i * 4 + 1] = y[i];
+                dataA[i * 4 + 2] = z[i];
+                dataA[i * 4 + 3] = cB.x;
+
+                dataB[i * 4 + 0] = cA.x;
+                dataB[i * 4 + 1] = cA.y;
+                dataB[i * 4 + 2] = cA.z;
+                dataB[i * 4 + 3] = cB.y;
+
+                dataC[i] = cB.z;
+            }
+        }
+
+        this.transformATexture.unlock();
+        this.transformBTexture.unlock();
+        this.transformCTexture.unlock();
+    }
+
     /**
      * Convert quaternion rotation stored in Vec3 to a rotation matrix.
      *
@@ -268,104 +368,6 @@ class GSplat {
             _m2.dot(_m2)
         );
     }
-
-    /**
-     * Updates data of covATexture and covBTexture based on the supplied rotation and scale
-     * components.
-     *
-     * @param {Float32Array} rot0 - The array containing the 'x' component of quaternion rotations.
-     * @param {Float32Array} rot1 - The array containing the 'y' component of quaternion rotations.
-     * @param {Float32Array} rot2 - The array containing the 'z' component of quaternion rotations.
-     * @param {Float32Array} rot3 - The array containing the 'w' component of quaternion rotations.
-     * @param {Float32Array} scale0 - The first scale component associated with the x-dimension.
-     * @param {Float32Array} scale1 - The second scale component associated with the y-dimension.
-     * @param {Float32Array} scale2 - The third scale component associated with the z-dimension.
-     */
-    updateCovData(rot0, rot1, rot2, rot3, scale0, scale1, scale2) {
-
-        const { numComponents, isHalf } = this.format;
-        const float2Half = FloatPacking.float2Half;
-        const quat = new Quat();
-        const mat = new Mat3();
-        const cA = new Vec3();
-        const cB = new Vec3();
-
-        const covA = this.covATexture.lock();
-        const covB = this.covBTexture.lock();
-
-        for (let i = 0; i < this.numSplats; ++i) {
-
-            // rotation
-            quat.set(rot0[i], rot1[i], rot2[i], rot3[i]).normalize();
-            if (quat.w < 0) {
-                quat.conjugate();
-            }
-            _r.set(quat.x, quat.y, quat.z);
-            this.quatToMat3(_r, mat);
-
-            // scale
-            _s.set(
-                Math.exp(scale0[i]),
-                Math.exp(scale1[i]),
-                Math.exp(scale2[i])
-            );
-
-            this.computeCov3d(mat, _s, cA, cB);
-
-            if (isHalf) {
-                covA[i * numComponents + 0] = float2Half(cA.x);
-                covA[i * numComponents + 1] = float2Half(cA.y);
-                covA[i * numComponents + 2] = float2Half(cA.z);
-
-                covB[i * numComponents + 0] = float2Half(cB.x);
-                covB[i * numComponents + 1] = float2Half(cB.y);
-                covB[i * numComponents + 2] = float2Half(cB.z);
-            } else {
-                covA[i * numComponents + 0] = cA.x;
-                covA[i * numComponents + 1] = cA.y;
-                covA[i * numComponents + 2] = cA.z;
-
-                covB[i * numComponents + 0] = cB.x;
-                covB[i * numComponents + 1] = cB.y;
-                covB[i * numComponents + 2] = cB.z;
-            }
-        }
-
-        this.covATexture.unlock();
-        this.covBTexture.unlock();
-    }
-
-    /**
-     * Updates pixel data of this.centerTexture based on the supplied center coordinates.
-     * The center coordinates are stored as either half or full precision floats depending on the
-     * texture format.
-     *
-     * @param {Float32Array} x - The array containing the 'x' component of the center points.
-     * @param {Float32Array} y - The array containing the 'y' component of the center points.
-     * @param {Float32Array} z - The array containing the 'z' component of the center points.
-     */
-    updateCenterData(x, y, z) {
-        const { numComponents, isHalf } = this.format;
-
-        const texture = this.centerTexture;
-        const data = texture.lock();
-        const float2Half = FloatPacking.float2Half;
-
-        for (let i = 0; i < this.numSplats; i++) {
-
-            if (isHalf) {
-                data[i * numComponents + 0] = float2Half(x[i]);
-                data[i * numComponents + 1] = float2Half(y[i]);
-                data[i * numComponents + 2] = float2Half(z[i]);
-            } else {
-                data[i * numComponents + 0] = x[i];
-                data[i * numComponents + 1] = y[i];
-                data[i * numComponents + 2] = z[i];
-            }
-        }
-
-        texture.unlock();
-    }
 }
 
 export { GSplat };
diff --git a/src/scene/gsplat/shader-generator-gsplat.js b/src/scene/gsplat/shader-generator-gsplat.js
index 0a9393daa89..7ad07014cab 100644
--- a/src/scene/gsplat/shader-generator-gsplat.js
+++ b/src/scene/gsplat/shader-generator-gsplat.js
@@ -34,10 +34,14 @@ const splatCoreVS = `
 
     uniform vec4 tex_params;
     uniform sampler2D splatColor;
-    uniform highp sampler2D splatCenter;
 
-    uniform highp sampler2D splatCovA;
-    uniform highp sampler2D splatCovB;
+    uniform highp sampler2D transformA;
+    uniform highp sampler2D transformB;
+    uniform highp sampler2D transformC;
+
+    vec3 center;
+    vec3 covA;
+    vec3 covB;
 
     #ifdef INT_INDICES
 
@@ -58,16 +62,14 @@ const splatCoreVS = `
             return texelFetch(splatColor, dataUV, 0);
         }
 
-        vec3 getCenter() {
-            return texelFetch(splatCenter, dataUV, 0).xyz;
-        }
+        void getTransform() {
+            vec4 tA = texelFetch(transformA, dataUV, 0);
+            vec4 tB = texelFetch(transformB, dataUV, 0);
+            vec4 tC = texelFetch(transformC, dataUV, 0);
 
-        vec3 getCovA() {
-            return texelFetch(splatCovA, dataUV, 0).xyz;
-        }
-
-        vec3 getCovB() {
-            return texelFetch(splatCovB, dataUV, 0).xyz;
+            center = tA.xyz;
+            covA = tB.xyz;
+            covB = vec3(tA.w, tB.w, tC.x);
         }
 
     #else
@@ -92,23 +94,25 @@ const splatCoreVS = `
             return texture2D(splatColor, dataUV);
         }
 
-        vec3 getCenter() {
-            return texture2D(splatCenter, dataUV).xyz;
-        }
-
-        vec3 getCovA() {
-            return texture2D(splatCovA, dataUV).xyz;
-        }
+        void getTransform() {
+            vec4 tA = texture2D(transformA, dataUV);
+            vec4 tB = texture2D(transformB, dataUV);
+            vec4 tC = texture2D(transformC, dataUV);
 
-        vec3 getCovB() {
-            return texture2D(splatCovB, dataUV).xyz;
+            center = tA.xyz;
+            covA = tB.xyz;
+            covB = vec3(tA.w, tB.w, tC.x);
         }
 
     #endif
 
     vec3 evalCenter() {
         evalDataUV();
-        return getCenter();
+
+        // get data
+        getTransform();
+
+        return center;
     }
 
     vec4 evalSplat(vec4 centerWorld)
@@ -123,13 +127,10 @@ const splatCoreVS = `
 
         color = getColor();
 
-        vec3 splat_cova = getCovA();
-        vec3 splat_covb = getCovB();
-
         mat3 Vrk = mat3(
-            splat_cova.x, splat_cova.y, splat_cova.z, 
-            splat_cova.y, splat_covb.x, splat_covb.y,
-            splat_cova.z, splat_covb.y, splat_covb.z
+            covA.x, covA.y, covA.z, 
+            covA.y, covB.x, covB.y,
+            covA.z, covB.y, covB.z
         );
 
         float focal = viewport.x * matrix_projection[0][0];