Skip to content

Commit

Permalink
fix: avoid memory allocation in hashtree
Browse files Browse the repository at this point in the history
  • Loading branch information
twoeths committed Jun 29, 2024
1 parent 9699e21 commit d5dd936
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 44 deletions.
2 changes: 1 addition & 1 deletion packages/persistent-merkle-tree/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
"homepage": "https://github.com/ChainSafe/persistent-merkle-tree#readme",
"dependencies": {
"@chainsafe/as-sha256": "0.4.2",
"@chainsafe/hashtree": "1.0.0",
"@chainsafe/hashtree": "1.0.1",
"@noble/hashes": "^1.3.0"
},
"peerDependencies": {
Expand Down
88 changes: 45 additions & 43 deletions packages/persistent-merkle-tree/src/hasher/hashtree.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import {hash, hashInto} from "@chainsafe/hashtree";
import {hashInto} from "@chainsafe/hashtree";
import {Hasher, HashObject} from "./types";
import {HashComputation, Node} from "../node";
import { byteArrayToHashObject, hashObjectToByteArray } from "@chainsafe/as-sha256";
import { byteArrayToHashObject } from "@chainsafe/as-sha256";

/**
* Best SIMD implementation is in 512 bits = 64 bytes
Expand All @@ -14,29 +14,27 @@ const MAX_INPUT_SIZE = PARALLEL_FACTOR * 64;
const uint8Input = new Uint8Array(MAX_INPUT_SIZE);
const uint32Input = new Uint32Array(uint8Input.buffer);
const uint8Output = new Uint8Array(PARALLEL_FACTOR * 32);
const uint32Output = new Uint32Array(uint8Output.buffer);

// having this will cause more memory to extract uint32
// const uint32Output = new Uint32Array(uint8Output.buffer);
// convenient reusable Uint8Array for hash64
const hash64Input = uint8Input.subarray(0, 64);
const hash64Output = uint8Output.subarray(0, 32);

export const hasher: Hasher = {
name: "hashtree",
digest64(obj1: Uint8Array, obj2: Uint8Array): Uint8Array {
if (obj1.length !== 32 || obj2.length !== 32) {
throw new Error("Invalid input length");
}
uint8Input.set(obj1, 0);
uint8Input.set(obj2, 32);
const hashInput = uint8Input.subarray(0, 64);
const hashOutput = uint8Output.subarray(0, 32);
hashInto(hashInput, hashOutput);
return hashOutput.slice();
hash64Input.set(obj1, 0);
hash64Input.set(obj2, 32);
hashInto(hash64Input, hash64Output);
return hash64Output.slice();
},
digest64HashObjects(obj1: HashObject, obj2: HashObject): HashObject {
hashObjectToUint32Array(obj1, uint32Input, 0);
hashObjectToUint32Array(obj2, uint32Input, 8);
const hashInput = uint8Input.subarray(0, 64);
const hashOutput = uint8Output.subarray(0, 32);
hashInto(hashInput, hashOutput);
return uint32ArrayToHashObject(uint32Output, 0);
hashObjectsToUint32Array(obj1, obj2, uint32Input);
hashInto(hash64Input, hash64Output);
return byteArrayToHashObject(hash64Output);
},
// given nLevel = 3
// digest multiple of 8 chunks = 256 bytes
Expand All @@ -57,21 +55,20 @@ export const hasher: Hasher = {
}

let outputLength = Math.floor(inputLength / 2);
let hashOutput: Uint8Array | null = null;

uint8Input.set(data, 0);
// hash into same buffer
let bufferIn = uint8Input.subarray(0, inputLength);
for (let i = nLevel; i > 0; i--) {
uint8Input.set(hashOutput ?? data, 0);
const hashInput = uint8Input.subarray(0, inputLength);
hashOutput = uint8Output.subarray(0, outputLength);
hashInto(hashInput, hashOutput);
const bufferOut = bufferIn.subarray(0, outputLength);
hashInto(bufferIn, bufferOut);
bufferIn = bufferOut;
inputLength = outputLength;
outputLength = Math.floor(inputLength / 2);
}

if (hashOutput === null) {
throw new Error("hashOutput is null");
}
// the result is unsafe as it will be modified later, consumer should save the result if needed
return hashOutput;
return bufferIn;
},
// eslint-disable-next-line @typescript-eslint/no-unused-vars
batchHashObjects(inputs: HashObject[]): HashObject[] {
Expand All @@ -90,7 +87,7 @@ export const hasher: Hasher = {
if (indexInBatch === batch - 1) {
hashInto(uint8Input, uint8Output);
for (let j = 0; j < batch / 2; j++) {
outHashObjects.push(uint32ArrayToHashObject(uint32Output, j * 8));
outHashObjects.push(byteArrayToHashObject(uint8Output.subarray(j * 32, (j + 1) * 32)));
}
}
}
Expand All @@ -102,7 +99,7 @@ export const hasher: Hasher = {
const remainingOutput = uint8Output.subarray(0, remaining * 16);
hashInto(remainingInput, remainingOutput);
for (let i = 0; i < remaining / 2; i++) {
outHashObjects.push(uint32ArrayToHashObject(uint32Output, i * 8));
outHashObjects.push(byteArrayToHashObject(remainingOutput.subarray(i * 32, (i + 1) * 32)));
}
}

Expand Down Expand Up @@ -131,8 +128,7 @@ export const hasher: Hasher = {
if (indexInBatch === PARALLEL_FACTOR - 1) {
hashInto(uint8Input, uint8Output);
for (const [j, destNode] of destNodes.entries()) {
const outputOffset = j * 8;
destNode.applyHash(uint32ArrayToHashObject(uint32Output, outputOffset));
destNode.applyHash(byteArrayToHashObject(uint8Output.subarray(j * 32, (j + 1) * 32)));
}
destNodes = [];
}
Expand All @@ -146,8 +142,7 @@ export const hasher: Hasher = {
hashInto(remainingInput, remainingOutput);
// destNodes was prepared above
for (const [i, destNode] of destNodes.entries()) {
const offset = i * 8;
destNode.applyHash(uint32ArrayToHashObject(uint32Output, offset));
destNode.applyHash(byteArrayToHashObject(remainingOutput.subarray(i * 32, (i + 1) * 32)));
}
}
}
Expand All @@ -165,15 +160,22 @@ function hashObjectToUint32Array(obj: HashObject, arr: Uint32Array, offset: numb
arr[offset + 7] = obj.h7;
}

function uint32ArrayToHashObject(arr: Uint32Array, offset: number): HashObject {
return {
h0: arr[offset],
h1: arr[offset + 1],
h2: arr[offset + 2],
h3: arr[offset + 3],
h4: arr[offset + 4],
h5: arr[offset + 5],
h6: arr[offset + 6],
h7: arr[offset + 7],
};
}
// note that uint32ArrayToHashObject will cause more memory
function hashObjectsToUint32Array(obj1: HashObject, obj2: HashObject, arr: Uint32Array): void {
arr[0] = obj1.h0;
arr[1] = obj1.h1;
arr[2] = obj1.h2;
arr[3] = obj1.h3;
arr[4] = obj1.h4;
arr[5] = obj1.h5;
arr[6] = obj1.h6;
arr[7] = obj1.h7;
arr[8] = obj2.h0;
arr[9] = obj2.h1;
arr[10] = obj2.h2;
arr[11] = obj2.h3;
arr[12] = obj2.h4;
arr[13] = obj2.h5;
arr[14] = obj2.h6;
arr[15] = obj2.h7;
}

0 comments on commit d5dd936

Please sign in to comment.