forked from google/shell-encryption
-
Notifications
You must be signed in to change notification settings - Fork 0
/
montgomery.h
472 lines (405 loc) · 18.7 KB
/
montgomery.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
/*
* Copyright 2017 Google LLC.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// Defines types that are necessary for Ring-Learning with Errors (rlwe)
// encryption.
#ifndef RLWE_MONTGOMERY_H_
#define RLWE_MONTGOMERY_H_
#include <cmath>
#include <cstdint>
#include <tuple>
#include <vector>
#include <glog/logging.h>
#include "absl/numeric/int128.h"
#include "absl/strings/str_cat.h"
#include "bits_util.h"
#include "constants.h"
#include "int256.h"
#include "prng/prng.h"
#include "serialization.pb.h"
#include "status_macros.h"
#include "statusor.h"
namespace rlwe {
namespace internal {
// Struct to capture the "bigger int" type.
template <typename T>
struct BigInt;
// Specialization for uint8, uint16, uint32, uint64, and uint128.
template <>
struct BigInt<Uint8> {
typedef Uint16 value_type;
};
template <>
struct BigInt<Uint16> {
typedef Uint32 value_type;
};
template <>
struct BigInt<Uint32> {
typedef Uint64 value_type;
};
template <>
struct BigInt<Uint64> {
typedef absl::uint128 value_type;
};
template <>
struct BigInt<absl::uint128> {
typedef uint256 value_type;
};
} // namespace internal
// The parameters necessary for a Montgomery integer. Note that the template
// parameters ensure that T is an unsigned integral of at least 8 bits.
template <typename T>
struct MontgomeryIntParams {
// Expose Int and its greater type. BigInt is required in order to multiply
// two Int and ensure that no overflow occurs.
//
// Thread safe.
using Int = T;
using BigInt = typename internal::BigInt<Int>::value_type;
static const size_t bitsize_int = sizeof(Int) * 8;
// Factory function to create MontgomeryIntParams.
static rlwe::StatusOr<std::unique_ptr<const MontgomeryIntParams>> Create(
Int modulus);
// The value R to be used in Montgomery multiplication. R will be selected as
// 2^bitsize(Int) and hence automatically verifies R > modulus.
const BigInt r = static_cast<BigInt>(1) << bitsize_int;
// The modulus over which these modular operations are being performed.
const Int modulus;
// The modulus over which these modular operations are being performed, cast
// as a BigInt.
const BigInt modulus_bigint;
// The number of bits in the modulus.
const unsigned int log_modulus;
// The value of R taken modulo the modulus.
const Int r_mod_modulus;
const Int
r_mod_modulus_barrett; // = (r_mod_modulus << bitsize_int) / modulus, to
// speed up multiplication by r_mod_modulus.
// The values below, inv_modulus and inv_r, satisfy the formula:
// R * inv_r - modulus * inv_modulus = 1
// Note that inv_modulus is not literally the multiplicative inverse of
// modulus modulo R.
const Int inv_modulus;
const Int
inv_r; // needed to translate from Montgomery to normal representation
const Int inv_r_barrett; // = (inv_r << bitsize_int) / modulus, to speed up
// multiplication by inv_r.
// The numerator used in the Barrett reduction.
const BigInt barrett_numerator; // = 2^(sizeof(Int)*8) / modulus
inline const Int Zero() const { return 0; }
inline const Int One() const { return 1; }
inline const Int Two() const { return 2; }
// Functions to perform Barrett reduction. For more details, see
// https://en.wikipedia.org/wiki/Barrett_reduction.
// This function should remains in the header file to avoid performance
// regressions.
inline Int BarrettReduce(Int input) const {
Int out =
static_cast<Int>((this->barrett_numerator * input) >> bitsize_int);
out = input - (out * this->modulus);
// The steps above produce an integer that is in the range [0, 2N).
// We now reduce to the range [0, N).
return (out >= this->modulus) ? out - this->modulus : out;
}
// Computes the serialized byte length of an integer.
inline unsigned int SerializedSize() const { return (log_modulus + 7) / 8; }
// Check whether (1 << log_n) fits into the underlying Int type.
static bool DoesLogNFit(Uint64 log_n) { return (log_n < bitsize_int - 1); }
private:
MontgomeryIntParams(Int mod)
: modulus(mod),
modulus_bigint(static_cast<BigInt>(this->modulus)),
log_modulus(internal::BitLength(this->modulus)),
r_mod_modulus(static_cast<Int>(this->r % this->modulus_bigint)),
r_mod_modulus_barrett(static_cast<Int>(
(static_cast<BigInt>(r_mod_modulus) << bitsize_int) / modulus)),
inv_modulus(static_cast<Int>(
std::get<1>(MontgomeryIntParams::Inverses(modulus_bigint, r)))),
inv_r(std::get<0>(MontgomeryIntParams::Inverses(modulus_bigint, r))),
inv_r_barrett(static_cast<Int>(
(static_cast<BigInt>(inv_r) << bitsize_int) / modulus)),
barrett_numerator(this->r / this->modulus_bigint) {}
// Computes the Montgomery inverse coefficients for r and modulus using
// the Extended Euclidean Algorithm.
//
// modulus must be odd.
// Returns a tuple of (inv_r, inv_modulus) such that:
// r * inv_r - modulus * inv_modulus = 1
static std::tuple<Int, Int> Inverses(BigInt modulus_bigint, BigInt r);
};
// Stores an integer in Montgomery representation. The goal of this
// class is to provide a static type that differentiates between integers
// in Montgomery representation and standard integers. Once a Montgomery
// integer is created, it has the * and + operators necessary to be treated
// as another integer.
// The underlying integer type T must be unsigned and must not be bool.
// This class is thread safe.
template <typename T>
class ABSL_MUST_USE_RESULT MontgomeryInt {
public:
// Expose Int and its greater type. BigInt is required in order to multiply
// two Int and ensure that no overflow occurs. This should also be used by
// external classes.
using Int = T;
using BigInt = typename internal::BigInt<Int>::value_type;
// Expose the parameter type.
using Params = MontgomeryIntParams<T>;
// Static factory that converts a non-Montgomery representation integer, the
// underlying integer type, into a Montgomery representation integer. Does not
// take ownership of params. i.e., import "a".
static rlwe::StatusOr<MontgomeryInt> ImportInt(Int n, const Params* params);
// Static functions to create a MontgomeryInt of 0 and 1.
static MontgomeryInt ImportZero(const Params* params);
static MontgomeryInt ImportOne(const Params* params);
// Import a random integer using entropy from specified prng. Does not take
// ownership of params or prng.
template <typename Prng = rlwe::SecurePrng>
static rlwe::StatusOr<MontgomeryInt> ImportRandom(Prng* prng,
const Params* params) {
// In order to generate unbiased randomness, we uniformly and randomly
// sample integers in [0, 2^params->log_modulus) until the generated integer
// is less than the modulus (i.e., we perform rejection sampling).
RLWE_ASSIGN_OR_RETURN(Int random_int,
GenerateRandomInt(params->log_modulus, prng));
while (random_int >= params->modulus) {
RLWE_ASSIGN_OR_RETURN(random_int,
GenerateRandomInt(params->log_modulus, prng));
}
return MontgomeryInt(random_int);
}
static BigInt DivAndTruncate(BigInt dividend, BigInt divisor);
// No default constructor.
MontgomeryInt() = delete;
// Default copy constructor.
MontgomeryInt(const MontgomeryInt& that) = default;
MontgomeryInt& operator=(const MontgomeryInt& that) = default;
// Convert a Montgomery representation integer back to the underlying integer.
// i.e., export "a".
inline Int ExportInt(const Params* params) const {
Int out =
static_cast<Int>((static_cast<BigInt>(params->inv_r_barrett) * n_) >>
params->bitsize_int);
out = n_ * params->inv_r - out * params->modulus;
// The steps above produce an integer that is in the range [0, 2N).
// We now reduce to the range [0, N).
out -= (out >= params->modulus) ? params->modulus : 0;
return out;
}
// Returns the least significant 64 bits of n.
static Uint64 ExportUInt64(Int n) { return static_cast<Uint64>(n); }
// Serialization.
rlwe::StatusOr<std::string> Serialize(const Params* params) const;
static rlwe::StatusOr<std::string> SerializeVector(
const std::vector<MontgomeryInt>& coeffs, const Params* params);
// Deserialization.
static rlwe::StatusOr<MontgomeryInt> Deserialize(absl::string_view payload,
const Params* params);
static rlwe::StatusOr<std::vector<MontgomeryInt>> DeserializeVector(
int num_coeffs, absl::string_view serialized, const Params* params);
// Modular multiplication.
// Perform a multiply followed by a modular reduction. Produces an output
// in the range [0, N).
// Taken from Hacker's Delight chapter on Montgomery multiplication.
MontgomeryInt Mul(const MontgomeryInt& that, const Params* params) const {
MontgomeryInt out(*this);
return out.MulInPlace(that, params);
}
// This function should remains in the header file to avoid performance
// regressions.
MontgomeryInt& MulInPlace(const MontgomeryInt& that, const Params* params) {
// This function computes the product of the two numbers (a and b), which
// will equal a * R * b * R in Montgomery representation. It then performs
// the reduction, creating a * b * R (mod N).
Int u = static_cast<Int>(n_ * that.n_) * params->inv_modulus;
BigInt t = static_cast<BigInt>(n_) * that.n_ + params->modulus_bigint * u;
Int t_msb = static_cast<Int>(t >> Params::bitsize_int);
// The steps above produce an integer that is in the range [0, 2N).
// We now reduce to the range [0, N).
n_ = (t_msb >= params->modulus) ? t_msb - params->modulus : t_msb;
return *this;
}
// Modular multiplication by a constant with precomputation.
// MontgomeryInt are stored under the form n * R, where R = 1 << bitsize_int.
// When multiplying by a constant c, one can precompute
// constant = (c * R^(-1) mod modulus)
// constant_barrett = (constant << bitsize_int) / modulus
// and perform modular reduction using Barrett reduction instead of
// Montgomery multiplication.
// The funciton GetConstant() returns a tuple (constant, constant_barrett):
// constant = ExportInt(params),
// and
// constant_barrett = (constant << bitsize_int) / modulus
std::tuple<Int, Int> GetConstant(const Params* params) const;
MontgomeryInt MulConstant(const Int& constant, const Int& constant_barrett,
const Params* params) const {
MontgomeryInt out(*this);
return out.MulConstantInPlace(constant, constant_barrett, params);
}
// This function should remains in the header file to avoid performance
// regressions.
MontgomeryInt& MulConstantInPlace(const Int& constant,
const Int& constant_barrett,
const Params* params) {
Int out = static_cast<Int>((static_cast<BigInt>(constant_barrett) * n_) >>
Params::bitsize_int);
n_ = n_ * constant - out * params->modulus;
// The steps above produce an integer that is in the range [0, 2N).
// We now reduce to the range [0, N).
n_ -= (n_ >= params->modulus) ? params->modulus : 0;
return *this;
}
// Montgomery addition.
MontgomeryInt Add(const MontgomeryInt& that, const Params* params) const {
MontgomeryInt out(*this);
return out.AddInPlace(that, params);
}
// This function should remains in the header file to avoid performance
// regressions.
MontgomeryInt& AddInPlace(const MontgomeryInt& that, const Params* params) {
// We can use Barrett reduction because n_ <= modulus < Max(Int)/2.
n_ = params->BarrettReduce(n_ + that.n_);
return *this;
}
// Modular negation.
MontgomeryInt Negate(const Params* params) const {
return MontgomeryInt(params->modulus - n_);
}
// This function should remains in the header file to avoid performance
// regressions.
MontgomeryInt& NegateInPlace(const Params* params) {
n_ = params->modulus - n_;
return *this;
}
// Modular subtraction.
MontgomeryInt Sub(const MontgomeryInt& that, const Params* params) const {
MontgomeryInt out(*this);
return out.SubInPlace(that, params);
}
// This function should remains in the header file to avoid performance
// regressions.
MontgomeryInt& SubInPlace(const MontgomeryInt& that, const Params* params) {
// We can use Barrett reduction because n_ <= modulus < Max(Int)/2.
n_ = params->BarrettReduce(n_ + (params->modulus - that.n_));
return *this;
}
// Batch operations (and in-place variants) over vectors of MontgomeryInt.
// We define two versions of the batch operations:
// - the first input is a vector and the second is a MontgomeryInt scalar:
// in that case, the scalar will be added/subtracted/multiplied with each
// element of the first input.
// - both inputs are vectors of same length: in that case, the operations
// will be performed component wise.
// These batch operations may fail if the input vectors are not of the same
// size.
// Batch addition of two vectors.
static rlwe::StatusOr<std::vector<MontgomeryInt>> BatchAdd(
const std::vector<MontgomeryInt>& in1,
const std::vector<MontgomeryInt>& in2, const Params* params);
static absl::Status BatchAddInPlace(std::vector<MontgomeryInt>* in1,
const std::vector<MontgomeryInt>& in2,
const Params* params);
// Batch addition of one vector with a scalar.
static rlwe::StatusOr<std::vector<MontgomeryInt>> BatchAdd(
const std::vector<MontgomeryInt>& in1, const MontgomeryInt& in2,
const Params* params);
static absl::Status BatchAddInPlace(std::vector<MontgomeryInt>* in1,
const MontgomeryInt& in2,
const Params* params);
// Batch subtraction of two vectors.
static rlwe::StatusOr<std::vector<MontgomeryInt>> BatchSub(
const std::vector<MontgomeryInt>& in1,
const std::vector<MontgomeryInt>& in2, const Params* params);
static absl::Status BatchSubInPlace(std::vector<MontgomeryInt>* in1,
const std::vector<MontgomeryInt>& in2,
const Params* params);
// Batch subtraction of one vector with a scalar.
static rlwe::StatusOr<std::vector<MontgomeryInt>> BatchSub(
const std::vector<MontgomeryInt>& in1, const MontgomeryInt& in2,
const Params* params);
static absl::Status BatchSubInPlace(std::vector<MontgomeryInt>* in1,
const MontgomeryInt& in2,
const Params* params);
// Batch multiplication of two vectors.
static rlwe::StatusOr<std::vector<MontgomeryInt>> BatchMul(
const std::vector<MontgomeryInt>& in1,
const std::vector<MontgomeryInt>& in2, const Params* params);
static absl::Status BatchMulInPlace(std::vector<MontgomeryInt>* in1,
const std::vector<MontgomeryInt>& in2,
const Params* params);
// Batch multiplication of two vectors, where the second vector is a constant.
static rlwe::StatusOr<std::vector<MontgomeryInt>> BatchMulConstant(
const std::vector<MontgomeryInt>& in1,
const std::vector<Int>& in2_constant,
const std::vector<Int>& in2_constant_barrett, const Params* params);
static absl::Status BatchMulConstantInPlace(
std::vector<MontgomeryInt>* in1, const std::vector<Int>& in2_constant,
const std::vector<Int>& in2_constant_barrett, const Params* params);
// Batch multiplication of a vector with a scalar.
static rlwe::StatusOr<std::vector<MontgomeryInt>> BatchMul(
const std::vector<MontgomeryInt>& in1, const MontgomeryInt& in2,
const Params* params);
static absl::Status BatchMulInPlace(std::vector<MontgomeryInt>* in1,
const MontgomeryInt& in2,
const Params* params);
// Batch multiplication of a vector with a constant scalar.
static rlwe::StatusOr<std::vector<MontgomeryInt>> BatchMulConstant(
const std::vector<MontgomeryInt>& in1, const Int& constant,
const Int& constant_barrett, const Params* params);
static absl::Status BatchMulConstantInPlace(std::vector<MontgomeryInt>* in1,
const Int& constant,
const Int& constant_barrett,
const Params* params);
// Equality.
bool operator==(const MontgomeryInt& that) const { return (n_ == that.n_); }
bool operator!=(const MontgomeryInt& that) const { return !(*this == that); }
// Modular exponentiation.
MontgomeryInt ModExp(Int exponent, const Params* params) const;
// Inverse.
MontgomeryInt MultiplicativeInverse(const Params* params) const;
private:
template <typename Prng = rlwe::SecurePrng>
static rlwe::StatusOr<Int> GenerateRandomInt(int log_modulus, Prng* prng) {
// Generate a random Int. As the modulus is always smaller than max(Int),
// there will be no issues with overflow.
int max_bits_per_step = std::min((int)Params::bitsize_int, (int)64);
auto bits_required = log_modulus;
Int rand = 0;
while (bits_required > 0) {
Int rand_bits = 0;
if (bits_required <= 8) {
// Generate 8 bits of randomness.
RLWE_ASSIGN_OR_RETURN(rand_bits, prng->Rand8());
// Extract bits_required bits and add them to rand.
Int needed_bits =
rand_bits & ((static_cast<Int>(1) << bits_required) - 1);
rand = (rand << bits_required) + needed_bits;
break;
} else {
// Generate 64 bits of randomness.
RLWE_ASSIGN_OR_RETURN(rand_bits, prng->Rand64());
// Extract min(64, bits in Int, bits_required) bits and add them to rand
int bits_to_extract = std::min(bits_required, max_bits_per_step);
Int needed_bits =
rand_bits & ((static_cast<Int>(1) << bits_to_extract) - 1);
rand = (rand << bits_to_extract) + needed_bits;
bits_required -= bits_to_extract;
}
}
return rand;
}
explicit MontgomeryInt(Int n) : n_(n) {}
Int n_;
};
} // namespace rlwe
#endif // RLWE_MONTGOMERY_H_