@@ -42,6 +42,12 @@ var jsonnetCache, _ = ristretto.NewCache(&ristretto.Config{
42
42
43
43
type MetadataType string
44
44
45
+ type OIDCProviderData struct {
46
+ Provider string `json:"provider"`
47
+ Tokens * identity.CredentialsOIDCEncryptedTokens `json:"tokens"`
48
+ Claims Claims `json:"claims"`
49
+ }
50
+
45
51
type VerifiedAddress struct {
46
52
Value string `json:"value"`
47
53
Via identity.VerifiableAddressType `json:"via"`
@@ -52,6 +58,8 @@ const (
52
58
53
59
PublicMetadata MetadataType = "identity.metadata_public"
54
60
AdminMetadata MetadataType = "identity.metadata_admin"
61
+
62
+ InternalContextKeyProviderData = "provider_data"
55
63
)
56
64
57
65
func (s * Strategy ) RegisterRegistrationRoutes (r * x.RouterPublic ) {
@@ -213,6 +221,25 @@ func (s *Strategy) Register(w http.ResponseWriter, r *http.Request, f *registrat
213
221
return errors .WithStack (flow .ErrCompletedByStrategy )
214
222
}
215
223
224
+ if oidcProviderData := gjson .GetBytes (f .InternalContext , flow .PrefixInternalContextKey (s .ID (), InternalContextKeyProviderData )); oidcProviderData .IsObject () {
225
+ var providerData OIDCProviderData
226
+ if err := json .Unmarshal ([]byte (oidcProviderData .Raw ), & providerData ); err != nil {
227
+ return s .handleError (ctx , w , r , f , pid , nil , errors .WithStack (herodot .ErrInternalServerError .WithReasonf ("Expected OIDC provider data in internal context to be an object but got: %s" , err )))
228
+ }
229
+ if pid != providerData .Provider {
230
+ return s .handleError (ctx , w , r , f , pid , nil , errors .WithStack (herodot .ErrInternalServerError .WithReasonf ("Expected OIDC provider data in internal context to have matching provider but got: %s" , providerData .Provider )))
231
+ }
232
+ _ , err = s .processRegistration (ctx , w , r , f , providerData .Tokens , & providerData .Claims , provider , & AuthCodeContainer {
233
+ FlowID : f .ID .String (),
234
+ Traits : p .Traits ,
235
+ TransientPayload : f .TransientPayload ,
236
+ }, "" )
237
+ if err != nil {
238
+ return s .handleError (ctx , w , r , f , pid , nil , err )
239
+ }
240
+ return errors .WithStack (flow .ErrCompletedByStrategy )
241
+ }
242
+
216
243
state := generateState (f .ID .String ())
217
244
if code , hasCode , _ := s .d .SessionTokenExchangePersister ().CodeForFlow (ctx , f .ID ); hasCode {
218
245
state .setCode (code .InitCode )
@@ -309,6 +336,13 @@ func (s *Strategy) processRegistration(ctx context.Context, w http.ResponseWrite
309
336
return nil , nil
310
337
}
311
338
339
+ providerDataKey := flow .PrefixInternalContextKey (s .ID (), InternalContextKeyProviderData )
340
+ if hasOIDCProviderData := gjson .GetBytes (rf .InternalContext , providerDataKey ).IsObject (); ! hasOIDCProviderData {
341
+ if internalContext , err := sjson .SetBytes (rf .InternalContext , providerDataKey , & OIDCProviderData {Provider : provider .Config ().ID , Tokens : token , Claims : * claims }); err == nil {
342
+ rf .InternalContext = internalContext
343
+ }
344
+ }
345
+
312
346
fetch := fetcher .NewFetcher (fetcher .WithClient (s .d .HTTPClient (r .Context ())), fetcher .WithCache (jsonnetCache , 60 * time .Minute ))
313
347
jsonnetMapperSnippet , err := fetch .FetchContext (r .Context (), provider .Config ().Mapper )
314
348
if err != nil {
@@ -347,6 +381,10 @@ func (s *Strategy) processRegistration(ctx context.Context, w http.ResponseWrite
347
381
return nil , s .handleError (ctx , w , r , rf , provider .Config ().ID , i .Traits , err )
348
382
}
349
383
384
+ if internalContext , err := sjson .DeleteBytes (rf .InternalContext , providerDataKey ); err == nil {
385
+ rf .InternalContext = internalContext
386
+ }
387
+
350
388
return nil , nil
351
389
}
352
390
0 commit comments