Skip to content

Commit

Permalink
prepare message → prepare state
Browse files Browse the repository at this point in the history
  • Loading branch information
jbr committed Sep 22, 2022
1 parent 0c57316 commit c26fc05
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 51 deletions.
56 changes: 28 additions & 28 deletions packages/prio3/src/prio3.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import { Field } from "@divviup/field";
import { Flp } from "./flp";
import { Buffer } from "buffer";

type PrepareMessage = {
type PrepareState = {
outputShare: OutputShare;
jointRandSeed: Buffer | null;
outboundMessage: Buffer;
Expand All @@ -17,7 +17,7 @@ type OutputShare = bigint[];
type Prio3Vdaf<Measurement> = Vdaf<
Measurement,
AggregationParameter,
PrepareMessage,
PrepareState,
AggregatorShare,
AggregationResult,
OutputShare
Expand Down Expand Up @@ -104,13 +104,13 @@ export class Prio3<Measurement> implements Prio3Vdaf<Measurement> {
return this.encodeShares([leaderShare, ...helperShares]);
}

async initialPrepareMessage(
async initialPrepareState(
verifyKey: Buffer,
aggregatorId: number,
_aggParam: AggregationParameter,
nonce: Buffer,
encodedInputShare: Buffer
): Promise<PrepareMessage> {
): Promise<PrepareState> {
const { prg, flp, field } = this;

const share = await this.decodeShare(aggregatorId, encodedInputShare);
Expand Down Expand Up @@ -155,7 +155,7 @@ export class Prio3<Measurement> implements Prio3Vdaf<Measurement> {
this.shares
);

const outboundMessage = this.encodePrepareMessage(
const outboundMessage = this.encodePrepareState(
verifierShare,
shareJointRandSeed
);
Expand All @@ -164,52 +164,52 @@ export class Prio3<Measurement> implements Prio3Vdaf<Measurement> {
}

prepareNext(
prepareMessage: PrepareMessage,
prepareState: PrepareState,
inbound: Buffer | null
):
| { prepareMessage: PrepareMessage; prepareShare: Buffer }
| { prepareState: PrepareState; prepareShare: Buffer }
| { outputShare: OutputShare } {
if (!inbound) {
return { prepareMessage, prepareShare: prepareMessage.outboundMessage };
return {
prepareState,
prepareShare: prepareState.outboundMessage,
};
}

const { verifier, jointRand } = this.decodePrepareMessage(inbound);
const { verifier, jointRand } = this.decodePrepareState(inbound);

const jointRandEquality =
(jointRand &&
prepareMessage.jointRandSeed &&
0 === Buffer.compare(jointRand, prepareMessage.jointRandSeed)) ||
jointRand === prepareMessage.jointRandSeed; // both null
prepareState.jointRandSeed &&
0 === Buffer.compare(jointRand, prepareState.jointRandSeed)) ||
jointRand === prepareState.jointRandSeed; // both null

if (!jointRandEquality || !this.flp.decide(verifier)) {
throw new Error("Verify error");
}

return { outputShare: prepareMessage.outputShare };
return { outputShare: prepareState.outputShare };
}

prepSharesToPrepareMessage(
prepSharesToPrepareState(
_aggParam: AggregationParameter,
encodedPrepShares: Buffer[]
): Buffer {
const { flp, prg, field } = this;
const jointRandCheck = Buffer.alloc(prg.seedSize);

const verifier = encodedPrepShares.reduce(
(verifier, encodedPrepMessage) => {
const { verifier: shareVerifier, jointRand: shareJointRand } =
this.decodePrepareMessage(encodedPrepMessage);
const verifier = encodedPrepShares.reduce((verifier, encodedPrepState) => {
const { verifier: shareVerifier, jointRand: shareJointRand } =
this.decodePrepareState(encodedPrepState);

if (flp.jointRandLen > 0 && shareJointRand) {
xorInPlace(jointRandCheck, shareJointRand);
}
if (flp.jointRandLen > 0 && shareJointRand) {
xorInPlace(jointRandCheck, shareJointRand);
}

return field.vecAdd(verifier, shareVerifier);
},
fill(flp.verifierLen, 0n)
);
return field.vecAdd(verifier, shareVerifier);
}, fill(flp.verifierLen, 0n));

return this.encodePrepareMessage(verifier, jointRandCheck);
return this.encodePrepareState(verifier, jointRandCheck);
}

outputSharesToAggregatorShare(
Expand Down Expand Up @@ -242,7 +242,7 @@ export class Prio3<Measurement> implements Prio3Vdaf<Measurement> {
: await this.decodeHelperShare(aggregatorId, encoded);
}

private decodePrepareMessage(input: Buffer): {
private decodePrepareState(input: Buffer): {
verifier: bigint[];
jointRand: Buffer | null;
} {
Expand All @@ -267,7 +267,7 @@ export class Prio3<Measurement> implements Prio3Vdaf<Measurement> {
return { verifier, jointRand };
}

private encodePrepareMessage(
private encodePrepareState(
verifier: bigint[],
jointRandShare: Buffer | null
): Buffer {
Expand Down
18 changes: 9 additions & 9 deletions packages/vdaf/src/index.spec.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { Field128 } from "@divviup/field";
import { Vdaf, testVdaf } from ".";

type PrepareMessage = {
type PrepareState = {
inputRange: { min: number; max: number };
encodedInputShare: Buffer;
};
Expand All @@ -13,7 +13,7 @@ type Measurement = number;
type TestVdaf = Vdaf<
Measurement,
AggregationParameter,
PrepareMessage,
PrepareState,
AggregatorShare,
AggregationResult,
OutputShare
Expand Down Expand Up @@ -41,27 +41,27 @@ export class VdafTest implements TestVdaf {
]);
}

initialPrepareMessage(
initialPrepareState(
_verifyKey: Buffer,
_aggregatorId: number,
_aggParam: AggregationParameter,
_nonce: Buffer,
inputShare: Buffer
): Promise<PrepareMessage> {
): Promise<PrepareState> {
return Promise.resolve({
inputRange: this.inputRange,
encodedInputShare: inputShare,
});
}

prepareNext(
prepareMessage: PrepareMessage,
prepareState: PrepareState,
inbound: Buffer | null
):
| { prepareMessage: PrepareMessage; prepareShare: Buffer }
| { prepareState: PrepareState; prepareShare: Buffer }
| { outputShare: bigint[] } {
if (!inbound) {
return { prepareMessage, prepareShare: prepareMessage.encodedInputShare };
return { prepareState, prepareShare: prepareState.encodedInputShare };
}

const measurement = Number(this.field.decode(inbound)[0]);
Expand All @@ -70,10 +70,10 @@ export class VdafTest implements TestVdaf {
throw new Error(`measurement ${measurement} was not in [${min}, ${max})`);
}

return { outputShare: this.field.decode(prepareMessage.encodedInputShare) };
return { outputShare: this.field.decode(prepareState.encodedInputShare) };
}

prepSharesToPrepareMessage(
prepSharesToPrepareState(
_aggParam: AggregationParameter,
prepShares: Buffer[]
): Buffer {
Expand Down
25 changes: 11 additions & 14 deletions packages/vdaf/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ export const VDAF_VERSION = "vdaf-00";
export interface Vdaf<
Measurement,
AggregationParameter,
PrepareMessage,
PrepareState,
AggregatorShare,
AggregationResult,
OutputShare
Expand All @@ -19,22 +19,22 @@ export interface Vdaf<

measurementToInputShares(measurement: Measurement): Promise<Buffer[]>;

initialPrepareMessage(
initialPrepareState(
verifyKey: Buffer,
aggId: number,
aggParam: AggregationParameter,
nonce: Buffer,
inputShare: Buffer
): Promise<PrepareMessage>;
): Promise<PrepareState>;

prepareNext(
prepareMessage: PrepareMessage,
prepareState: PrepareState,
inbound: Buffer | null
):
| { prepareMessage: PrepareMessage; prepareShare: Buffer }
| { prepareState: PrepareState; prepareShare: Buffer }
| { outputShare: OutputShare };

prepSharesToPrepareMessage(
prepSharesToPrepareState(
aggParam: AggregationParameter,
prepShares: Buffer[]
): Buffer;
Expand Down Expand Up @@ -103,7 +103,7 @@ export async function runVdaf<M, AP, P, AS, AR, OS>(

const prepStates: P[] = await Promise.all(
arr(vdaf.shares, (aggregatorId) =>
vdaf.initialPrepareMessage(
vdaf.initialPrepareState(
Buffer.from(verifyKey),
aggregatorId,
aggregationParameter,
Expand All @@ -118,10 +118,10 @@ export async function runVdaf<M, AP, P, AS, AR, OS>(
const outbound: Buffer[] = prepStates.map(
(state, aggregatorId, states) => {
const out = vdaf.prepareNext(state, inbound);
if (!("prepareMessage" in out) || !("prepareShare" in out)) {
throw new Error("expected prepareMessage and prepareShare");
if (!("prepareState" in out) || !("prepareShare" in out)) {
throw new Error("expected prepareState and prepareShare");
}
states[aggregatorId] = out.prepareMessage;
states[aggregatorId] = out.prepareState;
return out.prepareShare;
}
);
Expand All @@ -130,10 +130,7 @@ export async function runVdaf<M, AP, P, AS, AR, OS>(
prepTestVector.prep_shares[round].push(prepShare.toString("hex"));
}

inbound = vdaf.prepSharesToPrepareMessage(
aggregationParameter,
outbound
);
inbound = vdaf.prepSharesToPrepareState(aggregationParameter, outbound);
}

const outbound = prepStates.map((state) => {
Expand Down

0 comments on commit c26fc05

Please sign in to comment.