-
Notifications
You must be signed in to change notification settings - Fork 1
/
opk.go
118 lines (99 loc) · 2.89 KB
/
opk.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
package main
import (
"context"
"crypto"
"crypto/rsa"
"encoding/json"
"fmt"
"github.com/awnumar/memguard"
"github.com/lestrrat-go/jwx/v2/jwa"
"github.com/lestrrat-go/jwx/v2/jwk"
"github.com/lestrrat-go/jwx/v2/jws"
"github.com/openpubkey/openpubkey/client"
"github.com/openpubkey/openpubkey/gq"
"github.com/openpubkey/openpubkey/pktoken"
"github.com/openpubkey/openpubkey/pktoken/clientinstance"
"github.com/openpubkey/openpubkey/util"
)
const GQSecurityParameter = 256
const algo = jwa.ES256
func generateCic(
signer crypto.Signer,
) (*clientinstance.Claims, error) {
// Use our signing key to generate a JWK key with the alg header set
jwkKey, err := jwk.PublicKeyOf(signer)
if err != nil {
return nil, err
}
err = jwkKey.Set(jwk.AlgorithmKey, algo)
if err != nil {
return nil, err
}
// Use provided public key to generate client instance claims
cic, err := clientinstance.NewClaims(jwkKey, map[string]any{})
if err != nil {
return nil, fmt.Errorf("failed to instantiate client instance claims: %w", err)
}
// Define our OIDC nonce as a commitment to the client instance claims
// nonce, err := cic.Hash()
// if err != nil {
// return nil, fmt.Errorf("error getting nonce: %w", err)
// }
return cic, nil
}
func OidcAuth(
idToken *memguard.LockedBuffer,
cicToken []byte,
Op client.OpenIdProvider,
) (*pktoken.PKToken, error) {
headersB64, _, _, err := jws.SplitCompact(idToken.Bytes())
if err != nil {
return nil, fmt.Errorf("error getting original headers: %w", err)
}
headers := jws.NewHeaders()
err = parseJWTSegment(headersB64, &headers)
if err != nil {
return nil, err
}
opKey, err := Op.PublicKey(context.Background(), headers)
if err != nil {
return nil, fmt.Errorf("error getting OP public key: %w", err)
}
// if signGQ {
rsaPubKey := opKey.(*rsa.PublicKey)
sv, err := gq.NewSignerVerifier(rsaPubKey, GQSecurityParameter)
if err != nil {
return nil, fmt.Errorf("error creating GQ signer: %w", err)
}
gqToken, err := sv.SignJWT(idToken.Bytes())
if err != nil {
return nil, fmt.Errorf("error creating GQ signature: %w", err)
}
idToken = memguard.NewBufferFromBytes(gqToken)
// }
// Combine our ID token and signature over the cic to create our PK Token
pkt, err := pktoken.New(idToken.Bytes(), cicToken)
if err != nil {
return nil, fmt.Errorf("error creating PK Token: %w", err)
}
err = client.VerifyPKToken(context.Background(), pkt, Op)
if err != nil {
return nil, fmt.Errorf("error verifying PK Token: %w", err)
}
err = pkt.AddJKTHeader(opKey)
if err != nil {
return nil, fmt.Errorf("error adding JKT header: %w", err)
}
return pkt, nil
}
func parseJWTSegment(segment []byte, v any) error {
segmentJSON, err := util.Base64DecodeForJWT(segment)
if err != nil {
return fmt.Errorf("error decoding segment: %w", err)
}
err = json.Unmarshal(segmentJSON, v)
if err != nil {
return fmt.Errorf("error parsing segment: %w", err)
}
return nil
}