diff --git a/auth.go b/auth.go index ff6415776..658259b24 100644 --- a/auth.go +++ b/auth.go @@ -228,40 +228,40 @@ func encryptPassword(password string, seed []byte, pub *rsa.PublicKey) ([]byte, return rsa.EncryptOAEP(sha1, rand.Reader, pub, plain, nil) } -// Derived from https://github.com/MariaDB/server/blob/d8e6bb00888b1f82c031938f4c8ac5d97f6874c3/plugin/auth_ed25519/ref10/sign.c -func doEd25519Auth(scramble []byte, password string) ([]byte, error) { +// authEd25519 does ed25519 authentication used by MariaDB. +func authEd25519(scramble []byte, password string) ([]byte, error) { + // Derived from https://github.com/MariaDB/server/blob/d8e6bb00888b1f82c031938f4c8ac5d97f6874c3/plugin/auth_ed25519/ref10/sign.c + // Code style is from https://cs.opensource.google/go/go/+/refs/tags/go1.21.5:src/crypto/ed25519/ed25519.go;l=207 h := sha512.Sum512([]byte(password)) s, err := edwards25519.NewScalar().SetBytesWithClamping(h[:32]) if err != nil { return nil, err } + A := (&edwards25519.Point{}).ScalarBaseMult(s) - nonceHash := sha512.New() - nonceHash.Write(h[32:]) - nonceHash.Write(scramble) - nonce := nonceHash.Sum(nil) - - r, err := edwards25519.NewScalar().SetUniformBytes(nonce) + mh := sha512.New() + mh.Write(h[32:]) + mh.Write(scramble) + messageDigest := mh.Sum(nil) + r, err := edwards25519.NewScalar().SetUniformBytes(messageDigest) if err != nil { return nil, err } - R := (&edwards25519.Point{}).ScalarBaseMult(r) - A := (&edwards25519.Point{}).ScalarBaseMult(s) - - kHash := sha512.New() - kHash.Write(R.Bytes()) - kHash.Write(A.Bytes()) - kHash.Write(scramble) - k := kHash.Sum(nil) + R := (&edwards25519.Point{}).ScalarBaseMult(r) - K, err := edwards25519.NewScalar().SetUniformBytes(k) + kh := sha512.New() + kh.Write(R.Bytes()) + kh.Write(A.Bytes()) + kh.Write(scramble) + hramDigest := kh.Sum(nil) + k, err := edwards25519.NewScalar().SetUniformBytes(hramDigest) if err != nil { return nil, err } - S := K.MultiplyAdd(K, s, r) + S := k.MultiplyAdd(k, s, r) return append(R.Bytes(), S.Bytes()...), nil } @@ -335,8 +335,7 @@ func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, error) { if len(authData) != 32 { return nil, ErrMalformPkt } - - return doEd25519Auth(authData, mc.cfg.Passwd) + return authEd25519(authData, mc.cfg.Passwd) default: mc.cfg.Logger.Print("unknown auth plugin:", plugin)