Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add x25519 AKE test vectors #404

Merged
merged 15 commits into from
May 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
686 changes: 506 additions & 180 deletions draft-irtf-cfrg-opaque.md

Large diffs are not rendered by default.

55 changes: 55 additions & 0 deletions poc/ake_group.sage
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import sys
import hashlib

########## Definitions from RFC 7748 ##################
from sagelib.rfc7748 import *
from sagelib.groups import *
from sagelib.string_utils import *

try:
from sagelib.opaque_common import curve25519_clamp
except ImportError as e:
sys.exit("Error loading preprocessed sage files. Try running `make setup && make clean pyfiles`. Full error: " + e)


class GroupCurve25519(Group):
def __init__(self):
Group.__init__(self, "curve25519")

def generator(self):
return IntegerToByteArray(9)

def serialize(self, element):
# Curve25519 points are bytes
return element

def deserialize(self, encoded):
# Curve25519 points are bytes
return encoded

def serialize_scalar(self, scalar):
# Curve25519 scalars are represented as bytes
return scalar

def element_byte_length(self):
return 32

def scalar_byte_length(self):
return 32

def random_scalar(self, rng):
return curve25519_clamp(rng.random_bytes(32))

def scalar_mult(self, x, y):
return X25519(x, y)

def __str__(self):
return self.name

if __name__ == "__main__":
# From RFC7748: https://www.rfc-editor.org/rfc/rfc7748#section-6.1
a = bytes.fromhex("77076d0a7318a57d3c16c17251b26645df4c2f87ebc0992ab177fba51db92c2a")
A = bytes.fromhex("8520f0098930a754748b7ddcb43ef75a0dbf3a0d26381af4eba4a98eaa9b4e6a")
G = GroupCurve25519()
A_exp = G.scalar_mult(a, G.generator())
assert(A_exp == A)
30 changes: 19 additions & 11 deletions poc/format_test_vectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,8 @@
"server_public_key",
"server_nonce",
"client_nonce",
"server_public_keyshare",
"client_keyshare",
"server_private_keyshare",
"client_private_keyshare",
"client_keyshare_seed",
"server_keyshare_seed",
"blind_registration",
"blind_login",
]
Expand Down Expand Up @@ -64,7 +62,7 @@
"session_key",
]

### Fake Vector Keys
# Fake Vector Keys

fake_input_keys = [
"client_identity",
Expand All @@ -80,10 +78,8 @@
"server_public_key",
"server_nonce",
"client_nonce",
"server_public_keyshare",
"client_keyshare",
"server_private_keyshare",
"client_private_keyshare",
"client_keyshare_seed",
"server_keyshare_seed",
"blind_registration",
"blind_login",
"masking_key",
Expand All @@ -94,6 +90,7 @@
"KE2",
]


def to_hex(octet_string):
if isinstance(octet_string, str):
return "".join("{:02x}".format(ord(c)) for c in octet_string)
Expand All @@ -102,40 +99,47 @@ def to_hex(octet_string):
assert isinstance(octet_string, bytearray)
return ''.join(format(x, '02x') for x in octet_string)


def wrap_print(arg, *args):
line_length = 69
string = arg + " " + " ".join(args)
for hunk in (string[0+i:line_length+i] for i in range(0, len(string), line_length)):
if hunk and len(hunk.strip()) > 0:
print(hunk)


def format_vector_name(vector):
return "OPAQUE-" + vector["config"]["Name"]


def print_vector_config(vector):
for key in config_keys:
for config_key in vector["config"]:
if key == config_key:
wrap_print(key + ":", vector["config"][key])


def print_vector_inputs(arr, vector):
for key in arr:
for input_key in vector["inputs"]:
if key == input_key:
wrap_print(key + ":", vector["inputs"][key])


def print_vector_intermediates(arr, vector):
for key in arr:
for int_key in vector["intermediates"]:
if key == int_key:
wrap_print(key + ":", vector["intermediates"][key])


def print_vector_outputs(arr, vector):
for key in arr:
for output_key in vector["outputs"]:
if key == output_key:
wrap_print(key + ":", vector["outputs"][key])


def format_vector(vector, i):
print("\n#### Configuration\n")
print("~~~")
Expand All @@ -155,6 +159,7 @@ def format_vector(vector, i):
print("~~~")
print("")


def format_fake_vector(vector, i):
print("\n#### Configuration\n")
print("~~~")
Expand All @@ -170,6 +175,7 @@ def format_fake_vector(vector, i):
print("~~~")
print("")


with open(sys.argv[1], "r") as fh:
vectors = json.loads(fh.read())
real_vectors = []
Expand All @@ -181,10 +187,12 @@ def format_fake_vector(vector, i):
real_vectors.append(vector)
print("## Real Test Vectors {#real-vectors}\n")
for i, vector in enumerate(real_vectors):
print("### " + format_vector_name(vector) + " Real Test Vector " + str(i+1))
print("### " + format_vector_name(vector) +
" Real Test Vector " + str(i+1))
format_vector(vector, i)

print("## Fake Test Vectors {#fake-vectors}\n")
for i, vector in enumerate(fake_vectors):
print("### " + format_vector_name(vector) + " Fake Test Vector " + str(i+1))
print("### " + format_vector_name(vector) +
" Fake Test Vector " + str(i+1))
format_fake_vector(vector, i)
72 changes: 35 additions & 37 deletions poc/opaque_ake.sage
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ from collections import namedtuple

try:
from sagelib.opaque_common import derive_secret, hkdf_expand_label, hkdf_extract, I2OSP, OS2IP, OS2IP_le, encode_vector, encode_vector_len, to_hex, OPAQUE_NONCE_LENGTH
from sagelib.opaque_core import OPAQUECore
from sagelib.opaque_core import OPAQUECore, OPAQUE_SEED_LENGTH
from sagelib.opaque_messages import deserialize_credential_request, deserialize_credential_response
except ImportError as e:
sys.exit("Error loading preprocessed sage files. Try running `make setup && make clean pyfiles`. Full error: " + e)
Expand Down Expand Up @@ -76,9 +76,9 @@ class OPAQUE3DH(KeyExchange):
}

def derive_3dh_keys(self, dh_components, info):
dh1 = dh_components.sk1 * dh_components.pk1
dh2 = dh_components.sk2 * dh_components.pk2
dh3 = dh_components.sk3 * dh_components.pk3
dh1 = self.config.group.scalar_mult(dh_components.sk1, self.config.group.deserialize(dh_components.pk1))
dh2 = self.config.group.scalar_mult(dh_components.sk2, self.config.group.deserialize(dh_components.pk2))
dh3 = self.config.group.scalar_mult(dh_components.sk3, self.config.group.deserialize(dh_components.pk3))

dh1_encoded = self.config.group.serialize(dh1)
dh2_encoded = self.config.group.serialize(dh2)
Expand All @@ -101,10 +101,9 @@ class OPAQUE3DH(KeyExchange):

def auth_client_start(self):
self.client_nonce = self.rng.random_bytes(OPAQUE_NONCE_LENGTH)
self.client_private_keyshare = ZZ(self.config.group.random_scalar(self.rng))
self.client_public_keyshare_bytes = self.config.group.serialize(self.client_private_keyshare * self.config.group.generator())

return TripleDHMessageInit(self.client_nonce, self.client_public_keyshare_bytes)
self.client_keyshare_seed = self.rng.random_bytes(OPAQUE_SEED_LENGTH)
self.client_private_keyshare, self.client_public_keyshare = self.core.derive_diffie_hellman_key_pair(self.client_keyshare_seed)
return TripleDHMessageInit(self.client_nonce, self.client_public_keyshare)

def generate_ke1(self, password):
cred_request, cred_metadata = self.core.create_credential_request(password)
Expand All @@ -116,14 +115,14 @@ class OPAQUE3DH(KeyExchange):

return self.serialized_request + ke1.serialize()

def transcript_hasher(self, serialized_request, serialized_response, cleartext_credentials, client_nonce, client_public_keyshare_bytes, server_nonce, server_public_keyshare_bytes):
def transcript_hasher(self, serialized_request, serialized_response, cleartext_credentials, client_nonce, client_public_keyshare, server_nonce, server_public_keyshare_bytes):
hasher = self.config.hash()
hasher.update(_as_bytes("RFCXXXX")) # RFCXXXX
hasher.update(encode_vector(self.config.context)) # context
hasher.update(encode_vector_len(cleartext_credentials.client_identity, 2)) # client_identity
hasher.update(serialized_request) # ke1: cred request
hasher.update(client_nonce) # ke1: client nonce
hasher.update(client_public_keyshare_bytes) # ke1: client keyshare
hasher.update(client_public_keyshare) # ke1: client keyshare
hasher.update(encode_vector_len(cleartext_credentials.server_identity, 2)) # server identity
hasher.update(serialized_response) # ke2: cred response
hasher.update(server_nonce) # ke2: server nonce
Expand All @@ -133,21 +132,19 @@ class OPAQUE3DH(KeyExchange):

return hasher.digest()

def auth_server_respond(self, cred_request, cred_response, ke1, cleartext_credentials, server_private_key, client_public_key):
def auth_server_respond(self, cred_request, cred_response, ke1, cleartext_credentials, server_private_key, client_public_keyshare):
self.server_nonce = self.rng.random_bytes(OPAQUE_NONCE_LENGTH)
self.server_private_keyshare = ZZ(self.config.group.random_scalar(self.rng))
self.server_public_keyshare = self.server_private_keyshare * self.config.group.generator()
server_public_keyshare_bytes = self.config.group.serialize(self.server_public_keyshare)
client_public_keyshare = self.config.group.deserialize(ke1.client_public_keyshare)
self.server_keyshare_seed = self.rng.random_bytes(OPAQUE_SEED_LENGTH)
self.server_private_keyshare, self.server_public_keyshare_bytes = self.core.derive_diffie_hellman_key_pair(self.server_keyshare_seed)

transcript_hash = self.transcript_hasher(cred_request.serialize(), cred_response.serialize(), cleartext_credentials, ke1.client_nonce, ke1.client_public_keyshare, self.server_nonce, server_public_keyshare_bytes)
transcript_hash = self.transcript_hasher(cred_request.serialize(), cred_response.serialize(), cleartext_credentials, ke1.client_nonce, ke1.client_public_keyshare, self.server_nonce, self.server_public_keyshare_bytes)

# K3dh = epkU^eskS || epkU^skS || pkU^eskS
dh_components = TripleDHComponents(client_public_keyshare, self.server_private_keyshare, client_public_keyshare, server_private_key, client_public_key, self.server_private_keyshare)
dh_components = TripleDHComponents(ke1.client_public_keyshare, self.server_private_keyshare, ke1.client_public_keyshare, server_private_key, client_public_keyshare, self.server_private_keyshare)

server_mac_key, client_mac_key, session_key, handshake_secret = self.derive_3dh_keys(dh_components, self.hasher.digest())
mac = hmac.digest(server_mac_key, transcript_hash, self.config.hash)
ake2 = TripleDHMessageRespond(self.server_nonce, server_public_keyshare_bytes, mac)
ake2 = TripleDHMessageRespond(self.server_nonce, self.server_public_keyshare_bytes, mac)

self.server_mac_key = server_mac_key
self.ake2 = ake2
Expand All @@ -158,27 +155,24 @@ class OPAQUE3DH(KeyExchange):

return ake2

def generate_ke2(self, msg, oprf_seed, credential_identifier, envU, masking_key, server_identity, server_private_key, server_public_key, client_identity, client_public_key):
def generate_ke2(self, msg, oprf_seed, credential_identifier, envU, masking_key, server_identity, server_private_key, server_public_keyshare, client_identity, client_public_keyshare):
cred_request, offset = deserialize_credential_request(self.config, msg)
ke1 = deserialize_tripleDH_init(self.config, msg[offset:])

server_public_key_bytes = self.config.group.serialize(server_public_key)
cred_response = self.core.create_credential_response(cred_request, server_public_key_bytes, oprf_seed, envU, credential_identifier, masking_key)
cred_response = self.core.create_credential_response(cred_request, server_public_keyshare, oprf_seed, envU, credential_identifier, masking_key)
serialized_response = cred_response.serialize()
self.masking_nonce = cred_response.masking_nonce

cleartext_credentials = self.core.create_cleartext_credentials(server_public_key_bytes, self.config.group.serialize(client_public_key), server_identity, client_identity)
ake2 = self.auth_server_respond(cred_request, cred_response, ke1, cleartext_credentials, server_private_key, client_public_key)
cleartext_credentials = self.core.create_cleartext_credentials(server_public_keyshare, client_public_keyshare, server_identity, client_identity)
ake2 = self.auth_server_respond(cred_request, cred_response, ke1, cleartext_credentials, server_private_key, client_public_keyshare)

return serialized_response + ake2.serialize()

def auth_client_finalize(self, cred_response, ake2, cleartext_credentials, client_private_key, client_public_key):
transcript_hash = self.transcript_hasher(self.serialized_request, cred_response.serialize(), cleartext_credentials, self.client_nonce, self.client_public_keyshare_bytes, ake2.server_nonce, ake2.server_public_keyshare_bytes)
server_public_key = self.config.group.deserialize(cleartext_credentials.server_public_key_bytes)
server_public_keyshare = self.config.group.deserialize(ake2.server_public_keyshare_bytes)
def auth_client_finalize(self, cred_response, ake2, cleartext_credentials, client_private_key):
transcript_hash = self.transcript_hasher(self.serialized_request, cred_response.serialize(), cleartext_credentials, self.client_nonce, self.client_public_keyshare, ake2.server_nonce, ake2.server_public_keyshare_bytes)

# K3dh = epkS^eskU || pkS^eskU || epkS^skU
dh_components = TripleDHComponents(server_public_keyshare, self.client_private_keyshare, server_public_key, self.client_private_keyshare, server_public_keyshare, client_private_key)
dh_components = TripleDHComponents(ake2.server_public_keyshare_bytes, self.client_private_keyshare, cleartext_credentials.server_public_key_bytes, self.client_private_keyshare, ake2.server_public_keyshare_bytes, client_private_key)

server_mac_key, client_mac_key, session_key, handshake_secret = self.derive_3dh_keys(dh_components, self.hasher.digest())
server_mac = hmac.digest(server_mac_key, transcript_hash, self.config.hash)
Expand All @@ -197,16 +191,20 @@ class OPAQUE3DH(KeyExchange):

return TripleDHMessageFinish(client_mac)

def generate_ke3(self, msg, client_identity, client_public_key, server_identity):
def generate_ke3(self, msg, client_identity, server_identity):
cred_response, offset = deserialize_credential_response(self.config, msg)
ake2 = deserialize_tripleDH_respond(self.config, msg[offset:])
client_private_key_bytes, cleartext_credentials, export_key = self.core.recover_credentials(self.password, self.cred_metadata, cred_response, client_identity, server_identity)
client_private_key = OS2IP(client_private_key_bytes)
if "ristretto" in self.config.group.name or "decaf" in self.config.group.name:

if "curve25519" in self.config.group.name:
client_private_key = client_private_key_bytes
elif "ristretto" in self.config.group.name or "decaf" in self.config.group.name:
client_private_key = OS2IP_le(client_private_key_bytes)
self.export_key = export_key
else:
client_private_key = OS2IP(client_private_key_bytes)

ke3 = self.auth_client_finalize(cred_response, ake2, cleartext_credentials, client_private_key, client_public_key)
self.export_key = export_key
ke3 = self.auth_client_finalize(cred_response, ake2, cleartext_credentials, client_private_key)

return ke3.serialize()

Expand All @@ -228,11 +226,11 @@ class OPAQUE3DH(KeyExchange):
# } KE1M;
def deserialize_tripleDH_init(config, data):
client_nonce = data[0:OPAQUE_NONCE_LENGTH]
client_public_keyshare_bytes = data[OPAQUE_NONCE_LENGTH:]
client_public_keyshare = data[OPAQUE_NONCE_LENGTH:]
length = config.oprf_suite.group.element_byte_length()
if len(client_public_keyshare_bytes) != length:
raise Exception("Invalid client_public_keyshare length: %d %d" % (len(client_public_keyshare_bytes), length))
return TripleDHMessageInit(client_nonce, client_public_keyshare_bytes)
if len(client_public_keyshare) != length:
raise Exception("Invalid client_public_keyshare length: %d %d" % (len(client_public_keyshare), length))
return TripleDHMessageInit(client_nonce, client_public_keyshare)

class TripleDHMessageInit(object):
def __init__(self, client_nonce, client_public_keyshare):
Expand Down
8 changes: 8 additions & 0 deletions poc/opaque_common.sage
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,14 @@ def xor(a, b):
c[i] = c[i] ^^ v # bitwise XOR
return bytes(c)

# Performs the curve25519 clamping operation
def curve25519_clamp(scalar):
arr = bytearray(scalar)
arr[0] &= 248
arr[31] &= 127
arr[31] |= 64
return bytes(arr)

def hkdf_extract(config, salt, ikm):
return hmac.digest(salt, ikm, config.hash)

Expand Down
Loading