Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
hewigovens committed May 31, 2019
1 parent 1eb0acc commit a171e39
Show file tree
Hide file tree
Showing 4 changed files with 133 additions and 111 deletions.
6 changes: 2 additions & 4 deletions trezor-crypto/include/TrezorCrypto/schnorr.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ extern "C"

// result of sign operation
typedef struct {
bignum256 r, s;
uint8_t r[32];
uint8_t s[32];
} schnorr_sign_pair;

// sign/verify returns 0 if operation succeeded
Expand All @@ -44,9 +45,6 @@ int schnorr_sign(const ecdsa_curve *curve, const uint8_t *priv_key,
int schnorr_verify(const ecdsa_curve *curve, const uint8_t *pub_key,
const uint8_t *msg, const uint32_t msg_len,
const schnorr_sign_pair *sign);

void schnorr_to_hex_str(const schnorr_sign_pair *sign, char hex_str[128]);
void schnorr_from_hex_str(const char hex_str[128], schnorr_sign_pair *sign);
#ifdef __cplusplus
} /* extern "C" */
#endif
Expand Down
9 changes: 5 additions & 4 deletions trezor-crypto/src/ecdsa.c
Original file line number Diff line number Diff line change
Expand Up @@ -1176,8 +1176,8 @@ int zil_schnorr_sign(const ecdsa_curve *curve, const uint8_t *priv_key, const ui
}

// we're done
bn_write_be(&sign.r, sig);
bn_write_be(&sign.s, sig + 32);
memcpy(sig, sign.r, 32);
memcpy(sig + 32, sign.s, 32);

memzero(&k, sizeof(k));
memzero(&rng, sizeof(rng));
Expand All @@ -1195,8 +1195,9 @@ int zil_schnorr_sign(const ecdsa_curve *curve, const uint8_t *priv_key, const ui
int zil_schnorr_verify(const ecdsa_curve *curve, const uint8_t *pub_key, const uint8_t *sig, const uint8_t *msg, const uint32_t msg_len)
{
schnorr_sign_pair sign;
bn_read_be(sig, &sign.r);
bn_read_be(sig + 32, &sign.s);

memcpy(sign.r, sig, 32);
memcpy(sign.s, sig + 32, 32);

return schnorr_verify(curve, pub_key, msg, msg_len, &sign);
}
134 changes: 65 additions & 69 deletions trezor-crypto/src/schnorr.c
Original file line number Diff line number Diff line change
Expand Up @@ -21,42 +21,6 @@
*/

#include <TrezorCrypto/schnorr.h>
#include <TrezorCrypto/memzero.h>
#include "stdio.h"

static void hex_to_raw_bignum(const char *str, uint8_t bn_raw[32]) {
for (size_t i = 0; i < 32; i++) {
uint8_t c = 0;
if (str[i * 2] >= '0' && str[i * 2] <= '9') c += (str[i * 2] - '0') << 4;
if ((str[i * 2] & ~0x20) >= 'A' && (str[i * 2] & ~0x20) <= 'F')
c += (10 + (str[i * 2] & ~0x20) - 'A') << 4;
if (str[i * 2 + 1] >= '0' && str[i * 2 + 1] <= '9')
c += (str[i * 2 + 1] - '0');
if ((str[i * 2 + 1] & ~0x20) >= 'A' && (str[i * 2 + 1] & ~0x20) <= 'F')
c += (10 + (str[i * 2 + 1] & ~0x20) - 'A');
bn_raw[i] = c;
}
}

static void bn_be_to_hex_string(const bignum256 *b, char result[64]) {
uint8_t raw_number[32] = {0};
bn_write_be(b, raw_number);
for (int i = 0; i < 32; ++i)
sprintf(result + i * 2, "%02x", ((unsigned char *)raw_number)[i]);
}

void schnorr_to_hex_str(const schnorr_sign_pair *sign, char hex_str[128]) {
bn_be_to_hex_string(&sign->r, hex_str);
bn_be_to_hex_string(&sign->s, hex_str + 64);
}

void schnorr_from_hex_str(const char hex_str[128], schnorr_sign_pair *sign) {
uint8_t buf[32];
hex_to_raw_bignum(hex_str, buf);
bn_read_be(buf, &sign->r);
hex_to_raw_bignum(hex_str + 64, buf);
bn_read_be(buf, &sign->s);
}

// r = H(Q, kpub, m)
static void calc_r(const curve_point *Q, const uint8_t pub_key[33],
Expand All @@ -71,69 +35,101 @@ static void calc_r(const curve_point *Q, const uint8_t pub_key[33],
sha256_Update(&ctx, pub_key, 33);
sha256_Update(&ctx, msg, msg_len);
sha256_Final(&ctx, digest);

// Convert the raw bigendian 256 bit value to a normalized, partly reduced bignum
bn_read_be(digest, r);
}

// returns 0 if signing succeeded
// Returns 0 if signing succeeded
int schnorr_sign(const ecdsa_curve *curve, const uint8_t *priv_key,
const bignum256 *k, const uint8_t *msg, const uint32_t msg_len,
schnorr_sign_pair *result) {
uint8_t pub_key[33];
curve_point Q;
bignum256 private_key_scalar;
bignum256 r_temp;
bignum256 s_temp;
bignum256 r_kpriv_result;

bn_read_be(priv_key, &private_key_scalar);
uint8_t pub_key[33];
ecdsa_get_public_key33(curve, priv_key, pub_key);

/* Q = kG */
curve_point Q;
scalar_multiply(curve, k, &Q);
// Compute commitment Q = kG
point_multiply(curve, k, &curve->G, &Q);

/* r = H(Q, kpub, m) */
calc_r(&Q, pub_key, msg, msg_len, &result->r);
// Compute challenge r = H(Q, kpub, m)
calc_r(&Q, pub_key, msg, msg_len, &r_temp);

// Fully reduce the bignum
bn_mod(&r_temp, &curve->order);

/* s = k - r*kpriv mod(order) */
bignum256 s_temp;
bn_copy(&result->r, &s_temp);
bn_multiply(&private_key_scalar, &s_temp, &curve->order);
bn_subtractmod(k, &s_temp, &result->s, &curve->order);
memzero(&private_key_scalar, sizeof(private_key_scalar));
// Convert the normalized, fully reduced bignum to a raw bigendian 256 bit value
bn_write_be(&r_temp, result->r);

while (bn_is_less(&curve->order, &result->s)) {
bn_mod(&result->s, &curve->order);
}
// Compute s = k - r*kpriv
bn_copy(&r_temp, &r_kpriv_result);

if (bn_is_zero(&result->s) || bn_is_zero(&result->r)) {
return 1;
}
// r*kpriv result is partly reduced
bn_multiply(&private_key_scalar, &r_kpriv_result, &curve->order);

// k - r*kpriv result is normalized but not reduced
bn_subtractmod(k, &r_kpriv_result, &s_temp, &curve->order);

// Partly reduce the result
bn_fast_mod(&s_temp, &curve->order);

// Fully reduce the result
bn_mod(&s_temp, &curve->order);

// Convert the normalized, fully reduced bignum to a raw bigendian 256 bit value
bn_write_be(&s_temp, result->s);

if (bn_is_zero(&r_temp) || bn_is_zero(&s_temp)) return 1;

return 0;
}

// returns 0 if verification succeeded
// Returns 0 if verification succeeded
int schnorr_verify(const ecdsa_curve *curve, const uint8_t *pub_key,
const uint8_t *msg, const uint32_t msg_len,
const schnorr_sign_pair *sign) {
curve_point pub_key_point;
curve_point sG, Q;
bignum256 r_temp;
bignum256 s_temp;
bignum256 r_computed;

if (msg_len == 0) return 1;
if (bn_is_zero(&sign->r)) return 2;
if (bn_is_zero(&sign->s)) return 3;
if (bn_is_less(&curve->order, &sign->r)) return 4;
if (bn_is_less(&curve->order, &sign->s)) return 5;

curve_point pub_key_point;
// Convert the raw bigendian 256 bit values to normalized, partly reduced bignums
bn_read_be(sign->r, &r_temp);
bn_read_be(sign->s, &s_temp);

// Check if r,s are in [1, ..., order-1]
if (bn_is_zero(&r_temp)) return 2;
if (bn_is_zero(&s_temp)) return 3;
if (bn_is_less(&curve->order, &r_temp)) return 4;
if (bn_is_less(&curve->order, &s_temp)) return 5;
if (bn_is_equal(&curve->order, &r_temp)) return 6;
if (bn_is_equal(&curve->order, &s_temp)) return 7;

if (!ecdsa_read_pubkey(curve, pub_key, &pub_key_point)) {
return 6;
return 8;
}

// Compute Q = sG + r*kpub
curve_point sG, Q;
scalar_multiply(curve, &sign->s, &sG);
point_multiply(curve, &sign->r, &pub_key_point, &Q);
point_multiply(curve, &s_temp, &curve->G, &sG);
point_multiply(curve, &r_temp, &pub_key_point, &Q);
point_add(curve, &sG, &Q);

/* r = H(Q, kpub, m) */
bignum256 r;
calc_r(&Q, pub_key, msg, msg_len, &r);
// Compute r' = H(Q, kpub, m)
calc_r(&Q, pub_key, msg, msg_len, &r_computed);

// Fully reduce the bignum
bn_mod(&r_computed, &curve->order);

if (bn_is_equal(&r, &sign->r)) return 0; // success
// Check r == r'
if (bn_is_equal(&r_temp, &r_computed)) return 0; // success

return 10;
}
95 changes: 61 additions & 34 deletions trezor-crypto/tests/test_check.c
Original file line number Diff line number Diff line change
Expand Up @@ -4806,29 +4806,28 @@ START_TEST(test_schnorr_sign_verify) {

const ecdsa_curve *curve = &secp256k1;
bignum256 k;
uint8_t priv_key[32], buf_raw[32];
uint8_t priv_key[32];
uint8_t pub_key[33];
uint8_t buf_raw[32];
schnorr_sign_pair result;
schnorr_sign_pair expected;
int res;

for (size_t i = 0; i < sizeof(test_cases) / sizeof(*test_cases); i++) {
memcpy(priv_key, fromhex(test_cases[i].priv_key), 32);
memcpy(&buf_raw, fromhex(test_cases[i].k_hex), 32);
bn_read_be(buf_raw, &k);
schnorr_sign(curve, priv_key, &k, (const uint8_t *)test_cases[i].message,
strlen(test_cases[i].message), &result);

schnorr_sign_pair expected;

memcpy(&buf_raw, fromhex(test_cases[i].s_hex), 32);
bn_read_be(buf_raw, &expected.s);
memcpy(&buf_raw, fromhex(test_cases[i].r_hex), 32);
bn_read_be(buf_raw, &expected.r);
memcpy(&expected.s, fromhex(test_cases[i].s_hex), 32);
memcpy(&expected.r, fromhex(test_cases[i].r_hex), 32);

ck_assert_mem_eq(&expected.r, &result.r, 32);
ck_assert_mem_eq(&expected.s, &result.s, 32);

uint8_t pub_key[33];
ecdsa_get_public_key33(curve, priv_key, pub_key);
int res =
schnorr_verify(curve, pub_key, (const uint8_t *)test_cases[i].message,
res = schnorr_verify(curve, pub_key, (const uint8_t *)test_cases[i].message,
strlen(test_cases[i].message), &result);
ck_assert_int_eq(res, 0);
}
Expand All @@ -4852,61 +4851,89 @@ START_TEST(test_schnorr_fail_verify) {

const ecdsa_curve *curve = &secp256k1;
bignum256 k;
uint8_t priv_key[32], buf_raw[32];
bignum256 bn_temp;
uint8_t priv_key[32];
uint8_t pub_key[33];
uint8_t buf_raw[32];
schnorr_sign_pair result;
schnorr_sign_pair bad_result;
int res;

memcpy(priv_key, fromhex(test_case.priv_key), 32);
memcpy(&buf_raw, fromhex(test_case.k_hex), 32);
bn_read_be(buf_raw, &k);

schnorr_sign_pair result;
schnorr_sign(curve, priv_key, &k, (const uint8_t *)test_case.message,
strlen(test_case.message), &result);
uint8_t pub_key[33];

ecdsa_get_public_key33(curve, priv_key, pub_key);

// OK
int res = schnorr_verify(curve, pub_key, (const uint8_t *)test_case.message,
// Test result = 0 (OK)
res = schnorr_verify(curve, pub_key, (const uint8_t *)test_case.message,
strlen(test_case.message), &result);
ck_assert_int_eq(res, 0);

// Test result = 1 (empty message)
res = schnorr_verify(curve, pub_key, (const uint8_t *)test_case.message, 0,
&result);
ck_assert_int_eq(res, 1);

schnorr_sign_pair bad_result;

bn_copy(&result.s, &bad_result.s);
bn_zero(&bad_result.r);
// r == 0
// Test result = 2 (r = 0)
bn_zero(&bn_temp);
bn_write_be(&bn_temp, bad_result.r);
memcpy(bad_result.s, result.s, 32);
res = schnorr_verify(curve, pub_key, (const uint8_t *)test_case.message,
strlen(test_case.message), &bad_result);
ck_assert_int_eq(res, 2);

bn_copy(&result.r, &bad_result.r);
bn_zero(&bad_result.s);
// s == 0
// Test result = 3 (s = 0)
memcpy(bad_result.r, result.r, 32);
bn_zero(&bn_temp);
bn_write_be(&bn_temp, bad_result.s);
res = schnorr_verify(curve, pub_key, (const uint8_t *)test_case.message,
strlen(test_case.message), &bad_result);
ck_assert_int_eq(res, 3);

bn_copy(&result.s, &bad_result.s);
bn_copy(&curve->order, &bad_result.r);
bn_addi(&bad_result.r, 1);
// r == curve->order + 1
// Test result = 4 (curve->order < r)
bn_copy(&curve->order, &bn_temp);
bn_addi(&bn_temp, 1);
bn_write_be(&bn_temp, bad_result.r);
memcpy(bad_result.s, result.s, 32);
res = schnorr_verify(curve, pub_key, (const uint8_t *)test_case.message,
strlen(test_case.message), &bad_result);
ck_assert_int_eq(res, 4);

bn_copy(&result.r, &bad_result.r);
bn_copy(&curve->order, &bad_result.s);
bn_addi(&bad_result.s, 1);
// s == curve->order + 1
// Test result = 5 (curve->order < s)
memcpy(bad_result.r, result.r, 32);
bn_copy(&curve->order, &bn_temp);
bn_addi(&bn_temp, 1);
bn_write_be(&bn_temp, bad_result.s);
res = schnorr_verify(curve, pub_key, (const uint8_t *)test_case.message,
strlen(test_case.message), &bad_result);
ck_assert_int_eq(res, 5);

bn_copy(&result.r, &bad_result.r);
bn_copy(&result.s, &bad_result.s);
// change message
// Test result = 6 (curve->order = r)
bn_copy(&curve->order, &bn_temp);
bn_write_be(&bn_temp, bad_result.r);
memcpy(bad_result.s, result.s, 32);
res = schnorr_verify(curve, pub_key, (const uint8_t *)test_case.message,
strlen(test_case.message), &bad_result);
ck_assert_int_eq(res, 6);

// Test result = 7 (curve->order = s)
memcpy(bad_result.r, result.r, 32);
bn_copy(&curve->order, &bn_temp);
bn_write_be(&bn_temp, bad_result.s);
res = schnorr_verify(curve, pub_key, (const uint8_t *)test_case.message,
strlen(test_case.message), &bad_result);
ck_assert_int_eq(res, 7);

// Test result = 8 (failed ecdsa_read_pubkey)
// TBD

// Test result = 10 (r != r')
memcpy(bad_result.r, result.r, 32);
memcpy(bad_result.s, result.s, 32);
test_case.message = "12";
res = schnorr_verify(curve, pub_key, (const uint8_t *)test_case.message,
strlen(test_case.message), &bad_result);
Expand Down

0 comments on commit a171e39

Please sign in to comment.