Skip to content

Commit

Permalink
feat: implement and use hashtree.digestNLevelUnsafe()
Browse files Browse the repository at this point in the history
  • Loading branch information
twoeths committed Jun 21, 2024
1 parent 347f766 commit 5a917c6
Show file tree
Hide file tree
Showing 9 changed files with 245 additions and 138 deletions.
3 changes: 3 additions & 0 deletions packages/persistent-merkle-tree/src/hasher/as-sha256.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ export const hasher: Hasher = {
name: "as-sha256",
digest64: digest2Bytes32,
digest64HashObjects,
digestNLevelUnsafe(data: Uint8Array, nLevel: number): Uint8Array {
throw new Error("Not implemented");
},
batchHashObjects: (inputs: HashObject[]) => {
// as-sha256 uses SIMD for batch hash
if (inputs.length === 0) {
Expand Down
39 changes: 37 additions & 2 deletions packages/persistent-merkle-tree/src/hasher/hashtree.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import {hashInto} from "@chainsafe/hashtree";
import {hash, hashInto} from "@chainsafe/hashtree";
import {Hasher, HashObject} from "./types";
import {HashComputation, Node} from "../node";
import { byteArrayToHashObject, hashObjectToByteArray } from "@chainsafe/as-sha256";
Expand All @@ -10,7 +10,8 @@ import { byteArrayToHashObject, hashObjectToByteArray } from "@chainsafe/as-sha2
* Each input is 64 bytes
*/
const PARALLEL_FACTOR = 16;
const uint8Input = new Uint8Array(PARALLEL_FACTOR * 64);
const MAX_INPUT_SIZE = PARALLEL_FACTOR * 64;
const uint8Input = new Uint8Array(MAX_INPUT_SIZE);
const uint32Input = new Uint32Array(uint8Input.buffer);
const uint8Output = new Uint8Array(PARALLEL_FACTOR * 32);
const uint32Output = new Uint32Array(uint8Output.buffer);
Expand All @@ -37,6 +38,40 @@ export const hasher: Hasher = {
hashInto(hashInput, hashOutput);
return uint32ArrayToHashObject(uint32Output, 0);
},
// given nLevel = 3
// digest multiple of 8 chunks = 256 bytes
// the result is multiple of 1 chunk = 32 bytes
// this is the same to hashTreeRoot() of multiple validators
digestNLevelUnsafe(data: Uint8Array, nLevel: number): Uint8Array {
let inputLength = data.length;
const bytesInBatch = Math.pow(2, nLevel) * 32;
if (nLevel < 1) {
throw new Error(`Invalid nLevel, expect to be greater than 0, got ${nLevel}`);
}
if (inputLength % bytesInBatch !== 0) {
throw new Error(`Invalid input length, expect to be multiple of ${bytesInBatch} for nLevel ${nLevel}, got ${inputLength}`);
}
if (inputLength > MAX_INPUT_SIZE) {
throw new Error(`Invalid input length, expect to be less than ${MAX_INPUT_SIZE}, got ${inputLength}`);
}

let outputLength = Math.floor(inputLength / 2);
let hashOutput: Uint8Array | null = null;
for (let i = nLevel; i > 0; i--) {
uint8Input.set(hashOutput ?? data, 0);
const hashInput = uint8Input.subarray(0, inputLength);
hashOutput = uint8Output.subarray(0, outputLength);
hashInto(hashInput, hashOutput);
inputLength = outputLength;
outputLength = Math.floor(inputLength / 2);
}

if (hashOutput === null) {
throw new Error("hashOutput is null");
}
// the result is unsafe as it will be modified later, consumer should save the result if needed
return hashOutput;
},
// eslint-disable-next-line @typescript-eslint/no-unused-vars
batchHashObjects(inputs: HashObject[]): HashObject[] {
if (inputs.length === 0) {
Expand Down
8 changes: 8 additions & 0 deletions packages/persistent-merkle-tree/src/hasher/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,11 @@ export let hasher: Hasher = hashtreeHasher;
export function setHasher(newHasher: Hasher): void {
hasher = newHasher;
}

export function digest64(a: Uint8Array, b: Uint8Array): Uint8Array {
return hasher.digest64(a, b);
}

export function digestNLevelUnsafe(data: Uint8Array, nLevel: number): Uint8Array {
return hasher.digestNLevelUnsafe(data, nLevel);
}
3 changes: 3 additions & 0 deletions packages/persistent-merkle-tree/src/hasher/noble.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ export const hasher: Hasher = {
name: "noble",
digest64,
digest64HashObjects: (a, b) => uint8ArrayToHashObject(digest64(hashObjectToUint8Array(a), hashObjectToUint8Array(b))),
digestNLevelUnsafe(data: Uint8Array, nLevel: number): Uint8Array {
throw new Error("Not implemented");
},
batchHashObjects: (inputs: HashObject[]) => {
// noble does not support batch hash
if (inputs.length === 0) {
Expand Down
6 changes: 6 additions & 0 deletions packages/persistent-merkle-tree/src/hasher/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@ export type Hasher = {
* Hash two 32-byte HashObjects
*/
digest64HashObjects(a: HashObject, b: HashObject): HashObject;
/**
* Hash multiple chunks (1 chunk = 32 bytes) at multiple levels
* With nLevel = 3, hash multiple of 256 bytes, return multiple of 32 bytes.
* The result is unsafe as it will be overwritten by the next call.
*/
digestNLevelUnsafe(data: Uint8Array, nLevel: number): Uint8Array
/**
* Batch hash 2 * n HashObjects, return n HashObjects output
*/
Expand Down
1 change: 1 addition & 0 deletions packages/persistent-merkle-tree/src/node.ts
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,7 @@ export function getHashComputations(node: Node, offset: number, hashCompsByLevel
// else stop the recursion, LeafNode should have h0
}

// TODO - batch: move to hasher/index.ts
export function executeHashComputations(hashComputations: Array<HashComputation[]>): void {
hasher.executeHashComputations(hashComputations);
}
Expand Down
177 changes: 63 additions & 114 deletions packages/ssz/test/lodestarTypes/phase0/viewDU/listValidator.ts
Original file line number Diff line number Diff line change
@@ -1,128 +1,50 @@
import {BranchNode, HashComputation, HashComputationGroup, LeafNode, Node, arrayAtIndex, executeHashComputations, getHashComputations, setNodesAtDepth} from "@chainsafe/persistent-merkle-tree";
import {HashComputationGroup, Node, digestNLevelUnsafe, getHashComputations, setNodesAtDepth} from "@chainsafe/persistent-merkle-tree";
import { ListCompositeType } from "../../../../src/type/listComposite";
import { ArrayCompositeTreeViewDUCache } from "../../../../src/viewDU/arrayComposite";
import { ListCompositeTreeViewDU } from "../../../../src/viewDU/listComposite";
import { ValidatorNodeStructType } from "../validator";
import { ValidatorTreeViewDU } from "./validator";
import { ByteViews } from "../../../../src";
import { byteArrayToHashObject } from "@chainsafe/as-sha256";

/**
* Best SIMD implementation is in 512 bits = 64 bytes
* If not, hashtree will make a loop inside
* Given sha256 operates on a block of 4 bytes, we can hash 16 inputs at once
* Each input is 64 bytes
* TODO - batch: is 8 better?
* hashtree has a MAX_SIZE of 1024 bytes = 32 chunks
* Given a level3 of validators have 8 chunks, we can hash 4 validators at a time
*/
const PARALLEL_FACTOR = 16;
const PARALLEL_FACTOR = 4;

export class ListValidatorTreeViewDU extends ListCompositeTreeViewDU<ValidatorNodeStructType> {
private batchHashComputations: Array<HashComputation[]>;
private singleHashComputations: Array<HashComputation[]>;
private batchHashRootNodes: Array<Node>;
private singleHashRootNode: Node;
private batchLevel3Nodes: Array<Node[]>;
private singleLevel3Nodes: Node[];
private batchLevel3Bytes: Uint8Array;
private batchLevel4Bytes: Uint8Array;
// 32 * 8 = 256 bytes each
private level3ByteViewsArr: ByteViews[];
// 64 bytes each
private level4BytesArr: Uint8Array[];
private singleLevel3ByteView: ByteViews;
private singleLevel4Bytes: Uint8Array;

constructor(
readonly type: ListCompositeType<ValidatorNodeStructType>,
protected _rootNode: Node,
cache?: ArrayCompositeTreeViewDUCache
) {
super(type, _rootNode, cache);

this.batchHashComputations = [];
this.singleHashComputations = [];
this.batchHashRootNodes = [];
this.batchLevel3Nodes = [];
this.singleLevel3Nodes = [];
// each level 3 of validator has 8 chunks, each chunk has 32 bytes
this.batchLevel3Bytes = new Uint8Array(PARALLEL_FACTOR * 8 * 32);
this.level3ByteViewsArr = [];
for (let i = 0; i < PARALLEL_FACTOR; i++) {
// level 3, validator.pubkey
const pubkey0 = LeafNode.fromZero();
const pubkey1 = LeafNode.fromZero();
const pubkey = new BranchNode(pubkey0, pubkey1);
let hc: HashComputation = {src0: pubkey0, src1: pubkey1, dest: pubkey};
if (i === 0) {
arrayAtIndex(this.singleHashComputations, 3).push(hc);
this.singleLevel3Nodes.push(pubkey);
}
arrayAtIndex(this.batchHashComputations, 3).push(hc);
arrayAtIndex(this.batchLevel3Nodes, i).push(pubkey);

// level 2
const withdrawalCredential = LeafNode.fromZero();
const node20 = new BranchNode(pubkey, withdrawalCredential);
hc = {src0: pubkey, src1: withdrawalCredential, dest: node20};
if (i === 0) {
arrayAtIndex(this.singleHashComputations, 2).push(hc);
this.singleLevel3Nodes.push(withdrawalCredential);
}
arrayAtIndex(this.batchHashComputations, 2).push(hc);
arrayAtIndex(this.batchLevel3Nodes, i).push(withdrawalCredential);
// effectiveBalance, slashed
const effectiveBalance = LeafNode.fromZero();
const slashed = LeafNode.fromZero();
const node21 = new BranchNode(effectiveBalance, slashed);
hc = {src0: effectiveBalance, src1: slashed, dest: node21};
if (i === 0) {
arrayAtIndex(this.singleHashComputations, 2).push(hc);
this.singleLevel3Nodes.push(effectiveBalance);
this.singleLevel3Nodes.push(slashed);
}
arrayAtIndex(this.batchHashComputations, 2).push(hc);
arrayAtIndex(this.batchLevel3Nodes, i).push(effectiveBalance);
arrayAtIndex(this.batchLevel3Nodes, i).push(slashed);
// activationEligibilityEpoch, activationEpoch
const activationEligibilityEpoch = LeafNode.fromZero();
const activationEpoch = LeafNode.fromZero();
const node22 = new BranchNode(activationEligibilityEpoch, activationEpoch);
hc = {src0: activationEligibilityEpoch, src1: activationEpoch, dest: node22};
if (i === 0) {
arrayAtIndex(this.singleHashComputations, 2).push(hc);
this.singleLevel3Nodes.push(activationEligibilityEpoch);
this.singleLevel3Nodes.push(activationEpoch);
}
arrayAtIndex(this.batchHashComputations, 2).push(hc);
arrayAtIndex(this.batchLevel3Nodes, i).push(activationEligibilityEpoch);
arrayAtIndex(this.batchLevel3Nodes, i).push(activationEpoch);
// exitEpoch, withdrawableEpoch
const exitEpoch = LeafNode.fromZero();
const withdrawableEpoch = LeafNode.fromZero();
const node23 = new BranchNode(exitEpoch, withdrawableEpoch);
hc = {src0: exitEpoch, src1: withdrawableEpoch, dest: node23};
if (i === 0) {
arrayAtIndex(this.singleHashComputations, 2).push(hc);
this.singleLevel3Nodes.push(exitEpoch);
this.singleLevel3Nodes.push(withdrawableEpoch);
}
arrayAtIndex(this.batchHashComputations, 2).push(hc);
arrayAtIndex(this.batchLevel3Nodes, i).push(exitEpoch);
arrayAtIndex(this.batchLevel3Nodes, i).push(withdrawableEpoch);

// level 1
const node10 = new BranchNode(node20, node21);
hc = {src0: node20, src1: node21, dest: node10};
if (i === 0) {
arrayAtIndex(this.singleHashComputations, 1).push(hc);
}
arrayAtIndex(this.batchHashComputations, 1).push(hc);
const node11 = new BranchNode(node22, node23);
hc = {src0: node22, src1: node23, dest: node11};
if (i === 0) {
arrayAtIndex(this.singleHashComputations, 1).push(hc);
}
arrayAtIndex(this.batchHashComputations, 1).push(hc);

// level 0
const node00 = new BranchNode(node10, node11);
hc = {src0: node10, src1: node11, dest: node00};
if (i === 0) {
arrayAtIndex(this.singleHashComputations, 0).push(hc);
// this.singleHashRootNode = node00;
}
arrayAtIndex(this.batchHashComputations, 0).push(hc);
this.batchHashRootNodes.push(node00);
const uint8Array = this.batchLevel3Bytes.subarray(i * 8 * 32, (i + 1) * 8 * 32);
const dataView = new DataView(uint8Array.buffer, uint8Array.byteOffset, uint8Array.byteLength);
this.level3ByteViewsArr.push({uint8Array, dataView});
}

this.singleHashRootNode = this.batchHashRootNodes[0];
this.singleLevel3ByteView = this.level3ByteViewsArr[0];
// each level 4 of validator has 2 chunks for pubkey, each chunk has 32 bytes
this.batchLevel4Bytes = new Uint8Array(PARALLEL_FACTOR * 2 * 32);
this.level4BytesArr = [];
for (let i = 0; i < PARALLEL_FACTOR; i++) {
this.level4BytesArr.push(this.batchLevel4Bytes.subarray(i * 2 * 32, (i + 1) * 2 * 32));
}
this.singleLevel4Bytes = this.level4BytesArr[0];
}

commit(hashComps: HashComputationGroup | null = null): void {
Expand All @@ -142,22 +64,49 @@ export class ListValidatorTreeViewDU extends ListCompositeTreeViewDU<ValidatorNo
// commit every 16 validators in batch
for (let i = 0; i < endBatch; i++) {
const indexInBatch = i % PARALLEL_FACTOR;
viewsChanged[i].valueToTree(this.batchLevel3Nodes[indexInBatch]);
viewsChanged[i].valueToMerkleBytes(this.level3ByteViewsArr[indexInBatch], this.level4BytesArr[indexInBatch]);

if (indexInBatch === PARALLEL_FACTOR - 1) {
executeHashComputations(this.batchHashComputations);
// hash level 4
const pubkeyRoots = digestNLevelUnsafe(this.batchLevel4Bytes, 1);
if (pubkeyRoots.length !== PARALLEL_FACTOR * 32) {
throw new Error(`Invalid pubkeyRoots length, expect ${PARALLEL_FACTOR * 32}, got ${pubkeyRoots.length}`);
}
for (let j = 0; j < PARALLEL_FACTOR; j++) {
this.level3ByteViewsArr[j].uint8Array.set(pubkeyRoots.subarray(j * 32, (j + 1) * 32), 0);
}
const validatorRoots = digestNLevelUnsafe(this.batchLevel3Bytes, 3);
if (validatorRoots.length !== PARALLEL_FACTOR * 32) {
throw new Error(`Invalid validatorRoots length, expect ${PARALLEL_FACTOR * 32}, got ${validatorRoots.length}`);
}
// commit all validators in this batch
for (let j = PARALLEL_FACTOR - 1; j >= 0; j--) {
viewsChanged[i - j].commitToHashObject(this.batchHashRootNodes[PARALLEL_FACTOR - 1 - j]);
nodesChanged.push({index: i - j, node: viewsChanged[i - j].node});
const viewIndex = i - j;
const indexInBatch = viewIndex % PARALLEL_FACTOR;
const hashObject = byteArrayToHashObject(validatorRoots.subarray(indexInBatch * 32, (indexInBatch + 1) * 32));
viewsChanged[viewIndex].commitToHashObject(hashObject);
nodesChanged.push({index: viewIndex, node: viewsChanged[viewIndex].node});
}
}
}

// commit the remaining validators one by one
// commit the remaining validators, we can do in batch too but don't want to create new Uint8Array views
// it's not much different to commit one by one
for (let i = endBatch; i < viewsChanged.length; i++) {
viewsChanged[i].valueToTree(this.singleLevel3Nodes);
executeHashComputations(this.singleHashComputations);
viewsChanged[i].commitToHashObject(this.singleHashRootNode);
viewsChanged[i].valueToMerkleBytes(this.singleLevel3ByteView, this.singleLevel4Bytes);
// level 4 hash
const pubkeyRoot = digestNLevelUnsafe(this.singleLevel4Bytes, 1);
if (pubkeyRoot.length !== 32) {
throw new Error(`Invalid pubkeyRoot length, expect 32, got ${pubkeyRoot.length}`);
}
this.singleLevel3ByteView.uint8Array.set(pubkeyRoot, 0);
// level 3 hash
const validatorRoot = digestNLevelUnsafe(this.singleLevel3ByteView.uint8Array, 3);
if (validatorRoot.length !== 32) {
throw new Error(`Invalid validatorRoot length, expect 32, got ${validatorRoot.length}`);
}
const hashObject = byteArrayToHashObject(validatorRoot);
viewsChanged[i].commitToHashObject(hashObject);
nodesChanged.push({index: i, node: viewsChanged[i].node});
}

Expand Down
Loading

0 comments on commit 5a917c6

Please sign in to comment.