Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use more compact texture formats storing GSplat data #6129

Merged
merged 1 commit into from
Mar 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions src/framework/parsers/gsplat-resource.js
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,11 @@ class GSplatResource {

// texture data
splat.updateColorData(splatData.getProp('f_dc_0'), splatData.getProp('f_dc_1'), splatData.getProp('f_dc_2'), splatData.getProp('opacity'));
splat.updateCenterData(splatData.getProp('x'), splatData.getProp('y'), splatData.getProp('z'));
splat.updateCovData(splatData.getProp('rot_0'), splatData.getProp('rot_1'), splatData.getProp('rot_2'), splatData.getProp('rot_3'),
splatData.getProp('scale_0'), splatData.getProp('scale_1'), splatData.getProp('scale_2'));
splat.updateTransformData(
splatData.getProp('x'), splatData.getProp('y'), splatData.getProp('z'),
splatData.getProp('rot_0'), splatData.getProp('rot_1'), splatData.getProp('rot_2'), splatData.getProp('rot_3'),
splatData.getProp('scale_0'), splatData.getProp('scale_1'), splatData.getProp('scale_2')
);

// centers - constant buffer that is sent to the worker
const x = splatData.getProp('x');
Expand Down
290 changes: 146 additions & 144 deletions src/scene/gsplat/gsplat.js
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import { Quat } from '../../core/math/quat.js';
import { Vec2 } from '../../core/math/vec2.js';
import { Mat3 } from '../../core/math/mat3.js';
import {
ADDRESS_CLAMP_TO_EDGE, FILTER_NEAREST, PIXELFORMAT_RGB32F, PIXELFORMAT_RGBA16F, PIXELFORMAT_RGBA32F,
ADDRESS_CLAMP_TO_EDGE, FILTER_NEAREST, PIXELFORMAT_R16F, PIXELFORMAT_R32F, PIXELFORMAT_RGBA16F, PIXELFORMAT_RGBA32F,
PIXELFORMAT_RGBA8, SEMANTIC_ATTR13, TYPE_FLOAT32, TYPE_UINT32
} from '../../platform/graphics/constants.js';
import { Texture } from '../../platform/graphics/texture.js';
Expand All @@ -20,35 +20,34 @@ const _m2 = new Vec3();
const _s = new Vec3();
const _r = new Vec3();

/**
* @typedef {object} SplatTextureFormat
* @property {number} format - The pixel format of the texture.
* @property {number} numComponents - The number of components in the texture format.
* @property {boolean} isHalf - Indicates if the format uses half-precision floats.
*/

/** @ignore */
class GSplat {
device;

numSplats;

/** @type {VertexFormat} */
vertexFormat;

/** @type {SplatTextureFormat} */
format;
/**
* True if half format should be used, false is float format should be used or undefined if none
* are available.
*
* @type {boolean|undefined}
*/
halfFormat;

/** @type {Texture} */
colorTexture;

/** @type {Texture} */
covATexture;
transformATexture;

/** @type {Texture} */
covBTexture;
transformBTexture;

/** @type {Texture} */
centerTexture;
transformCTexture;

/** @type {Float32Array} */
centers;
Expand All @@ -70,20 +69,23 @@ class GSplat {
{ semantic: SEMANTIC_ATTR13, components: 1, type: device.isWebGL1 ? TYPE_FLOAT32 : TYPE_UINT32, asInt: !device.isWebGL1 }
]);

// create data textures
const size = this.evalTextureSize(numSplats);
this.format = this.getTextureFormat(device, false);
this.colorTexture = this.createTexture(device, 'splatColor', PIXELFORMAT_RGBA8, size);
this.centerTexture = this.createTexture(device, 'splatCenter', this.format.format, size);
this.covATexture = this.createTexture(device, 'splatCovA', this.format.format, size);
this.covBTexture = this.createTexture(device, 'splatCovB', this.format.format, size);
// create data textures if any format is available
this.halfFormat = this.getTextureFormat(device, true);

if (this.halfFormat !== undefined) {
const size = this.evalTextureSize(numSplats);
this.colorTexture = this.createTexture(device, 'splatColor', PIXELFORMAT_RGBA8, size);
this.transformATexture = this.createTexture(device, 'transformA', this.halfFormat ? PIXELFORMAT_RGBA16F : PIXELFORMAT_RGBA32F, size);
this.transformBTexture = this.createTexture(device, 'transformB', this.halfFormat ? PIXELFORMAT_RGBA16F : PIXELFORMAT_RGBA32F, size);
this.transformCTexture = this.createTexture(device, 'transformC', this.halfFormat ? PIXELFORMAT_R16F : PIXELFORMAT_R32F, size);
}
}

destroy() {
this.colorTexture.destroy();
this.centerTexture.destroy();
this.covATexture.destroy();
this.covBTexture.destroy();
this.colorTexture?.destroy();
this.transformATexture?.destroy();
this.transformBTexture?.destroy();
this.transformCTexture?.destroy();
}

/**
Expand All @@ -92,13 +94,15 @@ class GSplat {
*/
setupMaterial(material) {

material.setParameter('splatColor', this.colorTexture);
material.setParameter('splatCenter', this.centerTexture);
material.setParameter('splatCovA', this.covATexture);
material.setParameter('splatCovB', this.covBTexture);
if (this.colorTexture) {
material.setParameter('splatColor', this.colorTexture);
material.setParameter('transformA', this.transformATexture);
material.setParameter('transformB', this.transformBTexture);
material.setParameter('transformC', this.transformCTexture);

const { width, height } = this.colorTexture;
material.setParameter('tex_params', new Float32Array([width, height, 1 / width, 1 / height]));
const { width, height } = this.colorTexture;
material.setParameter('tex_params', new Float32Array([width, height, 1 / width, 1 / height]));
}
}

/**
Expand Down Expand Up @@ -144,24 +148,35 @@ class GSplat {
*
* @param {import('../../platform/graphics/graphics-device.js').GraphicsDevice} device - The graphics device.
* @param {boolean} preferHighPrecision - True to prefer high precision when available.
* @returns {SplatTextureFormat} The texture format info or undefined if not available.
* @returns {boolean|undefined} True if half format should be used, false is float format should
* be used or undefined if none are available.
*/
getTextureFormat(device, preferHighPrecision) {
const halfFormat = (device.extTextureHalfFloat && device.textureHalfFloatUpdatable) ? PIXELFORMAT_RGBA16F : undefined;
const half = halfFormat ? {
format: halfFormat,
numComponents: 4,
isHalf: true
} : undefined;

const floatFormat = device.isWebGPU ? PIXELFORMAT_RGBA32F : (device.extTextureFloat ? PIXELFORMAT_RGB32F : undefined);
const float = floatFormat ? {
format: floatFormat,
numComponents: floatFormat === PIXELFORMAT_RGBA32F ? 4 : 3,
isHalf: false
} : undefined;

return preferHighPrecision ? (float ?? half) : (half ?? float);

// on WebGL1 R32F is not supported, always use half precision
if (device.isWebGL1)
preferHighPrecision = false;

const halfSupported = device.extTextureHalfFloat && device.textureHalfFloatUpdatable;
const floatSupported = device.extTextureFloat;

// true if half format should be used, false is float format should be used or undefined if none are available.
let halfFormat;
if (preferHighPrecision) {
if (floatSupported) {
halfFormat = false;
} else if (halfSupported) {
halfFormat = true;
}
} else {
if (halfSupported) {
halfFormat = true;
} else if (floatSupported) {
halfFormat = false;
}
}

return halfFormat;
}

/**
Expand All @@ -177,6 +192,8 @@ class GSplat {
updateColorData(c0, c1, c2, opacity) {
const SH_C0 = 0.28209479177387814;
const texture = this.colorTexture;
if (!texture)
return;
const data = texture.lock();

/**
Expand Down Expand Up @@ -210,6 +227,89 @@ class GSplat {
texture.unlock();
}

/**
* @param {Float32Array} x - The array containing the 'x' component of the center points.
* @param {Float32Array} y - The array containing the 'y' component of the center points.
* @param {Float32Array} z - The array containing the 'z' component of the center points.
* @param {Float32Array} rot0 - The array containing the 'x' component of quaternion rotations.
* @param {Float32Array} rot1 - The array containing the 'y' component of quaternion rotations.
* @param {Float32Array} rot2 - The array containing the 'z' component of quaternion rotations.
* @param {Float32Array} rot3 - The array containing the 'w' component of quaternion rotations.
* @param {Float32Array} scale0 - The first scale component associated with the x-dimension.
* @param {Float32Array} scale1 - The second scale component associated with the y-dimension.
* @param {Float32Array} scale2 - The third scale component associated with the z-dimension.
*/
updateTransformData(x, y, z, rot0, rot1, rot2, rot3, scale0, scale1, scale2) {

const { halfFormat } = this;
const float2Half = FloatPacking.float2Half;

if (!this.transformATexture)
return;

const dataA = this.transformATexture.lock();
const dataB = this.transformBTexture.lock();
const dataC = this.transformCTexture.lock();

const quat = new Quat();
const mat = new Mat3();
const cA = new Vec3();
const cB = new Vec3();

for (let i = 0; i < this.numSplats; i++) {

// rotation
quat.set(rot0[i], rot1[i], rot2[i], rot3[i]).normalize();
if (quat.w < 0) {
quat.conjugate();
}
_r.set(quat.x, quat.y, quat.z);
this.quatToMat3(_r, mat);

// scale
_s.set(
Math.exp(scale0[i]),
Math.exp(scale1[i]),
Math.exp(scale2[i])
);

this.computeCov3d(mat, _s, cA, cB);

if (halfFormat) {

dataA[i * 4 + 0] = float2Half(x[i]);
dataA[i * 4 + 1] = float2Half(y[i]);
dataA[i * 4 + 2] = float2Half(z[i]);
dataA[i * 4 + 3] = float2Half(cB.x);

dataB[i * 4 + 0] = float2Half(cA.x);
dataB[i * 4 + 1] = float2Half(cA.y);
dataB[i * 4 + 2] = float2Half(cA.z);
dataB[i * 4 + 3] = float2Half(cB.y);

dataC[i] = float2Half(cB.z);

} else {

dataA[i * 4 + 0] = x[i];
dataA[i * 4 + 1] = y[i];
dataA[i * 4 + 2] = z[i];
dataA[i * 4 + 3] = cB.x;

dataB[i * 4 + 0] = cA.x;
dataB[i * 4 + 1] = cA.y;
dataB[i * 4 + 2] = cA.z;
dataB[i * 4 + 3] = cB.y;

dataC[i] = cB.z;
}
}

this.transformATexture.unlock();
this.transformBTexture.unlock();
this.transformCTexture.unlock();
}

/**
* Convert quaternion rotation stored in Vec3 to a rotation matrix.
*
Expand Down Expand Up @@ -268,104 +368,6 @@ class GSplat {
_m2.dot(_m2)
);
}

/**
* Updates data of covATexture and covBTexture based on the supplied rotation and scale
* components.
*
* @param {Float32Array} rot0 - The array containing the 'x' component of quaternion rotations.
* @param {Float32Array} rot1 - The array containing the 'y' component of quaternion rotations.
* @param {Float32Array} rot2 - The array containing the 'z' component of quaternion rotations.
* @param {Float32Array} rot3 - The array containing the 'w' component of quaternion rotations.
* @param {Float32Array} scale0 - The first scale component associated with the x-dimension.
* @param {Float32Array} scale1 - The second scale component associated with the y-dimension.
* @param {Float32Array} scale2 - The third scale component associated with the z-dimension.
*/
updateCovData(rot0, rot1, rot2, rot3, scale0, scale1, scale2) {

const { numComponents, isHalf } = this.format;
const float2Half = FloatPacking.float2Half;
const quat = new Quat();
const mat = new Mat3();
const cA = new Vec3();
const cB = new Vec3();

const covA = this.covATexture.lock();
const covB = this.covBTexture.lock();

for (let i = 0; i < this.numSplats; ++i) {

// rotation
quat.set(rot0[i], rot1[i], rot2[i], rot3[i]).normalize();
if (quat.w < 0) {
quat.conjugate();
}
_r.set(quat.x, quat.y, quat.z);
this.quatToMat3(_r, mat);

// scale
_s.set(
Math.exp(scale0[i]),
Math.exp(scale1[i]),
Math.exp(scale2[i])
);

this.computeCov3d(mat, _s, cA, cB);

if (isHalf) {
covA[i * numComponents + 0] = float2Half(cA.x);
covA[i * numComponents + 1] = float2Half(cA.y);
covA[i * numComponents + 2] = float2Half(cA.z);

covB[i * numComponents + 0] = float2Half(cB.x);
covB[i * numComponents + 1] = float2Half(cB.y);
covB[i * numComponents + 2] = float2Half(cB.z);
} else {
covA[i * numComponents + 0] = cA.x;
covA[i * numComponents + 1] = cA.y;
covA[i * numComponents + 2] = cA.z;

covB[i * numComponents + 0] = cB.x;
covB[i * numComponents + 1] = cB.y;
covB[i * numComponents + 2] = cB.z;
}
}

this.covATexture.unlock();
this.covBTexture.unlock();
}

/**
* Updates pixel data of this.centerTexture based on the supplied center coordinates.
* The center coordinates are stored as either half or full precision floats depending on the
* texture format.
*
* @param {Float32Array} x - The array containing the 'x' component of the center points.
* @param {Float32Array} y - The array containing the 'y' component of the center points.
* @param {Float32Array} z - The array containing the 'z' component of the center points.
*/
updateCenterData(x, y, z) {
const { numComponents, isHalf } = this.format;

const texture = this.centerTexture;
const data = texture.lock();
const float2Half = FloatPacking.float2Half;

for (let i = 0; i < this.numSplats; i++) {

if (isHalf) {
data[i * numComponents + 0] = float2Half(x[i]);
data[i * numComponents + 1] = float2Half(y[i]);
data[i * numComponents + 2] = float2Half(z[i]);
} else {
data[i * numComponents + 0] = x[i];
data[i * numComponents + 1] = y[i];
data[i * numComponents + 2] = z[i];
}
}

texture.unlock();
}
}

export { GSplat };
Loading