Skip to content

Commit

Permalink
initial refactoring to free up digest access for signature generation.
Browse files Browse the repository at this point in the history
  • Loading branch information
dghgit committed Sep 14, 2024
1 parent 7352222 commit 499cb95
Show file tree
Hide file tree
Showing 5 changed files with 176 additions and 74 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,18 @@
import org.bouncycastle.crypto.Digest;
import org.bouncycastle.crypto.Signer;
import org.bouncycastle.crypto.digests.SHA512Digest;
import org.bouncycastle.crypto.digests.SHAKEDigest;
import org.bouncycastle.crypto.params.ParametersWithRandom;
import org.bouncycastle.pqc.crypto.DigestUtils;
import org.bouncycastle.util.Arrays;

public class HashMLDSASigner
implements Signer
{
private MLDSAPrivateKeyParameters privKey;
private MLDSAPublicKeyParameters pubKey;

private MLDSAEngine engine;
private SecureRandom random;
private Digest digest;
private byte[] digestOidEncoding;
Expand All @@ -44,6 +47,16 @@ public void init(boolean forSigning, CipherParameters param)
random = null;
}

engine = privKey.getParameters().getEngine(this.random);

byte[] ctx = privKey.getContext();
if (ctx.length > 255)
{
throw new IllegalArgumentException("context too long");
}

engine.initSign(privKey.tr, true, ctx);

initDigest(privKey);
}
else
Expand Down Expand Up @@ -88,13 +101,7 @@ public void update(byte[] in, int off, int len)
@Override
public byte[] generateSignature() throws CryptoException, DataLengthException
{
MLDSAEngine engine = privKey.getParameters().getEngine(random);

byte[] ctx = privKey.getContext();
if (ctx.length > 255)
{
throw new RuntimeException("Context too long");
}
SHAKEDigest msgDigest = engine.getShake256Digest();

byte[] rnd = new byte[MLDSAEngine.RndBytes];
if (random != null)
Expand All @@ -105,14 +112,9 @@ public byte[] generateSignature() throws CryptoException, DataLengthException
byte[] hash = new byte[digest.getDigestSize()];
digest.doFinal(hash, 0);

byte[] ds_message = new byte[1 + 1 + ctx.length + + digestOidEncoding.length + hash.length];
ds_message[0] = 1;
ds_message[1] = (byte)ctx.length;
System.arraycopy(ctx, 0, ds_message, 2, ctx.length);
System.arraycopy(digestOidEncoding, 0, ds_message, 2 + ctx.length, digestOidEncoding.length);
System.arraycopy(hash, 0, ds_message, 2 + ctx.length + digestOidEncoding.length, hash.length);
byte[] ds_message = Arrays.concatenate(digestOidEncoding, hash);

return engine.signInternal(ds_message, ds_message.length, privKey.rho, privKey.k, privKey.tr, privKey.t0, privKey.s1, privKey.s2, rnd);
return engine.signInternal(ds_message, ds_message.length, privKey.rho, privKey.k, privKey.t0, privKey.s1, privKey.s2, rnd);
}

@Override
Expand Down Expand Up @@ -153,7 +155,7 @@ public byte[] internalGenerateSignature(byte[] message, byte[] random)
{
MLDSAEngine engine = privKey.getParameters().getEngine(this.random);

return engine.signInternal(message, message.length, privKey.rho, privKey.k, privKey.tr, privKey.t0, privKey.s1, privKey.s2, random);
return engine.signInternal(message, message.length, privKey.rho, privKey.k, privKey.t0, privKey.s1, privKey.s2, random);
}

public boolean internalVerifySignature(byte[] message, byte[] signature)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ class MLDSAEngine

private final int PolyUniformGamma1NBlocks;

private final Symmetric symmetric;;
private final Symmetric symmetric;
;

protected Symmetric GetSymmetric()
{
Expand Down Expand Up @@ -120,7 +121,7 @@ int getDilithiumOmega()
{
return DilithiumOmega;
}

int getDilithiumCTilde()
{
return DilithiumCTilde;
Expand All @@ -146,16 +147,6 @@ int getPolyUniformGamma1NBlocks()
return this.PolyUniformGamma1NBlocks;
}

SHAKEDigest getShake256Digest()
{
return this.shake256Digest;
}

SHAKEDigest getShake128Digest()
{
return this.shake128Digest;
}

MLDSAEngine(int mode, SecureRandom random)
{
this.DilithiumMode = mode;
Expand Down Expand Up @@ -206,7 +197,7 @@ SHAKEDigest getShake128Digest()
default:
throw new IllegalArgumentException("The mode " + mode + "is not supported by Crystals Dilithium!");
}

this.symmetric = new Symmetric.ShakeSymmetric();

this.random = random;
Expand Down Expand Up @@ -243,16 +234,15 @@ private byte[][] generateKeyPairInternal(byte[] seed)
byte[] tr = new byte[TrBytes];

byte[] rho = new byte[SeedBytes],
rhoPrime = new byte[CrhBytes],
key = new byte[SeedBytes];
rhoPrime = new byte[CrhBytes],
key = new byte[SeedBytes];

PolyVecMatrix aMatrix = new PolyVecMatrix(this);

PolyVecL s1 = new PolyVecL(this), s1hat;
PolyVecK s2 = new PolyVecK(this), t1 = new PolyVecK(this), t0 = new PolyVecK(this);



shake256Digest.update(seed, 0, SeedBytes);

//Domain separation
Expand Down Expand Up @@ -312,14 +302,42 @@ private byte[][] generateKeyPairInternal(byte[] seed)

byte[][] sk = Packing.packSecretKey(rho, tr, key, t0, s1, s2, this);

return new byte[][]{ sk[0], sk[1], sk[2], sk[3], sk[4], sk[5], encT1};
return new byte[][]{sk[0], sk[1], sk[2], sk[3], sk[4], sk[5], encT1};
}

public byte[] signInternal(byte[] msg, int msglen, byte[] rho, byte[] key, byte[] tr, byte[] t0Enc, byte[] s1Enc, byte[] s2Enc, byte[] rnd)
SHAKEDigest getShake256Digest()
{
return new SHAKEDigest(shake256Digest);
}
void initSign(byte[] tr, boolean isPreHash, byte[] ctx)
{
this.shake256Digest.update(tr, 0, TrBytes);
if (ctx != null)
{
this.shake256Digest.update((isPreHash) ? (byte)1 : (byte)0);
this.shake256Digest.update((byte)ctx.length);
this.shake256Digest.update(ctx, 0, ctx.length);
}
}

public byte[] signInternal(byte[] msg, int msglen, byte[] rho, byte[] key, byte[] t0Enc, byte[] s1Enc, byte[] s2Enc, byte[] rnd)
{
SHAKEDigest shake256 = new SHAKEDigest(shake256Digest);

shake256.update(msg, 0, msglen);

return generateSignature(shake256, rho, key, t0Enc, s1Enc, s2Enc, rnd);
}

byte[] generateSignature(SHAKEDigest shake256Digest, byte[] rho, byte[] key, byte[] t0Enc, byte[] s1Enc, byte[] s2Enc, byte[] rnd)
{
byte[] mu = new byte[CrhBytes];

shake256Digest.doFinal(mu, 0, CrhBytes);

int n;
byte[] outSig = new byte[CryptoBytes + msglen];
byte[] mu = new byte[CrhBytes], rhoPrime = new byte[CrhBytes];
byte[] outSig = new byte[CryptoBytes];
byte[] rhoPrime = new byte[CrhBytes];
short nonce = 0;
PolyVecL s1 = new PolyVecL(this), y = new PolyVecL(this), z = new PolyVecL(this);
PolyVecK t0 = new PolyVecK(this), s2 = new PolyVecK(this), w1 = new PolyVecK(this), w0 = new PolyVecK(this), h = new PolyVecK(this);
Expand All @@ -328,12 +346,6 @@ public byte[] signInternal(byte[] msg, int msglen, byte[] rho, byte[] key, byte[

Packing.unpackSecretKey(t0, s1, s2, t0Enc, s1Enc, s2Enc, this);

this.shake256Digest.update(tr, 0, TrBytes);
this.shake256Digest.update(msg, 0, msglen);
this.shake256Digest.doFinal(mu, 0, CrhBytes);



byte[] keyMu = Arrays.copyOf(key, SeedBytes + RndBytes + CrhBytes);
System.arraycopy(rnd, 0, keyMu, SeedBytes, RndBytes);
System.arraycopy(mu, 0, keyMu, SeedBytes + RndBytes, CrhBytes);
Expand Down Expand Up @@ -418,22 +430,21 @@ public byte[] signInternal(byte[] msg, int msglen, byte[] rho, byte[] key, byte[

public boolean verifyInternal(byte[] sig, int siglen, byte[] msg, int msglen, byte[] rho, byte[] encT1)
{
byte[] buf,
mu = new byte[CrhBytes],
c,
c2 = new byte[DilithiumCTilde];
Poly cp = new Poly(this);
PolyVecMatrix aMatrix = new PolyVecMatrix(this);
PolyVecL z = new PolyVecL(this);
PolyVecK t1 = new PolyVecK(this), w1 = new PolyVecK(this), h = new PolyVecK(this);

if (siglen != CryptoBytes)
{
return false;
}

// System.out.println("publickey = ");
// Helper.printByteArray(publicKey);
byte[] buf,
mu = new byte[CrhBytes],
c,
c2 = new byte[DilithiumCTilde];
Poly cp = new Poly(this);
PolyVecMatrix aMatrix = new PolyVecMatrix(this);
PolyVecL z = new PolyVecL(this);
PolyVecK t1 = new PolyVecK(this), w1 = new PolyVecK(this), h = new PolyVecK(this);

t1 = Packing.unpackPublicKey(t1, encT1, this);

Expand Down
Original file line number Diff line number Diff line change
@@ -1,19 +1,29 @@
package org.bouncycastle.pqc.crypto.mldsa;

import java.io.ByteArrayOutputStream;
import java.security.SecureRandom;

import org.bouncycastle.crypto.CipherParameters;
import org.bouncycastle.crypto.CryptoException;
import org.bouncycastle.crypto.DataLengthException;
import org.bouncycastle.crypto.Signer;
import org.bouncycastle.crypto.digests.SHAKEDigest;
import org.bouncycastle.crypto.params.ParametersWithRandom;
import org.bouncycastle.pqc.crypto.MessageSigner;

public class MLDSASigner
implements MessageSigner
implements Signer
{
private MLDSAPrivateKeyParameters privKey;
private MLDSAPublicKeyParameters pubKey;

private MLDSAEngine engine;
private SHAKEDigest msgDigest;

private SecureRandom random;

// TODO: temporary
private ByteArrayOutputStream bOut = new ByteArrayOutputStream();

public MLDSASigner()
{
}
Expand All @@ -35,11 +45,25 @@ public void init(boolean forSigning, CipherParameters param)
random = null;
}

engine = privKey.getParameters().getEngine(this.random);

byte[] ctx = privKey.getContext();
if (ctx.length > 255)
{
throw new IllegalArgumentException("context too long");
}

engine.initSign(privKey.tr, false, ctx);

msgDigest = engine.getShake256Digest();

isPreHash = privKey.getParameters().isPreHash();
}
else
{
pubKey = (MLDSAPublicKeyParameters)param;
engine = null;
msgDigest = null;
isPreHash = pubKey.getParameters().isPreHash();
}

Expand All @@ -49,39 +73,77 @@ public void init(boolean forSigning, CipherParameters param)
}
}

public byte[] generateSignature(byte[] message)
public void update(byte b)
{
MLDSAEngine engine = privKey.getParameters().getEngine(random);
if (msgDigest != null)
{
msgDigest.update(b);
}
else
{
bOut.write(b);
}
}

byte[] ctx = privKey.getContext();
if (ctx.length > 255)
public void update(byte[] in, int off, int len)
{
if (msgDigest != null)
{
throw new RuntimeException("Context too long");
msgDigest.update(in, off, len);
}
else
{
bOut.write(in, off, len);
}
}

public byte[] generateSignature()
throws CryptoException, DataLengthException
{
byte[] rnd = new byte[MLDSAEngine.RndBytes];
if (random != null)
{
random.nextBytes(rnd);
}

byte[] ds_message = new byte[1 + 1 + ctx.length + message.length];
ds_message[0] = 0;
ds_message[1] = (byte)ctx.length;
System.arraycopy(ctx, 0, ds_message, 2, ctx.length);
System.arraycopy(message, 0, ds_message, 2 + ctx.length, message.length);
return engine.generateSignature(msgDigest, privKey.rho, privKey.k, privKey.t0, privKey.s1, privKey.s2, rnd);
}

public boolean verifySignature(byte[] signature)
{
boolean isTrue = verifySignature(bOut.toByteArray(), signature);

bOut.reset();

return isTrue;
}

return engine.signInternal(ds_message, ds_message.length, privKey.rho, privKey.k, privKey.tr, privKey.t0, privKey.s1, privKey.s2, rnd);
public void reset()
{
bOut.reset();
}

byte[] generateSignature(byte[] message)
{
byte[] rnd = new byte[MLDSAEngine.RndBytes];
if (random != null)
{
random.nextBytes(rnd);
}

return engine.signInternal(message, message.length, privKey.rho, privKey.k, privKey.t0, privKey.s1, privKey.s2, rnd);
}

public byte[] internalGenerateSignature(byte[] message, byte[] random)
protected byte[] internalGenerateSignature(byte[] message, byte[] random)
{
MLDSAEngine engine = privKey.getParameters().getEngine(this.random);

return engine.signInternal(message, message.length, privKey.rho, privKey.k, privKey.tr, privKey.t0, privKey.s1, privKey.s2, random);
engine.initSign(privKey.tr, false, null);

return engine.signInternal(message, message.length, privKey.rho, privKey.k, privKey.t0, privKey.s1, privKey.s2, random);
}

public boolean verifySignature(byte[] message, byte[] signature)
boolean verifySignature(byte[] message, byte[] signature)
{
MLDSAEngine engine = pubKey.getParameters().getEngine(random);

Expand All @@ -99,6 +161,7 @@ public boolean verifySignature(byte[] message, byte[] signature)

return engine.verifyInternal(signature, signature.length, ds_message, ds_message.length, pubKey.rho, pubKey.t1);
}

public boolean internalVerifySignature(byte[] message, byte[] signature)
{
MLDSAEngine engine = pubKey.getParameters().getEngine(random);
Expand Down
Loading

0 comments on commit 499cb95

Please sign in to comment.