diff --git a/src/noise.ts b/src/noise.ts index 820d98e0..62a5bd9b 100644 --- a/src/noise.ts +++ b/src/noise.ts @@ -32,7 +32,7 @@ export class Noise implements INoiseConnection { public protocol = '/noise' public crypto: ICryptoInterface - private readonly prologue = new Uint8Array(0) + private readonly prologue: Uint8Array private readonly staticKeys: KeyPair private readonly earlyData?: bytes private readonly useNoisePipes: boolean @@ -41,7 +41,7 @@ export class Noise implements INoiseConnection { * @param {bytes} staticNoiseKey - x25519 private key, reuse for faster handshakes * @param {bytes} earlyData */ - constructor (staticNoiseKey?: bytes, earlyData?: bytes, crypto: ICryptoInterface = stablelib) { + constructor (staticNoiseKey?: bytes, earlyData?: bytes, crypto: ICryptoInterface = stablelib, prologueBytes?: Uint8Array) { this.earlyData = earlyData ?? new Uint8Array(0) // disabled until properly specked this.useNoisePipes = false @@ -53,6 +53,7 @@ export class Noise implements INoiseConnection { } else { this.staticKeys = this.crypto.generateX25519KeyPair() } + this.prologue = prologueBytes ?? new Uint8Array(0) } /** diff --git a/test/noise.spec.ts b/test/noise.spec.ts index 1d7fde3a..93b82fb0 100644 --- a/test/noise.spec.ts +++ b/test/noise.spec.ts @@ -120,9 +120,9 @@ describe('Noise', () => { const wrappedInbound = pbStream(inbound.conn) const wrappedOutbound = pbStream(outbound.conn) - const largePlaintext = randomBytes(100000) + const largePlaintext = randomBytes(60000) wrappedOutbound.writeLP(Buffer.from(largePlaintext)) - const response = await wrappedInbound.read(100000) + const response = await wrappedInbound.read(60000) expect(response.length).equals(largePlaintext.length) } catch (e) { @@ -374,4 +374,26 @@ describe('Noise', () => { assert(false, err.message) } }) + + it('should accept a prologue', async () => { + try { + const noiseInit = new Noise(undefined, undefined, stablelib, Buffer.from('Some prologue')) + const noiseResp = new Noise(undefined, undefined, stablelib, Buffer.from('Some prologue')) + + const [inboundConnection, outboundConnection] = duplexPair() + const [outbound, inbound] = await Promise.all([ + noiseInit.secureOutbound(localPeer, outboundConnection, remotePeer), + noiseResp.secureInbound(remotePeer, inboundConnection, localPeer) + ]) + const wrappedInbound = pbStream(inbound.conn) + const wrappedOutbound = pbStream(outbound.conn) + + wrappedOutbound.writeLP(Buffer.from('test')) + const response = await wrappedInbound.readLP() + expect(uint8ArrayToString(response.slice())).equal('test') + } catch (e) { + const err = e as Error + assert(false, err.message) + } + }) })