Skip to content

Commit fe5f1e3

Browse files
RandomLatticeandozw
andcommitted
PR feedback from @real-of-random
Co-authored-by: Sean Andersen <6730974+andozw@users.noreply.github.com>
1 parent 3cba981 commit fe5f1e3

File tree

3 files changed

+22
-29
lines changed

3 files changed

+22
-29
lines changed

src/modules/ecdh/tests_impl.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -168,10 +168,8 @@ static void test_ecdh_wycheproof(void) {
168168
parsed_ok = secp256k1_ec_pubkey_parse(CTX, &point, pk, testvectors[t].pk_len);
169169

170170
expected_result = testvectors[t].expected_result;
171-
172171
CHECK(parsed_ok == expected_result);
173-
174-
if (!parsed_ok && expected_result == 0) {
172+
if (!parsed_ok) {
175173
continue;
176174
}
177175

tools/tests_wycheproof_generate_ecdh.py

Lines changed: 20 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,15 @@
1313
from wycheproof_utils import to_c_array
1414

1515
def should_skip_flags(test_vector_flags):
16-
# skip these vectors because they are for ASN.1 encoding issues and other curves
16+
# skip these vectors because they are for ASN.1 encoding issues and other curves.
17+
# for more details, see https://github.com/bitcoin-core/secp256k1/pull/1492#discussion_r1572491546
1718
flags_to_skip = {"InvalidAsn", "WrongCurve"}
1819
return any(flag in test_vector_flags for flag in flags_to_skip)
1920

2021
def should_skip_tcid(test_vector_tcid):
21-
'''
22-
We skip some test case IDs that have a public key whose custom ASN.1 representation explicitly
23-
encodes some curve parameters that are invalid. libsecp256k1 never parses this part so we do
24-
not care testing those. See https://github.com/bitcoin-core/secp256k1/pull/1492#discussion_r1572491546
25-
'''
22+
# We skip some test case IDs that have a public key whose custom ASN.1 representation explicitly
23+
# encodes some curve parameters that are invalid. libsecp256k1 never parses this part so we do
24+
# not care testing those. See https://github.com/bitcoin-core/secp256k1/pull/1492#discussion_r1572491546
2625
tcids_to_skip = [496, 497, 502, 503, 504, 505, 507]
2726
return test_vector_tcid in tcids_to_skip
2827

@@ -37,12 +36,12 @@ def parse_der_pk(s):
3736
L = int(s[2])
3837
offset = 1
3938
elif L == 0x82:
40-
L = 256*int(s[2]) + int(s[3])
39+
L = 256 * int(s[2]) + int(s[3])
4140
offset = 2
4241
else:
4342
raise ValueError("invalid L")
44-
value = s[(offset+2):(L+2+offset)]
45-
rest = s[(L+2+offset):]
43+
value = s[(offset + 2):(L + 2 + offset)]
44+
rest = s[(L + 2 + offset):]
4645

4746
if len(rest) > 0 or tag == 0x06: # OBJECT IDENTIFIER
4847
return parse_der_pk(rest)
@@ -54,12 +53,14 @@ def parse_der_pk(s):
5453

5554
def parse_public_key(pk):
5655
der_pub_key = parse_der_pk(unhexlify(pk)) # Convert back to str and strip off the `0x`
57-
return hexlify(der_pub_key).decode('utf-8')[2:]
56+
return hexlify(der_pub_key).decode()[2:]
5857

5958
def normalize_private_key(sk):
6059
# Ensure the private key is at most 64 characters long, retaining the last 64 if longer.
60+
# In the wycheproof test vectors, some private keys have leading zeroes
6161
normalized = sk[-64:].zfill(64)
62-
assert len(normalized) == 64, "private key must be exactly 64 characters long."
62+
if len(normalized) != 64:
63+
raise ValueError("private key must be exactly 64 characters long.")
6364
return normalized
6465

6566
def normalize_expected_result(er):
@@ -71,25 +72,19 @@ def normalize_expected_result(er):
7172
with open(filename_input) as f:
7273
doc = json.load(f)
7374

74-
num_groups = len(doc['testGroups'])
75-
7675
num_vectors = 0
7776
offset_sk_running, offset_pk_running, offset_shared = 0, 0, 0
78-
out = ""
77+
test_vectors_out = ""
7978
private_keys = ""
8079
shared_secrets = ""
8180
public_keys = ""
8281
cache_sks = {}
8382
cache_public_keys = {}
8483

85-
for i in range(num_groups):
86-
group = doc['testGroups'][i]
87-
num_tests = len(group['tests'])
84+
for group in doc['testGroups']:
8885
assert group["type"] == "EcdhTest"
8986
assert group["curve"] == "secp256k1"
90-
for j in range(num_tests):
91-
test_vector = group['tests'][j]
92-
87+
for test_vector in group['tests']:
9388
if should_skip_flags(test_vector['flags']) or should_skip_tcid(test_vector['tcId']):
9489
continue
9590

@@ -102,8 +97,6 @@ def normalize_expected_result(er):
10297
sk_size = len(private_key) // 2
10398
pk_size = len(public_key) // 2
10499

105-
shared_secrets += ",\n " if num_vectors and shared_size else ""
106-
107100
new_sk = False
108101
sk = to_c_array(private_key)
109102
sk_offset = offset_sk_running
@@ -132,11 +125,13 @@ def normalize_expected_result(er):
132125
else:
133126
pk_offset = cache_public_keys[pk]
134127

128+
129+
shared_secrets += ",\n " if num_vectors and shared_size else ""
135130
shared_secrets += to_c_array(test_vector['shared'])
136131
wycheproof_tcid = test_vector['tcId']
137132

138-
out += " /" + "* tcId: " + str(test_vector['tcId']) + ". " + test_vector['comment'] + " *" + "/\n"
139-
out += f" {{{pk_offset}, {pk_size}, {sk_offset}, {sk_size}, {offset_shared}, {shared_size}, {expected_result}, {wycheproof_tcid} }},\n"
133+
test_vectors_out += " /" + "* tcId: " + str(test_vector['tcId']) + ". " + test_vector['comment'] + " *" + "/\n"
134+
test_vectors_out += f" {{{pk_offset}, {pk_size}, {sk_offset}, {sk_size}, {offset_shared}, {shared_size}, {expected_result}, {wycheproof_tcid} }},\n"
140135
if new_sk:
141136
offset_sk_running += sk_size
142137
if new_pk:
@@ -167,5 +162,5 @@ def normalize_expected_result(er):
167162
print("static const unsigned char wycheproof_ecdh_shared_secrets[] = { " + shared_secrets + "};\n")
168163

169164
print("static const wycheproof_ecdh_testvector testvectors[SECP256K1_ECDH_WYCHEPROOF_NUMBER_TESTVECTORS] = {")
170-
print(out)
165+
print(test_vectors_out)
171166
print("};")

tools/wycheproof_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,5 @@
88
def to_c_array(x):
99
if x == "":
1010
return ""
11-
s = ',0x'.join(a+b for a,b in zip(x[::2], x[1::2]))
11+
s = ',0x'.join(a + b for a, b in zip(x[::2], x[1::2]))
1212
return "0x" + s

0 commit comments

Comments
 (0)