Skip to content

Commit

Permalink
Nodes: Add AtomicFunctionNode (#29385)
Browse files Browse the repository at this point in the history
* add atomic operations

* add storeNode

* cleanup

---------
  • Loading branch information
cmhhelgeson authored Sep 17, 2024
1 parent 915392f commit 38fd5e9
Show file tree
Hide file tree
Showing 5 changed files with 126 additions and 2 deletions.
10 changes: 9 additions & 1 deletion examples/webgpu_compute_sort_bitonic.html
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
<script type="module">

import * as THREE from 'three';
import { storageObject, If, vec3, not, uniform, uv, uint, float, Fn, vec2, abs, int, invocationLocalIndex, workgroupArray, uvec2, floor, instanceIndex, workgroupBarrier } from 'three/tsl';
import { storageObject, If, vec3, not, uniform, uv, uint, float, Fn, vec2, abs, int, invocationLocalIndex, workgroupArray, uvec2, floor, instanceIndex, workgroupBarrier, atomicAdd, atomicStore } from 'three/tsl';

import { GUI } from 'three/addons/libs/lil-gui.module.min.js';

Expand Down Expand Up @@ -149,6 +149,9 @@
const highestBlockHeightBuffer = new THREE.StorageInstancedBufferAttribute( new Uint32Array( 1 ).fill( 2 ), 1 );
const highestBlockHeightStorage = storageObject( highestBlockHeightBuffer, 'uint', highestBlockHeightBuffer.count ).label( 'HighestBlockHeight' );

const counterBuffer = new THREE.StorageBufferAttribute( 1, 1 );
const counterStorage = storageObject( counterBuffer, 'uint', counterBuffer.count ).toAtomic().label( 'Counter' );

const array = new Uint32Array( Array.from( { length: size }, ( _, i ) => {

return i;
Expand Down Expand Up @@ -219,6 +222,7 @@

If( localStorage.element( idxAfter ).lessThan( localStorage.element( idxBefore ) ), () => {

atomicAdd( counterStorage.element( 0 ), 1 );
const temp = localStorage.element( idxBefore ).toVar();
localStorage.element( idxBefore ).assign( localStorage.element( idxAfter ) );
localStorage.element( idxAfter ).assign( temp );
Expand All @@ -233,6 +237,7 @@
If( currentElementsStorage.element( idxAfter ).lessThan( currentElementsStorage.element( idxBefore ) ), () => {

// Apply the swapped values to temporary storage.
atomicAdd( counterStorage.element( 0 ), 1 );
tempStorage.element( idxBefore ).assign( currentElementsStorage.element( idxAfter ) );
tempStorage.element( idxAfter ).assign( currentElementsStorage.element( idxBefore ) );

Expand Down Expand Up @@ -396,6 +401,7 @@
nextAlgoStorage.element( 0 ).assign( forceGlobalSwap ? StepType.FLIP_GLOBAL : StepType.FLIP_LOCAL );
nextBlockHeightStorage.element( 0 ).assign( 2 );
highestBlockHeightStorage.element( 0 ).assign( 2 );
atomicStore( counterStorage.element( 0 ), 0 );

} );

Expand Down Expand Up @@ -511,12 +517,14 @@

const algo = new Uint32Array( await renderer.getArrayBufferAsync( nextAlgoBuffer ) );
algo > StepType.DISPERSE_LOCAL ? ( nextStepGlobal = true ) : ( nextStepGlobal = false );
const totalSwaps = new Uint32Array( await renderer.getArrayBufferAsync( counterBuffer ) );

renderer.render( scene, camera );

timestamps[ forceGlobalSwap ? 'global_swap' : 'local_swap' ].innerHTML = `
Compute ${forceGlobalSwap ? 'Global' : 'Local'}: ${renderer.info.compute.frameCalls} pass in ${renderer.info.compute.timestamp.toFixed( 6 )}ms<br>
Total Swaps: ${totalSwaps}<br>
<div style="display: flex; flex-direction:row; justify-content: center; align-items: center;">
${forceGlobalSwap ? 'Global Swaps' : 'Local Swaps'} Compare Region&nbsp;
<div style="background-color: ${ forceGlobalSwap ? globalColors[ 0 ] : localColors[ 0 ]}; width:12.5px; height: 1em; border-radius: 20%;"></div>
Expand Down
1 change: 1 addition & 0 deletions src/nodes/TSL.js
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ export * from './geometry/RangeNode.js';
export * from './gpgpu/ComputeNode.js';
export * from './gpgpu/BarrierNode.js';
export * from './gpgpu/WorkgroupInfoNode.js';
export * from './gpgpu/AtomicFunctionNode.js';

// lighting
export * from './accessors/Lights.js';
Expand Down
15 changes: 15 additions & 0 deletions src/nodes/accessors/StorageBufferNode.js
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class StorageBufferNode extends BufferNode {
this.isStorageBufferNode = true;

this.access = GPUBufferBindingType.Storage;
this.isAtomic = false;

this.bufferObject = false;
this.bufferCount = bufferCount;
Expand Down Expand Up @@ -97,6 +98,20 @@ class StorageBufferNode extends BufferNode {

}

setAtomic( value ) {

this.isAtomic = value;

return this;

}

toAtomic() {

return this.setAtomic( true );

}

generate( builder ) {

if ( builder.isAvailable( 'storageBuffer' ) ) {
Expand Down
99 changes: 99 additions & 0 deletions src/nodes/gpgpu/AtomicFunctionNode.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import TempNode from '../core/TempNode.js';
import { nodeProxy } from '../tsl/TSLCore.js';

class AtomicFunctionNode extends TempNode {

static get type() {

return 'AtomicFunctionNode';

}

constructor( method, pointerNode, valueNode, storeNode = null ) {

super( 'uint' );

this.method = method;

this.pointerNode = pointerNode;
this.valueNode = valueNode;
this.storeNode = storeNode;

}

getInputType( builder ) {

return this.pointerNode.getNodeType( builder );

}

getNodeType( builder ) {

return this.getInputType( builder );

}

generate( builder ) {

const method = this.method;

const type = this.getNodeType( builder );
const inputType = this.getInputType( builder );

const a = this.pointerNode;
const b = this.valueNode;

const params = [];

params.push( `&${ a.build( builder, inputType ) }` );
params.push( b.build( builder, inputType ) );

const methodSnippet = `${ builder.getMethod( method, type ) }( ${params.join( ', ' )} )`;

if ( this.storeNode !== null ) {

const varSnippet = this.storeNode.build( builder, inputType );

builder.addLineFlowCode( `${varSnippet} = ${methodSnippet}` );

} else {

builder.addLineFlowCode( methodSnippet );

}

}

}

AtomicFunctionNode.ATOMIC_LOAD = 'atomicLoad';
AtomicFunctionNode.ATOMIC_STORE = 'atomicStore';
AtomicFunctionNode.ATOMIC_ADD = 'atomicAdd';
AtomicFunctionNode.ATOMIC_SUB = 'atomicSub';
AtomicFunctionNode.ATOMIC_MAX = 'atomicMax';
AtomicFunctionNode.ATOMIC_MIN = 'atomicMin';
AtomicFunctionNode.ATOMIC_AND = 'atomicAnd';
AtomicFunctionNode.ATOMIC_OR = 'atomicOr';
AtomicFunctionNode.ATOMIC_XOR = 'atomicXor';

export default AtomicFunctionNode;

const atomicNode = nodeProxy( AtomicFunctionNode );

export const atomicFunc = ( method, pointerNode, valueNode, storeNode ) => {

const node = atomicNode( method, pointerNode, valueNode, storeNode );
node.append();

return node;

};

export const atomicStore = ( pointerNode, valueNode, storeNode = null ) => atomicFunc( AtomicFunctionNode.ATOMIC_STORE, pointerNode, valueNode, storeNode );
export const atomicAdd = ( pointerNode, valueNode, storeNode = null ) => atomicFunc( AtomicFunctionNode.ATOMIC_ADD, pointerNode, valueNode, storeNode );
export const atomicSub = ( pointerNode, valueNode, storeNode = null ) => atomicFunc( AtomicFunctionNode.ATOMIC_SUB, pointerNode, valueNode, storeNode );
export const atomicMax = ( pointerNode, valueNode, storeNode = null ) => atomicFunc( AtomicFunctionNode.ATOMIC_MAX, pointerNode, valueNode, storeNode );
export const atomicMin = ( pointerNode, valueNode, storeNode = null ) => atomicFunc( AtomicFunctionNode.ATOMIC_MIN, pointerNode, valueNode, storeNode );
export const atomicAnd = ( pointerNode, valueNode, storeNode = null ) => atomicFunc( AtomicFunctionNode.ATOMIC_AND, pointerNode, valueNode, storeNode );
export const atomicOr = ( pointerNode, valueNode, storeNode = null ) => atomicFunc( AtomicFunctionNode.ATOMIC_OR, pointerNode, valueNode, storeNode );
export const atomicXor = ( pointerNode, valueNode, storeNode = null ) => atomicFunc( AtomicFunctionNode.ATOMIC_XOR, pointerNode, valueNode, storeNode );
3 changes: 2 additions & 1 deletion src/renderers/webgpu/nodes/WGSLNodeBuilder.js
Original file line number Diff line number Diff line change
Expand Up @@ -1052,7 +1052,8 @@ ${ flowData.code }
const bufferCount = bufferNode.bufferCount;

const bufferCountSnippet = bufferCount > 0 ? ', ' + bufferCount : '';
const bufferSnippet = `\t${ uniform.name } : array< ${ bufferType }${ bufferCountSnippet } >\n`;
const bufferTypeSnippet = bufferNode.isAtomic ? `atomic<${bufferType}>` : `${bufferType}`;
const bufferSnippet = `\t${ uniform.name } : array< ${ bufferTypeSnippet }${ bufferCountSnippet } >\n`;
const bufferAccessMode = bufferNode.isStorageBufferNode ? `storage, ${ this.getStorageAccess( bufferNode ) }` : 'uniform';

bufferSnippets.push( this._getWGSLStructBinding( 'NodeBuffer_' + bufferNode.id, bufferSnippet, bufferAccessMode, uniformIndexes.binding ++, uniformIndexes.group ) );
Expand Down

0 comments on commit 38fd5e9

Please sign in to comment.