-
Notifications
You must be signed in to change notification settings - Fork 27
/
nmod.pyx
276 lines (233 loc) · 7.91 KB
/
nmod.pyx
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
from flint.flint_base.flint_base cimport flint_scalar
from flint.utils.typecheck cimport typecheck
from flint.types.fmpq cimport any_as_fmpq
from flint.types.fmpz cimport any_as_fmpz
from flint.types.fmpz cimport fmpz
from flint.types.fmpq cimport fmpq
from flint.flintlib.flint cimport ulong
from flint.flintlib.fmpz cimport fmpz_t
from flint.flintlib.nmod cimport nmod_pow_fmpz, nmod_inv
from flint.flintlib.nmod_vec cimport *
from flint.flintlib.fmpz cimport fmpz_fdiv_ui, fmpz_init, fmpz_clear
from flint.flintlib.fmpz cimport fmpz_set_ui, fmpz_get_ui
from flint.flintlib.fmpq cimport fmpq_mod_fmpz
from flint.flintlib.ulong_extras cimport n_gcdinv, n_sqrtmod
from flint.utils.flint_exceptions import DomainError
cdef int any_as_nmod(mp_limb_t * val, obj, nmod_t mod) except -1:
cdef int success
cdef fmpz_t t
if typecheck(obj, nmod):
if (<nmod>obj).mod.n != mod.n:
raise ValueError("cannot coerce integers mod n with different n")
val[0] = (<nmod>obj).val
return 1
z = any_as_fmpz(obj)
if z is not NotImplemented:
val[0] = fmpz_fdiv_ui((<fmpz>z).val, mod.n)
return 1
q = any_as_fmpq(obj)
if q is not NotImplemented:
fmpz_init(t)
fmpz_set_ui(t, mod.n)
success = fmpq_mod_fmpz(t, (<fmpq>q).val, t)
val[0] = fmpz_get_ui(t)
fmpz_clear(t)
if not success:
raise ZeroDivisionError("%s does not exist mod %i!" % (q, mod.n))
return 1
return 0
cdef class nmod(flint_scalar):
"""
The nmod type represents elements of Z/nZ for word-size n.
>>> nmod(10,17) * 2
3
"""
def __init__(self, val, mod):
cdef mp_limb_t m
m = mod
nmod_init(&self.mod, m)
if not any_as_nmod(&self.val, val, self.mod):
raise TypeError("cannot create nmod from object of type %s" % type(val))
def repr(self):
return "nmod(%s, %s)" % (self.val, self.mod.n)
def str(self):
return str(int(self.val))
def __int__(self):
return int(self.val)
def modulus(self):
return self.mod.n
def __richcmp__(s, t, int op):
cdef mp_limb_t v
cdef bint res
if op != 2 and op != 3:
raise TypeError("nmods cannot be ordered")
if typecheck(s, nmod) and typecheck(t, nmod):
res = ((<nmod>s).val == (<nmod>t).val) and \
((<nmod>s).mod.n == (<nmod>t).mod.n)
if op == 2:
return res
else:
return not res
elif typecheck(s, nmod) and typecheck(t, int):
res = s.val == (t % s.mod.n)
if op == 2:
return res
else:
return not res
return NotImplemented
def __hash__(self):
return hash((int(self.val), self.modulus))
def __bool__(self):
return self.val != 0
def __pos__(self):
return self
def __neg__(self):
cdef nmod r = nmod.__new__(nmod)
r.mod = self.mod
r.val = nmod_neg(self.val, self.mod)
return r
def __add__(s, t):
cdef nmod r
cdef mp_limb_t val
if any_as_nmod(&val, t, (<nmod>s).mod):
r = nmod.__new__(nmod)
r.mod = (<nmod>s).mod
r.val = nmod_add(val, (<nmod>s).val, r.mod)
return r
return NotImplemented
def __radd__(s, t):
cdef nmod r
cdef mp_limb_t val
if any_as_nmod(&val, t, (<nmod>s).mod):
r = nmod.__new__(nmod)
r.mod = (<nmod>s).mod
r.val = nmod_add((<nmod>s).val, val, r.mod)
return r
return NotImplemented
def __sub__(s, t):
cdef nmod r
cdef mp_limb_t val
if any_as_nmod(&val, t, (<nmod>s).mod):
r = nmod.__new__(nmod)
r.mod = (<nmod>s).mod
r.val = nmod_sub((<nmod>s).val, val, r.mod)
return r
return NotImplemented
def __rsub__(s, t):
cdef nmod r
cdef mp_limb_t val
if any_as_nmod(&val, t, (<nmod>s).mod):
r = nmod.__new__(nmod)
r.mod = (<nmod>s).mod
r.val = nmod_sub(val, (<nmod>s).val, r.mod)
return r
return NotImplemented
def __mul__(s, t):
cdef nmod r
cdef mp_limb_t val
if any_as_nmod(&val, t, (<nmod>s).mod):
r = nmod.__new__(nmod)
r.mod = (<nmod>s).mod
r.val = nmod_mul(val, (<nmod>s).val, r.mod)
return r
return NotImplemented
def __rmul__(s, t):
cdef nmod r
cdef mp_limb_t val
if any_as_nmod(&val, t, (<nmod>s).mod):
r = nmod.__new__(nmod)
r.mod = (<nmod>s).mod
r.val = nmod_mul((<nmod>s).val, val, r.mod)
return r
return NotImplemented
@staticmethod
def _div_(s, t):
cdef nmod r
cdef mp_limb_t sval, tval, x
cdef nmod_t mod
cdef ulong tinvval
if typecheck(s, nmod):
mod = (<nmod>s).mod
sval = (<nmod>s).val
if not any_as_nmod(&tval, t, mod):
return NotImplemented
else:
mod = (<nmod>t).mod
tval = (<nmod>t).val
if not any_as_nmod(&sval, s, mod):
return NotImplemented
if tval == 0:
raise ZeroDivisionError("%s is not invertible mod %s" % (tval, mod.n))
if not s:
return s
g = n_gcdinv(&tinvval, <ulong>tval, <ulong>mod.n)
if g != 1:
raise ZeroDivisionError("%s is not invertible mod %s" % (tval, mod.n))
r = nmod.__new__(nmod)
r.mod = mod
r.val = nmod_mul(sval, <mp_limb_t>tinvval, mod)
return r
def __truediv__(s, t):
return nmod._div_(s, t)
def __rtruediv__(s, t):
return nmod._div_(t, s)
def __invert__(self):
cdef nmod r
cdef ulong g, inv, sval
sval = <ulong>(<nmod>self).val
g = n_gcdinv(&inv, sval, self.mod.n)
if g != 1:
raise ZeroDivisionError("%s is not invertible mod %s" % (sval, self.mod.n))
r = nmod.__new__(nmod)
r.mod = self.mod
r.val = <mp_limb_t>inv
return r
def __pow__(self, exp, modulus=None):
cdef nmod r
cdef mp_limb_t rval, mod
cdef ulong g, rinv
if modulus is not None:
raise TypeError("three-argument pow() not supported by nmod")
e = any_as_fmpz(exp)
if e is NotImplemented:
return NotImplemented
rval = (<nmod>self).val
mod = (<nmod>self).mod.n
# XXX: It is not clear that it is necessary to special case negative
# exponents here. The nmod_pow_fmpz function seems to handle this fine
# but the Flint docs say that the exponent must be nonnegative.
if e < 0:
g = n_gcdinv(&rinv, <ulong>rval, <ulong>mod)
if g != 1:
raise ZeroDivisionError("%s is not invertible mod %s" % (rval, mod))
rval = <mp_limb_t>rinv
e = -e
r = nmod.__new__(nmod)
r.mod = self.mod
r.val = nmod_pow_fmpz(rval, (<fmpz>e).val, self.mod)
return r
def sqrt(self):
"""
Return the square root of this nmod or raise an exception.
>>> s = nmod(10, 13).sqrt()
>>> s
6
>>> s * s
10
>>> nmod(11, 13).sqrt()
Traceback (most recent call last):
...
flint.utils.flint_exceptions.DomainError: no square root exists for 11 mod 13
The modulus must be prime.
"""
cdef nmod r
cdef mp_limb_t val
r = nmod.__new__(nmod)
r.mod = self.mod
if self.val == 0:
return r
val = n_sqrtmod(self.val, self.mod.n)
if val == 0:
raise DomainError("no square root exists for %s mod %s" % (self.val, self.mod.n))
r.val = val
return r