forked from triplewz/poseidon
-
Notifications
You must be signed in to change notification settings - Fork 0
/
mds.go
236 lines (199 loc) · 6.43 KB
/
mds.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
package poseidon
import (
"github.com/pkg/errors"
ff "github.com/triplewz/poseidon/bls12_381"
)
// mdsMatrices is matrices for improving the efficiency of Poseidon hash.
// see more details in the paper https://eprint.iacr.org/2019/458.pdf page 20.
type mdsMatrices struct {
// the input mds matrix.
m Matrix
// mInv is the inverse of the mds matrix.
mInv Matrix
// mHat is the matrix by eliminating the first row and column of the matrix.
mHat Matrix
// mHatInv is the inverse of the mHat matrix.
mHatInv Matrix
// mPrime is the matrix m' in the paper, and it holds m = m'*m''.
// mPrime consists of:
// 1 | 0
// 0 | mHat
mPrime Matrix
// mDoublePrime is the matrix m'' in the paper, and it holds m = m'*m''.
// mDoublePrime consists of:
// m_00 | v
// w_hat | I
// where M_00 is the first element of the mds matrix,
// w_hat and v are t-1 length vectors,
// I is the (t-1)*(t-1) identity matrix.
mDoublePrime Matrix
}
// SparseMatrix is specifically one of the form of m''.
// This means its first row and column are each dense, and the interior matrix
// (minor to the element in both the row and column) is the identity.
// For simplicity, we omit the identity matrix in m''.
type SparseMatrix struct {
// wHat is the first column of the M'' matrix, this is a little different with the wHat in the paper because
// we add M_00 to the beginning of the wHat.
wHat Vector
// v contains all but the first element, because it is already included in wHat.
v Vector
}
// create the mds matrices.
func createMDSMatrix(t int) (*mdsMatrices, error) {
m := genMDS(t)
return deriveMatrices(m)
}
// generate the mds (cauchy) matrix, which is invertible, and
// its sub-matrices are invertible as well.
func genMDS(t int) Matrix {
xVec := make([]*ff.Element, t)
yVec := make([]*ff.Element, t)
regen:
// generate x and y value where x[i] != y[i] to allow the values to be inverted, and
// there are no duplicates in the x vector or y vector, so that
// the determinant is always non-zero.
for i := 0; i < t; i++ {
xVec[i] = new(ff.Element).SetUint64(uint64(i))
yVec[i] = new(ff.Element).SetUint64(uint64(i + t))
}
m := make([][]*ff.Element, t)
for i := 0; i < t; i++ {
m[i] = make([]*ff.Element, t)
for j := 0; j < t; j++ {
m[i][j] = new(ff.Element).Add(xVec[i], yVec[j])
m[i][j].Inverse(m[i][j])
}
}
// m must be invertible.
if !IsInvertible(m) {
t++
goto regen
}
// m must be symmetric.
transm := transpose(m)
if !IsEqual(transm, m) {
panic("m is not symmetric!")
}
return m
}
// derive the mds matrices from m.
func deriveMatrices(m Matrix) (*mdsMatrices, error) {
mInv, err := Invert(m)
if err != nil {
return nil, errors.Errorf("gen mInv failed, err: %s", err)
}
mHat, err := minor(m, 0, 0)
if err != nil {
return nil, errors.Errorf("gen mHat failed, err: %s", err)
}
mHatInv, err := Invert(mHat)
if err != nil {
return nil, errors.Errorf("gen mHatInv failed, err: %s", err)
}
mPrime := genPrime(m)
mDoublePrime, err := genDoublePrime(m, mHatInv)
if err != nil {
return nil, errors.Errorf("gen double prime m failed, err: %s", err)
}
return &mdsMatrices{m, mInv, mHat, mHatInv, mPrime, mDoublePrime}, nil
}
// generate the matrix m', where m = m'*m''.
func genPrime(m Matrix) Matrix {
prime := make([][]*ff.Element, row(m))
prime[0] = append(prime[0], one)
for i := 1; i < column(m); i++ {
prime[0] = append(prime[0], zero)
}
for i := 1; i < row(m); i++ {
prime[i] = make([]*ff.Element, column(m))
prime[i][0] = zero
for j := 1; j < column(m); j++ {
prime[i][j] = m[i][j]
}
}
return prime
}
// generate the matrix m'', where m = m'*m''.
func genDoublePrime(m, mHatInv Matrix) (Matrix, error) {
w, v := genPreVectors(m)
wHat, err := LeftMatMul(mHatInv, w)
if err != nil {
return nil, errors.Errorf("compute wHat failed, err: %s", err)
}
doublePrime := make([][]*ff.Element, row(m))
doublePrime[0] = append([]*ff.Element{m[0][0]}, v...)
for i := 1; i < row(m); i++ {
doublePrime[i] = make([]*ff.Element, column(m))
doublePrime[i][0] = wHat[i-1]
for j := 1; j < column(m); j++ {
if j == i {
doublePrime[i][j] = one
} else {
doublePrime[i][j] = zero
}
}
}
return doublePrime, nil
}
// generate pre-computed vectors used in the sparse matrix.
func genPreVectors(m Matrix) (Vector, Vector) {
v := make([]*ff.Element, column(m)-1)
copy(v, m[0][1:])
w := make([]*ff.Element, row(m)-1)
for i := 1; i < row(m); i++ {
w[i-1] = m[i][0]
}
return w, v
}
// parseSparseMatrix parses the sparse matrix.
func parseSparseMatrix(m Matrix) (*SparseMatrix, error) {
sub, err := minor(m, 0, 0)
if err != nil {
return nil, errors.Errorf("get the sub matrix err: %s", err)
}
// m should be the sparse matrix, which has a (t-1)*(t-1) sub identity matrix.
if !IsSquareMatrix(m) || !IsIdentity(sub) {
return nil, errors.Errorf("cannot parse the sparse matrix!")
}
// wHat is the first column of the sparse matrix.
sparse := new(SparseMatrix)
sparse.wHat = make([]*ff.Element, row(m))
for i := 0; i < column(m); i++ {
sparse.wHat[i] = m[i][0]
}
// v contains all but the first element.
sparse.v = make([]*ff.Element, column(m)-1)
copy(sparse.v, m[0][1:])
return sparse, nil
}
// generate the sparse and pre-sparse matrices for fast computation of the Poseidon hash.
// we refer to the paper https://eprint.iacr.org/2019/458.pdf page 20 and
// the implementation in https://github.com/filecoin-project/neptune.
// at each partial round, use a sparse matrix instead of a dense matrix.
// to do this, we have to factored into two components, such that m' x m'' = m,
// use the sparse matrix m'' as the mds matrix,
// then the previous layer's m is replaced by m x m' = m*.
// from the last partial round, do the same work to the first partial round.
func genSparseMatrix(m Matrix, rp int) ([]*SparseMatrix, Matrix, error) {
sparses := make([]*SparseMatrix, rp)
preSparse := copyMatrixRows(m, 0, row(m))
for i := 0; i < rp; i++ {
mds, err := deriveMatrices(preSparse)
if err != nil {
return nil, nil, errors.Errorf("derive mds matrices err: %s", err)
}
// m* = m x m'
mat, err := MatMul(m, mds.mPrime)
if err != nil {
return nil, nil, errors.Errorf("get the previous layer's matrix err: %s", err)
}
// parse the sparse matrix by reverse order.
sparses[rp-i-1], err = parseSparseMatrix(mds.mDoublePrime)
if err != nil {
return nil, nil, errors.Errorf("parse sparse matrix err: %s", err)
}
preSparse = copyMatrixRows(mat, 0, row(mat))
}
return sparses, preSparse, nil
}