Skip to content

Commit

Permalink
Use more compact texture formats storing GSplat data (#6129)
Browse files Browse the repository at this point in the history
Co-authored-by: Martin Valigursky <mvaligursky@snapchat.com>
  • Loading branch information
mvaligursky and Martin Valigursky authored Mar 8, 2024
1 parent f270a7e commit 9e9838b
Show file tree
Hide file tree
Showing 3 changed files with 180 additions and 175 deletions.
8 changes: 5 additions & 3 deletions src/framework/parsers/gsplat-resource.js
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,11 @@ class GSplatResource {

// texture data
splat.updateColorData(splatData.getProp('f_dc_0'), splatData.getProp('f_dc_1'), splatData.getProp('f_dc_2'), splatData.getProp('opacity'));
splat.updateCenterData(splatData.getProp('x'), splatData.getProp('y'), splatData.getProp('z'));
splat.updateCovData(splatData.getProp('rot_0'), splatData.getProp('rot_1'), splatData.getProp('rot_2'), splatData.getProp('rot_3'),
splatData.getProp('scale_0'), splatData.getProp('scale_1'), splatData.getProp('scale_2'));
splat.updateTransformData(
splatData.getProp('x'), splatData.getProp('y'), splatData.getProp('z'),
splatData.getProp('rot_0'), splatData.getProp('rot_1'), splatData.getProp('rot_2'), splatData.getProp('rot_3'),
splatData.getProp('scale_0'), splatData.getProp('scale_1'), splatData.getProp('scale_2')
);

// centers - constant buffer that is sent to the worker
const x = splatData.getProp('x');
Expand Down
290 changes: 146 additions & 144 deletions src/scene/gsplat/gsplat.js
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import { Quat } from '../../core/math/quat.js';
import { Vec2 } from '../../core/math/vec2.js';
import { Mat3 } from '../../core/math/mat3.js';
import {
ADDRESS_CLAMP_TO_EDGE, FILTER_NEAREST, PIXELFORMAT_RGB32F, PIXELFORMAT_RGBA16F, PIXELFORMAT_RGBA32F,
ADDRESS_CLAMP_TO_EDGE, FILTER_NEAREST, PIXELFORMAT_R16F, PIXELFORMAT_R32F, PIXELFORMAT_RGBA16F, PIXELFORMAT_RGBA32F,
PIXELFORMAT_RGBA8, SEMANTIC_ATTR13, TYPE_FLOAT32, TYPE_UINT32
} from '../../platform/graphics/constants.js';
import { Texture } from '../../platform/graphics/texture.js';
Expand All @@ -20,35 +20,34 @@ const _m2 = new Vec3();
const _s = new Vec3();
const _r = new Vec3();

/**
* @typedef {object} SplatTextureFormat
* @property {number} format - The pixel format of the texture.
* @property {number} numComponents - The number of components in the texture format.
* @property {boolean} isHalf - Indicates if the format uses half-precision floats.
*/

/** @ignore */
class GSplat {
device;

numSplats;

/** @type {VertexFormat} */
vertexFormat;

/** @type {SplatTextureFormat} */
format;
/**
* True if half format should be used, false is float format should be used or undefined if none
* are available.
*
* @type {boolean|undefined}
*/
halfFormat;

/** @type {Texture} */
colorTexture;

/** @type {Texture} */
covATexture;
transformATexture;

/** @type {Texture} */
covBTexture;
transformBTexture;

/** @type {Texture} */
centerTexture;
transformCTexture;

/** @type {Float32Array} */
centers;
Expand All @@ -70,20 +69,23 @@ class GSplat {
{ semantic: SEMANTIC_ATTR13, components: 1, type: device.isWebGL1 ? TYPE_FLOAT32 : TYPE_UINT32, asInt: !device.isWebGL1 }
]);

// create data textures
const size = this.evalTextureSize(numSplats);
this.format = this.getTextureFormat(device, false);
this.colorTexture = this.createTexture(device, 'splatColor', PIXELFORMAT_RGBA8, size);
this.centerTexture = this.createTexture(device, 'splatCenter', this.format.format, size);
this.covATexture = this.createTexture(device, 'splatCovA', this.format.format, size);
this.covBTexture = this.createTexture(device, 'splatCovB', this.format.format, size);
// create data textures if any format is available
this.halfFormat = this.getTextureFormat(device, true);

if (this.halfFormat !== undefined) {
const size = this.evalTextureSize(numSplats);
this.colorTexture = this.createTexture(device, 'splatColor', PIXELFORMAT_RGBA8, size);
this.transformATexture = this.createTexture(device, 'transformA', this.halfFormat ? PIXELFORMAT_RGBA16F : PIXELFORMAT_RGBA32F, size);
this.transformBTexture = this.createTexture(device, 'transformB', this.halfFormat ? PIXELFORMAT_RGBA16F : PIXELFORMAT_RGBA32F, size);
this.transformCTexture = this.createTexture(device, 'transformC', this.halfFormat ? PIXELFORMAT_R16F : PIXELFORMAT_R32F, size);
}
}

destroy() {
this.colorTexture.destroy();
this.centerTexture.destroy();
this.covATexture.destroy();
this.covBTexture.destroy();
this.colorTexture?.destroy();
this.transformATexture?.destroy();
this.transformBTexture?.destroy();
this.transformCTexture?.destroy();
}

/**
Expand All @@ -92,13 +94,15 @@ class GSplat {
*/
setupMaterial(material) {

material.setParameter('splatColor', this.colorTexture);
material.setParameter('splatCenter', this.centerTexture);
material.setParameter('splatCovA', this.covATexture);
material.setParameter('splatCovB', this.covBTexture);
if (this.colorTexture) {
material.setParameter('splatColor', this.colorTexture);
material.setParameter('transformA', this.transformATexture);
material.setParameter('transformB', this.transformBTexture);
material.setParameter('transformC', this.transformCTexture);

const { width, height } = this.colorTexture;
material.setParameter('tex_params', new Float32Array([width, height, 1 / width, 1 / height]));
const { width, height } = this.colorTexture;
material.setParameter('tex_params', new Float32Array([width, height, 1 / width, 1 / height]));
}
}

/**
Expand Down Expand Up @@ -144,24 +148,35 @@ class GSplat {
*
* @param {import('../../platform/graphics/graphics-device.js').GraphicsDevice} device - The graphics device.
* @param {boolean} preferHighPrecision - True to prefer high precision when available.
* @returns {SplatTextureFormat} The texture format info or undefined if not available.
* @returns {boolean|undefined} True if half format should be used, false is float format should
* be used or undefined if none are available.
*/
getTextureFormat(device, preferHighPrecision) {
const halfFormat = (device.extTextureHalfFloat && device.textureHalfFloatUpdatable) ? PIXELFORMAT_RGBA16F : undefined;
const half = halfFormat ? {
format: halfFormat,
numComponents: 4,
isHalf: true
} : undefined;

const floatFormat = device.isWebGPU ? PIXELFORMAT_RGBA32F : (device.extTextureFloat ? PIXELFORMAT_RGB32F : undefined);
const float = floatFormat ? {
format: floatFormat,
numComponents: floatFormat === PIXELFORMAT_RGBA32F ? 4 : 3,
isHalf: false
} : undefined;

return preferHighPrecision ? (float ?? half) : (half ?? float);

// on WebGL1 R32F is not supported, always use half precision
if (device.isWebGL1)
preferHighPrecision = false;

const halfSupported = device.extTextureHalfFloat && device.textureHalfFloatUpdatable;
const floatSupported = device.extTextureFloat;

// true if half format should be used, false is float format should be used or undefined if none are available.
let halfFormat;
if (preferHighPrecision) {
if (floatSupported) {
halfFormat = false;
} else if (halfSupported) {
halfFormat = true;
}
} else {
if (halfSupported) {
halfFormat = true;
} else if (floatSupported) {
halfFormat = false;
}
}

return halfFormat;
}

/**
Expand All @@ -177,6 +192,8 @@ class GSplat {
updateColorData(c0, c1, c2, opacity) {
const SH_C0 = 0.28209479177387814;
const texture = this.colorTexture;
if (!texture)
return;
const data = texture.lock();

/**
Expand Down Expand Up @@ -210,6 +227,89 @@ class GSplat {
texture.unlock();
}

/**
* @param {Float32Array} x - The array containing the 'x' component of the center points.
* @param {Float32Array} y - The array containing the 'y' component of the center points.
* @param {Float32Array} z - The array containing the 'z' component of the center points.
* @param {Float32Array} rot0 - The array containing the 'x' component of quaternion rotations.
* @param {Float32Array} rot1 - The array containing the 'y' component of quaternion rotations.
* @param {Float32Array} rot2 - The array containing the 'z' component of quaternion rotations.
* @param {Float32Array} rot3 - The array containing the 'w' component of quaternion rotations.
* @param {Float32Array} scale0 - The first scale component associated with the x-dimension.
* @param {Float32Array} scale1 - The second scale component associated with the y-dimension.
* @param {Float32Array} scale2 - The third scale component associated with the z-dimension.
*/
updateTransformData(x, y, z, rot0, rot1, rot2, rot3, scale0, scale1, scale2) {

const { halfFormat } = this;
const float2Half = FloatPacking.float2Half;

if (!this.transformATexture)
return;

const dataA = this.transformATexture.lock();
const dataB = this.transformBTexture.lock();
const dataC = this.transformCTexture.lock();

const quat = new Quat();
const mat = new Mat3();
const cA = new Vec3();
const cB = new Vec3();

for (let i = 0; i < this.numSplats; i++) {

// rotation
quat.set(rot0[i], rot1[i], rot2[i], rot3[i]).normalize();
if (quat.w < 0) {
quat.conjugate();
}
_r.set(quat.x, quat.y, quat.z);
this.quatToMat3(_r, mat);

// scale
_s.set(
Math.exp(scale0[i]),
Math.exp(scale1[i]),
Math.exp(scale2[i])
);

this.computeCov3d(mat, _s, cA, cB);

if (halfFormat) {

dataA[i * 4 + 0] = float2Half(x[i]);
dataA[i * 4 + 1] = float2Half(y[i]);
dataA[i * 4 + 2] = float2Half(z[i]);
dataA[i * 4 + 3] = float2Half(cB.x);

dataB[i * 4 + 0] = float2Half(cA.x);
dataB[i * 4 + 1] = float2Half(cA.y);
dataB[i * 4 + 2] = float2Half(cA.z);
dataB[i * 4 + 3] = float2Half(cB.y);

dataC[i] = float2Half(cB.z);

} else {

dataA[i * 4 + 0] = x[i];
dataA[i * 4 + 1] = y[i];
dataA[i * 4 + 2] = z[i];
dataA[i * 4 + 3] = cB.x;

dataB[i * 4 + 0] = cA.x;
dataB[i * 4 + 1] = cA.y;
dataB[i * 4 + 2] = cA.z;
dataB[i * 4 + 3] = cB.y;

dataC[i] = cB.z;
}
}

this.transformATexture.unlock();
this.transformBTexture.unlock();
this.transformCTexture.unlock();
}

/**
* Convert quaternion rotation stored in Vec3 to a rotation matrix.
*
Expand Down Expand Up @@ -268,104 +368,6 @@ class GSplat {
_m2.dot(_m2)
);
}

/**
* Updates data of covATexture and covBTexture based on the supplied rotation and scale
* components.
*
* @param {Float32Array} rot0 - The array containing the 'x' component of quaternion rotations.
* @param {Float32Array} rot1 - The array containing the 'y' component of quaternion rotations.
* @param {Float32Array} rot2 - The array containing the 'z' component of quaternion rotations.
* @param {Float32Array} rot3 - The array containing the 'w' component of quaternion rotations.
* @param {Float32Array} scale0 - The first scale component associated with the x-dimension.
* @param {Float32Array} scale1 - The second scale component associated with the y-dimension.
* @param {Float32Array} scale2 - The third scale component associated with the z-dimension.
*/
updateCovData(rot0, rot1, rot2, rot3, scale0, scale1, scale2) {

const { numComponents, isHalf } = this.format;
const float2Half = FloatPacking.float2Half;
const quat = new Quat();
const mat = new Mat3();
const cA = new Vec3();
const cB = new Vec3();

const covA = this.covATexture.lock();
const covB = this.covBTexture.lock();

for (let i = 0; i < this.numSplats; ++i) {

// rotation
quat.set(rot0[i], rot1[i], rot2[i], rot3[i]).normalize();
if (quat.w < 0) {
quat.conjugate();
}
_r.set(quat.x, quat.y, quat.z);
this.quatToMat3(_r, mat);

// scale
_s.set(
Math.exp(scale0[i]),
Math.exp(scale1[i]),
Math.exp(scale2[i])
);

this.computeCov3d(mat, _s, cA, cB);

if (isHalf) {
covA[i * numComponents + 0] = float2Half(cA.x);
covA[i * numComponents + 1] = float2Half(cA.y);
covA[i * numComponents + 2] = float2Half(cA.z);

covB[i * numComponents + 0] = float2Half(cB.x);
covB[i * numComponents + 1] = float2Half(cB.y);
covB[i * numComponents + 2] = float2Half(cB.z);
} else {
covA[i * numComponents + 0] = cA.x;
covA[i * numComponents + 1] = cA.y;
covA[i * numComponents + 2] = cA.z;

covB[i * numComponents + 0] = cB.x;
covB[i * numComponents + 1] = cB.y;
covB[i * numComponents + 2] = cB.z;
}
}

this.covATexture.unlock();
this.covBTexture.unlock();
}

/**
* Updates pixel data of this.centerTexture based on the supplied center coordinates.
* The center coordinates are stored as either half or full precision floats depending on the
* texture format.
*
* @param {Float32Array} x - The array containing the 'x' component of the center points.
* @param {Float32Array} y - The array containing the 'y' component of the center points.
* @param {Float32Array} z - The array containing the 'z' component of the center points.
*/
updateCenterData(x, y, z) {
const { numComponents, isHalf } = this.format;

const texture = this.centerTexture;
const data = texture.lock();
const float2Half = FloatPacking.float2Half;

for (let i = 0; i < this.numSplats; i++) {

if (isHalf) {
data[i * numComponents + 0] = float2Half(x[i]);
data[i * numComponents + 1] = float2Half(y[i]);
data[i * numComponents + 2] = float2Half(z[i]);
} else {
data[i * numComponents + 0] = x[i];
data[i * numComponents + 1] = y[i];
data[i * numComponents + 2] = z[i];
}
}

texture.unlock();
}
}

export { GSplat };
Loading

0 comments on commit 9e9838b

Please sign in to comment.