Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 32 additions & 10 deletions service/entityresolution/multi-strategy/registration.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package multistrategy

import (
"context"
"encoding/json"
"fmt"
"log/slog"

Expand All @@ -14,6 +15,7 @@ import (
"github.com/opentdf/platform/service/logger"
"github.com/opentdf/platform/service/pkg/serviceregistry"
"go.opentelemetry.io/otel/trace"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/types/known/structpb"
)

Expand Down Expand Up @@ -43,25 +45,42 @@ func (ers *ERS) ResolveEntities(
ctx context.Context,
req *connect.Request[entityresolution.ResolveEntitiesRequest],
) (*connect.Response[entityresolution.ResolveEntitiesResponse], error) {
// Extract JWT claims from context (this would be set by authentication middleware)
jwtClaims, ok := ctx.Value(types.JWTClaimsContextKey).(types.JWTClaims)
if !ok {
ers.logger.Warn("no JWT claims found in context for multi-strategy ERS")
jwtClaims = make(types.JWTClaims)
}

payload := req.Msg.GetEntities()
resolvedEntities := make([]*entityresolution.EntityRepresentation, 0, len(payload))

for _, entity := range payload {
entityID := entity.GetId()
if entityID == "" {
ers.logger.Warn("empty entity ID in request")
continue
}

var claimsMap types.JWTClaims
switch entity.GetEntityType().(type) {
case *authorization.Entity_Claims:
claims := entity.GetClaims()
if claims != nil {
// First unmarshal to structpb.Struct
var claimsStruct structpb.Struct
err := claims.UnmarshalTo(&claimsStruct)
if err != nil {
return nil, connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("error unpacking anypb.Any to structpb.Struct: %w", err))
}
// Convert to map[string]interface{}
claimsMap = claimsStruct.AsMap()
}
default:
entityBytes, err := protojson.Marshal(entity)
if err != nil {
return nil, err
}
err = json.Unmarshal(entityBytes, &claimsMap)
if err != nil {
return nil, err
}
}

// Resolve entity using multi-strategy service
result, err := ers.service.ResolveEntity(ctx, entityID, jwtClaims)
result, err := ers.service.ResolveEntity(ctx, entityID, claimsMap)
if err != nil {
ers.logger.Error("failed to resolve entity",
slog.String("entity_id", entityID),
Expand Down Expand Up @@ -212,8 +231,11 @@ func (ers *ERS) createEntityChainFromSingleToken(ctx context.Context, token *aut
for _, strategy := range strategies {
attemptedStrategies = append(attemptedStrategies, strategy.Name)

// Put JWT claims into context for providers to access
ctxWithClaims := context.WithValue(ctx, types.JWTClaimsContextKey, jwtClaims)

// Resolve entity using this strategy
entityResult, err := ers.service.ResolveEntity(ctx, token.GetId(), jwtClaims)
entityResult, err := ers.service.ResolveEntity(ctxWithClaims, token.GetId(), jwtClaims)
if err != nil {
lastError = err
ers.logger.WarnContext(ctx, "strategy failed for token",
Expand Down
10 changes: 5 additions & 5 deletions service/entityresolution/multi-strategy/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,17 +60,17 @@ func (s *Service) GetConfig() types.MultiStrategyConfig {
}

// ResolveEntity resolves entity information using the configured strategies
func (s *Service) ResolveEntity(ctx context.Context, entityID string, jwtClaims types.JWTClaims) (*types.EntityResult, error) {
func (s *Service) ResolveEntity(ctx context.Context, entityID string, claimsMap types.JWTClaims) (*types.EntityResult, error) {
// Get all matching strategies based on JWT claims
strategies, err := s.strategyMatcher.SelectStrategies(ctx, jwtClaims)
strategies, err := s.strategyMatcher.SelectStrategies(ctx, claimsMap)
if err != nil {
return nil, types.WrapMultiStrategyError(
types.ErrorTypeStrategy,
"failed to select strategies",
err,
map[string]interface{}{
"entity_id": entityID,
"jwt_claims": extractClaimNames(jwtClaims),
"entity_map": extractClaimNames(claimsMap),
},
)
}
Expand All @@ -88,7 +88,7 @@ func (s *Service) ResolveEntity(ctx context.Context, entityID string, jwtClaims
for _, strategy := range strategies {
attemptedStrategies = append(attemptedStrategies, strategy.Name)

result, err := s.executeStrategy(ctx, entityID, jwtClaims, strategy)
result, err := s.executeStrategy(ctx, entityID, claimsMap, strategy)
if err != nil {
lastError = err

Expand Down Expand Up @@ -130,7 +130,7 @@ func (s *Service) ResolveEntity(ctx context.Context, entityID string, jwtClaims
"entity_id": entityID,
"failure_strategy": failureStrategy,
"attempted_strategies": attemptedStrategies,
"jwt_claims": extractClaimNames(jwtClaims),
"entity_map": extractClaimNames(claimsMap),
},
)
}
Expand Down
12 changes: 6 additions & 6 deletions service/entityresolution/multi-strategy/strategy_matcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,34 +31,34 @@ func (sm *StrategyMatcher) SelectStrategy(_ context.Context, claims types.JWTCla

return nil, types.NewStrategyError("no matching strategy found", map[string]interface{}{
"available_strategies": len(sm.strategies),
"jwt_claims": extractClaimNames(claims),
"entity_map": extractClaimNames(claims),
})
}

// SelectStrategies returns all strategies that match the JWT claims in configuration order
func (sm *StrategyMatcher) SelectStrategies(_ context.Context, claims types.JWTClaims) ([]*types.MappingStrategy, error) {
func (sm *StrategyMatcher) SelectStrategies(_ context.Context, claimsMap types.JWTClaims) ([]*types.MappingStrategy, error) {
var matchingStrategies []*types.MappingStrategy

for _, strategy := range sm.strategies {
if sm.matchesConditions(claims, strategy.Conditions) {
if sm.matchesConditions(claimsMap, strategy.Conditions) {
matchingStrategies = append(matchingStrategies, &strategy)
}
}

if len(matchingStrategies) == 0 {
return nil, types.NewStrategyError("no matching strategy found", map[string]interface{}{
"available_strategies": len(sm.strategies),
"jwt_claims": extractClaimNames(claims),
"entity_map": extractClaimNames(claimsMap),
})
}

return matchingStrategies, nil
}

// matchesConditions checks if JWT claims match strategy conditions
func (sm *StrategyMatcher) matchesConditions(claims types.JWTClaims, conditions types.StrategyConditions) bool {
func (sm *StrategyMatcher) matchesConditions(claimsMap types.JWTClaims, conditions types.StrategyConditions) bool {
for _, claimCondition := range conditions.JWTClaims {
if !sm.matchesClaimCondition(claims, claimCondition) {
if !sm.matchesClaimCondition(claimsMap, claimCondition) {
return false
}
}
Expand Down
56 changes: 38 additions & 18 deletions service/entityresolution/multi-strategy/v2/registration.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,23 @@ package multistrategy

import (
"context"
"encoding/json"
"fmt"
"log/slog"
"strconv"

"connectrpc.com/connect"
"github.com/go-viper/mapstructure/v2"
"github.com/lestrrat-go/jwx/v2/jwt"
"github.com/opentdf/platform/protocol/go/entity"
ersV2 "github.com/opentdf/platform/protocol/go/entityresolution/v2"
ent "github.com/opentdf/platform/service/entity"
multistrategy "github.com/opentdf/platform/service/entityresolution/multi-strategy"
"github.com/opentdf/platform/service/entityresolution/multi-strategy/types"
"github.com/opentdf/platform/service/logger"
"github.com/opentdf/platform/service/pkg/serviceregistry"
"go.opentelemetry.io/otel/trace"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/types/known/structpb"
)

Expand Down Expand Up @@ -49,27 +53,43 @@ func (ers *ERSV2) ResolveEntities(
ctx context.Context,
req *connect.Request[ersV2.ResolveEntitiesRequest],
) (*connect.Response[ersV2.ResolveEntitiesResponse], error) {
// Extract JWT claims from context (this would be set by authentication middleware)
jwtClaims, ok := ctx.Value(types.JWTClaimsContextKey).(types.JWTClaims)
if !ok {
ers.logger.Warn("no JWT claims found in context for multi-strategy ERS v2")
// For ResolveEntities, we need JWT claims to be provided by middleware
// This is different from CreateEntityChainsFromTokens which has the JWT token directly
jwtClaims = make(types.JWTClaims)
}

payload := req.Msg.GetEntities()
resolvedEntities := make([]*ersV2.EntityRepresentation, 0, len(payload))

for _, entityV2 := range payload {
for idx, entityV2 := range payload {
entityID := entityV2.GetEphemeralId()
if entityID == "" {
ers.logger.Warn("empty entity ID in request")
continue
entityID = ent.EntityIDPrefix + strconv.Itoa(idx)
ers.logger.Warn("empty entity ID in request; using generated ID", slog.String("entity_id", entityID))
}

var claimsMap types.JWTClaims
switch entityV2.GetEntityType().(type) {
case *entity.Entity_Claims:
claims := entityV2.GetClaims()
if claims != nil {
// First unmarshal to structpb.Struct
var claimsStruct structpb.Struct
err := claims.UnmarshalTo(&claimsStruct)
if err != nil {
return nil, connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("error unpacking anypb.Any to structpb.Struct: %w", err))
}
// Convert to map[string]interface{}
claimsMap = claimsStruct.AsMap()
}
default:
entityBytes, err := protojson.Marshal(entityV2)
if err != nil {
return nil, err
}
err = json.Unmarshal(entityBytes, &claimsMap)
if err != nil {
return nil, err
}
}

// Resolve entity using multi-strategy service
result, err := ers.service.ResolveEntity(ctx, entityID, jwtClaims)
result, err := ers.service.ResolveEntity(ctx, entityID, claimsMap)
if err != nil {
ers.logger.Error("failed to resolve entity",
slog.String("entity_id", entityID),
Expand Down Expand Up @@ -188,7 +208,7 @@ func (ers *ERSV2) createEntityChainFromSingleTokenV2(ctx context.Context, token
err,
map[string]interface{}{
"token_id": token.GetEphemeralId(),
"jwt_claims": extractClaimNames(jwtClaims),
"entity_map": extractClaimNames(jwtClaims),
},
)
}
Expand All @@ -198,7 +218,7 @@ func (ers *ERSV2) createEntityChainFromSingleTokenV2(ctx context.Context, token
"no matching strategies found for JWT claims",
map[string]interface{}{
"token_id": token.GetEphemeralId(),
"jwt_claims": extractClaimNames(jwtClaims),
"entity_map": extractClaimNames(jwtClaims),
},
)
}
Expand Down Expand Up @@ -276,7 +296,7 @@ func (ers *ERSV2) createEntityChainFromSingleTokenV2(ctx context.Context, token
"token_id": token.GetEphemeralId(),
"failure_strategy": failureStrategy,
"attempted_strategies": attemptedStrategies,
"jwt_claims": extractClaimNames(jwtClaims),
"entity_map": extractClaimNames(jwtClaims),
},
)
}
Expand Down Expand Up @@ -416,7 +436,7 @@ func getEntityValueV2(entityType interface{}) string {
}

// RegisterMultiStrategyERSV2 registers the v2 multi-strategy ERS service
func RegisterERSV2(config map[string]interface{}, logger *logger.Logger) (*ERSV2, serviceregistry.HandlerServer) {
func RegisterMultiStrategyERSV2(config map[string]interface{}, logger *logger.Logger) (*ERSV2, serviceregistry.HandlerServer) {
var multiStrategyConfig types.MultiStrategyConfig

if err := mapstructure.Decode(config, &multiStrategyConfig); err != nil {
Expand All @@ -433,7 +453,7 @@ func RegisterERSV2(config map[string]interface{}, logger *logger.Logger) (*ERSV2
return ers, nil
}

// extractClaimNames extracts the names of claims from JWTClaims for logging
// extractClaimNames extracts the names of fields from JWTClaims for logging
func extractClaimNames(claims types.JWTClaims) []string {
names := make([]string, 0, len(claims))
for name := range claims {
Expand Down
2 changes: 1 addition & 1 deletion service/entityresolution/v2/entity_resolution.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ func NewRegistration() *serviceregistry.Service[entityresolutionv2connect.Entity
claimsSVC.Tracer = srp.Tracer
return EntityResolution{EntityResolutionServiceHandler: claimsSVC}, claimsHandler
case MultiStrategyMode:
multiSVC, multiHandler := multistrategyv2.RegisterERSV2(srp.Config, srp.Logger)
multiSVC, multiHandler := multistrategyv2.RegisterMultiStrategyERSV2(srp.Config, srp.Logger)
multiSVC.Tracer = srp.Tracer
return EntityResolution{EntityResolutionServiceHandler: multiSVC}, multiHandler
default:
Expand Down
Loading