From bef0d9ab2baad0fc2076a804f422b750b8364349 Mon Sep 17 00:00:00 2001 From: Joseph Schorr Date: Tue, 1 Oct 2024 12:28:49 -0400 Subject: [PATCH 01/15] Relationships selected in SQL-based datastores should elide columns that have static values This vastly reduces data over the wire, as well as deserialization time and memory usage --- internal/datastore/common/relationships.go | 163 +++++++++ .../common/{tuple.go => sliceiter.go} | 0 internal/datastore/common/sql.go | 326 +++++++++++++----- internal/datastore/common/sql_test.go | 242 ++++--------- internal/datastore/crdb/caveat.go | 6 +- internal/datastore/crdb/crdb.go | 54 ++- internal/datastore/crdb/reader.go | 84 +---- internal/datastore/crdb/readwrite.go | 8 +- internal/datastore/crdb/stats.go | 8 +- internal/datastore/crdb/watch.go | 4 +- internal/datastore/mysql/caveat.go | 4 +- internal/datastore/mysql/datastore.go | 120 +++---- internal/datastore/mysql/reader.go | 33 +- internal/datastore/postgres/caveat.go | 4 +- internal/datastore/postgres/common/pgx.go | 136 +------- internal/datastore/postgres/postgres.go | 20 ++ internal/datastore/postgres/reader.go | 45 +-- internal/datastore/postgres/readwrite.go | 4 +- internal/datastore/postgres/stats.go | 4 +- internal/datastore/spanner/reader.go | 68 ++-- 20 files changed, 660 insertions(+), 673 deletions(-) create mode 100644 internal/datastore/common/relationships.go rename internal/datastore/common/{tuple.go => sliceiter.go} (100%) diff --git a/internal/datastore/common/relationships.go b/internal/datastore/common/relationships.go new file mode 100644 index 0000000000..a08cafd958 --- /dev/null +++ b/internal/datastore/common/relationships.go @@ -0,0 +1,163 @@ +package common + +import ( + "context" + "database/sql" + "fmt" + "time" + + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" + "google.golang.org/protobuf/types/known/timestamppb" + + "github.com/authzed/spicedb/pkg/datastore" + corev1 "github.com/authzed/spicedb/pkg/proto/core/v1" + "github.com/authzed/spicedb/pkg/tuple" +) + +const errUnableToQueryRels = "unable to query relationships: %w" + +// StaticValueOrAddColumnForSelect adds a column to the list of columns to select if the value +// is not static, otherwise it sets the value to the static value. +func StaticValueOrAddColumnForSelect(colsToSelect []any, queryInfo QueryInfo, colName string, field *string) []any { + // If the value is static, set the field to it and return. + if found, ok := queryInfo.FilteringValues[colName]; ok && found.SingleValue != nil { + *field = *found.SingleValue + return colsToSelect + } + + // Otherwise, add the column to the list of columns to select, as the value is not static. + colsToSelect = append(colsToSelect, field) + return colsToSelect +} + +// Querier is an interface for querying the database. +type Querier[R Rows] interface { + QueryFunc(ctx context.Context, f func(context.Context, R) error, sql string, args ...any) error +} + +// Rows is a common interface for database rows reading. +type Rows interface { + Scan(dest ...any) error + Next() bool + Err() error +} + +type closeRowsWithError interface { + Rows + Close() error +} + +type closeRows interface { + Rows + Close() +} + +// QueryRelationships queries relationships for the given query and transaction. +func QueryRelationships[R Rows, C ~map[string]any](ctx context.Context, queryInfo QueryInfo, sqlStatement string, args []any, span trace.Span, tx Querier[R], withIntegrity bool) (datastore.RelationshipIterator, error) { + defer span.End() + + colsToSelect := make([]any, 0, 8) + var resourceObjectType string + var resourceObjectID string + var resourceRelation string + var subjectObjectType string + var subjectObjectID string + var subjectRelation string + var caveatName sql.NullString + var caveatCtx C + var expiration *time.Time + + var integrityKeyID string + var integrityHash []byte + var timestamp time.Time + + colsToSelect = StaticValueOrAddColumnForSelect(colsToSelect, queryInfo, queryInfo.Schema.ColNamespace, &resourceObjectType) + colsToSelect = StaticValueOrAddColumnForSelect(colsToSelect, queryInfo, queryInfo.Schema.ColObjectID, &resourceObjectID) + colsToSelect = StaticValueOrAddColumnForSelect(colsToSelect, queryInfo, queryInfo.Schema.ColRelation, &resourceRelation) + colsToSelect = StaticValueOrAddColumnForSelect(colsToSelect, queryInfo, queryInfo.Schema.ColUsersetNamespace, &subjectObjectType) + colsToSelect = StaticValueOrAddColumnForSelect(colsToSelect, queryInfo, queryInfo.Schema.ColUsersetObjectID, &subjectObjectID) + colsToSelect = StaticValueOrAddColumnForSelect(colsToSelect, queryInfo, queryInfo.Schema.ColUsersetRelation, &subjectRelation) + + colsToSelect = append(colsToSelect, &caveatName, &caveatCtx, &expiration) + if withIntegrity { + colsToSelect = append(colsToSelect, &integrityKeyID, &integrityHash, ×tamp) + } + + return func(yield func(tuple.Relationship, error) bool) { + err := tx.QueryFunc(ctx, func(ctx context.Context, rows R) error { + var r Rows = rows + if crwe, ok := r.(closeRowsWithError); ok { + defer LogOnError(ctx, crwe.Close) + } else if cr, ok := r.(closeRows); ok { + defer cr.Close() + } + + span.AddEvent("Query issued to database") + relCount := 0 + for rows.Next() { + if err := rows.Scan(colsToSelect...); err != nil { + return fmt.Errorf(errUnableToQueryRels, fmt.Errorf("scan err: %w", err)) + } + + var caveat *corev1.ContextualizedCaveat + if caveatName.Valid { + var err error + caveat, err = ContextualizedCaveatFrom(caveatName.String, caveatCtx) + if err != nil { + return fmt.Errorf(errUnableToQueryRels, fmt.Errorf("unable to fetch caveat context: %w", err)) + } + } + + var integrity *corev1.RelationshipIntegrity + if integrityKeyID != "" { + integrity = &corev1.RelationshipIntegrity{ + KeyId: integrityKeyID, + Hash: integrityHash, + HashedAt: timestamppb.New(timestamp), + } + } + + if expiration != nil { + // Ensure the expiration is always read in UTC, since some datastores (like CRDB) + // will normalize to local time. + t := expiration.UTC() + expiration = &t + } + + relCount++ + if !yield(tuple.Relationship{ + RelationshipReference: tuple.RelationshipReference{ + Resource: tuple.ObjectAndRelation{ + ObjectType: resourceObjectType, + ObjectID: resourceObjectID, + Relation: resourceRelation, + }, + Subject: tuple.ObjectAndRelation{ + ObjectType: subjectObjectType, + ObjectID: subjectObjectID, + Relation: subjectRelation, + }, + }, + OptionalCaveat: caveat, + OptionalExpiration: expiration, + OptionalIntegrity: integrity, + }, nil) { + return nil + } + } + + if err := rows.Err(); err != nil { + return fmt.Errorf(errUnableToQueryRels, fmt.Errorf("rows err: %w", err)) + } + + span.AddEvent("Rels loaded", trace.WithAttributes(attribute.Int("relCount", relCount))) + return nil + }, sqlStatement, args...) + if err != nil { + if !yield(tuple.Relationship{}, err) { + return + } + } + }, nil +} diff --git a/internal/datastore/common/tuple.go b/internal/datastore/common/sliceiter.go similarity index 100% rename from internal/datastore/common/tuple.go rename to internal/datastore/common/sliceiter.go diff --git a/internal/datastore/common/sql.go b/internal/datastore/common/sql.go index d886927931..f313edf322 100644 --- a/internal/datastore/common/sql.go +++ b/internal/datastore/common/sql.go @@ -64,19 +64,34 @@ const ( // SchemaInformation holds the schema information from the SQL datastore implementation. type SchemaInformation struct { - colNamespace string - colObjectID string - colRelation string - colUsersetNamespace string - colUsersetObjectID string - colUsersetRelation string - colCaveatName string - colExpiration string - paginationFilterType PaginationFilterType - nowFunction string + RelationshipTableName string + ColNamespace string + ColObjectID string + ColRelation string + ColUsersetNamespace string + ColUsersetObjectID string + ColUsersetRelation string + ColCaveatName string + ColCaveatContext string + ColExpiration string + + // PaginationFilterType is the type of pagination filter to use for this schema. + PaginationFilterType PaginationFilterType + + // PlaceholderFormat is the format of placeholders to use for this schema. + PlaceholderFormat sq.PlaceholderFormat + + // NowFunction is the function to use to get the current time in the datastore. + NowFunction string + + // ExtaFields are additional fields that are not part of the core schema, but are + // requested by the caller for this query. + ExtraFields []string } +// NewSchemaInformation creates a new SchemaInformation object for a query. func NewSchemaInformation( + relationshipTableName, colNamespace, colObjectID, colRelation, @@ -84,11 +99,15 @@ func NewSchemaInformation( colUsersetObjectID, colUsersetRelation, colCaveatName string, + colCaveatContext string, colExpiration string, paginationFilterType PaginationFilterType, + placeholderFormat sq.PlaceholderFormat, nowFunction string, + extraFields ...string, ) SchemaInformation { return SchemaInformation{ + relationshipTableName, colNamespace, colObjectID, colRelation, @@ -96,43 +115,111 @@ func NewSchemaInformation( colUsersetObjectID, colUsersetRelation, colCaveatName, + colCaveatContext, colExpiration, paginationFilterType, + placeholderFormat, nowFunction, + extraFields, } } +type ColumnTracker struct { + SingleValue *string +} + // SchemaQueryFilterer wraps a SchemaInformation and SelectBuilder to give an opinionated // way to build query objects. type SchemaQueryFilterer struct { - schema SchemaInformation - queryBuilder sq.SelectBuilder - filteringColumnCounts map[string]int - filterMaximumIDCount uint16 + schema SchemaInformation + queryBuilder sq.SelectBuilder + filteringColumnTracker map[string]ColumnTracker + filterMaximumIDCount uint16 + isCustomQuery bool + extraFields []string + fromSuffix string +} + +// NewSchemaQueryFiltererForRelationshipsSelect creates a new SchemaQueryFilterer object for selecting +// relationships. This method will automatically filter the columns retrieved from the database, only +// selecting the columns that are not already specified with a single static value in the query. +func NewSchemaQueryFiltererForRelationshipsSelect(schema SchemaInformation, filterMaximumIDCount uint16, extraFields ...string) SchemaQueryFilterer { + if filterMaximumIDCount == 0 { + filterMaximumIDCount = 100 + log.Warn().Msg("SchemaQueryFilterer: filterMaximumIDCount not set, defaulting to 100") + } + + // Filter out any expired relationships. + // TODO(jschorr): Make this depend on whether expiration is necessary. + queryBuilder := sq.StatementBuilder.PlaceholderFormat(schema.PlaceholderFormat).Select().Where(sq.Or{ + sq.Eq{schema.ColExpiration: nil}, + sq.Expr(schema.ColExpiration + " > " + schema.NowFunction + "()"), + }) + + return SchemaQueryFilterer{ + schema: schema, + queryBuilder: queryBuilder, + filteringColumnTracker: map[string]ColumnTracker{}, + filterMaximumIDCount: filterMaximumIDCount, + isCustomQuery: false, + extraFields: extraFields, + } } -// NewSchemaQueryFilterer creates a new SchemaQueryFilterer object. -func NewSchemaQueryFilterer(schema SchemaInformation, initialQuery sq.SelectBuilder, filterMaximumIDCount uint16) SchemaQueryFilterer { +// NewSchemaQueryFiltererWithStartingQuery creates a new SchemaQueryFilterer object for selecting +// relationships, with a custom starting query. Unlike NewSchemaQueryFiltererForRelationshipsSelect, +// this method will not auto-filter the columns retrieved from the database. +func NewSchemaQueryFiltererWithStartingQuery(schema SchemaInformation, startingQuery sq.SelectBuilder, filterMaximumIDCount uint16) SchemaQueryFilterer { if filterMaximumIDCount == 0 { filterMaximumIDCount = 100 log.Warn().Msg("SchemaQueryFilterer: filterMaximumIDCount not set, defaulting to 100") } // Filter out any expired relationships. - initialQuery = initialQuery.Where(sq.Or{ - sq.Eq{schema.colExpiration: nil}, - sq.Expr(schema.colExpiration + " > " + schema.nowFunction + "()"), + // TODO(jschorr): Make this depend on whether expiration is necessary. + startingQuery = startingQuery.Where(sq.Or{ + sq.Eq{schema.ColExpiration: nil}, + sq.Expr(schema.ColExpiration + " > " + schema.NowFunction + "()"), }) return SchemaQueryFilterer{ - schema: schema, - queryBuilder: initialQuery, - filteringColumnCounts: map[string]int{}, - filterMaximumIDCount: filterMaximumIDCount, + schema: schema, + queryBuilder: startingQuery, + filteringColumnTracker: map[string]ColumnTracker{}, + filterMaximumIDCount: filterMaximumIDCount, + isCustomQuery: true, + extraFields: nil, + } +} + +// WithAdditionalFilter returns a new SchemaQueryFilterer with an additional filter applied to the query. +func (sqf SchemaQueryFilterer) WithAdditionalFilter(filter func(original sq.SelectBuilder) sq.SelectBuilder) SchemaQueryFilterer { + return SchemaQueryFilterer{ + schema: sqf.schema, + queryBuilder: filter(sqf.queryBuilder), + filteringColumnTracker: sqf.filteringColumnTracker, + filterMaximumIDCount: sqf.filterMaximumIDCount, + isCustomQuery: sqf.isCustomQuery, + extraFields: sqf.extraFields, + } +} + +func (sqf SchemaQueryFilterer) WithFromSuffix(fromSuffix string) SchemaQueryFilterer { + return SchemaQueryFilterer{ + schema: sqf.schema, + queryBuilder: sqf.queryBuilder, + filteringColumnTracker: sqf.filteringColumnTracker, + filterMaximumIDCount: sqf.filterMaximumIDCount, + isCustomQuery: sqf.isCustomQuery, + extraFields: sqf.extraFields, + fromSuffix: fromSuffix, } } func (sqf SchemaQueryFilterer) UnderlyingQueryBuilder() sq.SelectBuilder { + spiceerrors.DebugAssert(func() bool { + return sqf.isCustomQuery + }, "UnderlyingQueryBuilder should only be called on custom queries") return sqf.queryBuilder } @@ -140,22 +227,22 @@ func (sqf SchemaQueryFilterer) TupleOrder(order options.SortOrder) SchemaQueryFi switch order { case options.ByResource: sqf.queryBuilder = sqf.queryBuilder.OrderBy( - sqf.schema.colNamespace, - sqf.schema.colObjectID, - sqf.schema.colRelation, - sqf.schema.colUsersetNamespace, - sqf.schema.colUsersetObjectID, - sqf.schema.colUsersetRelation, + sqf.schema.ColNamespace, + sqf.schema.ColObjectID, + sqf.schema.ColRelation, + sqf.schema.ColUsersetNamespace, + sqf.schema.ColUsersetObjectID, + sqf.schema.ColUsersetRelation, ) case options.BySubject: sqf.queryBuilder = sqf.queryBuilder.OrderBy( - sqf.schema.colUsersetNamespace, - sqf.schema.colUsersetObjectID, - sqf.schema.colUsersetRelation, - sqf.schema.colNamespace, - sqf.schema.colObjectID, - sqf.schema.colRelation, + sqf.schema.ColUsersetNamespace, + sqf.schema.ColUsersetObjectID, + sqf.schema.ColUsersetRelation, + sqf.schema.ColNamespace, + sqf.schema.ColObjectID, + sqf.schema.ColRelation, ) } @@ -174,47 +261,47 @@ func (sqf SchemaQueryFilterer) After(cursor options.Cursor, order options.SortOr columnsAndValues := map[options.SortOrder][]nameAndValue{ options.ByResource: { { - sqf.schema.colNamespace, cursor.Resource.ObjectType, + sqf.schema.ColNamespace, cursor.Resource.ObjectType, }, { - sqf.schema.colObjectID, cursor.Resource.ObjectID, + sqf.schema.ColObjectID, cursor.Resource.ObjectID, }, { - sqf.schema.colRelation, cursor.Resource.Relation, + sqf.schema.ColRelation, cursor.Resource.Relation, }, { - sqf.schema.colUsersetNamespace, cursor.Subject.ObjectType, + sqf.schema.ColUsersetNamespace, cursor.Subject.ObjectType, }, { - sqf.schema.colUsersetObjectID, cursor.Subject.ObjectID, + sqf.schema.ColUsersetObjectID, cursor.Subject.ObjectID, }, { - sqf.schema.colUsersetRelation, cursor.Subject.Relation, + sqf.schema.ColUsersetRelation, cursor.Subject.Relation, }, }, options.BySubject: { { - sqf.schema.colUsersetNamespace, cursor.Subject.ObjectType, + sqf.schema.ColUsersetNamespace, cursor.Subject.ObjectType, }, { - sqf.schema.colUsersetObjectID, cursor.Subject.ObjectID, + sqf.schema.ColUsersetObjectID, cursor.Subject.ObjectID, }, { - sqf.schema.colNamespace, cursor.Resource.ObjectType, + sqf.schema.ColNamespace, cursor.Resource.ObjectType, }, { - sqf.schema.colObjectID, cursor.Resource.ObjectID, + sqf.schema.ColObjectID, cursor.Resource.ObjectID, }, { - sqf.schema.colRelation, cursor.Resource.Relation, + sqf.schema.ColRelation, cursor.Resource.Relation, }, { - sqf.schema.colUsersetRelation, cursor.Subject.Relation, + sqf.schema.ColUsersetRelation, cursor.Subject.Relation, }, }, }[order] - switch sqf.schema.paginationFilterType { + switch sqf.schema.PaginationFilterType { case TupleComparison: // For performance reasons, remove any column names that have static values in the query. columnNames := make([]string, 0, len(columnsAndValues)) @@ -222,7 +309,7 @@ func (sqf SchemaQueryFilterer) After(cursor options.Cursor, order options.SortOr comparisonSlotCount := 0 for _, cav := range columnsAndValues { - if sqf.filteringColumnCounts[cav.name] != 1 { + if r, ok := sqf.filteringColumnTracker[cav.name]; !ok || r.SingleValue == nil { columnNames = append(columnNames, cav.name) valueSlots = append(valueSlots, cav.value) comparisonSlotCount++ @@ -242,10 +329,10 @@ func (sqf SchemaQueryFilterer) After(cursor options.Cursor, order options.SortOr orClause := sq.Or{} for index, cav := range columnsAndValues { - if sqf.filteringColumnCounts[cav.name] != 1 { + if r, ok := sqf.filteringColumnTracker[cav.name]; !ok || r.SingleValue != nil { andClause := sq.And{} for _, previous := range columnsAndValues[0:index] { - if sqf.filteringColumnCounts[previous.name] != 1 { + if r, ok := sqf.filteringColumnTracker[previous.name]; !ok || r.SingleValue != nil { andClause = append(andClause, sq.Eq{previous.name: previous.value}) } } @@ -266,25 +353,31 @@ func (sqf SchemaQueryFilterer) After(cursor options.Cursor, order options.SortOr // FilterToResourceType returns a new SchemaQueryFilterer that is limited to resources of the // specified type. func (sqf SchemaQueryFilterer) FilterToResourceType(resourceType string) SchemaQueryFilterer { - sqf.queryBuilder = sqf.queryBuilder.Where(sq.Eq{sqf.schema.colNamespace: resourceType}) - sqf.recordColumnValue(sqf.schema.colNamespace) + sqf.queryBuilder = sqf.queryBuilder.Where(sq.Eq{sqf.schema.ColNamespace: resourceType}) + sqf.recordColumnValue(sqf.schema.ColNamespace, resourceType) return sqf } -func (sqf SchemaQueryFilterer) recordColumnValue(colName string) { - if value, ok := sqf.filteringColumnCounts[colName]; ok { - sqf.filteringColumnCounts[colName] = value + 1 - return +func (sqf SchemaQueryFilterer) recordColumnValue(colName string, colValue string) { + existing, ok := sqf.filteringColumnTracker[colName] + if ok { + if existing.SingleValue != nil && *existing.SingleValue != colValue { + sqf.filteringColumnTracker[colName] = ColumnTracker{SingleValue: nil} + } + } else { + sqf.filteringColumnTracker[colName] = ColumnTracker{SingleValue: &colValue} } +} - sqf.filteringColumnCounts[colName] = 1 +func (sqf SchemaQueryFilterer) recordMutableColumnValue(colName string) { + sqf.filteringColumnTracker[colName] = ColumnTracker{SingleValue: nil} } // FilterToResourceID returns a new SchemaQueryFilterer that is limited to resources with the // specified ID. func (sqf SchemaQueryFilterer) FilterToResourceID(objectID string) SchemaQueryFilterer { - sqf.queryBuilder = sqf.queryBuilder.Where(sq.Eq{sqf.schema.colObjectID: objectID}) - sqf.recordColumnValue(sqf.schema.colObjectID) + sqf.queryBuilder = sqf.queryBuilder.Where(sq.Eq{sqf.schema.ColObjectID: objectID}) + sqf.recordColumnValue(sqf.schema.ColObjectID, objectID) return sqf } @@ -309,7 +402,7 @@ func (sqf SchemaQueryFilterer) FilterWithResourceIDPrefix(prefix string) (Schema prefix = strings.ReplaceAll(prefix, `\`, `\\`) prefix = strings.ReplaceAll(prefix, "_", `\_`) - sqf.queryBuilder = sqf.queryBuilder.Where(sq.Like{sqf.schema.colObjectID: prefix + "%"}) + sqf.queryBuilder = sqf.queryBuilder.Where(sq.Like{sqf.schema.ColObjectID: prefix + "%"}) // NOTE: we do *not* record the use of the resource ID column here, because it is not used // statically and thus is necessary for sorting operations. @@ -332,7 +425,7 @@ func (sqf SchemaQueryFilterer) FilterToResourceIDs(resourceIds []string) (Schema }, "cannot have more than %d resource IDs in a single filter", sqf.filterMaximumIDCount) var builder strings.Builder - builder.WriteString(sqf.schema.colObjectID) + builder.WriteString(sqf.schema.ColObjectID) builder.WriteString(" IN (") args := make([]any, 0, len(resourceIds)) @@ -342,7 +435,7 @@ func (sqf SchemaQueryFilterer) FilterToResourceIDs(resourceIds []string) (Schema } args = append(args, resourceID) - sqf.recordColumnValue(sqf.schema.colObjectID) + sqf.recordColumnValue(sqf.schema.ColObjectID, resourceID) } builder.WriteString("?") @@ -358,8 +451,8 @@ func (sqf SchemaQueryFilterer) FilterToResourceIDs(resourceIds []string) (Schema // FilterToRelation returns a new SchemaQueryFilterer that is limited to resources with the // specified relation. func (sqf SchemaQueryFilterer) FilterToRelation(relation string) SchemaQueryFilterer { - sqf.queryBuilder = sqf.queryBuilder.Where(sq.Eq{sqf.schema.colRelation: relation}) - sqf.recordColumnValue(sqf.schema.colRelation) + sqf.queryBuilder = sqf.queryBuilder.Where(sq.Eq{sqf.schema.ColRelation: relation}) + sqf.recordColumnValue(sqf.schema.ColRelation, relation) return sqf } @@ -417,9 +510,9 @@ func (sqf SchemaQueryFilterer) FilterWithRelationshipsFilter(filter datastore.Re } if filter.OptionalExpirationOption == datastore.ExpirationFilterOptionHasExpiration { - csqf.queryBuilder = csqf.queryBuilder.Where(sq.NotEq{csqf.schema.colExpiration: nil}) + csqf.queryBuilder = csqf.queryBuilder.Where(sq.NotEq{csqf.schema.ColExpiration: nil}) } else if filter.OptionalExpirationOption == datastore.ExpirationFilterOptionNoExpiration { - csqf.queryBuilder = csqf.queryBuilder.Where(sq.Eq{csqf.schema.colExpiration: nil}) + csqf.queryBuilder = csqf.queryBuilder.Where(sq.Eq{csqf.schema.ColExpiration: nil}) } return csqf, nil @@ -440,12 +533,21 @@ func (sqf SchemaQueryFilterer) MustFilterWithSubjectsSelectors(selectors ...data func (sqf SchemaQueryFilterer) FilterWithSubjectsSelectors(selectors ...datastore.SubjectsSelector) (SchemaQueryFilterer, error) { selectorsOrClause := sq.Or{} + // If there is more than a single filter, record all the subjects as mutable, as the subjects returned + // can differ for each branch. + // TODO(jschorr): Optimize this further where applicable. + if len(selectors) > 1 { + sqf.recordMutableColumnValue(sqf.schema.ColUsersetNamespace) + sqf.recordMutableColumnValue(sqf.schema.ColUsersetObjectID) + sqf.recordMutableColumnValue(sqf.schema.ColUsersetRelation) + } + for _, selector := range selectors { selectorClause := sq.And{} if len(selector.OptionalSubjectType) > 0 { - selectorClause = append(selectorClause, sq.Eq{sqf.schema.colUsersetNamespace: selector.OptionalSubjectType}) - sqf.recordColumnValue(sqf.schema.colUsersetNamespace) + selectorClause = append(selectorClause, sq.Eq{sqf.schema.ColUsersetNamespace: selector.OptionalSubjectType}) + sqf.recordColumnValue(sqf.schema.ColUsersetNamespace, selector.OptionalSubjectType) } if len(selector.OptionalSubjectIds) > 0 { @@ -454,7 +556,7 @@ func (sqf SchemaQueryFilterer) FilterWithSubjectsSelectors(selectors ...datastor }, "cannot have more than %d subject IDs in a single filter", sqf.filterMaximumIDCount) var builder strings.Builder - builder.WriteString(sqf.schema.colUsersetObjectID) + builder.WriteString(sqf.schema.ColUsersetObjectID) builder.WriteString(" IN (") args := make([]any, 0, len(selector.OptionalSubjectIds)) @@ -464,7 +566,7 @@ func (sqf SchemaQueryFilterer) FilterWithSubjectsSelectors(selectors ...datastor } args = append(args, subjectID) - sqf.recordColumnValue(sqf.schema.colUsersetObjectID) + sqf.recordColumnValue(sqf.schema.ColUsersetObjectID, subjectID) } builder.WriteString("?") @@ -478,8 +580,8 @@ func (sqf SchemaQueryFilterer) FilterWithSubjectsSelectors(selectors ...datastor if !selector.RelationFilter.IsEmpty() { if selector.RelationFilter.OnlyNonEllipsisRelations { - selectorClause = append(selectorClause, sq.NotEq{sqf.schema.colUsersetRelation: datastore.Ellipsis}) - sqf.recordColumnValue(sqf.schema.colUsersetRelation) + selectorClause = append(selectorClause, sq.NotEq{sqf.schema.ColUsersetRelation: datastore.Ellipsis}) + sqf.recordMutableColumnValue(sqf.schema.ColUsersetRelation) } else { relations := make([]string, 0, 2) if selector.RelationFilter.IncludeEllipsisRelation { @@ -492,14 +594,14 @@ func (sqf SchemaQueryFilterer) FilterWithSubjectsSelectors(selectors ...datastor if len(relations) == 1 { relName := relations[0] - selectorClause = append(selectorClause, sq.Eq{sqf.schema.colUsersetRelation: relName}) - sqf.recordColumnValue(sqf.schema.colUsersetRelation) + selectorClause = append(selectorClause, sq.Eq{sqf.schema.ColUsersetRelation: relName}) + sqf.recordColumnValue(sqf.schema.ColUsersetRelation, relName) } else { orClause := sq.Or{} for _, relationName := range relations { dsRelationName := stringz.DefaultEmpty(relationName, datastore.Ellipsis) - orClause = append(orClause, sq.Eq{sqf.schema.colUsersetRelation: dsRelationName}) - sqf.recordColumnValue(sqf.schema.colUsersetRelation) + orClause = append(orClause, sq.Eq{sqf.schema.ColUsersetRelation: dsRelationName}) + sqf.recordColumnValue(sqf.schema.ColUsersetRelation, dsRelationName) } selectorClause = append(selectorClause, orClause) @@ -517,27 +619,27 @@ func (sqf SchemaQueryFilterer) FilterWithSubjectsSelectors(selectors ...datastor // FilterToSubjectFilter returns a new SchemaQueryFilterer that is limited to resources with // subjects that match the specified filter. func (sqf SchemaQueryFilterer) FilterToSubjectFilter(filter *v1.SubjectFilter) SchemaQueryFilterer { - sqf.queryBuilder = sqf.queryBuilder.Where(sq.Eq{sqf.schema.colUsersetNamespace: filter.SubjectType}) - sqf.recordColumnValue(sqf.schema.colUsersetNamespace) + sqf.queryBuilder = sqf.queryBuilder.Where(sq.Eq{sqf.schema.ColUsersetNamespace: filter.SubjectType}) + sqf.recordColumnValue(sqf.schema.ColUsersetNamespace, filter.SubjectType) if filter.OptionalSubjectId != "" { - sqf.queryBuilder = sqf.queryBuilder.Where(sq.Eq{sqf.schema.colUsersetObjectID: filter.OptionalSubjectId}) - sqf.recordColumnValue(sqf.schema.colUsersetObjectID) + sqf.queryBuilder = sqf.queryBuilder.Where(sq.Eq{sqf.schema.ColUsersetObjectID: filter.OptionalSubjectId}) + sqf.recordColumnValue(sqf.schema.ColUsersetObjectID, filter.OptionalSubjectId) } if filter.OptionalRelation != nil { dsRelationName := stringz.DefaultEmpty(filter.OptionalRelation.Relation, datastore.Ellipsis) - sqf.queryBuilder = sqf.queryBuilder.Where(sq.Eq{sqf.schema.colUsersetRelation: dsRelationName}) - sqf.recordColumnValue(sqf.schema.colUsersetRelation) + sqf.queryBuilder = sqf.queryBuilder.Where(sq.Eq{sqf.schema.ColUsersetRelation: dsRelationName}) + sqf.recordColumnValue(sqf.schema.ColUsersetRelation, datastore.Ellipsis) } return sqf } func (sqf SchemaQueryFilterer) FilterWithCaveatName(caveatName string) SchemaQueryFilterer { - sqf.queryBuilder = sqf.queryBuilder.Where(sq.Eq{sqf.schema.colCaveatName: caveatName}) - sqf.recordColumnValue(sqf.schema.colCaveatName) + sqf.queryBuilder = sqf.queryBuilder.Where(sq.Eq{sqf.schema.ColCaveatName: caveatName}) + sqf.recordColumnValue(sqf.schema.ColCaveatName, caveatName) return sqf } @@ -549,7 +651,7 @@ func (sqf SchemaQueryFilterer) limit(limit uint64) SchemaQueryFilterer { // QueryExecutor is a tuple query runner shared by SQL implementations of the datastore. type QueryExecutor struct { - Executor ExecuteQueryFunc + Executor ExecuteReadRelsQueryFunc } // ExecuteQuery executes the query. @@ -558,6 +660,10 @@ func (tqs QueryExecutor) ExecuteQuery( query SchemaQueryFilterer, opts ...options.QueryOptionsOption, ) (datastore.RelationshipIterator, error) { + if query.isCustomQuery { + return nil, spiceerrors.MustBugf("ExecuteQuery should not be called on custom queries") + } + queryOpts := options.NewQueryOptionsWithOptions(opts...) query = query.TupleOrder(queryOpts.Sort) @@ -580,19 +686,57 @@ func (tqs QueryExecutor) ExecuteQuery( limit = *queryOpts.Limit } - toExecute := query.limit(limit) + if limit < math.MaxInt64 { + query = query.limit(limit) + } + + toExecute := query + + // Set the column names to select. + columnNamesToSelect := make([]string, 0, 8+len(query.extraFields)) + + columnNamesToSelect = checkColumn(columnNamesToSelect, query.filteringColumnTracker, query.schema.ColNamespace) + columnNamesToSelect = checkColumn(columnNamesToSelect, query.filteringColumnTracker, query.schema.ColObjectID) + columnNamesToSelect = checkColumn(columnNamesToSelect, query.filteringColumnTracker, query.schema.ColRelation) + columnNamesToSelect = checkColumn(columnNamesToSelect, query.filteringColumnTracker, query.schema.ColUsersetNamespace) + columnNamesToSelect = checkColumn(columnNamesToSelect, query.filteringColumnTracker, query.schema.ColUsersetObjectID) + columnNamesToSelect = checkColumn(columnNamesToSelect, query.filteringColumnTracker, query.schema.ColUsersetRelation) + + columnNamesToSelect = append(columnNamesToSelect, query.schema.ColCaveatName, query.schema.ColCaveatContext, query.schema.ColExpiration) + columnNamesToSelect = append(columnNamesToSelect, query.schema.ExtraFields...) + + toExecute.queryBuilder = toExecute.queryBuilder.Columns(columnNamesToSelect...) + + from := query.schema.RelationshipTableName + if query.fromSuffix != "" { + from += " " + query.fromSuffix + } + + toExecute.queryBuilder = toExecute.queryBuilder.From(from) - // Run the query. sql, args, err := toExecute.queryBuilder.ToSql() if err != nil { return nil, err } - return tqs.Executor(ctx, sql, args) + return tqs.Executor(ctx, QueryInfo{query.schema, query.filteringColumnTracker}, sql, args) +} + +func checkColumn(columns []string, tracker map[string]ColumnTracker, colName string) []string { + if r, ok := tracker[colName]; !ok || r.SingleValue == nil { + return append(columns, colName) + } + return columns +} + +// QueryInfo holds the schema information and filtering values for a query. +type QueryInfo struct { + Schema SchemaInformation + FilteringValues map[string]ColumnTracker } -// ExecuteQueryFunc is a function that can be used to execute a single rendered SQL query. -type ExecuteQueryFunc func(ctx context.Context, sql string, args []any) (datastore.RelationshipIterator, error) +// ExecuteReadRelsQueryFunc is a function that can be used to execute a single rendered SQL query. +type ExecuteReadRelsQueryFunc func(ctx context.Context, queryInfo QueryInfo, sql string, args []any) (datastore.RelationshipIterator, error) // TxCleanupFunc is a function that should be executed when the caller of // TransactionFactory is done with the transaction. diff --git a/internal/datastore/common/sql_test.go b/internal/datastore/common/sql_test.go index 11043512b9..5e6bc060fd 100644 --- a/internal/datastore/common/sql_test.go +++ b/internal/datastore/common/sql_test.go @@ -3,8 +3,6 @@ package common import ( "testing" - "github.com/google/uuid" - "github.com/authzed/spicedb/pkg/datastore/options" "github.com/authzed/spicedb/pkg/tuple" @@ -20,11 +18,11 @@ var toCursor = options.ToCursor func TestSchemaQueryFilterer(t *testing.T) { tests := []struct { - name string - run func(filterer SchemaQueryFilterer) SchemaQueryFilterer - expectedSQL string - expectedArgs []any - expectedColumnCounts map[string]int + name string + run func(filterer SchemaQueryFilterer) SchemaQueryFilterer + expectedSQL string + expectedArgs []any + expectedStaticColumns []string }{ { "relation filter", @@ -33,9 +31,7 @@ func TestSchemaQueryFilterer(t *testing.T) { }, "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND relation = ?", []any{"somerelation"}, - map[string]int{ - "relation": 1, - }, + []string{"relation"}, }, { "resource ID filter", @@ -44,9 +40,7 @@ func TestSchemaQueryFilterer(t *testing.T) { }, "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND object_id = ?", []any{"someresourceid"}, - map[string]int{ - "object_id": 1, - }, + []string{"object_id"}, }, { "resource IDs filter", @@ -55,7 +49,7 @@ func TestSchemaQueryFilterer(t *testing.T) { }, "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND object_id LIKE ?", []any{"someprefix%"}, - map[string]int{}, // object_id is not statically used, so not present in the map + []string{}, }, { "resource IDs prefix filter", @@ -64,9 +58,7 @@ func TestSchemaQueryFilterer(t *testing.T) { }, "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND object_id IN (?,?)", []any{"someresourceid", "anotherresourceid"}, - map[string]int{ - "object_id": 2, - }, + []string{}, }, { "resource type filter", @@ -75,9 +67,7 @@ func TestSchemaQueryFilterer(t *testing.T) { }, "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND ns = ?", []any{"sometype"}, - map[string]int{ - "ns": 1, - }, + []string{"ns"}, }, { "resource filter", @@ -86,11 +76,7 @@ func TestSchemaQueryFilterer(t *testing.T) { }, "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND ns = ? AND object_id = ? AND relation = ?", []any{"sometype", "someobj", "somerel"}, - map[string]int{ - "ns": 1, - "object_id": 1, - "relation": 1, - }, + []string{"ns", "object_id", "relation"}, }, { "relationships filter with no IDs or relations", @@ -101,9 +87,7 @@ func TestSchemaQueryFilterer(t *testing.T) { }, "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND ns = ?", []any{"sometype"}, - map[string]int{ - "ns": 1, - }, + []string{"ns"}, }, { "relationships filter with single ID", @@ -115,10 +99,7 @@ func TestSchemaQueryFilterer(t *testing.T) { }, "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND ns = ? AND object_id IN (?)", []any{"sometype", "someid"}, - map[string]int{ - "ns": 1, - "object_id": 1, - }, + []string{"ns", "object_id"}, }, { "relationships filter with no IDs", @@ -130,9 +111,7 @@ func TestSchemaQueryFilterer(t *testing.T) { }, "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND ns = ?", []any{"sometype"}, - map[string]int{ - "ns": 1, - }, + []string{"ns"}, }, { "relationships filter with multiple IDs", @@ -144,10 +123,7 @@ func TestSchemaQueryFilterer(t *testing.T) { }, "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND ns = ? AND object_id IN (?,?)", []any{"sometype", "someid", "anotherid"}, - map[string]int{ - "ns": 1, - "object_id": 2, - }, + []string{"ns"}, }, { "subjects filter with no IDs or relations", @@ -158,9 +134,7 @@ func TestSchemaQueryFilterer(t *testing.T) { }, "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND ((subject_ns = ?))", []any{"somesubjectype"}, - map[string]int{ - "subject_ns": 1, - }, + []string{"subject_ns"}, }, { "multiple subjects filters with just types", @@ -173,9 +147,7 @@ func TestSchemaQueryFilterer(t *testing.T) { }, "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND ((subject_ns = ?) OR (subject_ns = ?))", []any{"somesubjectype", "anothersubjectype"}, - map[string]int{ - "subject_ns": 2, - }, + []string{}, }, { "subjects filter with single ID", @@ -187,10 +159,7 @@ func TestSchemaQueryFilterer(t *testing.T) { }, "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND ((subject_ns = ? AND subject_object_id IN (?)))", []any{"somesubjectype", "somesubjectid"}, - map[string]int{ - "subject_ns": 1, - "subject_object_id": 1, - }, + []string{"subject_ns", "subject_object_id"}, }, { "subjects filter with single ID and no type", @@ -201,9 +170,7 @@ func TestSchemaQueryFilterer(t *testing.T) { }, "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND ((subject_object_id IN (?)))", []any{"somesubjectid"}, - map[string]int{ - "subject_object_id": 1, - }, + []string{"subject_object_id"}, }, { "empty subjects filter", @@ -212,7 +179,7 @@ func TestSchemaQueryFilterer(t *testing.T) { }, "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND ((1=1))", nil, - map[string]int{}, + []string{}, }, { "subjects filter with multiple IDs", @@ -224,10 +191,7 @@ func TestSchemaQueryFilterer(t *testing.T) { }, "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND ((subject_ns = ? AND subject_object_id IN (?,?)))", []any{"somesubjectype", "somesubjectid", "anothersubjectid"}, - map[string]int{ - "subject_ns": 1, - "subject_object_id": 2, - }, + []string{"subject_ns"}, }, { "subjects filter with single ellipsis relation", @@ -239,10 +203,7 @@ func TestSchemaQueryFilterer(t *testing.T) { }, "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND ((subject_ns = ? AND subject_relation = ?))", []any{"somesubjectype", "..."}, - map[string]int{ - "subject_ns": 1, - "subject_relation": 1, - }, + []string{"subject_ns", "subject_relation"}, }, { "subjects filter with single defined relation", @@ -254,10 +215,7 @@ func TestSchemaQueryFilterer(t *testing.T) { }, "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND ((subject_ns = ? AND subject_relation = ?))", []any{"somesubjectype", "somesubrel"}, - map[string]int{ - "subject_ns": 1, - "subject_relation": 1, - }, + []string{"subject_ns", "subject_relation"}, }, { "subjects filter with only non-ellipsis", @@ -269,10 +227,7 @@ func TestSchemaQueryFilterer(t *testing.T) { }, "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND ((subject_ns = ? AND subject_relation <> ?))", []any{"somesubjectype", "..."}, - map[string]int{ - "subject_ns": 1, - "subject_relation": 1, - }, + []string{"subject_ns"}, }, { "subjects filter with defined relation and ellipsis", @@ -284,10 +239,7 @@ func TestSchemaQueryFilterer(t *testing.T) { }, "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND ((subject_ns = ? AND (subject_relation = ? OR subject_relation = ?)))", []any{"somesubjectype", "...", "somesubrel"}, - map[string]int{ - "subject_ns": 1, - "subject_relation": 2, - }, + []string{"subject_ns"}, }, { "subjects filter", @@ -300,11 +252,7 @@ func TestSchemaQueryFilterer(t *testing.T) { }, "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND ((subject_ns = ? AND subject_object_id IN (?,?) AND (subject_relation = ? OR subject_relation = ?)))", []any{"somesubjectype", "somesubjectid", "anothersubjectid", "...", "somesubrel"}, - map[string]int{ - "subject_ns": 1, - "subject_object_id": 2, - "subject_relation": 2, - }, + []string{"subject_ns"}, }, { "multiple subjects filter", @@ -328,11 +276,7 @@ func TestSchemaQueryFilterer(t *testing.T) { }, "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND ((subject_ns = ? AND subject_object_id IN (?,?) AND (subject_relation = ? OR subject_relation = ?)) OR (subject_ns = ? AND subject_object_id IN (?,?) AND (subject_relation = ? OR subject_relation = ?)) OR (subject_ns = ? AND subject_relation <> ?))", []any{"somesubjectype", "a", "b", "...", "somesubrel", "anothersubjecttype", "b", "c", "...", "anotherrel", "thirdsubjectype", "..."}, - map[string]int{ - "subject_ns": 3, - "subject_object_id": 4, - "subject_relation": 5, - }, + []string{}, }, { "v1 subject filter with namespace", @@ -343,9 +287,7 @@ func TestSchemaQueryFilterer(t *testing.T) { }, "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND subject_ns = ?", []any{"subns"}, - map[string]int{ - "subject_ns": 1, - }, + []string{"subject_ns"}, }, { "v1 subject filter with subject id", @@ -357,10 +299,7 @@ func TestSchemaQueryFilterer(t *testing.T) { }, "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND subject_ns = ? AND subject_object_id = ?", []any{"subns", "subid"}, - map[string]int{ - "subject_ns": 1, - "subject_object_id": 1, - }, + []string{"subject_ns", "subject_object_id"}, }, { "v1 subject filter with relation", @@ -374,10 +313,7 @@ func TestSchemaQueryFilterer(t *testing.T) { }, "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND subject_ns = ? AND subject_relation = ?", []any{"subns", "subrel"}, - map[string]int{ - "subject_ns": 1, - "subject_relation": 1, - }, + []string{"subject_ns", "subject_relation"}, }, { "v1 subject filter with empty relation", @@ -391,10 +327,7 @@ func TestSchemaQueryFilterer(t *testing.T) { }, "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND subject_ns = ? AND subject_relation = ?", []any{"subns", "..."}, - map[string]int{ - "subject_ns": 1, - "subject_relation": 1, - }, + []string{"subject_ns", "subject_relation"}, }, { "v1 subject filter", @@ -409,11 +342,7 @@ func TestSchemaQueryFilterer(t *testing.T) { }, "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND subject_ns = ? AND subject_object_id = ? AND subject_relation = ?", []any{"subns", "subid", "somerel"}, - map[string]int{ - "subject_ns": 1, - "subject_object_id": 1, - "subject_relation": 1, - }, + []string{"subject_ns", "subject_object_id", "subject_relation"}, }, { "limit", @@ -422,7 +351,7 @@ func TestSchemaQueryFilterer(t *testing.T) { }, "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) LIMIT 100", nil, - map[string]int{}, + []string{}, }, { "full resources filter", @@ -444,14 +373,7 @@ func TestSchemaQueryFilterer(t *testing.T) { }, "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND ns = ? AND relation = ? AND object_id IN (?,?) AND ((subject_ns = ? AND subject_object_id IN (?,?) AND (subject_relation = ? OR subject_relation = ?)))", []any{"someresourcetype", "somerelation", "someid", "anotherid", "somesubjectype", "somesubjectid", "anothersubjectid", "...", "somesubrel"}, - map[string]int{ - "ns": 1, - "object_id": 2, - "relation": 1, - "subject_ns": 1, - "subject_object_id": 2, - "subject_relation": 2, - }, + []string{"ns", "relation", "subject_ns"}, }, { "order by", @@ -464,9 +386,7 @@ func TestSchemaQueryFilterer(t *testing.T) { }, "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND ns = ? ORDER BY ns, object_id, relation, subject_ns, subject_object_id, subject_relation", []any{"someresourcetype"}, - map[string]int{ - "ns": 1, - }, + []string{"ns"}, }, { "after with just namespace", @@ -479,9 +399,7 @@ func TestSchemaQueryFilterer(t *testing.T) { }, "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND ns = ? AND (object_id,relation,subject_ns,subject_object_id,subject_relation) > (?,?,?,?,?)", []any{"someresourcetype", "foo", "viewer", "user", "bar", "..."}, - map[string]int{ - "ns": 1, - }, + []string{"ns"}, }, { "after with just relation", @@ -494,9 +412,7 @@ func TestSchemaQueryFilterer(t *testing.T) { }, "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND relation = ? AND (ns,object_id,subject_ns,subject_object_id,subject_relation) > (?,?,?,?,?)", []any{"somerelation", "someresourcetype", "foo", "user", "bar", "..."}, - map[string]int{ - "relation": 1, - }, + []string{"relation"}, }, { "after with namespace and single resource id", @@ -510,10 +426,7 @@ func TestSchemaQueryFilterer(t *testing.T) { }, "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND ns = ? AND object_id IN (?) AND (relation,subject_ns,subject_object_id,subject_relation) > (?,?,?,?)", []any{"someresourcetype", "one", "viewer", "user", "bar", "..."}, - map[string]int{ - "ns": 1, - "object_id": 1, - }, + []string{"ns", "object_id"}, }, { "after with single resource id", @@ -526,9 +439,7 @@ func TestSchemaQueryFilterer(t *testing.T) { }, "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND object_id IN (?) AND (ns,relation,subject_ns,subject_object_id,subject_relation) > (?,?,?,?,?)", []any{"one", "someresourcetype", "viewer", "user", "bar", "..."}, - map[string]int{ - "object_id": 1, - }, + []string{"object_id"}, }, { "after with namespace and resource ids", @@ -542,10 +453,7 @@ func TestSchemaQueryFilterer(t *testing.T) { }, "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND ns = ? AND object_id IN (?,?) AND (object_id,relation,subject_ns,subject_object_id,subject_relation) > (?,?,?,?,?)", []any{"someresourcetype", "one", "two", "foo", "viewer", "user", "bar", "..."}, - map[string]int{ - "ns": 1, - "object_id": 2, - }, + []string{"ns"}, }, { "after with namespace and relation", @@ -559,10 +467,7 @@ func TestSchemaQueryFilterer(t *testing.T) { }, "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND ns = ? AND relation = ? AND (object_id,subject_ns,subject_object_id,subject_relation) > (?,?,?,?)", []any{"someresourcetype", "somerelation", "foo", "user", "bar", "..."}, - map[string]int{ - "ns": 1, - "relation": 1, - }, + []string{"ns", "relation"}, }, { "after with subject namespace", @@ -573,9 +478,7 @@ func TestSchemaQueryFilterer(t *testing.T) { }, "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND ((subject_ns = ?)) AND (ns,object_id,relation,subject_object_id,subject_relation) > (?,?,?,?,?)", []any{"somesubjectype", "someresourcetype", "foo", "viewer", "bar", "..."}, - map[string]int{ - "subject_ns": 1, - }, + []string{"subject_ns"}, }, { "after with subject namespaces", @@ -590,9 +493,7 @@ func TestSchemaQueryFilterer(t *testing.T) { }, "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND ((subject_ns = ?)) AND ((subject_ns = ?)) AND (ns,object_id,relation,subject_ns,subject_object_id,subject_relation) > (?,?,?,?,?,?)", []any{"somesubjectype", "anothersubjectype", "someresourcetype", "foo", "viewer", "user", "bar", "..."}, - map[string]int{ - "subject_ns": 2, - }, + []string{}, }, { "after with resource ID prefix", @@ -601,7 +502,7 @@ func TestSchemaQueryFilterer(t *testing.T) { }, "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND object_id LIKE ? AND (ns,object_id,relation,subject_ns,subject_object_id,subject_relation) > (?,?,?,?,?,?)", []any{"someprefix%", "someresourcetype", "foo", "viewer", "user", "bar", "..."}, - map[string]int{}, + []string{}, }, { "order by subject", @@ -614,9 +515,7 @@ func TestSchemaQueryFilterer(t *testing.T) { }, "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND ns = ? ORDER BY subject_ns, subject_object_id, subject_relation, ns, object_id, relation", []any{"someresourcetype"}, - map[string]int{ - "ns": 1, - }, + []string{"ns"}, }, { "order by subject, after with subject namespace", @@ -627,9 +526,7 @@ func TestSchemaQueryFilterer(t *testing.T) { }, "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND ((subject_ns = ?)) AND (subject_object_id,ns,object_id,relation,subject_relation) > (?,?,?,?,?)", []any{"somesubjectype", "bar", "someresourcetype", "foo", "viewer", "..."}, - map[string]int{ - "subject_ns": 1, - }, + []string{"subject_ns"}, }, { "order by subject, after with subject namespace and subject object id", @@ -641,7 +538,7 @@ func TestSchemaQueryFilterer(t *testing.T) { }, "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND ((subject_ns = ? AND subject_object_id IN (?))) AND (ns,object_id,relation,subject_relation) > (?,?,?,?)", []any{"somesubjectype", "foo", "someresourcetype", "someresource", "viewer", "..."}, - map[string]int{"subject_ns": 1, "subject_object_id": 1}, + []string{"subject_ns", "subject_object_id"}, }, { "order by subject, after with subject namespace and multiple subject object IDs", @@ -653,15 +550,15 @@ func TestSchemaQueryFilterer(t *testing.T) { }, "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND ((subject_ns = ? AND subject_object_id IN (?,?))) AND (subject_object_id,ns,object_id,relation,subject_relation) > (?,?,?,?,?)", []any{"somesubjectype", "foo", "bar", "next", "someresourcetype", "someresource", "viewer", "..."}, - map[string]int{"subject_ns": 1, "subject_object_id": 2}, + []string{"subject_ns"}, }, } for _, test := range tests { test := test t.Run(test.name, func(t *testing.T) { - base := sq.Select("*") schema := NewSchemaInformation( + "relationtuples", "ns", "object_id", "relation", @@ -669,14 +566,25 @@ func TestSchemaQueryFilterer(t *testing.T) { "subject_object_id", "subject_relation", "caveat", + "caveat_context", "expiration", TupleComparison, + sq.Question, "NOW", ) - filterer := NewSchemaQueryFilterer(schema, base, 100) + filterer := NewSchemaQueryFiltererForRelationshipsSelect(schema, 100) ran := test.run(filterer) - require.Equal(t, test.expectedColumnCounts, ran.filteringColumnCounts) + foundStaticColumns := []string{} + for col, tracker := range ran.filteringColumnTracker { + if tracker.SingleValue != nil { + foundStaticColumns = append(foundStaticColumns, col) + } + } + + require.ElementsMatch(t, test.expectedStaticColumns, foundStaticColumns) + + ran.queryBuilder = ran.queryBuilder.Columns("*") sql, args, err := ran.queryBuilder.ToSql() require.NoError(t, err) @@ -685,29 +593,3 @@ func TestSchemaQueryFilterer(t *testing.T) { }) } } - -func BenchmarkSchemaFilterer(b *testing.B) { - si := NewSchemaInformation( - "namespace", - "object_id", - "object_relation", - "resource_type", - "resource_id", - "resource_relation", - "caveat_name", - "expiration", - TupleComparison, - "NOW", - ) - sqf := NewSchemaQueryFilterer(si, sq.Select("*"), 100) - var names []string - for i := 0; i < 500; i++ { - names = append(names, uuid.NewString()) - } - - b.ReportAllocs() - b.ResetTimer() - for i := 0; i < b.N; i++ { - _, _ = sqf.FilterToResourceIDs(names) - } -} diff --git a/internal/datastore/crdb/caveat.go b/internal/datastore/crdb/caveat.go index 3f66b95810..ebaae37301 100644 --- a/internal/datastore/crdb/caveat.go +++ b/internal/datastore/crdb/caveat.go @@ -23,7 +23,7 @@ var ( ) writeCaveat = psql.Insert(tableCaveat).Columns(colCaveatName, colCaveatDefinition).Suffix(upsertCaveatSuffix) readCaveat = psql.Select(colCaveatDefinition, colTimestamp) - listCaveat = psql.Select(colCaveatName, colCaveatDefinition, colTimestamp).From(tableCaveat).OrderBy(colCaveatName) + listCaveat = psql.Select(colCaveatName, colCaveatDefinition, colTimestamp).OrderBy(colCaveatName) deleteCaveat = psql.Delete(tableCaveat) ) @@ -35,7 +35,7 @@ const ( ) func (cr *crdbReader) ReadCaveatByName(ctx context.Context, name string) (*core.CaveatDefinition, datastore.Revision, error) { - query := cr.fromBuilder(readCaveat, tableCaveat).Where(sq.Eq{colCaveatName: name}) + query := cr.fromWithAsOfSystemTime(readCaveat.Where(sq.Eq{colCaveatName: name}), tableCaveat) sql, args, err := query.ToSql() if err != nil { return nil, datastore.NoRevision, fmt.Errorf(errReadCaveat, name, err) @@ -79,7 +79,7 @@ type bytesAndTimestamp struct { } func (cr *crdbReader) lookupCaveats(ctx context.Context, caveatNames []string) ([]datastore.RevisionedCaveat, error) { - caveatsWithNames := cr.fromBuilder(listCaveat, tableCaveat) + caveatsWithNames := cr.fromWithAsOfSystemTime(listCaveat, tableCaveat) if len(caveatNames) > 0 { caveatsWithNames = caveatsWithNames.Where(sq.Eq{colCaveatName: caveatNames}) } diff --git a/internal/datastore/crdb/crdb.go b/internal/datastore/crdb/crdb.go index 88d6e3e9ae..de66a70135 100644 --- a/internal/datastore/crdb/crdb.go +++ b/internal/datastore/crdb/crdb.go @@ -200,6 +200,34 @@ func newCRDBDatastore(ctx context.Context, url string, options ...Option) (datas return nil, fmt.Errorf("invalid head migration found for cockroach: %w", err) } + var extraFields []string + relTableName := tableTuple + if config.withIntegrity { + relTableName = tableTupleWithIntegrity + extraFields = []string{ + colIntegrityKeyID, + colIntegrityHash, + colTimestamp, + } + } + + schema := common.NewSchemaInformation( + relTableName, + colNamespace, + colObjectID, + colRelation, + colUsersetNamespace, + colUsersetObjectID, + colUsersetRelation, + colCaveatContextName, + colCaveatContext, + colExpiration, + common.ExpandedLogicComparison, + sq.Dollar, + "NOW", + extraFields..., + ) + ds := &crdbDatastore{ RemoteClockRevisions: revisions.NewRemoteClockRevisions( config.gcWindow, @@ -221,6 +249,7 @@ func newCRDBDatastore(ctx context.Context, url string, options ...Option) (datas filterMaximumIDCount: config.filterMaximumIDCount, supportsIntegrity: config.withIntegrity, gcWindow: config.gcWindow, + schema: schema, } ds.RemoteClockRevisions.SetNowFunc(ds.headRevisionInternal) @@ -305,6 +334,7 @@ type crdbDatastore struct { overlapKeyInit func(ctx context.Context) keySet analyzeBeforeStatistics bool gcWindow time.Duration + schema common.SchemaInformation beginChangefeedQuery string transactionNowQuery string @@ -323,11 +353,12 @@ func (cds *crdbDatastore) SnapshotReader(rev datastore.Revision) datastore.Reade Executor: pgxcommon.NewPGXExecutorWithIntegrityOption(cds.readPool, cds.supportsIntegrity), } - fromBuilder := func(query sq.SelectBuilder, fromStr string) sq.SelectBuilder { - return query.From(fromStr + " AS OF SYSTEM TIME " + rev.String()) + withAsOfSystemTime := func(query sq.SelectBuilder, tableName string) sq.SelectBuilder { + return query.From(tableName + " AS OF SYSTEM TIME " + rev.String()) } - return &crdbReader{cds.readPool, executor, noOverlapKeyer, nil, fromBuilder, cds.filterMaximumIDCount, cds.tableTupleName(), cds.supportsIntegrity} + asOfSystemTimeSuffix := "AS OF SYSTEM TIME " + rev.String() + return &crdbReader{cds.readPool, executor, noOverlapKeyer, nil, withAsOfSystemTime, asOfSystemTimeSuffix, cds.filterMaximumIDCount, cds.schema, cds.supportsIntegrity} } func (cds *crdbDatastore) ReadWriteTx( @@ -375,11 +406,12 @@ func (cds *crdbDatastore) ReadWriteTx( executor, cds.writeOverlapKeyer, cds.overlapKeyInit(ctx), - func(query sq.SelectBuilder, fromStr string) sq.SelectBuilder { - return query.From(fromStr) + func(query sq.SelectBuilder, tableName string) sq.SelectBuilder { + return query.From(tableName) }, + "", // No AS OF SYSTEM TIME for writes cds.filterMaximumIDCount, - cds.tableTupleName(), + cds.schema, cds.supportsIntegrity, }, tx, @@ -526,14 +558,6 @@ func (cds *crdbDatastore) Features(ctx context.Context) (*datastore.Features, er return features, err } -func (cds *crdbDatastore) tableTupleName() string { - if cds.supportsIntegrity { - return tableTupleWithIntegrity - } - - return tableTuple -} - func (cds *crdbDatastore) features(ctx context.Context) (*datastore.Features, error) { features := datastore.Features{ ContinuousCheckpointing: datastore.Feature{ @@ -567,7 +591,7 @@ func (cds *crdbDatastore) features(ctx context.Context) (*datastore.Features, er features.Watch.Reason = fmt.Sprintf("Range feeds must be enabled in CockroachDB and the user must have permission to create them in order to enable the Watch API: %s", err.Error()) } return nil - }, fmt.Sprintf(cds.beginChangefeedQuery, cds.tableTupleName(), head, "1s")) + }, fmt.Sprintf(cds.beginChangefeedQuery, cds.schema.RelationshipTableName, head, "1s")) <-streamCtx.Done() diff --git a/internal/datastore/crdb/reader.go b/internal/datastore/crdb/reader.go index 70002a8965..ce9a950b05 100644 --- a/internal/datastore/crdb/reader.go +++ b/internal/datastore/crdb/reader.go @@ -29,19 +29,6 @@ var ( countRels = psql.Select("count(*)") - schema = common.NewSchemaInformation( - colNamespace, - colObjectID, - colRelation, - colUsersetNamespace, - colUsersetObjectID, - colUsersetRelation, - colCaveatContextName, - colExpiration, - common.ExpandedLogicComparison, - "NOW", - ) - queryCounters = psql.Select( colCounterName, colCounterSerializedFilter, @@ -51,14 +38,15 @@ var ( ) type crdbReader struct { - query pgxcommon.DBFuncQuerier - executor common.QueryExecutor - keyer overlapKeyer - overlapKeySet keySet - fromBuilder func(query sq.SelectBuilder, fromStr string) sq.SelectBuilder - filterMaximumIDCount uint16 - tupleTableName string - withIntegrity bool + query pgxcommon.DBFuncQuerier + executor common.QueryExecutor + keyer overlapKeyer + overlapKeySet keySet + fromWithAsOfSystemTime func(query sq.SelectBuilder, tableName string) sq.SelectBuilder + asOfSystemTimeSuffix string + filterMaximumIDCount uint16 + schema common.SchemaInformation + withIntegrity bool } func (cr *crdbReader) CountRelationships(ctx context.Context, name string) (int, error) { @@ -76,8 +64,8 @@ func (cr *crdbReader) CountRelationships(ctx context.Context, name string) (int, return 0, err } - query := cr.fromBuilder(countRels, cr.tupleTableName) - builder, err := common.NewSchemaQueryFilterer(schema, query, cr.filterMaximumIDCount).FilterWithRelationshipsFilter(relFilter) + query := cr.fromWithAsOfSystemTime(countRels, cr.schema.RelationshipTableName) + builder, err := common.NewSchemaQueryFiltererWithStartingQuery(cr.schema, query, cr.filterMaximumIDCount).FilterWithRelationshipsFilter(relFilter) if err != nil { return 0, err } @@ -105,8 +93,7 @@ func (cr *crdbReader) LookupCounters(ctx context.Context) ([]datastore.Relations } func (cr *crdbReader) lookupCounters(ctx context.Context, optionalFilterName string) ([]datastore.RelationshipCounter, error) { - query := cr.fromBuilder(queryCounters, tableRelationshipCounter) - + query := cr.fromWithAsOfSystemTime(queryCounters, tableRelationshipCounter) if optionalFilterName != noFilterOnCounterName { query = query.Where(sq.Eq{colCounterName: optionalFilterName}) } @@ -178,44 +165,13 @@ func (cr *crdbReader) ReadNamespaceByName( } func (cr *crdbReader) ListAllNamespaces(ctx context.Context) ([]datastore.RevisionedNamespace, error) { - nsDefs, err := loadAllNamespaces(ctx, cr.query, cr.fromBuilder) + nsDefs, err := loadAllNamespaces(ctx, cr.query, cr.fromWithAsOfSystemTime) if err != nil { return nil, fmt.Errorf(errUnableToListNamespaces, err) } return nsDefs, nil } -func (cr *crdbReader) queryTuples() sq.SelectBuilder { - if cr.withIntegrity { - return psql.Select( - colNamespace, - colObjectID, - colRelation, - colUsersetNamespace, - colUsersetObjectID, - colUsersetRelation, - colCaveatContextName, - colCaveatContext, - colExpiration, - colIntegrityKeyID, - colIntegrityHash, - colTimestamp, - ) - } - - return psql.Select( - colNamespace, - colObjectID, - colRelation, - colUsersetNamespace, - colUsersetObjectID, - colUsersetRelation, - colCaveatContextName, - colCaveatContext, - colExpiration, - ) -} - func (cr *crdbReader) LookupNamespacesWithNames(ctx context.Context, nsNames []string) ([]datastore.RevisionedNamespace, error) { if len(nsNames) == 0 { return nil, nil @@ -232,8 +188,7 @@ func (cr *crdbReader) QueryRelationships( filter datastore.RelationshipsFilter, opts ...options.QueryOptionsOption, ) (iter datastore.RelationshipIterator, err error) { - query := cr.fromBuilder(cr.queryTuples(), cr.tupleTableName) - qBuilder, err := common.NewSchemaQueryFilterer(schema, query, cr.filterMaximumIDCount).FilterWithRelationshipsFilter(filter) + qBuilder, err := common.NewSchemaQueryFiltererForRelationshipsSelect(cr.schema, cr.filterMaximumIDCount).WithFromSuffix(cr.asOfSystemTimeSuffix).FilterWithRelationshipsFilter(filter) if err != nil { return nil, err } @@ -246,8 +201,8 @@ func (cr *crdbReader) ReverseQueryRelationships( subjectsFilter datastore.SubjectsFilter, opts ...options.ReverseQueryOptionsOption, ) (iter datastore.RelationshipIterator, err error) { - query := cr.fromBuilder(cr.queryTuples(), cr.tupleTableName) - qBuilder, err := common.NewSchemaQueryFilterer(schema, query, cr.filterMaximumIDCount). + qBuilder, err := common.NewSchemaQueryFiltererForRelationshipsSelect(cr.schema, cr.filterMaximumIDCount). + WithFromSuffix(cr.asOfSystemTimeSuffix). FilterWithSubjectsSelectors(subjectsFilter.AsSelector()) if err != nil { return nil, err @@ -270,8 +225,7 @@ func (cr *crdbReader) ReverseQueryRelationships( } func (cr crdbReader) loadNamespace(ctx context.Context, tx pgxcommon.DBFuncQuerier, nsName string) (*core.NamespaceDefinition, time.Time, error) { - query := cr.fromBuilder(queryReadNamespace, tableNamespace).Where(sq.Eq{colNamespace: nsName}) - + query := cr.fromWithAsOfSystemTime(queryReadNamespace, tableNamespace).Where(sq.Eq{colNamespace: nsName}) sql, args, err := query.ToSql() if err != nil { return nil, time.Time{}, err @@ -304,8 +258,7 @@ func (cr crdbReader) lookupNamespaces(ctx context.Context, tx pgxcommon.DBFuncQu clause = append(clause, sq.Eq{colNamespace: nsName}) } - query := cr.fromBuilder(queryReadNamespace, tableNamespace).Where(clause) - + query := cr.fromWithAsOfSystemTime(queryReadNamespace, tableNamespace).Where(clause) sql, args, err := query.ToSql() if err != nil { return nil, err @@ -346,7 +299,6 @@ func (cr crdbReader) lookupNamespaces(ctx context.Context, tx pgxcommon.DBFuncQu func loadAllNamespaces(ctx context.Context, tx pgxcommon.DBFuncQuerier, fromBuilder func(sq.SelectBuilder, string) sq.SelectBuilder) ([]datastore.RevisionedNamespace, error) { query := fromBuilder(queryReadNamespace, tableNamespace) - sql, args, err := query.ToSql() if err != nil { return nil, err diff --git a/internal/datastore/crdb/readwrite.go b/internal/datastore/crdb/readwrite.go index 53a25cc586..acfaf97709 100644 --- a/internal/datastore/crdb/readwrite.go +++ b/internal/datastore/crdb/readwrite.go @@ -123,11 +123,11 @@ var ( ) func (rwt *crdbReadWriteTXN) insertQuery() sq.InsertBuilder { - return psql.Insert(rwt.tupleTableName) + return psql.Insert(rwt.schema.RelationshipTableName) } func (rwt *crdbReadWriteTXN) queryDeleteTuples() sq.DeleteBuilder { - return psql.Delete(rwt.tupleTableName) + return psql.Delete(rwt.schema.RelationshipTableName) } func (rwt *crdbReadWriteTXN) queryWriteTuple() sq.InsertBuilder { @@ -555,10 +555,10 @@ var copyColsWithIntegrity = []string{ func (rwt *crdbReadWriteTXN) BulkLoad(ctx context.Context, iter datastore.BulkWriteRelationshipSource) (uint64, error) { if rwt.withIntegrity { - return pgxcommon.BulkLoad(ctx, rwt.tx, rwt.tupleTableName, copyColsWithIntegrity, iter) + return pgxcommon.BulkLoad(ctx, rwt.tx, rwt.schema.RelationshipTableName, copyColsWithIntegrity, iter) } - return pgxcommon.BulkLoad(ctx, rwt.tx, rwt.tupleTableName, copyCols, iter) + return pgxcommon.BulkLoad(ctx, rwt.tx, rwt.schema.RelationshipTableName, copyCols, iter) } var _ datastore.ReadWriteTransaction = &crdbReadWriteTXN{} diff --git a/internal/datastore/crdb/stats.go b/internal/datastore/crdb/stats.go index 2b66297d91..b01a1f3722 100644 --- a/internal/datastore/crdb/stats.go +++ b/internal/datastore/crdb/stats.go @@ -44,8 +44,8 @@ func (cds *crdbDatastore) Statistics(ctx context.Context) (datastore.Stats, erro if err != nil { return fmt.Errorf("unable to read namespaces: %w", err) } - nsDefs, err = loadAllNamespaces(ctx, pgxcommon.QuerierFuncsFor(tx), func(sb squirrel.SelectBuilder, fromStr string) squirrel.SelectBuilder { - return sb.From(fromStr) + nsDefs, err = loadAllNamespaces(ctx, pgxcommon.QuerierFuncsFor(tx), func(sb squirrel.SelectBuilder, tableName string) squirrel.SelectBuilder { + return sb.From(tableName) }) if err != nil { return fmt.Errorf("unable to read namespaces: %w", err) @@ -57,7 +57,7 @@ func (cds *crdbDatastore) Statistics(ctx context.Context) (datastore.Stats, erro if cds.analyzeBeforeStatistics { if err := cds.readPool.BeginTxFunc(ctx, pgx.TxOptions{AccessMode: pgx.ReadOnly}, func(tx pgx.Tx) error { - if _, err := tx.Exec(ctx, "ANALYZE "+cds.tableTupleName()); err != nil { + if _, err := tx.Exec(ctx, "ANALYZE "+cds.schema.RelationshipTableName); err != nil { return fmt.Errorf("unable to analyze tuple table: %w", err) } @@ -131,7 +131,7 @@ func (cds *crdbDatastore) Statistics(ctx context.Context) (datastore.Stats, erro log.Warn().Bool("has-rows", hasRows).Msg("unable to find row count in statistics query result") return nil - }, "SHOW STATISTICS FOR TABLE "+cds.tableTupleName()); err != nil { + }, "SHOW STATISTICS FOR TABLE "+cds.schema.RelationshipTableName); err != nil { return datastore.Stats{}, fmt.Errorf("unable to query unique estimated row count: %w", err) } diff --git a/internal/datastore/crdb/watch.go b/internal/datastore/crdb/watch.go index de341112f7..2008a743ad 100644 --- a/internal/datastore/crdb/watch.go +++ b/internal/datastore/crdb/watch.go @@ -128,7 +128,7 @@ func (cds *crdbDatastore) watch( tableNames := make([]string, 0, 4) tableNames = append(tableNames, tableTransactionMetadata) if opts.Content&datastore.WatchRelationships == datastore.WatchRelationships { - tableNames = append(tableNames, cds.tableTupleName()) + tableNames = append(tableNames, cds.schema.RelationshipTableName) } if opts.Content&datastore.WatchSchema == datastore.WatchSchema { tableNames = append(tableNames, tableNamespace) @@ -433,7 +433,7 @@ func (cds *crdbDatastore) processChanges(ctx context.Context, changes pgx.Rows, } switch tableName { - case cds.tableTupleName(): + case cds.schema.RelationshipTableName: var caveatName string var caveatContext map[string]any if details.After != nil && details.After.RelationshipCaveatName != "" { diff --git a/internal/datastore/mysql/caveat.go b/internal/datastore/mysql/caveat.go index 84283a3bb6..6cb7edafab 100644 --- a/internal/datastore/mysql/caveat.go +++ b/internal/datastore/mysql/caveat.go @@ -22,7 +22,7 @@ const ( ) func (mr *mysqlReader) ReadCaveatByName(ctx context.Context, name string) (*core.CaveatDefinition, datastore.Revision, error) { - filteredReadCaveat := mr.filterer(mr.ReadCaveatQuery) + filteredReadCaveat := mr.aliveFilter(mr.ReadCaveatQuery) sqlStatement, args, err := filteredReadCaveat.Where(sq.Eq{colName: name}).ToSql() if err != nil { return nil, datastore.NoRevision, err @@ -68,7 +68,7 @@ func (mr *mysqlReader) lookupCaveats(ctx context.Context, caveatNames []string) caveatsWithNames = caveatsWithNames.Where(sq.Eq{colName: caveatNames}) } - filteredListCaveat := mr.filterer(caveatsWithNames) + filteredListCaveat := mr.aliveFilter(caveatsWithNames) listSQL, listArgs, err := filteredListCaveat.ToSql() if err != nil { return nil, err diff --git a/internal/datastore/mysql/datastore.go b/internal/datastore/mysql/datastore.go index e70ae293f7..d0959620f2 100644 --- a/internal/datastore/mysql/datastore.go +++ b/internal/datastore/mysql/datastore.go @@ -18,7 +18,6 @@ import ( "github.com/google/uuid" "github.com/prometheus/client_golang/prometheus" "go.opentelemetry.io/otel" - "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/trace" "golang.org/x/sync/errgroup" @@ -29,7 +28,6 @@ import ( log "github.com/authzed/spicedb/internal/logging" "github.com/authzed/spicedb/pkg/datastore" "github.com/authzed/spicedb/pkg/datastore/options" - "github.com/authzed/spicedb/pkg/tuple" ) const ( @@ -246,6 +244,22 @@ func newMySQLDatastore(ctx context.Context, uri string, replicaIndex int, option -1*config.gcWindow.Seconds(), ) + schema := common.NewSchemaInformation( + driver.RelationTuple(), + colNamespace, + colObjectID, + colRelation, + colUsersetNamespace, + colUsersetObjectID, + colUsersetRelation, + colCaveatName, + colCaveatContext, + colExpiration, + common.ExpandedLogicComparison, + sq.Question, + "NOW", + ) + store := &Datastore{ MigrationValidator: common.NewMigrationValidator(headMigration, config.allowedMigrations), db: db, @@ -267,6 +281,7 @@ func newMySQLDatastore(ctx context.Context, uri string, replicaIndex int, option readTxOptions: &sql.TxOptions{Isolation: sql.LevelSerializable, ReadOnly: true}, maxRetries: config.maxRetries, analyzeBeforeStats: config.analyzeBeforeStats, + schema: schema, CachedOptimizedRevisions: revisions.NewCachedOptimizedRevisions( maxRevisionStaleness, ), @@ -326,6 +341,7 @@ func (mds *Datastore) SnapshotReader(rev datastore.Revision) datastore.Reader { executor, buildLivingObjectFilterForRevision(rev), mds.filterMaximumIDCount, + mds.schema, } } @@ -369,6 +385,7 @@ func (mds *Datastore) ReadWriteTx( executor, currentlyLivingObjects, mds.filterMaximumIDCount, + mds.schema, }, mds.driver.RelationTuple(), tx, @@ -417,7 +434,24 @@ type querier interface { QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) } -func newMySQLExecutor(tx querier) common.ExecuteQueryFunc { +type wrappedTX struct { + tx querier +} + +func (wtx wrappedTX) QueryFunc(ctx context.Context, f func(context.Context, common.Rows) error, sql string, args ...any) error { + rows, err := wtx.tx.QueryContext(ctx, sql, args...) + if err != nil { + return err + } + + if rows.Err() != nil { + return rows.Err() + } + + return f(ctx, rows) +} + +func newMySQLExecutor(tx querier) common.ExecuteReadRelsQueryFunc { // This implementation does not create a transaction because it's redundant for single statements, and it avoids // the network overhead and reduce contention on the connection pool. From MySQL docs: // @@ -433,82 +467,9 @@ func newMySQLExecutor(tx querier) common.ExecuteQueryFunc { // // Prepared statements are also not used given they perform poorly on environments where connections have // short lifetime (e.g. to gracefully handle load-balancer connection drain) - return func(ctx context.Context, sqlQuery string, args []interface{}) (datastore.RelationshipIterator, error) { - return func(yield func(tuple.Relationship, error) bool) { - span := trace.SpanFromContext(ctx) - - rows, err := tx.QueryContext(ctx, sqlQuery, args...) - if err != nil { - yield(tuple.Relationship{}, fmt.Errorf(errUnableToQueryTuples, err)) - return - } - defer common.LogOnError(ctx, rows.Close) - - span.AddEvent("Query issued to database") - - relCount := 0 - - defer func() { - span.AddEvent("Relationships loaded", trace.WithAttributes(attribute.Int("relCount", relCount))) - }() - - for rows.Next() { - var resourceObjectType string - var resourceObjectID string - var relation string - var subjectObjectType string - var subjectObjectID string - var subjectRelation string - var caveatName string - var caveatContext structpbWrapper - var expiration *time.Time - err := rows.Scan( - &resourceObjectType, - &resourceObjectID, - &relation, - &subjectObjectType, - &subjectObjectID, - &subjectRelation, - &caveatName, - &caveatContext, - &expiration, - ) - if err != nil { - yield(tuple.Relationship{}, fmt.Errorf(errUnableToQueryTuples, err)) - return - } - - caveat, err := common.ContextualizedCaveatFrom(caveatName, caveatContext) - if err != nil { - yield(tuple.Relationship{}, fmt.Errorf(errUnableToQueryTuples, err)) - return - } - - relCount++ - if !yield(tuple.Relationship{ - RelationshipReference: tuple.RelationshipReference{ - Resource: tuple.ObjectAndRelation{ - ObjectType: resourceObjectType, - ObjectID: resourceObjectID, - Relation: relation, - }, - Subject: tuple.ObjectAndRelation{ - ObjectType: subjectObjectType, - ObjectID: subjectObjectID, - Relation: subjectRelation, - }, - }, - OptionalCaveat: caveat, - OptionalExpiration: expiration, - }, nil) { - return - } - } - if err := rows.Err(); err != nil { - yield(tuple.Relationship{}, fmt.Errorf(errUnableToQueryTuples, err)) - return - } - }, nil + return func(ctx context.Context, queryInfo common.QueryInfo, sqlQuery string, args []interface{}) (datastore.RelationshipIterator, error) { + span := trace.SpanFromContext(ctx) + return common.QueryRelationships[common.Rows, structpbWrapper](ctx, queryInfo, sqlQuery, args, span, wrappedTX{tx}, false) } } @@ -529,6 +490,7 @@ type Datastore struct { watchBufferWriteTimeout time.Duration maxRetries uint8 filterMaximumIDCount uint16 + schema common.SchemaInformation optimizedRevisionQuery string validTransactionQuery string diff --git a/internal/datastore/mysql/reader.go b/internal/datastore/mysql/reader.go index 55eaaf4ea7..592b844575 100644 --- a/internal/datastore/mysql/reader.go +++ b/internal/datastore/mysql/reader.go @@ -24,8 +24,9 @@ type mysqlReader struct { txSource txFactory executor common.QueryExecutor - filterer queryFilterer + aliveFilter queryFilterer filterMaximumIDCount uint16 + schema common.SchemaInformation } type queryFilterer func(original sq.SelectBuilder) sq.SelectBuilder @@ -39,19 +40,6 @@ const ( errUnableToReadCount = "unable to read count: %w" ) -var schema = common.NewSchemaInformation( - colNamespace, - colObjectID, - colRelation, - colUsersetNamespace, - colUsersetObjectID, - colUsersetRelation, - colCaveatName, - colExpiration, - common.ExpandedLogicComparison, - "NOW", -) - func (mr *mysqlReader) CountRelationships(ctx context.Context, name string) (int, error) { // Ensure the counter is registered. counters, err := mr.lookupCounters(ctx, name) @@ -68,7 +56,7 @@ func (mr *mysqlReader) CountRelationships(ctx context.Context, name string) (int return 0, err } - qBuilder, err := common.NewSchemaQueryFilterer(schema, mr.filterer(mr.CountRelsQuery), mr.filterMaximumIDCount).FilterWithRelationshipsFilter(relFilter) + qBuilder, err := common.NewSchemaQueryFiltererWithStartingQuery(mr.schema, mr.aliveFilter(mr.CountRelsQuery), mr.filterMaximumIDCount).FilterWithRelationshipsFilter(relFilter) if err != nil { return 0, err } @@ -116,7 +104,7 @@ func (mr *mysqlReader) LookupCounters(ctx context.Context) ([]datastore.Relation } func (mr *mysqlReader) lookupCounters(ctx context.Context, optionalName string) ([]datastore.RelationshipCounter, error) { - query := mr.filterer(mr.ReadCounterQuery) + query := mr.aliveFilter(mr.ReadCounterQuery) if optionalName != noFilterOnCounterName { query = query.Where(sq.Eq{colCounterName: optionalName}) } @@ -177,7 +165,9 @@ func (mr *mysqlReader) QueryRelationships( filter datastore.RelationshipsFilter, opts ...options.QueryOptionsOption, ) (iter datastore.RelationshipIterator, err error) { - qBuilder, err := common.NewSchemaQueryFilterer(schema, mr.filterer(mr.QueryRelsQuery), mr.filterMaximumIDCount).FilterWithRelationshipsFilter(filter) + qBuilder, err := common.NewSchemaQueryFiltererForRelationshipsSelect(mr.schema, mr.filterMaximumIDCount). + WithAdditionalFilter(mr.aliveFilter). + FilterWithRelationshipsFilter(filter) if err != nil { return nil, err } @@ -190,7 +180,8 @@ func (mr *mysqlReader) ReverseQueryRelationships( subjectsFilter datastore.SubjectsFilter, opts ...options.ReverseQueryOptionsOption, ) (iter datastore.RelationshipIterator, err error) { - qBuilder, err := common.NewSchemaQueryFilterer(schema, mr.filterer(mr.QueryRelsQuery), mr.filterMaximumIDCount). + qBuilder, err := common.NewSchemaQueryFiltererForRelationshipsSelect(mr.schema, mr.filterMaximumIDCount). + WithAdditionalFilter(mr.aliveFilter). FilterWithSubjectsSelectors(subjectsFilter.AsSelector()) if err != nil { return nil, err @@ -220,7 +211,7 @@ func (mr *mysqlReader) ReadNamespaceByName(ctx context.Context, nsName string) ( } defer common.LogOnError(ctx, txCleanup) - loaded, version, err := loadNamespace(ctx, nsName, tx, mr.filterer(mr.ReadNamespaceQuery)) + loaded, version, err := loadNamespace(ctx, nsName, tx, mr.aliveFilter(mr.ReadNamespaceQuery)) switch { case errors.As(err, &datastore.NamespaceNotFoundError{}): return nil, datastore.NoRevision, err @@ -265,7 +256,7 @@ func (mr *mysqlReader) ListAllNamespaces(ctx context.Context) ([]datastore.Revis } defer common.LogOnError(ctx, txCleanup) - query := mr.filterer(mr.ReadNamespaceQuery) + query := mr.aliveFilter(mr.ReadNamespaceQuery) nsDefs, err := loadAllNamespaces(ctx, tx, query) if err != nil { @@ -291,7 +282,7 @@ func (mr *mysqlReader) LookupNamespacesWithNames(ctx context.Context, nsNames [] clause = append(clause, sq.Eq{colNamespace: nsName}) } - query := mr.filterer(mr.ReadNamespaceQuery.Where(clause)) + query := mr.aliveFilter(mr.ReadNamespaceQuery.Where(clause)) nsDefs, err := loadAllNamespaces(ctx, tx, query) if err != nil { diff --git a/internal/datastore/postgres/caveat.go b/internal/datastore/postgres/caveat.go index 567dac4a97..4688bad766 100644 --- a/internal/datastore/postgres/caveat.go +++ b/internal/datastore/postgres/caveat.go @@ -33,7 +33,7 @@ const ( ) func (r *pgReader) ReadCaveatByName(ctx context.Context, name string) (*core.CaveatDefinition, datastore.Revision, error) { - filteredReadCaveat := r.filterer(readCaveat) + filteredReadCaveat := r.aliveFilter(readCaveat) sql, args, err := filteredReadCaveat.Where(sq.Eq{colCaveatName: name}).ToSql() if err != nil { return nil, datastore.NoRevision, fmt.Errorf(errReadCaveat, err) @@ -78,7 +78,7 @@ func (r *pgReader) lookupCaveats(ctx context.Context, caveatNames []string) ([]d caveatsWithNames = caveatsWithNames.Where(sq.Eq{colCaveatName: caveatNames}) } - filteredListCaveat := r.filterer(caveatsWithNames) + filteredListCaveat := r.aliveFilter(caveatsWithNames) sql, args, err := filteredListCaveat.ToSql() if err != nil { return nil, fmt.Errorf(errListCaveats, err) diff --git a/internal/datastore/postgres/common/pgx.go b/internal/datastore/postgres/common/pgx.go index 4187d96334..3a0de6803a 100644 --- a/internal/datastore/postgres/common/pgx.go +++ b/internal/datastore/postgres/common/pgx.go @@ -2,9 +2,7 @@ package common import ( "context" - "database/sql" "errors" - "fmt" "time" "github.com/ccoveille/go-safecast" @@ -16,150 +14,28 @@ import ( "github.com/jackc/pgx/v5/pgxpool" "github.com/jackc/pgx/v5/tracelog" "github.com/rs/zerolog" - "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/trace" - "google.golang.org/protobuf/types/known/timestamppb" "github.com/authzed/spicedb/internal/datastore/common" log "github.com/authzed/spicedb/internal/logging" "github.com/authzed/spicedb/pkg/datastore" - corev1 "github.com/authzed/spicedb/pkg/proto/core/v1" - "github.com/authzed/spicedb/pkg/tuple" ) -const errUnableToQueryTuples = "unable to query tuples: %w" - // NewPGXExecutor creates an executor that uses the pgx library to make the specified queries. -func NewPGXExecutor(querier DBFuncQuerier) common.ExecuteQueryFunc { - return func(ctx context.Context, sql string, args []any) (datastore.RelationshipIterator, error) { +func NewPGXExecutor(querier DBFuncQuerier) common.ExecuteReadRelsQueryFunc { + return func(ctx context.Context, queryInfo common.QueryInfo, sql string, args []any) (datastore.RelationshipIterator, error) { span := trace.SpanFromContext(ctx) - return queryRels(ctx, sql, args, span, querier, false) + return common.QueryRelationships[pgx.Rows, map[string]any](ctx, queryInfo, sql, args, span, querier, false) } } -func NewPGXExecutorWithIntegrityOption(querier DBFuncQuerier, withIntegrity bool) common.ExecuteQueryFunc { - return func(ctx context.Context, sql string, args []any) (datastore.RelationshipIterator, error) { +func NewPGXExecutorWithIntegrityOption(querier DBFuncQuerier, withIntegrity bool) common.ExecuteReadRelsQueryFunc { + return func(ctx context.Context, queryInfo common.QueryInfo, sql string, args []any) (datastore.RelationshipIterator, error) { span := trace.SpanFromContext(ctx) - return queryRels(ctx, sql, args, span, querier, withIntegrity) + return common.QueryRelationships[pgx.Rows, map[string]any](ctx, queryInfo, sql, args, span, querier, withIntegrity) } } -// queryRels queries relationships for the given query and transaction. -func queryRels(ctx context.Context, sqlStatement string, args []any, span trace.Span, tx DBFuncQuerier, withIntegrity bool) (datastore.RelationshipIterator, error) { - return func(yield func(tuple.Relationship, error) bool) { - err := tx.QueryFunc(ctx, func(ctx context.Context, rows pgx.Rows) error { - span.AddEvent("Query issued to database") - - var resourceObjectType string - var resourceObjectID string - var resourceRelation string - var subjectObjectType string - var subjectObjectID string - var subjectRelation string - var caveatName sql.NullString - var caveatCtx map[string]any - var expiration *time.Time - - relCount := 0 - for rows.Next() { - var integrity *corev1.RelationshipIntegrity - - if withIntegrity { - var integrityKeyID string - var integrityHash []byte - var timestamp time.Time - - if err := rows.Scan( - &resourceObjectType, - &resourceObjectID, - &resourceRelation, - &subjectObjectType, - &subjectObjectID, - &subjectRelation, - &caveatName, - &caveatCtx, - &expiration, - &integrityKeyID, - &integrityHash, - ×tamp, - ); err != nil { - return fmt.Errorf(errUnableToQueryTuples, fmt.Errorf("scan err: %w", err)) - } - - integrity = &corev1.RelationshipIntegrity{ - KeyId: integrityKeyID, - Hash: integrityHash, - HashedAt: timestamppb.New(timestamp), - } - } else { - if err := rows.Scan( - &resourceObjectType, - &resourceObjectID, - &resourceRelation, - &subjectObjectType, - &subjectObjectID, - &subjectRelation, - &caveatName, - &caveatCtx, - &expiration, - ); err != nil { - return fmt.Errorf(errUnableToQueryTuples, fmt.Errorf("scan err: %w", err)) - } - } - - var caveat *corev1.ContextualizedCaveat - if caveatName.Valid { - var err error - caveat, err = common.ContextualizedCaveatFrom(caveatName.String, caveatCtx) - if err != nil { - return fmt.Errorf(errUnableToQueryTuples, fmt.Errorf("unable to fetch caveat context: %w", err)) - } - } - - if expiration != nil { - // Ensure the returned expiration is always in UTC, as some datastores (like CRDB) - // convert to the local timezone when reading. - utc := expiration.UTC() - expiration = &utc - } - - relCount++ - if !yield(tuple.Relationship{ - RelationshipReference: tuple.RelationshipReference{ - Resource: tuple.ObjectAndRelation{ - ObjectType: resourceObjectType, - ObjectID: resourceObjectID, - Relation: resourceRelation, - }, - Subject: tuple.ObjectAndRelation{ - ObjectType: subjectObjectType, - ObjectID: subjectObjectID, - Relation: subjectRelation, - }, - }, - OptionalCaveat: caveat, - OptionalIntegrity: integrity, - OptionalExpiration: expiration, - }, nil) { - return nil - } - } - - if err := rows.Err(); err != nil { - return fmt.Errorf(errUnableToQueryTuples, fmt.Errorf("rows err: %w", err)) - } - - span.AddEvent("Rels loaded", trace.WithAttributes(attribute.Int("relCount", relCount))) - return nil - }, sqlStatement, args...) - if err != nil { - if !yield(tuple.Relationship{}, err) { - return - } - } - }, nil -} - // ParseConfigWithInstrumentation returns a pgx.ConnConfig that has been instrumented for observability func ParseConfigWithInstrumentation(url string) (*pgx.ConnConfig, error) { connConfig, err := pgx.ParseConfig(url) diff --git a/internal/datastore/postgres/postgres.go b/internal/datastore/postgres/postgres.go index bc1aae6303..3607942b45 100644 --- a/internal/datastore/postgres/postgres.go +++ b/internal/datastore/postgres/postgres.go @@ -316,6 +316,22 @@ func newPostgresDatastore( maxRevisionStaleness := time.Duration(float64(config.revisionQuantization.Nanoseconds())* config.maxRevisionStalenessPercent) * time.Nanosecond + schema := common.NewSchemaInformation( + tableTuple, + colNamespace, + colObjectID, + colRelation, + colUsersetNamespace, + colUsersetObjectID, + colUsersetRelation, + colCaveatContextName, + colCaveatContext, + colExpiration, + common.TupleComparison, + sq.Dollar, + "NOW", + ) + datastore := &pgDatastore{ CachedOptimizedRevisions: revisions.NewCachedOptimizedRevisions( maxRevisionStaleness, @@ -341,6 +357,7 @@ func newPostgresDatastore( isPrimary: isPrimary, inStrictReadMode: config.readStrictMode, filterMaximumIDCount: config.filterMaximumIDCount, + schema: schema, } if isPrimary && config.readStrictMode { @@ -393,6 +410,7 @@ type pgDatastore struct { watchEnabled bool isPrimary bool inStrictReadMode bool + schema common.SchemaInformation includeQueryParametersInTraces bool credentialsProvider datastore.CredentialsProvider @@ -425,6 +443,7 @@ func (pgd *pgDatastore) SnapshotReader(revRaw datastore.Revision) datastore.Read executor, buildLivingObjectFilterForRevision(rev), pgd.filterMaximumIDCount, + pgd.schema, } } @@ -468,6 +487,7 @@ func (pgd *pgDatastore) ReadWriteTx( executor, currentlyLivingObjects, pgd.filterMaximumIDCount, + pgd.schema, }, tx, newXID, diff --git a/internal/datastore/postgres/reader.go b/internal/datastore/postgres/reader.go index 14a1b99a0e..8ab3097e60 100644 --- a/internal/datastore/postgres/reader.go +++ b/internal/datastore/postgres/reader.go @@ -18,40 +18,16 @@ import ( type pgReader struct { query pgxcommon.DBFuncQuerier executor common.QueryExecutor - filterer queryFilterer + aliveFilter queryFilterer filterMaximumIDCount uint16 + schema common.SchemaInformation } type queryFilterer func(original sq.SelectBuilder) sq.SelectBuilder var ( - queryTuples = psql.Select( - colNamespace, - colObjectID, - colRelation, - colUsersetNamespace, - colUsersetObjectID, - colUsersetRelation, - colCaveatContextName, - colCaveatContext, - colExpiration, - ).From(tableTuple) - countRels = psql.Select("COUNT(*)").From(tableTuple) - schema = common.NewSchemaInformation( - colNamespace, - colObjectID, - colRelation, - colUsersetNamespace, - colUsersetObjectID, - colUsersetRelation, - colCaveatContextName, - colExpiration, - common.TupleComparison, - "NOW", - ) - readNamespace = psql. Select(colConfig, colCreatedXid). From(tableNamespace) @@ -85,7 +61,7 @@ func (r *pgReader) CountRelationships(ctx context.Context, name string) (int, er return 0, err } - qBuilder, err := common.NewSchemaQueryFilterer(schema, r.filterer(countRels), r.filterMaximumIDCount).FilterWithRelationshipsFilter(relFilter) + qBuilder, err := common.NewSchemaQueryFiltererWithStartingQuery(r.schema, r.aliveFilter(countRels), r.filterMaximumIDCount).FilterWithRelationshipsFilter(relFilter) if err != nil { return 0, err } @@ -125,7 +101,7 @@ func (r *pgReader) lookupCounters(ctx context.Context, optionalName string) ([]d query = query.Where(sq.Eq{colCounterName: optionalName}) } - sql, args, err := r.filterer(query).ToSql() + sql, args, err := r.aliveFilter(query).ToSql() if err != nil { return nil, fmt.Errorf("unable to lookup counters: %w", err) } @@ -173,7 +149,9 @@ func (r *pgReader) QueryRelationships( filter datastore.RelationshipsFilter, opts ...options.QueryOptionsOption, ) (iter datastore.RelationshipIterator, err error) { - qBuilder, err := common.NewSchemaQueryFilterer(schema, r.filterer(queryTuples), r.filterMaximumIDCount).FilterWithRelationshipsFilter(filter) + qBuilder, err := common.NewSchemaQueryFiltererForRelationshipsSelect(r.schema, r.filterMaximumIDCount). + WithAdditionalFilter(r.aliveFilter). + FilterWithRelationshipsFilter(filter) if err != nil { return nil, err } @@ -186,7 +164,8 @@ func (r *pgReader) ReverseQueryRelationships( subjectsFilter datastore.SubjectsFilter, opts ...options.ReverseQueryOptionsOption, ) (iter datastore.RelationshipIterator, err error) { - qBuilder, err := common.NewSchemaQueryFilterer(schema, r.filterer(queryTuples), r.filterMaximumIDCount). + qBuilder, err := common.NewSchemaQueryFiltererForRelationshipsSelect(r.schema, r.filterMaximumIDCount). + WithAdditionalFilter(r.aliveFilter). FilterWithSubjectsSelectors(subjectsFilter.AsSelector()) if err != nil { return nil, err @@ -209,7 +188,7 @@ func (r *pgReader) ReverseQueryRelationships( } func (r *pgReader) ReadNamespaceByName(ctx context.Context, nsName string) (*core.NamespaceDefinition, datastore.Revision, error) { - loaded, version, err := r.loadNamespace(ctx, nsName, r.query, r.filterer) + loaded, version, err := r.loadNamespace(ctx, nsName, r.query, r.aliveFilter) switch { case errors.As(err, &datastore.NamespaceNotFoundError{}): return nil, datastore.NoRevision, err @@ -239,7 +218,7 @@ func (r *pgReader) loadNamespace(ctx context.Context, namespace string, tx pgxco } func (r *pgReader) ListAllNamespaces(ctx context.Context) ([]datastore.RevisionedNamespace, error) { - nsDefsWithRevisions, err := loadAllNamespaces(ctx, r.query, r.filterer) + nsDefsWithRevisions, err := loadAllNamespaces(ctx, r.query, r.aliveFilter) if err != nil { return nil, fmt.Errorf(errUnableToListNamespaces, err) } @@ -258,7 +237,7 @@ func (r *pgReader) LookupNamespacesWithNames(ctx context.Context, nsNames []stri } nsDefsWithRevisions, err := loadAllNamespaces(ctx, r.query, func(original sq.SelectBuilder) sq.SelectBuilder { - return r.filterer(original).Where(clause) + return r.aliveFilter(original).Where(clause) }) if err != nil { return nil, fmt.Errorf(errUnableToListNamespaces, err) diff --git a/internal/datastore/postgres/readwrite.go b/internal/datastore/postgres/readwrite.go index 6bc3ead8c4..e135b158ad 100644 --- a/internal/datastore/postgres/readwrite.go +++ b/internal/datastore/postgres/readwrite.go @@ -585,7 +585,7 @@ func (rwt *pgReadWriteTXN) WriteNamespaces(ctx context.Context, newConfigs ...*c } func (rwt *pgReadWriteTXN) DeleteNamespaces(ctx context.Context, nsNames ...string) error { - filterer := func(original sq.SelectBuilder) sq.SelectBuilder { + aliveFilter := func(original sq.SelectBuilder) sq.SelectBuilder { return original.Where(sq.Eq{colDeletedXid: liveDeletedTxnID}) } @@ -593,7 +593,7 @@ func (rwt *pgReadWriteTXN) DeleteNamespaces(ctx context.Context, nsNames ...stri tplClauses := make([]sq.Sqlizer, 0, len(nsNames)) querier := pgxcommon.QuerierFuncsFor(rwt.tx) for _, nsName := range nsNames { - _, _, err := rwt.loadNamespace(ctx, nsName, querier, filterer) + _, _, err := rwt.loadNamespace(ctx, nsName, querier, aliveFilter) switch { case errors.As(err, &datastore.NamespaceNotFoundError{}): return err diff --git a/internal/datastore/postgres/stats.go b/internal/datastore/postgres/stats.go index b428cd4657..0e0bea63f6 100644 --- a/internal/datastore/postgres/stats.go +++ b/internal/datastore/postgres/stats.go @@ -51,7 +51,7 @@ func (pgd *pgDatastore) Statistics(ctx context.Context) (datastore.Stats, error) return datastore.Stats{}, fmt.Errorf("unable to prepare row count sql: %w", err) } - filterer := func(original sq.SelectBuilder) sq.SelectBuilder { + aliveFilter := func(original sq.SelectBuilder) sq.SelectBuilder { return original.Where(sq.Eq{colDeletedXid: liveDeletedTxnID}) } @@ -69,7 +69,7 @@ func (pgd *pgDatastore) Statistics(ctx context.Context) (datastore.Stats, error) return fmt.Errorf("unable to query unique ID: %w", err) } - nsDefsWithRevisions, err := loadAllNamespaces(ctx, pgxcommon.QuerierFuncsFor(tx), filterer) + nsDefsWithRevisions, err := loadAllNamespaces(ctx, pgxcommon.QuerierFuncsFor(tx), aliveFilter) if err != nil { return fmt.Errorf("unable to load namespaces: %w", err) } diff --git a/internal/datastore/spanner/reader.go b/internal/datastore/spanner/reader.go index c165dc7699..318b44f36b 100644 --- a/internal/datastore/spanner/reader.go +++ b/internal/datastore/spanner/reader.go @@ -7,6 +7,7 @@ import ( "time" "cloud.google.com/go/spanner" + sq "github.com/Masterminds/squirrel" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/trace" "google.golang.org/grpc/codes" @@ -54,7 +55,7 @@ func (sr spannerReader) CountRelationships(ctx context.Context, name string) (in return 0, err } - builder, err := common.NewSchemaQueryFilterer(schema, countRels, sr.filterMaximumIDCount).FilterWithRelationshipsFilter(relFilter) + builder, err := common.NewSchemaQueryFiltererWithStartingQuery(schema, countRels, sr.filterMaximumIDCount).FilterWithRelationshipsFilter(relFilter) if err != nil { return 0, err } @@ -134,7 +135,7 @@ func (sr spannerReader) QueryRelationships( filter datastore.RelationshipsFilter, opts ...options.QueryOptionsOption, ) (iter datastore.RelationshipIterator, err error) { - qBuilder, err := common.NewSchemaQueryFilterer(schema, queryTuples, sr.filterMaximumIDCount).FilterWithRelationshipsFilter(filter) + qBuilder, err := common.NewSchemaQueryFiltererForRelationshipsSelect(schema, sr.filterMaximumIDCount).FilterWithRelationshipsFilter(filter) if err != nil { return nil, err } @@ -147,7 +148,7 @@ func (sr spannerReader) ReverseQueryRelationships( subjectsFilter datastore.SubjectsFilter, opts ...options.ReverseQueryOptionsOption, ) (iter datastore.RelationshipIterator, err error) { - qBuilder, err := common.NewSchemaQueryFilterer(schema, queryTuples, sr.filterMaximumIDCount). + qBuilder, err := common.NewSchemaQueryFiltererForRelationshipsSelect(schema, sr.filterMaximumIDCount). FilterWithSubjectsSelectors(subjectsFilter.AsSelector()) if err != nil { return nil, err @@ -171,8 +172,8 @@ func (sr spannerReader) ReverseQueryRelationships( var errStopIterator = fmt.Errorf("stop iteration") -func queryExecutor(txSource txFactory) common.ExecuteQueryFunc { - return func(ctx context.Context, sql string, args []any) (datastore.RelationshipIterator, error) { +func queryExecutor(txSource txFactory) common.ExecuteReadRelsQueryFunc { + return func(ctx context.Context, queryInfo common.QueryInfo, sql string, args []any) (datastore.RelationshipIterator, error) { return func(yield func(tuple.Relationship, error) bool) { span := trace.SpanFromContext(ctx) span.AddEvent("Query issued to database") @@ -185,27 +186,29 @@ func queryExecutor(txSource txFactory) common.ExecuteQueryFunc { relCount := 0 defer span.SetAttributes(attribute.Int("count", relCount)) + var resourceObjectType string + var resourceObjectID string + var relation string + var subjectObjectType string + var subjectObjectID string + var subjectRelation string + var caveatName spanner.NullString + var caveatCtx spanner.NullJSON + var expirationOrNull spanner.NullTime + + colsToSelect := make([]any, 0, 8) + + colsToSelect = common.StaticValueOrAddColumnForSelect(colsToSelect, queryInfo, queryInfo.Schema.ColNamespace, &resourceObjectType) + colsToSelect = common.StaticValueOrAddColumnForSelect(colsToSelect, queryInfo, queryInfo.Schema.ColObjectID, &resourceObjectID) + colsToSelect = common.StaticValueOrAddColumnForSelect(colsToSelect, queryInfo, queryInfo.Schema.ColRelation, &relation) + colsToSelect = common.StaticValueOrAddColumnForSelect(colsToSelect, queryInfo, queryInfo.Schema.ColUsersetNamespace, &subjectObjectType) + colsToSelect = common.StaticValueOrAddColumnForSelect(colsToSelect, queryInfo, queryInfo.Schema.ColUsersetObjectID, &subjectObjectID) + colsToSelect = common.StaticValueOrAddColumnForSelect(colsToSelect, queryInfo, queryInfo.Schema.ColUsersetRelation, &subjectRelation) + + colsToSelect = append(colsToSelect, &caveatName, &caveatCtx, &expirationOrNull) + if err := iter.Do(func(row *spanner.Row) error { - var resourceObjectType string - var resourceObjectID string - var relation string - var subjectObjectType string - var subjectObjectID string - var subjectRelation string - var caveatName spanner.NullString - var caveatCtx spanner.NullJSON - var expirationOrNull spanner.NullTime - err := row.Columns( - &resourceObjectType, - &resourceObjectID, - &relation, - &subjectObjectType, - &subjectObjectID, - &subjectRelation, - &caveatName, - &caveatCtx, - &expirationOrNull, - ) + err := row.Columns(colsToSelect...) if err != nil { return err } @@ -355,18 +358,6 @@ func readAllNamespaces(iter *spanner.RowIterator, span trace.Span) ([]datastore. return allNamespaces, nil } -var queryTuples = sql.Select( - colNamespace, - colObjectID, - colRelation, - colUsersetNamespace, - colUsersetObjectID, - colUsersetRelation, - colCaveatName, - colCaveatContext, - colExpiration, -).From(tableRelationship) - var countRels = sql.Select("COUNT(*)").From(tableRelationship) var queryTuplesForDelete = sql.Select( @@ -379,6 +370,7 @@ var queryTuplesForDelete = sql.Select( ).From(tableRelationship) var schema = common.NewSchemaInformation( + tableRelationship, colNamespace, colObjectID, colRelation, @@ -386,8 +378,10 @@ var schema = common.NewSchemaInformation( colUsersetObjectID, colUsersetRelation, colCaveatName, + colCaveatContext, colExpiration, common.ExpandedLogicComparison, + sq.AtP, "CURRENT_TIMESTAMP", ) From 60fc3d930eba0172589448a20e94bd28ad9cbf73 Mon Sep 17 00:00:00 2001 From: Joseph Schorr Date: Sun, 27 Oct 2024 16:58:41 -0400 Subject: [PATCH 02/15] Skip loading of caveats in SQL when unnecessary --- internal/datastore/common/relationships.go | 24 ++++++++--- internal/datastore/common/sql.go | 20 +++++++-- internal/datastore/memdb/readonly.go | 21 ++++++--- internal/datastore/mysql/readwrite.go | 3 ++ internal/datastore/spanner/reader.go | 11 ++++- internal/dispatch/graph/check_test.go | 4 +- internal/graph/check.go | 14 +++++- internal/testfixtures/datastore.go | 4 ++ pkg/datastore/options/options.go | 7 +-- .../options/zz_generated.query_options.go | 9 ++++ pkg/datastore/test/relationships.go | 43 ++++++++++++++++--- 11 files changed, 131 insertions(+), 29 deletions(-) diff --git a/internal/datastore/common/relationships.go b/internal/datastore/common/relationships.go index a08cafd958..cffb92c693 100644 --- a/internal/datastore/common/relationships.go +++ b/internal/datastore/common/relationships.go @@ -79,11 +79,21 @@ func QueryRelationships[R Rows, C ~map[string]any](ctx context.Context, queryInf colsToSelect = StaticValueOrAddColumnForSelect(colsToSelect, queryInfo, queryInfo.Schema.ColUsersetObjectID, &subjectObjectID) colsToSelect = StaticValueOrAddColumnForSelect(colsToSelect, queryInfo, queryInfo.Schema.ColUsersetRelation, &subjectRelation) - colsToSelect = append(colsToSelect, &caveatName, &caveatCtx, &expiration) + if !queryInfo.SkipCaveats { + colsToSelect = append(colsToSelect, &caveatName, &caveatCtx) + } + + colsToSelect = append(colsToSelect, &expiration) + if withIntegrity { colsToSelect = append(colsToSelect, &integrityKeyID, &integrityHash, ×tamp) } + if len(colsToSelect) == 0 { + var unused int + colsToSelect = append(colsToSelect, &unused) + } + return func(yield func(tuple.Relationship, error) bool) { err := tx.QueryFunc(ctx, func(ctx context.Context, rows R) error { var r Rows = rows @@ -101,11 +111,13 @@ func QueryRelationships[R Rows, C ~map[string]any](ctx context.Context, queryInf } var caveat *corev1.ContextualizedCaveat - if caveatName.Valid { - var err error - caveat, err = ContextualizedCaveatFrom(caveatName.String, caveatCtx) - if err != nil { - return fmt.Errorf(errUnableToQueryRels, fmt.Errorf("unable to fetch caveat context: %w", err)) + if !queryInfo.SkipCaveats { + if caveatName.Valid { + var err error + caveat, err = ContextualizedCaveatFrom(caveatName.String, caveatCtx) + if err != nil { + return fmt.Errorf(errUnableToQueryRels, fmt.Errorf("unable to fetch caveat context: %w", err)) + } } } diff --git a/internal/datastore/common/sql.go b/internal/datastore/common/sql.go index f313edf322..a781cd7b8e 100644 --- a/internal/datastore/common/sql.go +++ b/internal/datastore/common/sql.go @@ -702,8 +702,18 @@ func (tqs QueryExecutor) ExecuteQuery( columnNamesToSelect = checkColumn(columnNamesToSelect, query.filteringColumnTracker, query.schema.ColUsersetObjectID) columnNamesToSelect = checkColumn(columnNamesToSelect, query.filteringColumnTracker, query.schema.ColUsersetRelation) - columnNamesToSelect = append(columnNamesToSelect, query.schema.ColCaveatName, query.schema.ColCaveatContext, query.schema.ColExpiration) + columnNamesToSelect = append(columnNamesToSelect, query.schema.ColExpiration) + + if !queryOpts.SkipCaveats { + columnNamesToSelect = append(columnNamesToSelect, query.schema.ColCaveatName, query.schema.ColCaveatContext) + } + + selectingNoColumns := false columnNamesToSelect = append(columnNamesToSelect, query.schema.ExtraFields...) + if len(columnNamesToSelect) == 0 { + columnNamesToSelect = append(columnNamesToSelect, "1") + selectingNoColumns = true + } toExecute.queryBuilder = toExecute.queryBuilder.Columns(columnNamesToSelect...) @@ -719,7 +729,7 @@ func (tqs QueryExecutor) ExecuteQuery( return nil, err } - return tqs.Executor(ctx, QueryInfo{query.schema, query.filteringColumnTracker}, sql, args) + return tqs.Executor(ctx, QueryInfo{query.schema, query.filteringColumnTracker, queryOpts.SkipCaveats, selectingNoColumns}, sql, args) } func checkColumn(columns []string, tracker map[string]ColumnTracker, colName string) []string { @@ -731,8 +741,10 @@ func checkColumn(columns []string, tracker map[string]ColumnTracker, colName str // QueryInfo holds the schema information and filtering values for a query. type QueryInfo struct { - Schema SchemaInformation - FilteringValues map[string]ColumnTracker + Schema SchemaInformation + FilteringValues map[string]ColumnTracker + SkipCaveats bool + SelectingNoColumns bool } // ExecuteReadRelsQueryFunc is a function that can be used to execute a single rendered SQL query. diff --git a/internal/datastore/memdb/readonly.go b/internal/datastore/memdb/readonly.go index 87ef405c1f..7878b3c276 100644 --- a/internal/datastore/memdb/readonly.go +++ b/internal/datastore/memdb/readonly.go @@ -151,11 +151,11 @@ func (r *memdbReader) QueryRelationships( fallthrough case options.ByResource: - iter := newMemdbTupleIterator(r.now, filteredIterator, queryOpts.Limit) + iter := newMemdbTupleIterator(r.now, filteredIterator, queryOpts.Limit, queryOpts.SkipCaveats) return iter, nil case options.BySubject: - return newSubjectSortedIterator(r.now, filteredIterator, queryOpts.Limit) + return newSubjectSortedIterator(r.now, filteredIterator, queryOpts.Limit, queryOpts.SkipCaveats) default: return nil, spiceerrors.MustBugf("unsupported sort order: %v", queryOpts.Sort) @@ -214,11 +214,11 @@ func (r *memdbReader) ReverseQueryRelationships( fallthrough case options.ByResource: - iter := newMemdbTupleIterator(r.now, filteredIterator, queryOpts.LimitForReverse) + iter := newMemdbTupleIterator(r.now, filteredIterator, queryOpts.LimitForReverse, false) return iter, nil case options.BySubject: - return newSubjectSortedIterator(r.now, filteredIterator, queryOpts.LimitForReverse) + return newSubjectSortedIterator(r.now, filteredIterator, queryOpts.LimitForReverse, false) default: return nil, spiceerrors.MustBugf("unsupported sort order: %v", queryOpts.SortForReverse) @@ -476,7 +476,7 @@ func makeCursorFilterFn(after options.Cursor, order options.SortOrder) func(tpl return noopCursorFilter } -func newSubjectSortedIterator(now time.Time, it memdb.ResultIterator, limit *uint64) (datastore.RelationshipIterator, error) { +func newSubjectSortedIterator(now time.Time, it memdb.ResultIterator, limit *uint64, skipCaveats bool) (datastore.RelationshipIterator, error) { results := make([]tuple.Relationship, 0) // Coalesce all of the results into memory @@ -490,6 +490,10 @@ func newSubjectSortedIterator(now time.Time, it memdb.ResultIterator, limit *uin continue } + if skipCaveats && rt.OptionalCaveat != nil { + return nil, spiceerrors.MustBugf("unexpected caveat in result for relationship: %v", rt) + } + results = append(results, rt) } @@ -526,7 +530,7 @@ func eq(lhsNamespace, lhsObjectID, lhsRelation string, rhs tuple.ObjectAndRelati return lhsNamespace == rhs.ObjectType && lhsObjectID == rhs.ObjectID && lhsRelation == rhs.Relation } -func newMemdbTupleIterator(now time.Time, it memdb.ResultIterator, limit *uint64) datastore.RelationshipIterator { +func newMemdbTupleIterator(now time.Time, it memdb.ResultIterator, limit *uint64, skipCaveats bool) datastore.RelationshipIterator { var count uint64 return func(yield func(tuple.Relationship, error) bool) { for { @@ -551,6 +555,11 @@ func newMemdbTupleIterator(now time.Time, it memdb.ResultIterator, limit *uint64 continue } + if skipCaveats && rt.OptionalCaveat != nil { + yield(rt, fmt.Errorf("unexpected caveat in result for relationship: %v", rt)) + return + } + if !yield(rt, err) { return } diff --git a/internal/datastore/mysql/readwrite.go b/internal/datastore/mysql/readwrite.go index cdc93c454e..d8b5b40647 100644 --- a/internal/datastore/mysql/readwrite.go +++ b/internal/datastore/mysql/readwrite.go @@ -17,6 +17,7 @@ import ( "github.com/ccoveille/go-safecast" "github.com/go-sql-driver/mysql" "github.com/jzelinskie/stringz" + "golang.org/x/exp/maps" "github.com/authzed/spicedb/internal/datastore/common" "github.com/authzed/spicedb/internal/datastore/revisions" @@ -59,6 +60,8 @@ func (cc *structpbWrapper) Scan(val any) error { if !ok { return fmt.Errorf("unsupported type: %T", v) } + + maps.Clear(*cc) return json.Unmarshal(v, &cc) } diff --git a/internal/datastore/spanner/reader.go b/internal/datastore/spanner/reader.go index 318b44f36b..216361886b 100644 --- a/internal/datastore/spanner/reader.go +++ b/internal/datastore/spanner/reader.go @@ -205,7 +205,16 @@ func queryExecutor(txSource txFactory) common.ExecuteReadRelsQueryFunc { colsToSelect = common.StaticValueOrAddColumnForSelect(colsToSelect, queryInfo, queryInfo.Schema.ColUsersetObjectID, &subjectObjectID) colsToSelect = common.StaticValueOrAddColumnForSelect(colsToSelect, queryInfo, queryInfo.Schema.ColUsersetRelation, &subjectRelation) - colsToSelect = append(colsToSelect, &caveatName, &caveatCtx, &expirationOrNull) + colsToSelect = append(colsToSelect, &expirationOrNull) + + if !queryInfo.SkipCaveats { + colsToSelect = append(colsToSelect, &caveatName, &caveatCtx) + } + + if len(colsToSelect) == 0 { + var unused int64 + colsToSelect = append(colsToSelect, &unused) + } if err := iter.Do(func(row *spanner.Row) error { err := row.Columns(colsToSelect...) diff --git a/internal/dispatch/graph/check_test.go b/internal/dispatch/graph/check_test.go index 83453c7eb5..189337c207 100644 --- a/internal/dispatch/graph/check_test.go +++ b/internal/dispatch/graph/check_test.go @@ -1251,7 +1251,7 @@ func TestCheckPermissionOverSchema(t *testing.T) { definition user {} definition role { - relation member: user + relation member: user with somecaveat } definition resource { @@ -1287,7 +1287,7 @@ func TestCheckPermissionOverSchema(t *testing.T) { definition user {} definition role { - relation member: user + relation member: user with somecaveat } definition resource { diff --git a/internal/graph/check.go b/internal/graph/check.go index a4d6906b41..63df50d0bf 100644 --- a/internal/graph/check.go +++ b/internal/graph/check.go @@ -19,6 +19,7 @@ import ( "github.com/authzed/spicedb/internal/namespace" "github.com/authzed/spicedb/internal/taskrunner" "github.com/authzed/spicedb/pkg/datastore" + "github.com/authzed/spicedb/pkg/datastore/options" "github.com/authzed/spicedb/pkg/genutil/mapz" "github.com/authzed/spicedb/pkg/middleware/nodeid" nspkg "github.com/authzed/spicedb/pkg/namespace" @@ -324,6 +325,8 @@ func (cc *ConcurrentChecker) checkDirect(ctx context.Context, crc currentRequest hasNonTerminals := false hasDirectSubject := false hasWildcardSubject := false + directSubjectOrWildcardCanHaveCaveats := false + nonTerminalsCanHaveCaveats := false defer func() { if hasNonTerminals { @@ -348,6 +351,10 @@ func (cc *ConcurrentChecker) checkDirect(ctx context.Context, crc currentRequest } else if allowedDirectRelation.GetRelation() == crc.parentReq.Subject.Relation { hasDirectSubject = true } + + if allowedDirectRelation.RequiredCaveat != nil { + directSubjectOrWildcardCanHaveCaveats = true + } } // If the relation found is not an ellipsis, then this is a nested relation that @@ -357,6 +364,9 @@ func (cc *ConcurrentChecker) checkDirect(ctx context.Context, crc currentRequest // relations can reach the target subject type. if allowedDirectRelation.GetRelation() != tuple.Ellipsis { hasNonTerminals = true + if allowedDirectRelation.RequiredCaveat != nil { + nonTerminalsCanHaveCaveats = true + } } } @@ -395,7 +405,7 @@ func (cc *ConcurrentChecker) checkDirect(ctx context.Context, crc currentRequest OptionalSubjectsSelectors: subjectSelectors, } - it, err := ds.QueryRelationships(ctx, filter) + it, err := ds.QueryRelationships(ctx, filter, options.WithSkipCaveats(!directSubjectOrWildcardCanHaveCaveats)) if err != nil { return checkResultError(NewCheckFailureErr(err), emptyMetadata) } @@ -444,7 +454,7 @@ func (cc *ConcurrentChecker) checkDirect(ctx context.Context, crc currentRequest }, } - it, err := ds.QueryRelationships(ctx, filter) + it, err := ds.QueryRelationships(ctx, filter, options.WithSkipCaveats(!nonTerminalsCanHaveCaveats)) if err != nil { return checkResultError(NewCheckFailureErr(err), emptyMetadata) } diff --git a/internal/testfixtures/datastore.go b/internal/testfixtures/datastore.go index 53d0b2e8d4..b73d9235dc 100644 --- a/internal/testfixtures/datastore.go +++ b/internal/testfixtures/datastore.go @@ -37,6 +37,7 @@ var DocumentNS = ns.Namespace( ns.MustRelation("owner", nil, ns.AllowedRelation("user", "..."), + ns.AllowedRelationWithCaveat("user", "...", ns.AllowedCaveat("test")), ), ns.MustRelation("editor", nil, @@ -45,6 +46,7 @@ var DocumentNS = ns.Namespace( ns.MustRelation("viewer", nil, ns.AllowedRelation("user", "..."), + ns.AllowedRelationWithCaveat("user", "...", ns.AllowedCaveat("test")), ), ns.MustRelation("viewer_and_editor", nil, @@ -85,6 +87,7 @@ var FolderNS = ns.Namespace( ns.MustRelation("owner", nil, ns.AllowedRelation("user", "..."), + ns.AllowedRelationWithCaveat("user", "...", ns.AllowedCaveat("test")), ), ns.MustRelation("editor", nil, @@ -94,6 +97,7 @@ var FolderNS = ns.Namespace( nil, ns.AllowedRelation("user", "..."), ns.AllowedRelation("folder", "viewer"), + ns.AllowedRelationWithCaveat("folder", "viewer", ns.AllowedCaveat("test")), ), ns.MustRelation("parent", nil, ns.AllowedRelation("folder", "...")), ns.MustRelation("edit", diff --git a/pkg/datastore/options/options.go b/pkg/datastore/options/options.go index f8fe66f4f4..6a55c0582d 100644 --- a/pkg/datastore/options/options.go +++ b/pkg/datastore/options/options.go @@ -43,9 +43,10 @@ func ToRelationship(c Cursor) *tuple.Relationship { // QueryOptions are the options that can affect the results of a normal forward query. type QueryOptions struct { - Limit *uint64 `debugmap:"visible"` - Sort SortOrder `debugmap:"visible"` - After Cursor `debugmap:"visible"` + Limit *uint64 `debugmap:"visible"` + Sort SortOrder `debugmap:"visible"` + After Cursor `debugmap:"visible"` + SkipCaveats bool `debugmap:"visible"` } // ReverseQueryOptions are the options that can affect the results of a reverse query. diff --git a/pkg/datastore/options/zz_generated.query_options.go b/pkg/datastore/options/zz_generated.query_options.go index f761b06b66..79db45999e 100644 --- a/pkg/datastore/options/zz_generated.query_options.go +++ b/pkg/datastore/options/zz_generated.query_options.go @@ -34,6 +34,7 @@ func (q *QueryOptions) ToOption() QueryOptionsOption { to.Limit = q.Limit to.Sort = q.Sort to.After = q.After + to.SkipCaveats = q.SkipCaveats } } @@ -43,6 +44,7 @@ func (q QueryOptions) DebugMap() map[string]any { debugMap["Limit"] = helpers.DebugValue(q.Limit, false) debugMap["Sort"] = helpers.DebugValue(q.Sort, false) debugMap["After"] = helpers.DebugValue(q.After, false) + debugMap["SkipCaveats"] = helpers.DebugValue(q.SkipCaveats, false) return debugMap } @@ -83,6 +85,13 @@ func WithAfter(after Cursor) QueryOptionsOption { } } +// WithSkipCaveats returns an option that can set SkipCaveats on a QueryOptions +func WithSkipCaveats(skipCaveats bool) QueryOptionsOption { + return func(q *QueryOptions) { + q.SkipCaveats = skipCaveats + } +} + type ReverseQueryOptionsOption func(r *ReverseQueryOptions) // NewReverseQueryOptionsWithOptions creates a new ReverseQueryOptions with the passed in options set diff --git a/pkg/datastore/test/relationships.go b/pkg/datastore/test/relationships.go index cac514944a..19f534d3dd 100644 --- a/pkg/datastore/test/relationships.go +++ b/pkg/datastore/test/relationships.go @@ -1113,10 +1113,11 @@ func RecreateRelationshipsAfterDeleteWithFilter(t *testing.T, tester DatastoreTe // QueryRelationshipsWithVariousFiltersTest tests various relationship filters for query relationships. func QueryRelationshipsWithVariousFiltersTest(t *testing.T, tester DatastoreTester) { tcs := []struct { - name string - filter datastore.RelationshipsFilter - relationships []string - expected []string + name string + filter datastore.RelationshipsFilter + withoutCaveats bool + relationships []string + expected []string }{ { name: "resource type", @@ -1475,6 +1476,38 @@ func QueryRelationshipsWithVariousFiltersTest(t *testing.T, tester DatastoreTest "folder:someotherfolder#viewer@user:tom", }, }, + { + name: "resource type with caveats", + filter: datastore.RelationshipsFilter{ + OptionalResourceType: "document", + }, + relationships: []string{ + "document:first#viewer@user:tom[firstcaveat]", + "document:second#viewer@user:tom[secondcaveat]", + "folder:secondfolder#viewer@user:tom", + "folder:someotherfolder#viewer@user:tom", + }, + expected: []string{"document:first#viewer@user:tom[firstcaveat]", "document:second#viewer@user:tom[secondcaveat]"}, + }, + { + name: "resource type with caveats and context", + filter: datastore.RelationshipsFilter{ + OptionalResourceType: "document", + }, + relationships: []string{ + "document:first#viewer@user:tom[firstcaveat:{\"foo\":\"bar\"}]", + "document:second#viewer@user:tom[secondcaveat]", + "document:third#viewer@user:tom[secondcaveat:{\"bar\":\"baz\"}]", + "folder:secondfolder#viewer@user:tom", + "folder:someotherfolder#viewer@user:tom", + }, + expected: []string{ + "document:first#viewer@user:tom[firstcaveat:{\"foo\":\"bar\"}]", + "document:second#viewer@user:tom[secondcaveat]", + "document:third#viewer@user:tom[secondcaveat:{\"bar\":\"baz\"}]", + }, + }, + { name: "relationship expiration", filter: datastore.RelationshipsFilter{ @@ -1610,7 +1643,7 @@ func QueryRelationshipsWithVariousFiltersTest(t *testing.T, tester DatastoreTest require.NoError(err) reader := ds.SnapshotReader(headRev) - iter, err := reader.QueryRelationships(ctx, tc.filter) + iter, err := reader.QueryRelationships(ctx, tc.filter, options.WithSkipCaveats(tc.withoutCaveats)) require.NoError(err) var results []string From 2270811875328cff69136ad56d65f70bb54e2512 Mon Sep 17 00:00:00 2001 From: Joseph Schorr Date: Tue, 5 Nov 2024 12:44:29 -0500 Subject: [PATCH 03/15] Add additional testing for column elision --- internal/datastore/common/sql_test.go | 259 ++++++++++++++++++++++++++ 1 file changed, 259 insertions(+) diff --git a/internal/datastore/common/sql_test.go b/internal/datastore/common/sql_test.go index 5e6bc060fd..042bacc3f1 100644 --- a/internal/datastore/common/sql_test.go +++ b/internal/datastore/common/sql_test.go @@ -1,6 +1,7 @@ package common import ( + "context" "testing" "github.com/authzed/spicedb/pkg/datastore/options" @@ -593,3 +594,261 @@ func TestSchemaQueryFilterer(t *testing.T) { }) } } + +func TestExecuteQuery(t *testing.T) { + tcs := []struct { + name string + run func(filterer SchemaQueryFilterer) SchemaQueryFilterer + options []options.QueryOptionsOption + expectedSQL string + expectedArgs []any + expectedSelectingNoColumns bool + expectedSkipCaveats bool + }{ + { + name: "filter by static resource type", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + return filterer.FilterToResourceType("sometype") + }, + expectedSQL: "SELECT object_id, relation, subject_ns, subject_object_id, subject_relation, caveat, caveat_context FROM relationtuples WHERE ns = ?", + expectedArgs: []any{"sometype"}, + }, + { + name: "filter by static resource type and resource ID", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + return filterer.FilterToResourceType("sometype").FilterToResourceID("someobj") + }, + expectedSQL: "SELECT relation, subject_ns, subject_object_id, subject_relation, caveat, caveat_context FROM relationtuples WHERE ns = ? AND object_id = ?", + expectedArgs: []any{"sometype", "someobj"}, + }, + { + name: "filter by static resource type and resource ID prefix", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + return filterer.FilterToResourceType("sometype").MustFilterWithResourceIDPrefix("someprefix") + }, + expectedSQL: "SELECT object_id, relation, subject_ns, subject_object_id, subject_relation, caveat, caveat_context FROM relationtuples WHERE ns = ? AND object_id LIKE ?", + expectedArgs: []any{"sometype", "someprefix%"}, + }, + { + name: "filter by static resource type and resource IDs", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + return filterer.FilterToResourceType("sometype").MustFilterToResourceIDs([]string{"someobj", "anotherobj"}) + }, + expectedSQL: "SELECT object_id, relation, subject_ns, subject_object_id, subject_relation, caveat, caveat_context FROM relationtuples WHERE ns = ? AND object_id IN (?,?)", + expectedArgs: []any{"sometype", "someobj", "anotherobj"}, + }, + { + name: "filter by static resource type, resource ID and relation", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + return filterer.FilterToResourceType("sometype").FilterToResourceID("someobj").FilterToRelation("somerel") + }, + expectedSQL: "SELECT subject_ns, subject_object_id, subject_relation, caveat, caveat_context FROM relationtuples WHERE ns = ? AND object_id = ? AND relation = ?", + expectedArgs: []any{"sometype", "someobj", "somerel"}, + }, + { + name: "filter by static resource type, resource ID, relation and subject type", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + return filterer.FilterToResourceType("sometype").FilterToResourceID("someobj").FilterToRelation("somerel").FilterToSubjectFilter(&v1.SubjectFilter{ + SubjectType: "subns", + }) + }, + expectedSQL: "SELECT subject_object_id, subject_relation, caveat, caveat_context FROM relationtuples WHERE ns = ? AND object_id = ? AND relation = ? AND subject_ns = ?", + expectedArgs: []any{"sometype", "someobj", "somerel", "subns"}, + }, + { + name: "filter by static resource type, resource ID, relation, subject type and subject ID", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + return filterer.FilterToResourceType("sometype").FilterToResourceID("someobj").FilterToRelation("somerel").FilterToSubjectFilter(&v1.SubjectFilter{ + SubjectType: "subns", + OptionalSubjectId: "subid", + }) + }, + expectedSQL: "SELECT subject_relation, caveat, caveat_context FROM relationtuples WHERE ns = ? AND object_id = ? AND relation = ? AND subject_ns = ? AND subject_object_id = ?", + expectedArgs: []any{"sometype", "someobj", "somerel", "subns", "subid"}, + }, + { + name: "filter by static resource type, resource ID, relation, subject type, subject ID and subject relation", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + return filterer.FilterToResourceType("sometype").FilterToResourceID("someobj").FilterToRelation("somerel").FilterToSubjectFilter(&v1.SubjectFilter{ + SubjectType: "subns", + OptionalSubjectId: "subid", + OptionalRelation: &v1.SubjectFilter_RelationFilter{ + Relation: "subrel", + }, + }) + }, + expectedSQL: "SELECT caveat, caveat_context FROM relationtuples WHERE ns = ? AND object_id = ? AND relation = ? AND subject_ns = ? AND subject_object_id = ? AND subject_relation = ?", + expectedArgs: []any{"sometype", "someobj", "somerel", "subns", "subid", "subrel"}, + }, + { + name: "filter by static everything without caveats", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + return filterer.FilterToResourceType("sometype").FilterToResourceID("someobj").FilterToRelation("somerel").FilterToSubjectFilter(&v1.SubjectFilter{ + SubjectType: "subns", + OptionalSubjectId: "subid", + OptionalRelation: &v1.SubjectFilter_RelationFilter{ + Relation: "subrel", + }, + }) + }, + options: []options.QueryOptionsOption{ + options.WithSkipCaveats(true), + }, + expectedSkipCaveats: true, + expectedSelectingNoColumns: true, + expectedSQL: "SELECT 1 FROM relationtuples WHERE ns = ? AND object_id = ? AND relation = ? AND subject_ns = ? AND subject_object_id = ? AND subject_relation = ?", + expectedArgs: []any{"sometype", "someobj", "somerel", "subns", "subid", "subrel"}, + }, + { + name: "filter by static everything (except one field) without caveats", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + return filterer.FilterToResourceType("sometype").MustFilterToResourceIDs([]string{"someobj", "anotherobj"}).FilterToRelation("somerel").FilterToSubjectFilter(&v1.SubjectFilter{ + SubjectType: "subns", + OptionalSubjectId: "subid", + OptionalRelation: &v1.SubjectFilter_RelationFilter{ + Relation: "subrel", + }, + }) + }, + options: []options.QueryOptionsOption{ + options.WithSkipCaveats(true), + }, + expectedSkipCaveats: true, + expectedSQL: "SELECT object_id FROM relationtuples WHERE ns = ? AND object_id IN (?,?) AND relation = ? AND subject_ns = ? AND subject_object_id = ? AND subject_relation = ?", + expectedArgs: []any{"sometype", "someobj", "anotherobj", "somerel", "subns", "subid", "subrel"}, + }, + { + name: "filter by static resource type with no caveats", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + return filterer.FilterToResourceType("sometype") + }, + options: []options.QueryOptionsOption{ + options.WithSkipCaveats(true), + }, + expectedSkipCaveats: true, + expectedSQL: "SELECT object_id, relation, subject_ns, subject_object_id, subject_relation FROM relationtuples WHERE ns = ?", + expectedArgs: []any{"sometype"}, + }, + { + name: "filter by just subject type", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + return filterer.FilterToSubjectFilter(&v1.SubjectFilter{ + SubjectType: "subns", + }) + }, + expectedSQL: "SELECT ns, object_id, relation, subject_object_id, subject_relation, caveat, caveat_context FROM relationtuples WHERE subject_ns = ?", + expectedArgs: []any{"subns"}, + }, + { + name: "filter by just subject type and subject ID", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + return filterer.FilterToSubjectFilter(&v1.SubjectFilter{ + SubjectType: "subns", + OptionalSubjectId: "subid", + }) + }, + expectedSQL: "SELECT ns, object_id, relation, subject_relation, caveat, caveat_context FROM relationtuples WHERE subject_ns = ? AND subject_object_id = ?", + expectedArgs: []any{"subns", "subid"}, + }, + { + name: "filter by just subject type and subject relation", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + return filterer.FilterToSubjectFilter(&v1.SubjectFilter{ + SubjectType: "subns", + OptionalRelation: &v1.SubjectFilter_RelationFilter{ + Relation: "subrel", + }, + }) + }, + expectedSQL: "SELECT ns, object_id, relation, subject_object_id, caveat, caveat_context FROM relationtuples WHERE subject_ns = ? AND subject_relation = ?", + expectedArgs: []any{"subns", "subrel"}, + }, + { + name: "filter by just subject type and subject ID and relation", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + return filterer.FilterToSubjectFilter(&v1.SubjectFilter{ + SubjectType: "subns", + OptionalSubjectId: "subid", + OptionalRelation: &v1.SubjectFilter_RelationFilter{ + Relation: "subrel", + }, + }) + }, + expectedSQL: "SELECT ns, object_id, relation, caveat, caveat_context FROM relationtuples WHERE subject_ns = ? AND subject_object_id = ? AND subject_relation = ?", + expectedArgs: []any{"subns", "subid", "subrel"}, + }, + { + name: "filter by multiple subject types, but static subject ID", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + return filterer.FilterToSubjectFilter(&v1.SubjectFilter{ + SubjectType: "subns", + OptionalSubjectId: "subid", + }).FilterToSubjectFilter(&v1.SubjectFilter{ + SubjectType: "anothersubns", + OptionalSubjectId: "subid", + }) + }, + expectedSQL: "SELECT ns, object_id, relation, subject_ns, subject_relation, caveat, caveat_context FROM relationtuples WHERE subject_ns = ? AND subject_object_id = ? AND subject_ns = ? AND subject_object_id = ?", + expectedArgs: []any{"subns", "subid", "anothersubns", "subid"}, + }, + { + name: "multiple subjects filters with just types", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + return filterer.MustFilterWithSubjectsSelectors(datastore.SubjectsSelector{ + OptionalSubjectType: "somesubjectype", + }, datastore.SubjectsSelector{ + OptionalSubjectType: "anothersubjectype", + }) + }, + expectedSQL: "SELECT ns, object_id, relation, subject_ns, subject_object_id, subject_relation, caveat, caveat_context FROM relationtuples WHERE ((subject_ns = ?) OR (subject_ns = ?))", + expectedArgs: []any{"somesubjectype", "anothersubjectype"}, + }, + { + name: "multiple subjects filters with just types and static resource type", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + return filterer.MustFilterWithSubjectsSelectors(datastore.SubjectsSelector{ + OptionalSubjectType: "somesubjectype", + }, datastore.SubjectsSelector{ + OptionalSubjectType: "anothersubjectype", + }).FilterToResourceType("sometype") + }, + expectedSQL: "SELECT object_id, relation, subject_ns, subject_object_id, subject_relation, caveat, caveat_context FROM relationtuples WHERE ((subject_ns = ?) OR (subject_ns = ?)) AND ns = ?", + expectedArgs: []any{"somesubjectype", "anothersubjectype", "sometype"}, + }, + } + + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + schema := NewSchemaInformation( + "relationtuples", + "ns", + "object_id", + "relation", + "subject_ns", + "subject_object_id", + "subject_relation", + "caveat", + "caveat_context", + TupleComparison, + sq.Question, + ) + filterer := NewSchemaQueryFiltererForRelationshipsSelect(schema, 100) + ran := tc.run(filterer) + + var wasRun bool + fake := QueryExecutor{ + Executor: func(ctx context.Context, queryInfo QueryInfo, sql string, args []any) (datastore.RelationshipIterator, error) { + wasRun = true + require.Equal(t, tc.expectedSQL, sql) + require.Equal(t, tc.expectedArgs, args) + require.Equal(t, tc.expectedSelectingNoColumns, queryInfo.SelectingNoColumns) + require.Equal(t, tc.expectedSkipCaveats, queryInfo.SkipCaveats) + return nil, nil + }, + } + _, err := fake.ExecuteQuery(context.Background(), ran, tc.options...) + require.NoError(t, err) + require.True(t, wasRun) + }) + } +} From 767ab8dbc18c1aa302d8c14bbf4eadf19f0853ac Mon Sep 17 00:00:00 2001 From: Joseph Schorr Date: Tue, 5 Nov 2024 13:01:08 -0500 Subject: [PATCH 04/15] Change tests to use a new entrypoint for creating memdb for testing This will allow us to centrally register additional datastore validation that only runs at test time --- internal/caveats/run_test.go | 7 ++++--- .../datastore/dsfortesting/dsfortesting.go | 16 ++++++++++++++ internal/datastore/proxy/observable_test.go | 3 ++- .../proxy/relationshipintegrity_test.go | 16 +++++++------- .../proxy/schemacaching/estimatedsize_test.go | 3 ++- .../schemacaching/standardcaching_test.go | 3 ++- internal/dispatch/combined/combined_test.go | 3 ++- internal/dispatch/graph/check_test.go | 15 ++++++------- internal/dispatch/graph/expand_test.go | 3 ++- .../dispatch/graph/lookupresources2_test.go | 11 +++++----- .../dispatch/graph/lookupresources_test.go | 9 ++++---- .../dispatch/graph/lookupsubjects_test.go | 5 +++-- .../dispatch/graph/reachableresources_test.go | 21 ++++++++++--------- internal/graph/computed/computecheck_test.go | 7 ++++--- internal/graph/hints/checkhints_test.go | 3 ++- internal/namespace/aliasing_test.go | 3 ++- internal/namespace/annotate_test.go | 3 ++- internal/namespace/canonicalization_test.go | 5 +++-- internal/namespace/util_test.go | 3 ++- internal/relationships/validation_test.go | 3 ++- .../services/integrationtesting/cert_test.go | 4 ++-- .../consistencytestutil/clusteranddata.go | 3 ++- internal/services/shared/schema_test.go | 3 ++- .../steelthreadtesting/steelthread_test.go | 8 ++++++- internal/services/v1/preconditions_test.go | 3 ++- pkg/cmd/server/server_test.go | 6 +++--- pkg/typesystem/reachabilitygraph_test.go | 7 ++++--- pkg/typesystem/typesystem_test.go | 5 +++-- pkg/validationfile/loader_test.go | 6 +++--- 29 files changed, 116 insertions(+), 71 deletions(-) create mode 100644 internal/datastore/dsfortesting/dsfortesting.go diff --git a/internal/caveats/run_test.go b/internal/caveats/run_test.go index 390edb1884..2a799a1da9 100644 --- a/internal/caveats/run_test.go +++ b/internal/caveats/run_test.go @@ -9,6 +9,7 @@ import ( "github.com/stretchr/testify/require" "github.com/authzed/spicedb/internal/caveats" + "github.com/authzed/spicedb/internal/datastore/dsfortesting" "github.com/authzed/spicedb/internal/datastore/memdb" "github.com/authzed/spicedb/internal/testfixtures" "github.com/authzed/spicedb/pkg/datastore" @@ -448,7 +449,7 @@ func TestRunCaveatExpressions(t *testing.T) { t.Run(tc.name, func(t *testing.T) { req := require.New(t) - rawDS, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + rawDS, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) req.NoError(err) ds, _ := testfixtures.DatastoreFromSchemaAndTestRelationships(rawDS, ` @@ -507,7 +508,7 @@ func TestRunCaveatExpressions(t *testing.T) { func TestRunCaveatWithMissingMap(t *testing.T) { req := require.New(t) - rawDS, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + rawDS, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) req.NoError(err) ds, _ := testfixtures.DatastoreFromSchemaAndTestRelationships(rawDS, ` @@ -536,7 +537,7 @@ func TestRunCaveatWithMissingMap(t *testing.T) { func TestRunCaveatWithEmptyMap(t *testing.T) { req := require.New(t) - rawDS, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + rawDS, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) req.NoError(err) ds, _ := testfixtures.DatastoreFromSchemaAndTestRelationships(rawDS, ` diff --git a/internal/datastore/dsfortesting/dsfortesting.go b/internal/datastore/dsfortesting/dsfortesting.go new file mode 100644 index 0000000000..04e8ae10d7 --- /dev/null +++ b/internal/datastore/dsfortesting/dsfortesting.go @@ -0,0 +1,16 @@ +package dsfortesting + +import ( + "time" + + "github.com/authzed/spicedb/internal/datastore/memdb" + "github.com/authzed/spicedb/pkg/datastore" +) + +func NewMemDBDatastoreForTesting( + watchBufferLength uint16, + revisionQuantization, + gcWindow time.Duration, +) (datastore.Datastore, error) { + return memdb.NewMemdbDatastore(watchBufferLength, revisionQuantization, gcWindow) +} diff --git a/internal/datastore/proxy/observable_test.go b/internal/datastore/proxy/observable_test.go index 63a388957c..874be163aa 100644 --- a/internal/datastore/proxy/observable_test.go +++ b/internal/datastore/proxy/observable_test.go @@ -4,6 +4,7 @@ import ( "testing" "time" + "github.com/authzed/spicedb/internal/datastore/dsfortesting" "github.com/authzed/spicedb/internal/datastore/memdb" "github.com/authzed/spicedb/pkg/datastore" "github.com/authzed/spicedb/pkg/datastore/test" @@ -12,7 +13,7 @@ import ( type observableTest struct{} func (obs observableTest) New(revisionQuantization, _, gcWindow time.Duration, watchBufferLength uint16) (datastore.Datastore, error) { - db, err := memdb.NewMemdbDatastore(watchBufferLength, revisionQuantization, gcWindow) + db, err := dsfortesting.NewMemDBDatastoreForTesting(watchBufferLength, revisionQuantization, gcWindow) if err != nil { return nil, err } diff --git a/internal/datastore/proxy/relationshipintegrity_test.go b/internal/datastore/proxy/relationshipintegrity_test.go index f8f03ecc54..a591720dde 100644 --- a/internal/datastore/proxy/relationshipintegrity_test.go +++ b/internal/datastore/proxy/relationshipintegrity_test.go @@ -10,7 +10,7 @@ import ( "github.com/stretchr/testify/require" "google.golang.org/protobuf/types/known/timestamppb" - "github.com/authzed/spicedb/internal/datastore/memdb" + "github.com/authzed/spicedb/internal/datastore/dsfortesting" "github.com/authzed/spicedb/pkg/datastore" core "github.com/authzed/spicedb/pkg/proto/core/v1" "github.com/authzed/spicedb/pkg/tuple" @@ -58,7 +58,7 @@ var expiredKeyForTesting = KeyConfig{ } func TestWriteWithPredefinedIntegrity(t *testing.T) { - ds, err := memdb.NewMemdbDatastore(0, 5*time.Second, 1*time.Hour) + ds, err := dsfortesting.NewMemDBDatastoreForTesting(0, 5*time.Second, 1*time.Hour) require.NoError(t, err) pds, err := NewRelationshipIntegrityProxy(ds, DefaultKeyForTesting, nil) @@ -76,7 +76,7 @@ func TestWriteWithPredefinedIntegrity(t *testing.T) { } func TestReadWithMissingIntegrity(t *testing.T) { - ds, err := memdb.NewMemdbDatastore(0, 5*time.Second, 1*time.Hour) + ds, err := dsfortesting.NewMemDBDatastoreForTesting(0, 5*time.Second, 1*time.Hour) require.NoError(t, err) // Write a relationship to the underlying datastore without integrity information. @@ -108,7 +108,7 @@ func TestReadWithMissingIntegrity(t *testing.T) { } func TestBasicIntegrityFailureDueToInvalidHashVersion(t *testing.T) { - ds, err := memdb.NewMemdbDatastore(0, 5*time.Second, 1*time.Hour) + ds, err := dsfortesting.NewMemDBDatastoreForTesting(0, 5*time.Second, 1*time.Hour) require.NoError(t, err) pds, err := NewRelationshipIntegrityProxy(ds, DefaultKeyForTesting, nil) @@ -157,7 +157,7 @@ func TestBasicIntegrityFailureDueToInvalidHashVersion(t *testing.T) { } func TestBasicIntegrityFailureDueToInvalidHashSignature(t *testing.T) { - ds, err := memdb.NewMemdbDatastore(0, 5*time.Second, 1*time.Hour) + ds, err := dsfortesting.NewMemDBDatastoreForTesting(0, 5*time.Second, 1*time.Hour) require.NoError(t, err) pds, err := NewRelationshipIntegrityProxy(ds, DefaultKeyForTesting, nil) @@ -206,7 +206,7 @@ func TestBasicIntegrityFailureDueToInvalidHashSignature(t *testing.T) { } func TestBasicIntegrityFailureDueToWriteWithExpiredKey(t *testing.T) { - ds, err := memdb.NewMemdbDatastore(0, 5*time.Second, 1*time.Hour) + ds, err := dsfortesting.NewMemDBDatastoreForTesting(0, 5*time.Second, 1*time.Hour) require.NoError(t, err) // Create a proxy with the to-be-expired key and write some relationships. @@ -245,7 +245,7 @@ func TestBasicIntegrityFailureDueToWriteWithExpiredKey(t *testing.T) { } func TestWatchIntegrityFailureDueToInvalidHashSignature(t *testing.T) { - ds, err := memdb.NewMemdbDatastore(0, 5*time.Second, 1*time.Hour) + ds, err := dsfortesting.NewMemDBDatastoreForTesting(0, 5*time.Second, 1*time.Hour) require.NoError(t, err) headRev, err := ds.HeadRevision(context.Background()) @@ -289,7 +289,7 @@ func TestWatchIntegrityFailureDueToInvalidHashSignature(t *testing.T) { func BenchmarkQueryRelsWithIntegrity(b *testing.B) { for _, withIntegrity := range []bool{true, false} { b.Run(fmt.Sprintf("withIntegrity=%t", withIntegrity), func(b *testing.B) { - ds, err := memdb.NewMemdbDatastore(0, 5*time.Second, 1*time.Hour) + ds, err := dsfortesting.NewMemDBDatastoreForTesting(0, 5*time.Second, 1*time.Hour) require.NoError(b, err) pds, err := NewRelationshipIntegrityProxy(ds, DefaultKeyForTesting, nil) diff --git a/internal/datastore/proxy/schemacaching/estimatedsize_test.go b/internal/datastore/proxy/schemacaching/estimatedsize_test.go index 275f544bf7..f0a6db168f 100644 --- a/internal/datastore/proxy/schemacaching/estimatedsize_test.go +++ b/internal/datastore/proxy/schemacaching/estimatedsize_test.go @@ -14,6 +14,7 @@ import ( "github.com/ccoveille/go-safecast" "github.com/stretchr/testify/require" + "github.com/authzed/spicedb/internal/datastore/dsfortesting" "github.com/authzed/spicedb/internal/datastore/memdb" core "github.com/authzed/spicedb/pkg/proto/core/v1" "github.com/authzed/spicedb/pkg/validationfile" @@ -46,7 +47,7 @@ func TestEstimatedDefinitionSizes(t *testing.T) { filePath := filePath t.Run(path.Base(filePath), func(t *testing.T) { require := require.New(t) - ds, err := memdb.NewMemdbDatastore(0, 1*time.Second, memdb.DisableGC) + ds, err := dsfortesting.NewMemDBDatastoreForTesting(0, 1*time.Second, memdb.DisableGC) require.NoError(err) fullyResolved, _, err := validationfile.PopulateFromFiles(context.Background(), ds, []string{filePath}) diff --git a/internal/datastore/proxy/schemacaching/standardcaching_test.go b/internal/datastore/proxy/schemacaching/standardcaching_test.go index 2bc6e20ec1..5cdf164ceb 100644 --- a/internal/datastore/proxy/schemacaching/standardcaching_test.go +++ b/internal/datastore/proxy/schemacaching/standardcaching_test.go @@ -11,6 +11,7 @@ import ( "github.com/stretchr/testify/require" "golang.org/x/sync/errgroup" + "github.com/authzed/spicedb/internal/datastore/dsfortesting" "github.com/authzed/spicedb/internal/datastore/memdb" "github.com/authzed/spicedb/internal/datastore/proxy/proxy_test" "github.com/authzed/spicedb/internal/datastore/revisions" @@ -366,7 +367,7 @@ func TestSnapshotCachingRealDatastore(t *testing.T) { for _, tc := range tcs { tc := tc t.Run(tc.name, func(t *testing.T) { - rawDS, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + rawDS, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) require.NoError(t, err) ctx := context.Background() diff --git a/internal/dispatch/combined/combined_test.go b/internal/dispatch/combined/combined_test.go index a15e3196ff..79663e23f0 100644 --- a/internal/dispatch/combined/combined_test.go +++ b/internal/dispatch/combined/combined_test.go @@ -6,6 +6,7 @@ import ( "github.com/stretchr/testify/require" + "github.com/authzed/spicedb/internal/datastore/dsfortesting" "github.com/authzed/spicedb/internal/datastore/memdb" datastoremw "github.com/authzed/spicedb/internal/middleware/datastore" "github.com/authzed/spicedb/internal/testfixtures" @@ -22,7 +23,7 @@ func TestCombinedRecursiveCall(t *testing.T) { ctx := datastoremw.ContextWithHandle(context.Background()) - rawDS, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + rawDS, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) require.NoError(t, err) ds, revision := testfixtures.DatastoreFromSchemaAndTestRelationships(rawDS, ` diff --git a/internal/dispatch/graph/check_test.go b/internal/dispatch/graph/check_test.go index 189337c207..5c75adc40a 100644 --- a/internal/dispatch/graph/check_test.go +++ b/internal/dispatch/graph/check_test.go @@ -9,6 +9,7 @@ import ( "github.com/stretchr/testify/require" "github.com/authzed/spicedb/internal/datastore/common" + "github.com/authzed/spicedb/internal/datastore/dsfortesting" "github.com/authzed/spicedb/internal/datastore/memdb" "github.com/authzed/spicedb/internal/dispatch" "github.com/authzed/spicedb/internal/dispatch/caching" @@ -153,7 +154,7 @@ func TestMaxDepth(t *testing.T) { t.Parallel() require := require.New(t) - rawDS, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + rawDS, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) require.NoError(err) ds, _ := testfixtures.StandardDatastoreWithSchema(rawDS, require) @@ -1322,7 +1323,7 @@ func TestCheckPermissionOverSchema(t *testing.T) { dispatcher := NewLocalOnlyDispatcher(10, 100) - ds, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + ds, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) require.NoError(err) ds, revision := testfixtures.DatastoreFromSchemaAndTestRelationships(ds, tc.schema, tc.relationships, require) @@ -1823,7 +1824,7 @@ func TestCheckWithHints(t *testing.T) { dispatcher := NewLocalOnlyDispatcher(10, 100) - ds, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + ds, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) require.NoError(err) ds, revision := testfixtures.DatastoreFromSchemaAndTestRelationships(ds, tc.schema, tc.relationships, require) @@ -1863,7 +1864,7 @@ func TestCheckHintsPartialApplication(t *testing.T) { dispatcher := NewLocalOnlyDispatcher(10, 100) - ds, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + ds, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) require.NoError(err) ds, revision := testfixtures.DatastoreFromSchemaAndTestRelationships(ds, ` @@ -1909,7 +1910,7 @@ func TestCheckHintsPartialApplicationOverArrow(t *testing.T) { dispatcher := NewLocalOnlyDispatcher(10, 100) - ds, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + ds, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) require.NoError(err) ds, revision := testfixtures.DatastoreFromSchemaAndTestRelationships(ds, ` @@ -1955,7 +1956,7 @@ func TestCheckHintsPartialApplicationOverArrow(t *testing.T) { } func newLocalDispatcherWithConcurrencyLimit(t testing.TB, concurrencyLimit uint16) (context.Context, dispatch.Dispatcher, datastore.Revision) { - rawDS, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + rawDS, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) require.NoError(t, err) ds, revision := testfixtures.StandardDatastoreWithData(rawDS, require.New(t)) @@ -1977,7 +1978,7 @@ func newLocalDispatcher(t testing.TB) (context.Context, dispatch.Dispatcher, dat } func newLocalDispatcherWithSchemaAndRels(t testing.TB, schema string, rels []tuple.Relationship) (context.Context, dispatch.Dispatcher, datastore.Revision) { - rawDS, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + rawDS, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) require.NoError(t, err) ds, revision := testfixtures.DatastoreFromSchemaAndTestRelationships(rawDS, schema, rels, require.New(t)) diff --git a/internal/dispatch/graph/expand_test.go b/internal/dispatch/graph/expand_test.go index 051199bacb..d48adaf554 100644 --- a/internal/dispatch/graph/expand_test.go +++ b/internal/dispatch/graph/expand_test.go @@ -15,6 +15,7 @@ import ( "google.golang.org/protobuf/testing/protocmp" "github.com/authzed/spicedb/internal/datastore/common" + "github.com/authzed/spicedb/internal/datastore/dsfortesting" "github.com/authzed/spicedb/internal/datastore/memdb" expand "github.com/authzed/spicedb/internal/graph" datastoremw "github.com/authzed/spicedb/internal/middleware/datastore" @@ -280,7 +281,7 @@ func TestMaxDepthExpand(t *testing.T) { require := require.New(t) - rawDS, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + rawDS, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) require.NoError(err) ds, _ := testfixtures.StandardDatastoreWithSchema(rawDS, require) diff --git a/internal/dispatch/graph/lookupresources2_test.go b/internal/dispatch/graph/lookupresources2_test.go index 7d7a1851d9..1a3b702209 100644 --- a/internal/dispatch/graph/lookupresources2_test.go +++ b/internal/dispatch/graph/lookupresources2_test.go @@ -13,6 +13,7 @@ import ( "go.uber.org/goleak" "google.golang.org/protobuf/types/known/structpb" + "github.com/authzed/spicedb/internal/datastore/dsfortesting" "github.com/authzed/spicedb/internal/datastore/memdb" "github.com/authzed/spicedb/internal/dispatch" datastoremw "github.com/authzed/spicedb/internal/middleware/datastore" @@ -310,7 +311,7 @@ func TestMaxDepthLookup2(t *testing.T) { t.Parallel() require := require.New(t) - rawDS, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + rawDS, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) require.NoError(err) ds, revision := testfixtures.StandardDatastoreWithData(rawDS, require) @@ -754,7 +755,7 @@ func TestLookupResources2OverSchemaWithCursors(t *testing.T) { dispatcher := NewLocalOnlyDispatcher(10, 100) - ds, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + ds, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) require.NoError(err) ds, revision := testfixtures.DatastoreFromSchemaAndTestRelationships(ds, tc.schema, tc.relationships, require) @@ -830,7 +831,7 @@ func TestLookupResources2ImmediateTimeout(t *testing.T) { require := require.New(t) - rawDS, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + rawDS, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) require.NoError(err) ds, revision := testfixtures.StandardDatastoreWithData(rawDS, require) @@ -865,7 +866,7 @@ func TestLookupResources2WithError(t *testing.T) { require := require.New(t) - rawDS, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + rawDS, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) require.NoError(err) ds, revision := testfixtures.StandardDatastoreWithData(rawDS, require) @@ -1341,7 +1342,7 @@ func TestLookupResources2EnsureCheckHints(t *testing.T) { t.Run(tc.name, func(t *testing.T) { require := require.New(t) - rawDS, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + rawDS, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) require.NoError(err) ds, revision := testfixtures.DatastoreFromSchemaAndTestRelationships(rawDS, tc.schema, tc.relationships, require) diff --git a/internal/dispatch/graph/lookupresources_test.go b/internal/dispatch/graph/lookupresources_test.go index ec0f731c23..9766cbedde 100644 --- a/internal/dispatch/graph/lookupresources_test.go +++ b/internal/dispatch/graph/lookupresources_test.go @@ -10,6 +10,7 @@ import ( "github.com/ccoveille/go-safecast" "github.com/stretchr/testify/require" + "github.com/authzed/spicedb/internal/datastore/dsfortesting" "github.com/authzed/spicedb/internal/datastore/memdb" "github.com/authzed/spicedb/internal/dispatch" datastoremw "github.com/authzed/spicedb/internal/middleware/datastore" @@ -302,7 +303,7 @@ func TestMaxDepthLookup(t *testing.T) { t.Parallel() require := require.New(t) - rawDS, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + rawDS, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) require.NoError(err) ds, revision := testfixtures.StandardDatastoreWithData(rawDS, require) @@ -606,7 +607,7 @@ func TestLookupResourcesOverSchemaWithCursors(t *testing.T) { dispatcher := NewLocalOnlyDispatcher(10, 100) - ds, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + ds, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) require.NoError(err) ds, revision := testfixtures.DatastoreFromSchemaAndTestRelationships(ds, tc.schema, tc.relationships, require) @@ -664,7 +665,7 @@ func TestLookupResourcesImmediateTimeout(t *testing.T) { require := require.New(t) - rawDS, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + rawDS, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) require.NoError(err) ds, revision := testfixtures.StandardDatastoreWithData(rawDS, require) @@ -697,7 +698,7 @@ func TestLookupResourcesWithError(t *testing.T) { require := require.New(t) - rawDS, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + rawDS, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) require.NoError(err) ds, revision := testfixtures.StandardDatastoreWithData(rawDS, require) diff --git a/internal/dispatch/graph/lookupsubjects_test.go b/internal/dispatch/graph/lookupsubjects_test.go index d441749a12..23c1d84038 100644 --- a/internal/dispatch/graph/lookupsubjects_test.go +++ b/internal/dispatch/graph/lookupsubjects_test.go @@ -10,6 +10,7 @@ import ( "github.com/authzed/spicedb/internal/caveats" "github.com/authzed/spicedb/internal/datastore/common" + "github.com/authzed/spicedb/internal/datastore/dsfortesting" "github.com/authzed/spicedb/internal/datastore/memdb" "github.com/authzed/spicedb/internal/dispatch" log "github.com/authzed/spicedb/internal/logging" @@ -194,7 +195,7 @@ func TestLookupSubjectsMaxDepth(t *testing.T) { t.Parallel() require := require.New(t) - rawDS, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + rawDS, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) require.NoError(err) ds, _ := testfixtures.StandardDatastoreWithSchema(rawDS, require) @@ -997,7 +998,7 @@ func TestLookupSubjectsOverSchema(t *testing.T) { dispatcher := NewLocalOnlyDispatcher(10, 100) - ds, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + ds, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) require.NoError(err) ds, revision := testfixtures.DatastoreFromSchemaAndTestRelationships(ds, tc.schema, tc.relationships, require) diff --git a/internal/dispatch/graph/reachableresources_test.go b/internal/dispatch/graph/reachableresources_test.go index e4442b2a29..759022cb6c 100644 --- a/internal/dispatch/graph/reachableresources_test.go +++ b/internal/dispatch/graph/reachableresources_test.go @@ -14,6 +14,7 @@ import ( "github.com/stretchr/testify/require" "golang.org/x/sync/errgroup" + "github.com/authzed/spicedb/internal/datastore/dsfortesting" "github.com/authzed/spicedb/internal/datastore/memdb" "github.com/authzed/spicedb/internal/dispatch" "github.com/authzed/spicedb/internal/dispatch/caching" @@ -257,7 +258,7 @@ func BenchmarkReachableResources(b *testing.B) { ) require := require.New(b) - rawDS, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + rawDS, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) require.NoError(err) ds, revision := testfixtures.StandardDatastoreWithData(rawDS, require) @@ -567,7 +568,7 @@ func TestCaveatedReachableResources(t *testing.T) { dispatcher := NewLocalOnlyDispatcher(10, 100) - ds, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + ds, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) require.NoError(err) ds, revision := testfixtures.DatastoreFromSchemaAndTestRelationships(ds, tc.schema, tc.relationships, require) @@ -636,7 +637,7 @@ func TestReachableResourcesWithConsistencyLimitOf1(t *testing.T) { func TestReachableResourcesMultipleEntrypointEarlyCancel(t *testing.T) { t.Parallel() - rawDS, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + rawDS, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) require.NoError(t, err) testRels := make([]tuple.Relationship, 0) @@ -712,7 +713,7 @@ func TestReachableResourcesMultipleEntrypointEarlyCancel(t *testing.T) { func TestReachableResourcesCursors(t *testing.T) { t.Parallel() - rawDS, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + rawDS, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) require.NoError(t, err) testRels := make([]tuple.Relationship, 0) @@ -828,7 +829,7 @@ func TestReachableResourcesCursors(t *testing.T) { func TestReachableResourcesPaginationWithLimit(t *testing.T) { t.Parallel() - rawDS, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + rawDS, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) require.NoError(t, err) testRels := make([]tuple.Relationship, 0) @@ -909,7 +910,7 @@ func TestReachableResourcesPaginationWithLimit(t *testing.T) { func TestReachableResourcesWithQueryError(t *testing.T) { t.Parallel() - rawDS, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + rawDS, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) require.NoError(t, err) testRels := make([]tuple.Relationship, 0) @@ -1209,7 +1210,7 @@ func TestReachableResourcesOverSchema(t *testing.T) { dispatcher := NewLocalOnlyDispatcher(10, 100) - ds, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + ds, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) require.NoError(err) ds, revision := testfixtures.DatastoreFromSchemaAndTestRelationships(ds, tc.schema, tc.relationships, require) @@ -1269,7 +1270,7 @@ func TestReachableResourcesOverSchema(t *testing.T) { func TestReachableResourcesWithPreCancelation(t *testing.T) { t.Parallel() - rawDS, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + rawDS, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) require.NoError(t, err) testRels := make([]tuple.Relationship, 0) @@ -1323,7 +1324,7 @@ func TestReachableResourcesWithPreCancelation(t *testing.T) { func TestReachableResourcesWithUnexpectedContextCancelation(t *testing.T) { t.Parallel() - rawDS, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + rawDS, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) require.NoError(t, err) testRels := make([]tuple.Relationship, 0) @@ -1407,7 +1408,7 @@ func (cr *cancelingReader) ReverseQueryRelationships( func TestReachableResourcesWithCachingInParallelTest(t *testing.T) { t.Parallel() - rawDS, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + rawDS, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) require.NoError(t, err) testRels := make([]tuple.Relationship, 0) diff --git a/internal/graph/computed/computecheck_test.go b/internal/graph/computed/computecheck_test.go index c00102fc34..0eea7e9bee 100644 --- a/internal/graph/computed/computecheck_test.go +++ b/internal/graph/computed/computecheck_test.go @@ -8,6 +8,7 @@ import ( "google.golang.org/protobuf/types/known/structpb" + "github.com/authzed/spicedb/internal/datastore/dsfortesting" "github.com/authzed/spicedb/internal/datastore/memdb" "github.com/authzed/spicedb/internal/dispatch/graph" "github.com/authzed/spicedb/internal/graph/computed" @@ -805,7 +806,7 @@ func TestComputeCheckWithCaveats(t *testing.T) { for _, tt := range testCases { tt := tt t.Run(tt.name, func(t *testing.T) { - ds, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + ds, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) require.NoError(t, err) dispatch := graph.NewLocalOnlyDispatcher(10, 100) @@ -855,7 +856,7 @@ func TestComputeCheckWithCaveats(t *testing.T) { } func TestComputeCheckError(t *testing.T) { - ds, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + ds, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) require.NoError(t, err) dispatch := graph.NewLocalOnlyDispatcher(10, 100) @@ -878,7 +879,7 @@ func TestComputeCheckError(t *testing.T) { } func TestComputeBulkCheck(t *testing.T) { - ds, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + ds, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) require.NoError(t, err) dispatch := graph.NewLocalOnlyDispatcher(10, 100) diff --git a/internal/graph/hints/checkhints_test.go b/internal/graph/hints/checkhints_test.go index 1513fc0f01..bcce7a3b4e 100644 --- a/internal/graph/hints/checkhints_test.go +++ b/internal/graph/hints/checkhints_test.go @@ -6,6 +6,7 @@ import ( "github.com/stretchr/testify/require" + "github.com/authzed/spicedb/internal/datastore/dsfortesting" "github.com/authzed/spicedb/internal/datastore/memdb" datastoremw "github.com/authzed/spicedb/internal/middleware/datastore" "github.com/authzed/spicedb/pkg/datastore" @@ -97,7 +98,7 @@ func TestHintForEntrypoint(t *testing.T) { func buildReachabilityGraph(t *testing.T, schema string) *typesystem.ReachabilityGraph { require := require.New(t) - ds, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + ds, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) require.NoError(err) ctx := datastoremw.ContextWithDatastore(context.Background(), ds) diff --git a/internal/namespace/aliasing_test.go b/internal/namespace/aliasing_test.go index ebe2649d2a..8483c9b5ef 100644 --- a/internal/namespace/aliasing_test.go +++ b/internal/namespace/aliasing_test.go @@ -9,6 +9,7 @@ import ( core "github.com/authzed/spicedb/pkg/proto/core/v1" "github.com/authzed/spicedb/pkg/typesystem" + "github.com/authzed/spicedb/internal/datastore/dsfortesting" "github.com/authzed/spicedb/internal/datastore/memdb" ns "github.com/authzed/spicedb/pkg/namespace" ) @@ -195,7 +196,7 @@ func TestAliasing(t *testing.T) { t.Run(tc.name, func(t *testing.T) { require := require.New(t) - ds, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + ds, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) require.NoError(err) lastRevision, err := ds.HeadRevision(context.Background()) diff --git a/internal/namespace/annotate_test.go b/internal/namespace/annotate_test.go index e1deb67b28..3ada546a37 100644 --- a/internal/namespace/annotate_test.go +++ b/internal/namespace/annotate_test.go @@ -6,6 +6,7 @@ import ( "github.com/stretchr/testify/require" + "github.com/authzed/spicedb/internal/datastore/dsfortesting" "github.com/authzed/spicedb/internal/datastore/memdb" "github.com/authzed/spicedb/pkg/schemadsl/compiler" "github.com/authzed/spicedb/pkg/schemadsl/input" @@ -15,7 +16,7 @@ import ( func TestAnnotateNamespace(t *testing.T) { require := require.New(t) - ds, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + ds, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) require.NoError(err) compiled, err := compiler.Compile(compiler.InputSchema{ diff --git a/internal/namespace/canonicalization_test.go b/internal/namespace/canonicalization_test.go index 3250faf2e1..c8c313966a 100644 --- a/internal/namespace/canonicalization_test.go +++ b/internal/namespace/canonicalization_test.go @@ -10,6 +10,7 @@ import ( core "github.com/authzed/spicedb/pkg/proto/core/v1" "github.com/authzed/spicedb/pkg/typesystem" + "github.com/authzed/spicedb/internal/datastore/dsfortesting" "github.com/authzed/spicedb/internal/datastore/memdb" ns "github.com/authzed/spicedb/pkg/namespace" "github.com/authzed/spicedb/pkg/schemadsl/compiler" @@ -425,7 +426,7 @@ func TestCanonicalization(t *testing.T) { t.Run(tc.name, func(t *testing.T) { require := require.New(t) - ds, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + ds, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) require.NoError(err) ctx := context.Background() @@ -552,7 +553,7 @@ func TestCanonicalizationComparison(t *testing.T) { t.Run(tc.name, func(t *testing.T) { require := require.New(t) - ds, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + ds, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) require.NoError(err) ctx := context.Background() diff --git a/internal/namespace/util_test.go b/internal/namespace/util_test.go index 7865860244..82a69d3a4b 100644 --- a/internal/namespace/util_test.go +++ b/internal/namespace/util_test.go @@ -7,6 +7,7 @@ import ( "github.com/stretchr/testify/require" + "github.com/authzed/spicedb/internal/datastore/dsfortesting" "github.com/authzed/spicedb/internal/datastore/memdb" "github.com/authzed/spicedb/internal/namespace" "github.com/authzed/spicedb/internal/testfixtures" @@ -162,7 +163,7 @@ func TestCheckNamespaceAndRelations(t *testing.T) { for _, tc := range tcs { t.Run(tc.name, func(t *testing.T) { req := require.New(t) - rawDS, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + rawDS, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) req.NoError(err) ds, _ := testfixtures.DatastoreFromSchemaAndTestRelationships(rawDS, tc.schema, nil, req) diff --git a/internal/relationships/validation_test.go b/internal/relationships/validation_test.go index ef645c50fe..6fc5c3a9ab 100644 --- a/internal/relationships/validation_test.go +++ b/internal/relationships/validation_test.go @@ -6,6 +6,7 @@ import ( "github.com/stretchr/testify/require" + "github.com/authzed/spicedb/internal/datastore/dsfortesting" "github.com/authzed/spicedb/internal/datastore/memdb" "github.com/authzed/spicedb/internal/testfixtures" core "github.com/authzed/spicedb/pkg/proto/core/v1" @@ -311,7 +312,7 @@ func TestValidateRelationshipOperations(t *testing.T) { t.Parallel() req := require.New(t) - ds, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + ds, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) req.NoError(err) uds, rev := testfixtures.DatastoreFromSchemaAndTestRelationships(ds, tc.schema, nil, req) diff --git a/internal/services/integrationtesting/cert_test.go b/internal/services/integrationtesting/cert_test.go index 329d6e114a..465122a34e 100644 --- a/internal/services/integrationtesting/cert_test.go +++ b/internal/services/integrationtesting/cert_test.go @@ -23,7 +23,7 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/backoff" - "github.com/authzed/spicedb/internal/datastore/memdb" + "github.com/authzed/spicedb/internal/datastore/dsfortesting" "github.com/authzed/spicedb/internal/dispatch/graph" datastoremw "github.com/authzed/spicedb/internal/middleware/datastore" "github.com/authzed/spicedb/internal/middleware/servicespecific" @@ -115,7 +115,7 @@ func TestCertRotation(t *testing.T) { require.NoError(t, certFile.Close()) // start a server with an initial set of certs - emptyDS, err := memdb.NewMemdbDatastore(0, 10, time.Duration(90_000_000_000_000)) + emptyDS, err := dsfortesting.NewMemDBDatastoreForTesting(0, 10, time.Duration(90_000_000_000_000)) require.NoError(t, err) ds, revision := tf.StandardDatastoreWithData(emptyDS, require.New(t)) ctx, cancel := context.WithCancel(context.Background()) diff --git a/internal/services/integrationtesting/consistencytestutil/clusteranddata.go b/internal/services/integrationtesting/consistencytestutil/clusteranddata.go index cb38dbf365..d273e7bd86 100644 --- a/internal/services/integrationtesting/consistencytestutil/clusteranddata.go +++ b/internal/services/integrationtesting/consistencytestutil/clusteranddata.go @@ -10,6 +10,7 @@ import ( "github.com/stretchr/testify/require" "google.golang.org/grpc" + "github.com/authzed/spicedb/internal/datastore/dsfortesting" "github.com/authzed/spicedb/internal/datastore/memdb" "github.com/authzed/spicedb/internal/dispatch/caching" "github.com/authzed/spicedb/internal/dispatch/graph" @@ -38,7 +39,7 @@ type ConsistencyClusterAndData struct { func LoadDataAndCreateClusterForTesting(t *testing.T, consistencyTestFilePath string, revisionDelta time.Duration, additionalServerOptions ...server.ConfigOption) ConsistencyClusterAndData { require := require.New(t) - ds, err := memdb.NewMemdbDatastore(0, revisionDelta, memdb.DisableGC) + ds, err := dsfortesting.NewMemDBDatastoreForTesting(0, revisionDelta, memdb.DisableGC) require.NoError(err) return BuildDataAndCreateClusterForTesting(t, consistencyTestFilePath, ds, additionalServerOptions...) diff --git a/internal/services/shared/schema_test.go b/internal/services/shared/schema_test.go index d2efcad058..983fd35895 100644 --- a/internal/services/shared/schema_test.go +++ b/internal/services/shared/schema_test.go @@ -6,6 +6,7 @@ import ( "github.com/stretchr/testify/require" + "github.com/authzed/spicedb/internal/datastore/dsfortesting" "github.com/authzed/spicedb/internal/datastore/memdb" "github.com/authzed/spicedb/internal/testfixtures" "github.com/authzed/spicedb/pkg/datastore" @@ -292,7 +293,7 @@ func TestApplySchemaChanges(t *testing.T) { for _, tc := range tcs { t.Run(tc.name, func(t *testing.T) { require := require.New(t) - rawDS, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + rawDS, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) require.NoError(err) // Write the initial schema. diff --git a/internal/services/steelthreadtesting/steelthread_test.go b/internal/services/steelthreadtesting/steelthread_test.go index 79fa8211ef..b34485c489 100644 --- a/internal/services/steelthreadtesting/steelthread_test.go +++ b/internal/services/steelthreadtesting/steelthread_test.go @@ -17,7 +17,7 @@ import ( v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" - "github.com/authzed/spicedb/internal/datastore/memdb" + "github.com/authzed/spicedb/internal/datastore/dsfortesting" "github.com/authzed/spicedb/internal/testserver" testdatastore "github.com/authzed/spicedb/internal/testserver/datastore" "github.com/authzed/spicedb/internal/testserver/datastore/config" @@ -31,7 +31,13 @@ const defaultConnBufferSize = humanize.MiByte func TestMemdbSteelThreads(t *testing.T) { for _, tc := range steelThreadTestCases { t.Run(tc.name, func(t *testing.T) { +<<<<<<< HEAD emptyDS, err := memdb.NewMemdbDatastore(0, 5*time.Second, 2*time.Hour) +======= + t.Parallel() + + emptyDS, err := dsfortesting.NewMemDBDatastoreForTesting(0, 5*time.Second, 2*time.Hour) +>>>>>>> 963e4a60 (Change tests to use a new entrypoint for creating memdb for testing) require.NoError(t, err) runSteelThreadTest(t, tc, emptyDS) diff --git a/internal/services/v1/preconditions_test.go b/internal/services/v1/preconditions_test.go index 88ab71b7a0..3ac10366ce 100644 --- a/internal/services/v1/preconditions_test.go +++ b/internal/services/v1/preconditions_test.go @@ -7,6 +7,7 @@ import ( v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" "github.com/stretchr/testify/require" + "github.com/authzed/spicedb/internal/datastore/dsfortesting" "github.com/authzed/spicedb/internal/datastore/memdb" "github.com/authzed/spicedb/internal/testfixtures" "github.com/authzed/spicedb/pkg/datastore" @@ -32,7 +33,7 @@ var prefixNoMatch = &v1.RelationshipFilter{ func TestPreconditions(t *testing.T) { require := require.New(t) - uninitialized, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + uninitialized, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) require.NoError(err) ds, _ := testfixtures.StandardDatastoreWithData(uninitialized, require) diff --git a/pkg/cmd/server/server_test.go b/pkg/cmd/server/server_test.go index dadcee86ad..b36e405da5 100644 --- a/pkg/cmd/server/server_test.go +++ b/pkg/cmd/server/server_test.go @@ -7,7 +7,7 @@ import ( "testing" "time" - "github.com/authzed/spicedb/internal/datastore/memdb" + "github.com/authzed/spicedb/internal/datastore/dsfortesting" "github.com/authzed/spicedb/internal/logging" "github.com/authzed/spicedb/pkg/cmd/datastore" "github.com/authzed/spicedb/pkg/cmd/util" @@ -26,7 +26,7 @@ func TestServerGracefulTermination(t *testing.T) { defer goleak.VerifyNone(t, append(testutil.GoLeakIgnores(), goleak.IgnoreCurrent())...) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - ds, err := memdb.NewMemdbDatastore(0, 1*time.Second, 10*time.Second) + ds, err := dsfortesting.NewMemDBDatastoreForTesting(0, 1*time.Second, 10*time.Second) require.NoError(t, err) c := ConfigWithOptions( @@ -164,7 +164,7 @@ func TestServerGracefulTerminationOnError(t *testing.T) { defer goleak.VerifyNone(t, append(testutil.GoLeakIgnores(), goleak.IgnoreCurrent())...) ctx, cancel := context.WithCancel(context.Background()) - ds, err := memdb.NewMemdbDatastore(0, 1*time.Second, 10*time.Second) + ds, err := dsfortesting.NewMemDBDatastoreForTesting(0, 1*time.Second, 10*time.Second) require.NoError(t, err) c := ConfigWithOptions(&Config{ diff --git a/pkg/typesystem/reachabilitygraph_test.go b/pkg/typesystem/reachabilitygraph_test.go index 8ca9386d09..02dfdf80ca 100644 --- a/pkg/typesystem/reachabilitygraph_test.go +++ b/pkg/typesystem/reachabilitygraph_test.go @@ -7,6 +7,7 @@ import ( "github.com/stretchr/testify/require" + "github.com/authzed/spicedb/internal/datastore/dsfortesting" "github.com/authzed/spicedb/internal/datastore/memdb" datastoremw "github.com/authzed/spicedb/internal/middleware/datastore" "github.com/authzed/spicedb/pkg/datastore" @@ -206,7 +207,7 @@ func TestRelationsEncounteredForSubject(t *testing.T) { t.Run(tc.name, func(t *testing.T) { require := require.New(t) - ds, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + ds, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) require.NoError(err) ctx := datastoremw.ContextWithDatastore(context.Background(), ds) @@ -575,7 +576,7 @@ func TestRelationsEncounteredForResource(t *testing.T) { t.Run(tc.name, func(t *testing.T) { require := require.New(t) - ds, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + ds, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) require.NoError(err) ctx := datastoremw.ContextWithDatastore(context.Background(), ds) @@ -1186,7 +1187,7 @@ func TestReachabilityGraph(t *testing.T) { t.Run(tc.name, func(t *testing.T) { require := require.New(t) - ds, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + ds, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) require.NoError(err) ctx := datastoremw.ContextWithDatastore(context.Background(), ds) diff --git a/pkg/typesystem/typesystem_test.go b/pkg/typesystem/typesystem_test.go index 3f576b38da..b865d24386 100644 --- a/pkg/typesystem/typesystem_test.go +++ b/pkg/typesystem/typesystem_test.go @@ -10,6 +10,7 @@ import ( "github.com/authzed/spicedb/pkg/genutil/mapz" core "github.com/authzed/spicedb/pkg/proto/core/v1" + "github.com/authzed/spicedb/internal/datastore/dsfortesting" "github.com/authzed/spicedb/internal/datastore/memdb" datastoremw "github.com/authzed/spicedb/internal/middleware/datastore" "github.com/authzed/spicedb/pkg/caveats" @@ -416,7 +417,7 @@ func TestTypeSystem(t *testing.T) { t.Run(tc.name, func(t *testing.T) { require := require.New(t) - ds, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + ds, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) require.NoError(err) ctx := context.Background() @@ -937,7 +938,7 @@ func TestTypeSystemAccessors(t *testing.T) { t.Run(tc.name, func(t *testing.T) { require := require.New(t) - ds, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC) + ds, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) require.NoError(err) ctx := datastoremw.ContextWithDatastore(context.Background(), ds) diff --git a/pkg/validationfile/loader_test.go b/pkg/validationfile/loader_test.go index 09741be873..fc1a9e67cc 100644 --- a/pkg/validationfile/loader_test.go +++ b/pkg/validationfile/loader_test.go @@ -5,7 +5,7 @@ import ( "sort" "testing" - "github.com/authzed/spicedb/internal/datastore/memdb" + "github.com/authzed/spicedb/internal/datastore/dsfortesting" "github.com/authzed/spicedb/internal/datastore/proxy/proxy_test" "github.com/authzed/spicedb/pkg/datastore" "github.com/authzed/spicedb/pkg/datastore/options" @@ -127,7 +127,7 @@ func TestPopulateFromFiles(t *testing.T) { tt := tt t.Run(tt.name, func(t *testing.T) { require := require.New(t) - ds, err := memdb.NewMemdbDatastore(0, 0, 0) + ds, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, 0) require.NoError(err) parsed, _, err := PopulateFromFiles(context.Background(), ds, tt.filePaths) @@ -153,7 +153,7 @@ func TestPopulateFromFiles(t *testing.T) { func TestPopulationChunking(t *testing.T) { require := require.New(t) - ds, err := memdb.NewMemdbDatastore(0, 0, 0) + ds, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, 0) require.NoError(err) cs := txCountingDatastore{delegate: ds} From 55eb074f994d3a79135eee883386b1d18b982ea7 Mon Sep 17 00:00:00 2001 From: Joseph Schorr Date: Tue, 5 Nov 2024 13:26:17 -0500 Subject: [PATCH 05/15] Add validation test for elision of columns This validation test acts as a proxy in the memdb testing datastore and validates that the column elision code (which *isn't* used in memdb) matches the static fields to the values returned for all relationships loaded --- .../datastore/dsfortesting/dsfortesting.go | 139 +++++++++++++++++- 1 file changed, 138 insertions(+), 1 deletion(-) diff --git a/internal/datastore/dsfortesting/dsfortesting.go b/internal/datastore/dsfortesting/dsfortesting.go index 04e8ae10d7..42fad1209a 100644 --- a/internal/datastore/dsfortesting/dsfortesting.go +++ b/internal/datastore/dsfortesting/dsfortesting.go @@ -1,16 +1,153 @@ package dsfortesting import ( + "context" + "fmt" "time" + sq "github.com/Masterminds/squirrel" + + "github.com/authzed/spicedb/internal/datastore/common" "github.com/authzed/spicedb/internal/datastore/memdb" "github.com/authzed/spicedb/pkg/datastore" + "github.com/authzed/spicedb/pkg/datastore/options" + "github.com/authzed/spicedb/pkg/tuple" ) +// NewMemDBDatastoreForTesting creates a new in-memory datastore for testing. +// This is a convenience function that wraps the creation of a new MemDB datastore, +// and injects additional proxies for validation at test time. +// NOTE: These additional proxies are not performant for use in production (but then, +// neither is memdb) func NewMemDBDatastoreForTesting( watchBufferLength uint16, revisionQuantization, gcWindow time.Duration, ) (datastore.Datastore, error) { - return memdb.NewMemdbDatastore(watchBufferLength, revisionQuantization, gcWindow) + ds, err := memdb.NewMemdbDatastore(watchBufferLength, revisionQuantization, gcWindow) + if err != nil { + return nil, err + } + + return validatingDatastore{ds}, nil +} + +type validatingDatastore struct { + datastore.Datastore +} + +func (vds validatingDatastore) SnapshotReader(rev datastore.Revision) datastore.Reader { + return validatingReader{vds.Datastore.SnapshotReader(rev)} +} + +type validatingReader struct { + datastore.Reader +} + +func (vr validatingReader) QueryRelationships( + ctx context.Context, + filter datastore.RelationshipsFilter, + options ...options.QueryOptionsOption, +) (datastore.RelationshipIterator, error) { + schema := common.NewSchemaInformation( + "relationtuples", + "ns", + "object_id", + "relation", + "subject_ns", + "subject_object_id", + "subject_relation", + "caveat", + "caveat_context", + common.TupleComparison, + sq.Question, + ) + + qBuilder, err := common.NewSchemaQueryFiltererForRelationshipsSelect(schema, 100). + FilterWithRelationshipsFilter(filter) + if err != nil { + return nil, err + } + + // Run the filter through the common SQL ellison system and ensure that any + // relationships return have values matching the static fields, if applicable. + var queryInfo *common.QueryInfo + executor := common.QueryExecutor{ + Executor: func(ctx context.Context, qi common.QueryInfo, sql string, args []any) (datastore.RelationshipIterator, error) { + queryInfo = &qi + return nil, nil + }, + } + + _, _ = executor.ExecuteQuery(ctx, qBuilder, options...) + if queryInfo == nil { + return nil, fmt.Errorf("no query info returned") + } + + checkStaticField := func(returnedValue string, fieldName string) error { + if found, ok := queryInfo.FilteringValues[fieldName]; ok && found.SingleValue != nil { + if returnedValue != *found.SingleValue { + return fmt.Errorf("static field `%s` does not match expected value `%s`: `%s", fieldName, returnedValue, *found.SingleValue) + } + } + + return nil + } + + // Run the actual query on the memdb instance. + iter, err := vr.Reader.QueryRelationships(ctx, filter, options...) + if err != nil { + return nil, err + } + + return func(yield func(tuple.Relationship, error) bool) { + for rel, err := range iter { + if err != nil { + if !yield(rel, err) { + return + } + continue + } + + if err := checkStaticField(rel.Resource.ObjectType, "ns"); err != nil { + if !yield(rel, err) { + return + } + } + + if err := checkStaticField(rel.Resource.ObjectID, "object_id"); err != nil { + if !yield(rel, err) { + return + } + } + + if err := checkStaticField(rel.Resource.Relation, "relation"); err != nil { + if !yield(rel, err) { + return + } + } + + if err := checkStaticField(rel.Subject.ObjectType, "subject_ns"); err != nil { + if !yield(rel, err) { + return + } + } + + if err := checkStaticField(rel.Subject.ObjectID, "subject_object_id"); err != nil { + if !yield(rel, err) { + return + } + } + + if err := checkStaticField(rel.Subject.Relation, "subject_relation"); err != nil { + if !yield(rel, err) { + return + } + } + + if !yield(rel, err) { + return + } + } + }, nil } From 8abe931bc2088087dba24b35f0d4991e21dd1ab8 Mon Sep 17 00:00:00 2001 From: Joseph Schorr Date: Fri, 22 Nov 2024 18:21:08 +0000 Subject: [PATCH 06/15] Move column elision behind an experimental flag --- internal/datastore/common/relationships.go | 10 ++++- internal/datastore/common/sql.go | 42 ++++++++++++++----- internal/datastore/common/sql_test.go | 4 ++ internal/datastore/crdb/crdb.go | 1 + internal/datastore/crdb/options.go | 15 +++++++ .../datastore/dsfortesting/dsfortesting.go | 3 ++ internal/datastore/mysql/datastore.go | 1 + internal/datastore/mysql/options.go | 15 +++++++ internal/datastore/postgres/options.go | 15 +++++++ internal/datastore/postgres/postgres.go | 1 + internal/datastore/spanner/options.go | 16 +++++++ internal/datastore/spanner/reader.go | 30 ++++--------- internal/datastore/spanner/spanner.go | 21 +++++++++- pkg/cmd/datastore/datastore.go | 10 +++++ pkg/cmd/datastore/zz_generated.options.go | 9 ++++ 15 files changed, 155 insertions(+), 38 deletions(-) diff --git a/internal/datastore/common/relationships.go b/internal/datastore/common/relationships.go index cffb92c693..7860b18650 100644 --- a/internal/datastore/common/relationships.go +++ b/internal/datastore/common/relationships.go @@ -20,6 +20,12 @@ const errUnableToQueryRels = "unable to query relationships: %w" // StaticValueOrAddColumnForSelect adds a column to the list of columns to select if the value // is not static, otherwise it sets the value to the static value. func StaticValueOrAddColumnForSelect(colsToSelect []any, queryInfo QueryInfo, colName string, field *string) []any { + if queryInfo.Schema.ColumnOptimization == ColumnOptimizationOptionNone { + // If column optimization is disabled, always add the column to the list of columns to select. + colsToSelect = append(colsToSelect, field) + return colsToSelect + } + // If the value is static, set the field to it and return. if found, ok := queryInfo.FilteringValues[colName]; ok && found.SingleValue != nil { *field = *found.SingleValue @@ -79,7 +85,7 @@ func QueryRelationships[R Rows, C ~map[string]any](ctx context.Context, queryInf colsToSelect = StaticValueOrAddColumnForSelect(colsToSelect, queryInfo, queryInfo.Schema.ColUsersetObjectID, &subjectObjectID) colsToSelect = StaticValueOrAddColumnForSelect(colsToSelect, queryInfo, queryInfo.Schema.ColUsersetRelation, &subjectRelation) - if !queryInfo.SkipCaveats { + if !queryInfo.SkipCaveats || queryInfo.Schema.ColumnOptimization == ColumnOptimizationOptionNone { colsToSelect = append(colsToSelect, &caveatName, &caveatCtx) } @@ -111,7 +117,7 @@ func QueryRelationships[R Rows, C ~map[string]any](ctx context.Context, queryInf } var caveat *corev1.ContextualizedCaveat - if !queryInfo.SkipCaveats { + if !queryInfo.SkipCaveats || queryInfo.Schema.ColumnOptimization == ColumnOptimizationOptionNone { if caveatName.Valid { var err error caveat, err = ContextualizedCaveatFrom(caveatName.String, caveatCtx) diff --git a/internal/datastore/common/sql.go b/internal/datastore/common/sql.go index a781cd7b8e..83fc2cbd8e 100644 --- a/internal/datastore/common/sql.go +++ b/internal/datastore/common/sql.go @@ -47,7 +47,7 @@ var ( tracer = otel.Tracer("spicedb/internal/datastore/common") ) -// PaginationFilterType is an enumerator +// PaginationFilterType is an enumerator for pagination filter types. type PaginationFilterType uint8 const ( @@ -62,6 +62,17 @@ const ( ExpandedLogicComparison ) +// ColumnOptimizationOption is an enumerator for column optimization options. +type ColumnOptimizationOption int + +const ( + // ColumnOptimizationOptionNone is the default option, which does not optimize the static columns. + ColumnOptimizationOptionNone ColumnOptimizationOption = iota + + // ColumnOptimizationOptionStaticValue is an option that optimizes the column for a static value. + ColumnOptimizationOptionStaticValues +) + // SchemaInformation holds the schema information from the SQL datastore implementation. type SchemaInformation struct { RelationshipTableName string @@ -84,6 +95,9 @@ type SchemaInformation struct { // NowFunction is the function to use to get the current time in the datastore. NowFunction string + // ColumnOptimization is the optimization to use for columns in the schema, if any. + ColumnOptimization ColumnOptimizationOption + // ExtaFields are additional fields that are not part of the core schema, but are // requested by the caller for this query. ExtraFields []string @@ -104,6 +118,7 @@ func NewSchemaInformation( paginationFilterType PaginationFilterType, placeholderFormat sq.PlaceholderFormat, nowFunction string, + columnOptionizationOption ColumnOptimizationOption, extraFields ...string, ) SchemaInformation { return SchemaInformation{ @@ -120,6 +135,7 @@ func NewSchemaInformation( paginationFilterType, placeholderFormat, nowFunction, + columnOptionizationOption, extraFields, } } @@ -695,19 +711,19 @@ func (tqs QueryExecutor) ExecuteQuery( // Set the column names to select. columnNamesToSelect := make([]string, 0, 8+len(query.extraFields)) - columnNamesToSelect = checkColumn(columnNamesToSelect, query.filteringColumnTracker, query.schema.ColNamespace) - columnNamesToSelect = checkColumn(columnNamesToSelect, query.filteringColumnTracker, query.schema.ColObjectID) - columnNamesToSelect = checkColumn(columnNamesToSelect, query.filteringColumnTracker, query.schema.ColRelation) - columnNamesToSelect = checkColumn(columnNamesToSelect, query.filteringColumnTracker, query.schema.ColUsersetNamespace) - columnNamesToSelect = checkColumn(columnNamesToSelect, query.filteringColumnTracker, query.schema.ColUsersetObjectID) - columnNamesToSelect = checkColumn(columnNamesToSelect, query.filteringColumnTracker, query.schema.ColUsersetRelation) + columnNamesToSelect = checkColumn(columnNamesToSelect, query.schema.ColumnOptimization, query.filteringColumnTracker, query.schema.ColNamespace) + columnNamesToSelect = checkColumn(columnNamesToSelect, query.schema.ColumnOptimization, query.filteringColumnTracker, query.schema.ColObjectID) + columnNamesToSelect = checkColumn(columnNamesToSelect, query.schema.ColumnOptimization, query.filteringColumnTracker, query.schema.ColRelation) + columnNamesToSelect = checkColumn(columnNamesToSelect, query.schema.ColumnOptimization, query.filteringColumnTracker, query.schema.ColUsersetNamespace) + columnNamesToSelect = checkColumn(columnNamesToSelect, query.schema.ColumnOptimization, query.filteringColumnTracker, query.schema.ColUsersetObjectID) + columnNamesToSelect = checkColumn(columnNamesToSelect, query.schema.ColumnOptimization, query.filteringColumnTracker, query.schema.ColUsersetRelation) - columnNamesToSelect = append(columnNamesToSelect, query.schema.ColExpiration) - - if !queryOpts.SkipCaveats { + if !queryOpts.SkipCaveats || query.schema.ColumnOptimization == ColumnOptimizationOptionNone { columnNamesToSelect = append(columnNamesToSelect, query.schema.ColCaveatName, query.schema.ColCaveatContext) } + columnNamesToSelect = append(columnNamesToSelect, query.schema.ColExpiration) + selectingNoColumns := false columnNamesToSelect = append(columnNamesToSelect, query.schema.ExtraFields...) if len(columnNamesToSelect) == 0 { @@ -732,7 +748,11 @@ func (tqs QueryExecutor) ExecuteQuery( return tqs.Executor(ctx, QueryInfo{query.schema, query.filteringColumnTracker, queryOpts.SkipCaveats, selectingNoColumns}, sql, args) } -func checkColumn(columns []string, tracker map[string]ColumnTracker, colName string) []string { +func checkColumn(columns []string, option ColumnOptimizationOption, tracker map[string]ColumnTracker, colName string) []string { + if option == ColumnOptimizationOptionNone { + return append(columns, colName) + } + if r, ok := tracker[colName]; !ok || r.SingleValue == nil { return append(columns, colName) } diff --git a/internal/datastore/common/sql_test.go b/internal/datastore/common/sql_test.go index 042bacc3f1..0195eb6d40 100644 --- a/internal/datastore/common/sql_test.go +++ b/internal/datastore/common/sql_test.go @@ -572,6 +572,7 @@ func TestSchemaQueryFilterer(t *testing.T) { TupleComparison, sq.Question, "NOW", + ColumnOptimizationOptionStaticValues, ) filterer := NewSchemaQueryFiltererForRelationshipsSelect(schema, 100) @@ -829,8 +830,11 @@ func TestExecuteQuery(t *testing.T) { "subject_relation", "caveat", "caveat_context", + "expiration", TupleComparison, sq.Question, + "NOW", + ColumnOptimizationOptionStaticValues, ) filterer := NewSchemaQueryFiltererForRelationshipsSelect(schema, 100) ran := tc.run(filterer) diff --git a/internal/datastore/crdb/crdb.go b/internal/datastore/crdb/crdb.go index de66a70135..5ee9109bb3 100644 --- a/internal/datastore/crdb/crdb.go +++ b/internal/datastore/crdb/crdb.go @@ -225,6 +225,7 @@ func newCRDBDatastore(ctx context.Context, url string, options ...Option) (datas common.ExpandedLogicComparison, sq.Dollar, "NOW", + config.columnOptimizationOption, extraFields..., ) diff --git a/internal/datastore/crdb/options.go b/internal/datastore/crdb/options.go index 67d8933638..80ead5a4f4 100644 --- a/internal/datastore/crdb/options.go +++ b/internal/datastore/crdb/options.go @@ -4,6 +4,7 @@ import ( "fmt" "time" + "github.com/authzed/spicedb/internal/datastore/common" pgxcommon "github.com/authzed/spicedb/internal/datastore/postgres/common" log "github.com/authzed/spicedb/internal/logging" ) @@ -28,6 +29,7 @@ type crdbOptions struct { enablePrometheusStats bool withIntegrity bool includeQueryParametersInTraces bool + columnOptimizationOption common.ColumnOptimizationOption allowedMigrations []string } @@ -56,6 +58,7 @@ const ( defaultConnectRate = 100 * time.Millisecond defaultFilterMaximumIDCount = 100 defaultWithIntegrity = false + defaultColumnOptimizationOption = common.ColumnOptimizationOptionNone defaultIncludeQueryParametersInTraces = false ) @@ -80,6 +83,7 @@ func generateConfig(options []Option) (crdbOptions, error) { connectRate: defaultConnectRate, filterMaximumIDCount: defaultFilterMaximumIDCount, withIntegrity: defaultWithIntegrity, + columnOptimizationOption: defaultColumnOptimizationOption, includeQueryParametersInTraces: defaultIncludeQueryParametersInTraces, } @@ -353,3 +357,14 @@ func AllowedMigrations(allowedMigrations []string) Option { func IncludeQueryParametersInTraces(includeQueryParametersInTraces bool) Option { return func(po *crdbOptions) { po.includeQueryParametersInTraces = includeQueryParametersInTraces } } + +// WithColumnOptimization configures the column optimization option for the datastore. +func WithColumnOptimization(isEnabled bool) Option { + return func(po *crdbOptions) { + if isEnabled { + po.columnOptimizationOption = common.ColumnOptimizationOptionStaticValues + } else { + po.columnOptimizationOption = common.ColumnOptimizationOptionNone + } + } +} diff --git a/internal/datastore/dsfortesting/dsfortesting.go b/internal/datastore/dsfortesting/dsfortesting.go index 42fad1209a..5e83d9ce44 100644 --- a/internal/datastore/dsfortesting/dsfortesting.go +++ b/internal/datastore/dsfortesting/dsfortesting.go @@ -59,8 +59,11 @@ func (vr validatingReader) QueryRelationships( "subject_relation", "caveat", "caveat_context", + "expiration", common.TupleComparison, sq.Question, + "NOW", + common.ColumnOptimizationOptionStaticValues, ) qBuilder, err := common.NewSchemaQueryFiltererForRelationshipsSelect(schema, 100). diff --git a/internal/datastore/mysql/datastore.go b/internal/datastore/mysql/datastore.go index d0959620f2..da7780a48b 100644 --- a/internal/datastore/mysql/datastore.go +++ b/internal/datastore/mysql/datastore.go @@ -258,6 +258,7 @@ func newMySQLDatastore(ctx context.Context, uri string, replicaIndex int, option common.ExpandedLogicComparison, sq.Question, "NOW", + config.columnOptimizationOption, ) store := &Datastore{ diff --git a/internal/datastore/mysql/options.go b/internal/datastore/mysql/options.go index 4a48e44fad..713e97e671 100644 --- a/internal/datastore/mysql/options.go +++ b/internal/datastore/mysql/options.go @@ -4,6 +4,7 @@ import ( "fmt" "time" + "github.com/authzed/spicedb/internal/datastore/common" log "github.com/authzed/spicedb/internal/logging" ) @@ -25,6 +26,7 @@ const ( defaultGCEnabled = true defaultCredentialsProviderName = "" defaultFilterMaximumIDCount = 100 + defaultColumnOptimizationOption = common.ColumnOptimizationOptionNone ) type mysqlOptions struct { @@ -47,6 +49,7 @@ type mysqlOptions struct { credentialsProviderName string filterMaximumIDCount uint16 allowedMigrations []string + columnOptimizationOption common.ColumnOptimizationOption } // Option provides the facility to configure how clients within the @@ -70,6 +73,7 @@ func generateConfig(options []Option) (mysqlOptions, error) { gcEnabled: defaultGCEnabled, credentialsProviderName: defaultCredentialsProviderName, filterMaximumIDCount: defaultFilterMaximumIDCount, + columnOptimizationOption: defaultColumnOptimizationOption, } for _, option := range options { @@ -269,3 +273,14 @@ func FilterMaximumIDCount(filterMaximumIDCount uint16) Option { func AllowedMigrations(allowedMigrations []string) Option { return func(mo *mysqlOptions) { mo.allowedMigrations = allowedMigrations } } + +// WithColumnOptimization configures the column optimization strategy for the MySQL datastore. +func WithColumnOptimization(isEnabled bool) Option { + return func(mo *mysqlOptions) { + if isEnabled { + mo.columnOptimizationOption = common.ColumnOptimizationOptionStaticValues + } else { + mo.columnOptimizationOption = common.ColumnOptimizationOptionNone + } + } +} diff --git a/internal/datastore/postgres/options.go b/internal/datastore/postgres/options.go index be79c03d9a..0997c5be16 100644 --- a/internal/datastore/postgres/options.go +++ b/internal/datastore/postgres/options.go @@ -4,6 +4,7 @@ import ( "fmt" "time" + "github.com/authzed/spicedb/internal/datastore/common" pgxcommon "github.com/authzed/spicedb/internal/datastore/postgres/common" log "github.com/authzed/spicedb/internal/logging" ) @@ -28,6 +29,7 @@ type postgresOptions struct { analyzeBeforeStatistics bool gcEnabled bool readStrictMode bool + columnOptimizationOption common.ColumnOptimizationOption includeQueryParametersInTraces bool migrationPhase string @@ -68,6 +70,7 @@ const ( defaultCredentialsProviderName = "" defaultReadStrictMode = false defaultFilterMaximumIDCount = 100 + defaultColumnOptimizationOption = common.ColumnOptimizationOptionNone defaultIncludeQueryParametersInTraces = false ) @@ -92,6 +95,7 @@ func generateConfig(options []Option) (postgresOptions, error) { queryInterceptor: nil, filterMaximumIDCount: defaultFilterMaximumIDCount, includeQueryParametersInTraces: defaultIncludeQueryParametersInTraces, + columnOptimizationOption: defaultColumnOptimizationOption, } for _, option := range options { @@ -385,3 +389,14 @@ func FilterMaximumIDCount(filterMaximumIDCount uint16) Option { func IncludeQueryParametersInTraces(includeQueryParametersInTraces bool) Option { return func(po *postgresOptions) { po.includeQueryParametersInTraces = includeQueryParametersInTraces } } + +// WithColumnOptimization sets the column optimization option for the datastore. +func WithColumnOptimization(isEnabled bool) Option { + return func(po *postgresOptions) { + if isEnabled { + po.columnOptimizationOption = common.ColumnOptimizationOptionStaticValues + } else { + po.columnOptimizationOption = common.ColumnOptimizationOptionNone + } + } +} diff --git a/internal/datastore/postgres/postgres.go b/internal/datastore/postgres/postgres.go index 3607942b45..0cef2e43ac 100644 --- a/internal/datastore/postgres/postgres.go +++ b/internal/datastore/postgres/postgres.go @@ -330,6 +330,7 @@ func newPostgresDatastore( common.TupleComparison, sq.Dollar, "NOW", + config.columnOptimizationOption, ) datastore := &pgDatastore{ diff --git a/internal/datastore/spanner/options.go b/internal/datastore/spanner/options.go index 29d9617428..a0ae438a23 100644 --- a/internal/datastore/spanner/options.go +++ b/internal/datastore/spanner/options.go @@ -6,6 +6,7 @@ import ( "runtime" "time" + "github.com/authzed/spicedb/internal/datastore/common" log "github.com/authzed/spicedb/internal/logging" ) @@ -26,6 +27,7 @@ type spannerOptions struct { migrationPhase string allowedMigrations []string filterMaximumIDCount uint16 + columnOptimizationOption common.ColumnOptimizationOption } type migrationPhase uint8 @@ -49,6 +51,7 @@ const ( defaultDisableStats = false maxRevisionQuantization = 24 * time.Hour defaultFilterMaximumIDCount = 100 + defaultColumnOptimizationOption = common.ColumnOptimizationOptionNone ) // Option provides the facility to configure how clients within the Spanner @@ -72,6 +75,7 @@ func generateConfig(options []Option) (spannerOptions, error) { maxSessions: 400, migrationPhase: "", // no migration filterMaximumIDCount: defaultFilterMaximumIDCount, + columnOptimizationOption: defaultColumnOptimizationOption, } for _, option := range options { @@ -224,3 +228,15 @@ func AllowedMigrations(allowedMigrations []string) Option { func FilterMaximumIDCount(filterMaximumIDCount uint16) Option { return func(po *spannerOptions) { po.filterMaximumIDCount = filterMaximumIDCount } } + +// WithColumnOptimization configures the Spanner driver to optimize the columns +// in the underlying tables. +func WithColumnOptimization(isEnabled bool) Option { + return func(po *spannerOptions) { + if isEnabled { + po.columnOptimizationOption = common.ColumnOptimizationOptionStaticValues + } else { + po.columnOptimizationOption = common.ColumnOptimizationOptionNone + } + } +} diff --git a/internal/datastore/spanner/reader.go b/internal/datastore/spanner/reader.go index 216361886b..fb23c0e06e 100644 --- a/internal/datastore/spanner/reader.go +++ b/internal/datastore/spanner/reader.go @@ -7,7 +7,6 @@ import ( "time" "cloud.google.com/go/spanner" - sq "github.com/Masterminds/squirrel" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/trace" "google.golang.org/grpc/codes" @@ -35,6 +34,7 @@ type spannerReader struct { executor common.QueryExecutor txSource txFactory filterMaximumIDCount uint16 + schema common.SchemaInformation } func (sr spannerReader) CountRelationships(ctx context.Context, name string) (int, error) { @@ -55,7 +55,7 @@ func (sr spannerReader) CountRelationships(ctx context.Context, name string) (in return 0, err } - builder, err := common.NewSchemaQueryFiltererWithStartingQuery(schema, countRels, sr.filterMaximumIDCount).FilterWithRelationshipsFilter(relFilter) + builder, err := common.NewSchemaQueryFiltererWithStartingQuery(sr.schema, countRels, sr.filterMaximumIDCount).FilterWithRelationshipsFilter(relFilter) if err != nil { return 0, err } @@ -135,7 +135,7 @@ func (sr spannerReader) QueryRelationships( filter datastore.RelationshipsFilter, opts ...options.QueryOptionsOption, ) (iter datastore.RelationshipIterator, err error) { - qBuilder, err := common.NewSchemaQueryFiltererForRelationshipsSelect(schema, sr.filterMaximumIDCount).FilterWithRelationshipsFilter(filter) + qBuilder, err := common.NewSchemaQueryFiltererForRelationshipsSelect(sr.schema, sr.filterMaximumIDCount).FilterWithRelationshipsFilter(filter) if err != nil { return nil, err } @@ -148,7 +148,7 @@ func (sr spannerReader) ReverseQueryRelationships( subjectsFilter datastore.SubjectsFilter, opts ...options.ReverseQueryOptionsOption, ) (iter datastore.RelationshipIterator, err error) { - qBuilder, err := common.NewSchemaQueryFiltererForRelationshipsSelect(schema, sr.filterMaximumIDCount). + qBuilder, err := common.NewSchemaQueryFiltererForRelationshipsSelect(sr.schema, sr.filterMaximumIDCount). FilterWithSubjectsSelectors(subjectsFilter.AsSelector()) if err != nil { return nil, err @@ -205,12 +205,12 @@ func queryExecutor(txSource txFactory) common.ExecuteReadRelsQueryFunc { colsToSelect = common.StaticValueOrAddColumnForSelect(colsToSelect, queryInfo, queryInfo.Schema.ColUsersetObjectID, &subjectObjectID) colsToSelect = common.StaticValueOrAddColumnForSelect(colsToSelect, queryInfo, queryInfo.Schema.ColUsersetRelation, &subjectRelation) - colsToSelect = append(colsToSelect, &expirationOrNull) - - if !queryInfo.SkipCaveats { + if !queryInfo.SkipCaveats || queryInfo.Schema.ColumnOptimization == common.ColumnOptimizationOptionNone { colsToSelect = append(colsToSelect, &caveatName, &caveatCtx) } + colsToSelect = append(colsToSelect, &expirationOrNull) + if len(colsToSelect) == 0 { var unused int64 colsToSelect = append(colsToSelect, &unused) @@ -378,20 +378,4 @@ var queryTuplesForDelete = sql.Select( colUsersetRelation, ).From(tableRelationship) -var schema = common.NewSchemaInformation( - tableRelationship, - colNamespace, - colObjectID, - colRelation, - colUsersetNamespace, - colUsersetObjectID, - colUsersetRelation, - colCaveatName, - colCaveatContext, - colExpiration, - common.ExpandedLogicComparison, - sq.AtP, - "CURRENT_TIMESTAMP", -) - var _ datastore.Reader = spannerReader{} diff --git a/internal/datastore/spanner/spanner.go b/internal/datastore/spanner/spanner.go index a0304a07e2..38d4190576 100644 --- a/internal/datastore/spanner/spanner.go +++ b/internal/datastore/spanner/spanner.go @@ -90,6 +90,7 @@ type spannerDatastore struct { client *spanner.Client config spannerOptions database string + schema common.SchemaInformation cachedEstimatedBytesPerRelationshipLock sync.RWMutex cachedEstimatedBytesPerRelationship uint64 @@ -200,6 +201,22 @@ func NewSpannerDatastore(ctx context.Context, database string, opts ...Option) ( cachedEstimatedBytesPerRelationshipLock: sync.RWMutex{}, tableSizesStatsTable: tableSizesStatsTable, filterMaximumIDCount: config.filterMaximumIDCount, + schema: common.NewSchemaInformation( + tableRelationship, + colNamespace, + colObjectID, + colRelation, + colUsersetNamespace, + colUsersetObjectID, + colUsersetRelation, + colCaveatName, + colCaveatContext, + colExpiration, + common.ExpandedLogicComparison, + sq.AtP, + "CURRENT_TIMESTAMP", + config.columnOptimizationOption, + ), } // Optimized revision and revision checking use a stale read for the // current timestamp. @@ -248,7 +265,7 @@ func (sd *spannerDatastore) SnapshotReader(revisionRaw datastore.Revision) datas return &traceableRTX{delegate: sd.client.Single().WithTimestampBound(spanner.ReadTimestamp(r.Time()))} } executor := common.QueryExecutor{Executor: queryExecutor(txSource)} - return spannerReader{executor, txSource, sd.filterMaximumIDCount} + return spannerReader{executor, txSource, sd.filterMaximumIDCount, sd.schema} } func (sd *spannerDatastore) readTransactionMetadata(ctx context.Context, transactionTag string) (map[string]any, error) { @@ -297,7 +314,7 @@ func (sd *spannerDatastore) ReadWriteTx(ctx context.Context, fn datastore.TxUser executor := common.QueryExecutor{Executor: queryExecutor(txSource)} rwt := spannerReadWriteTXN{ - spannerReader{executor, txSource, sd.filterMaximumIDCount}, + spannerReader{executor, txSource, sd.filterMaximumIDCount, sd.schema}, spannerRWT, } err := func() error { diff --git a/pkg/cmd/datastore/datastore.go b/pkg/cmd/datastore/datastore.go index 8b7da6726b..2248697e51 100644 --- a/pkg/cmd/datastore/datastore.go +++ b/pkg/cmd/datastore/datastore.go @@ -166,6 +166,9 @@ type Config struct { // Migrations MigrationPhase string `debugmap:"visible"` AllowedMigrations []string `debugmap:"visible"` + + // Expermimental + ExperimentalColumnOptimization bool `debugmap:"visible"` } //go:generate go run github.com/ecordell/optgen -sensitive-field-name-matches uri,secure -output zz_generated.relintegritykey.options.go . RelIntegrityKey @@ -271,6 +274,8 @@ func RegisterDatastoreFlagsWithPrefix(flagSet *pflag.FlagSet, prefix string, opt return fmt.Errorf("failed to mark flag as hidden: %w", err) } + flagSet.BoolVar(&opts.ExperimentalColumnOptimization, flagName("datastore-experimental-column-optimization"), false, "enable experimental column optimization") + return nil } @@ -317,6 +322,7 @@ func DefaultDatastoreConfig() *Config { RelationshipIntegrityCurrentKey: RelIntegrityKey{}, RelationshipIntegrityExpiredKeys: []string{}, AllowedMigrations: []string{}, + ExperimentalColumnOptimization: false, IncludeQueryParametersInTraces: false, } } @@ -512,6 +518,7 @@ func newCRDBDatastore(ctx context.Context, opts Config) (datastore.Datastore, er crdb.FilterMaximumIDCount(opts.FilterMaximumIDCount), crdb.WithIntegrity(opts.RelationshipIntegrityEnabled), crdb.AllowedMigrations(opts.AllowedMigrations), + crdb.WithColumnOptimization(opts.ExperimentalColumnOptimization), crdb.IncludeQueryParametersInTraces(opts.IncludeQueryParametersInTraces), ) } @@ -553,6 +560,7 @@ func commonPostgresDatastoreOptions(opts Config) ([]postgres.Option, error) { postgres.WithEnablePrometheusStats(opts.EnableDatastoreMetrics), postgres.MaxRetries(maxRetries), postgres.FilterMaximumIDCount(opts.FilterMaximumIDCount), + postgres.WithColumnOptimization(opts.ExperimentalColumnOptimization), postgres.IncludeQueryParametersInTraces(opts.IncludeQueryParametersInTraces), }, nil } @@ -636,6 +644,7 @@ func newSpannerDatastore(ctx context.Context, opts Config) (datastore.Datastore, spanner.MigrationPhase(opts.MigrationPhase), spanner.AllowedMigrations(opts.AllowedMigrations), spanner.FilterMaximumIDCount(opts.FilterMaximumIDCount), + spanner.WithColumnOptimization(opts.ExperimentalColumnOptimization), ) } @@ -680,6 +689,7 @@ func commonMySQLDatastoreOptions(opts Config) ([]mysql.Option, error) { mysql.RevisionQuantization(opts.RevisionQuantization), mysql.FilterMaximumIDCount(opts.FilterMaximumIDCount), mysql.AllowedMigrations(opts.AllowedMigrations), + mysql.WithColumnOptimization(opts.ExperimentalColumnOptimization), }, nil } diff --git a/pkg/cmd/datastore/zz_generated.options.go b/pkg/cmd/datastore/zz_generated.options.go index 4bf6a58d30..ab9d39c92e 100644 --- a/pkg/cmd/datastore/zz_generated.options.go +++ b/pkg/cmd/datastore/zz_generated.options.go @@ -78,6 +78,7 @@ func (c *Config) ToOption() ConfigOption { to.WatchConnectTimeout = c.WatchConnectTimeout to.MigrationPhase = c.MigrationPhase to.AllowedMigrations = c.AllowedMigrations + to.ExperimentalColumnOptimization = c.ExperimentalColumnOptimization } } @@ -130,6 +131,7 @@ func (c Config) DebugMap() map[string]any { debugMap["WatchConnectTimeout"] = helpers.DebugValue(c.WatchConnectTimeout, false) debugMap["MigrationPhase"] = helpers.DebugValue(c.MigrationPhase, false) debugMap["AllowedMigrations"] = helpers.DebugValue(c.AllowedMigrations, false) + debugMap["ExperimentalColumnOptimization"] = helpers.DebugValue(c.ExperimentalColumnOptimization, false) return debugMap } @@ -519,3 +521,10 @@ func SetAllowedMigrations(allowedMigrations []string) ConfigOption { c.AllowedMigrations = allowedMigrations } } + +// WithExperimentalColumnOptimization returns an option that can set ExperimentalColumnOptimization on a Config +func WithExperimentalColumnOptimization(experimentalColumnOptimization bool) ConfigOption { + return func(c *Config) { + c.ExperimentalColumnOptimization = experimentalColumnOptimization + } +} From 4a5880dcb7814bc614f1e275584487026d8af4cd Mon Sep 17 00:00:00 2001 From: Joseph Schorr Date: Tue, 10 Dec 2024 16:25:51 -0500 Subject: [PATCH 07/15] Update tests for expiration filtering --- internal/datastore/common/sql_test.go | 38 +++++++++++++-------------- internal/datastore/mysql/datastore.go | 8 +++--- 2 files changed, 23 insertions(+), 23 deletions(-) diff --git a/internal/datastore/common/sql_test.go b/internal/datastore/common/sql_test.go index 0195eb6d40..35cd4cc833 100644 --- a/internal/datastore/common/sql_test.go +++ b/internal/datastore/common/sql_test.go @@ -611,7 +611,7 @@ func TestExecuteQuery(t *testing.T) { run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.FilterToResourceType("sometype") }, - expectedSQL: "SELECT object_id, relation, subject_ns, subject_object_id, subject_relation, caveat, caveat_context FROM relationtuples WHERE ns = ?", + expectedSQL: "SELECT object_id, relation, subject_ns, subject_object_id, subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE (expiration IS NULL OR expiration > NOW()) AND ns = ?", expectedArgs: []any{"sometype"}, }, { @@ -619,7 +619,7 @@ func TestExecuteQuery(t *testing.T) { run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.FilterToResourceType("sometype").FilterToResourceID("someobj") }, - expectedSQL: "SELECT relation, subject_ns, subject_object_id, subject_relation, caveat, caveat_context FROM relationtuples WHERE ns = ? AND object_id = ?", + expectedSQL: "SELECT relation, subject_ns, subject_object_id, subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE (expiration IS NULL OR expiration > NOW()) AND ns = ? AND object_id = ?", expectedArgs: []any{"sometype", "someobj"}, }, { @@ -627,7 +627,7 @@ func TestExecuteQuery(t *testing.T) { run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.FilterToResourceType("sometype").MustFilterWithResourceIDPrefix("someprefix") }, - expectedSQL: "SELECT object_id, relation, subject_ns, subject_object_id, subject_relation, caveat, caveat_context FROM relationtuples WHERE ns = ? AND object_id LIKE ?", + expectedSQL: "SELECT object_id, relation, subject_ns, subject_object_id, subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE (expiration IS NULL OR expiration > NOW()) AND ns = ? AND object_id LIKE ?", expectedArgs: []any{"sometype", "someprefix%"}, }, { @@ -635,7 +635,7 @@ func TestExecuteQuery(t *testing.T) { run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.FilterToResourceType("sometype").MustFilterToResourceIDs([]string{"someobj", "anotherobj"}) }, - expectedSQL: "SELECT object_id, relation, subject_ns, subject_object_id, subject_relation, caveat, caveat_context FROM relationtuples WHERE ns = ? AND object_id IN (?,?)", + expectedSQL: "SELECT object_id, relation, subject_ns, subject_object_id, subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE (expiration IS NULL OR expiration > NOW()) AND ns = ? AND object_id IN (?,?)", expectedArgs: []any{"sometype", "someobj", "anotherobj"}, }, { @@ -643,7 +643,7 @@ func TestExecuteQuery(t *testing.T) { run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.FilterToResourceType("sometype").FilterToResourceID("someobj").FilterToRelation("somerel") }, - expectedSQL: "SELECT subject_ns, subject_object_id, subject_relation, caveat, caveat_context FROM relationtuples WHERE ns = ? AND object_id = ? AND relation = ?", + expectedSQL: "SELECT subject_ns, subject_object_id, subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE (expiration IS NULL OR expiration > NOW()) AND ns = ? AND object_id = ? AND relation = ?", expectedArgs: []any{"sometype", "someobj", "somerel"}, }, { @@ -653,7 +653,7 @@ func TestExecuteQuery(t *testing.T) { SubjectType: "subns", }) }, - expectedSQL: "SELECT subject_object_id, subject_relation, caveat, caveat_context FROM relationtuples WHERE ns = ? AND object_id = ? AND relation = ? AND subject_ns = ?", + expectedSQL: "SELECT subject_object_id, subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE (expiration IS NULL OR expiration > NOW()) AND ns = ? AND object_id = ? AND relation = ? AND subject_ns = ?", expectedArgs: []any{"sometype", "someobj", "somerel", "subns"}, }, { @@ -664,7 +664,7 @@ func TestExecuteQuery(t *testing.T) { OptionalSubjectId: "subid", }) }, - expectedSQL: "SELECT subject_relation, caveat, caveat_context FROM relationtuples WHERE ns = ? AND object_id = ? AND relation = ? AND subject_ns = ? AND subject_object_id = ?", + expectedSQL: "SELECT subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE (expiration IS NULL OR expiration > NOW()) AND ns = ? AND object_id = ? AND relation = ? AND subject_ns = ? AND subject_object_id = ?", expectedArgs: []any{"sometype", "someobj", "somerel", "subns", "subid"}, }, { @@ -678,7 +678,7 @@ func TestExecuteQuery(t *testing.T) { }, }) }, - expectedSQL: "SELECT caveat, caveat_context FROM relationtuples WHERE ns = ? AND object_id = ? AND relation = ? AND subject_ns = ? AND subject_object_id = ? AND subject_relation = ?", + expectedSQL: "SELECT caveat, caveat_context, expiration FROM relationtuples WHERE (expiration IS NULL OR expiration > NOW()) AND ns = ? AND object_id = ? AND relation = ? AND subject_ns = ? AND subject_object_id = ? AND subject_relation = ?", expectedArgs: []any{"sometype", "someobj", "somerel", "subns", "subid", "subrel"}, }, { @@ -696,8 +696,8 @@ func TestExecuteQuery(t *testing.T) { options.WithSkipCaveats(true), }, expectedSkipCaveats: true, - expectedSelectingNoColumns: true, - expectedSQL: "SELECT 1 FROM relationtuples WHERE ns = ? AND object_id = ? AND relation = ? AND subject_ns = ? AND subject_object_id = ? AND subject_relation = ?", + expectedSelectingNoColumns: false, + expectedSQL: "SELECT expiration FROM relationtuples WHERE (expiration IS NULL OR expiration > NOW()) AND ns = ? AND object_id = ? AND relation = ? AND subject_ns = ? AND subject_object_id = ? AND subject_relation = ?", expectedArgs: []any{"sometype", "someobj", "somerel", "subns", "subid", "subrel"}, }, { @@ -715,7 +715,7 @@ func TestExecuteQuery(t *testing.T) { options.WithSkipCaveats(true), }, expectedSkipCaveats: true, - expectedSQL: "SELECT object_id FROM relationtuples WHERE ns = ? AND object_id IN (?,?) AND relation = ? AND subject_ns = ? AND subject_object_id = ? AND subject_relation = ?", + expectedSQL: "SELECT object_id, expiration FROM relationtuples WHERE (expiration IS NULL OR expiration > NOW()) AND ns = ? AND object_id IN (?,?) AND relation = ? AND subject_ns = ? AND subject_object_id = ? AND subject_relation = ?", expectedArgs: []any{"sometype", "someobj", "anotherobj", "somerel", "subns", "subid", "subrel"}, }, { @@ -727,7 +727,7 @@ func TestExecuteQuery(t *testing.T) { options.WithSkipCaveats(true), }, expectedSkipCaveats: true, - expectedSQL: "SELECT object_id, relation, subject_ns, subject_object_id, subject_relation FROM relationtuples WHERE ns = ?", + expectedSQL: "SELECT object_id, relation, subject_ns, subject_object_id, subject_relation, expiration FROM relationtuples WHERE (expiration IS NULL OR expiration > NOW()) AND ns = ?", expectedArgs: []any{"sometype"}, }, { @@ -737,7 +737,7 @@ func TestExecuteQuery(t *testing.T) { SubjectType: "subns", }) }, - expectedSQL: "SELECT ns, object_id, relation, subject_object_id, subject_relation, caveat, caveat_context FROM relationtuples WHERE subject_ns = ?", + expectedSQL: "SELECT ns, object_id, relation, subject_object_id, subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE (expiration IS NULL OR expiration > NOW()) AND subject_ns = ?", expectedArgs: []any{"subns"}, }, { @@ -748,7 +748,7 @@ func TestExecuteQuery(t *testing.T) { OptionalSubjectId: "subid", }) }, - expectedSQL: "SELECT ns, object_id, relation, subject_relation, caveat, caveat_context FROM relationtuples WHERE subject_ns = ? AND subject_object_id = ?", + expectedSQL: "SELECT ns, object_id, relation, subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE (expiration IS NULL OR expiration > NOW()) AND subject_ns = ? AND subject_object_id = ?", expectedArgs: []any{"subns", "subid"}, }, { @@ -761,7 +761,7 @@ func TestExecuteQuery(t *testing.T) { }, }) }, - expectedSQL: "SELECT ns, object_id, relation, subject_object_id, caveat, caveat_context FROM relationtuples WHERE subject_ns = ? AND subject_relation = ?", + expectedSQL: "SELECT ns, object_id, relation, subject_object_id, caveat, caveat_context, expiration FROM relationtuples WHERE (expiration IS NULL OR expiration > NOW()) AND subject_ns = ? AND subject_relation = ?", expectedArgs: []any{"subns", "subrel"}, }, { @@ -775,7 +775,7 @@ func TestExecuteQuery(t *testing.T) { }, }) }, - expectedSQL: "SELECT ns, object_id, relation, caveat, caveat_context FROM relationtuples WHERE subject_ns = ? AND subject_object_id = ? AND subject_relation = ?", + expectedSQL: "SELECT ns, object_id, relation, caveat, caveat_context, expiration FROM relationtuples WHERE (expiration IS NULL OR expiration > NOW()) AND subject_ns = ? AND subject_object_id = ? AND subject_relation = ?", expectedArgs: []any{"subns", "subid", "subrel"}, }, { @@ -789,7 +789,7 @@ func TestExecuteQuery(t *testing.T) { OptionalSubjectId: "subid", }) }, - expectedSQL: "SELECT ns, object_id, relation, subject_ns, subject_relation, caveat, caveat_context FROM relationtuples WHERE subject_ns = ? AND subject_object_id = ? AND subject_ns = ? AND subject_object_id = ?", + expectedSQL: "SELECT ns, object_id, relation, subject_ns, subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE (expiration IS NULL OR expiration > NOW()) AND subject_ns = ? AND subject_object_id = ? AND subject_ns = ? AND subject_object_id = ?", expectedArgs: []any{"subns", "subid", "anothersubns", "subid"}, }, { @@ -801,7 +801,7 @@ func TestExecuteQuery(t *testing.T) { OptionalSubjectType: "anothersubjectype", }) }, - expectedSQL: "SELECT ns, object_id, relation, subject_ns, subject_object_id, subject_relation, caveat, caveat_context FROM relationtuples WHERE ((subject_ns = ?) OR (subject_ns = ?))", + expectedSQL: "SELECT ns, object_id, relation, subject_ns, subject_object_id, subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE (expiration IS NULL OR expiration > NOW()) AND ((subject_ns = ?) OR (subject_ns = ?))", expectedArgs: []any{"somesubjectype", "anothersubjectype"}, }, { @@ -813,7 +813,7 @@ func TestExecuteQuery(t *testing.T) { OptionalSubjectType: "anothersubjectype", }).FilterToResourceType("sometype") }, - expectedSQL: "SELECT object_id, relation, subject_ns, subject_object_id, subject_relation, caveat, caveat_context FROM relationtuples WHERE ((subject_ns = ?) OR (subject_ns = ?)) AND ns = ?", + expectedSQL: "SELECT object_id, relation, subject_ns, subject_object_id, subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE (expiration IS NULL OR expiration > NOW()) AND ((subject_ns = ?) OR (subject_ns = ?)) AND ns = ?", expectedArgs: []any{"somesubjectype", "anothersubjectype", "sometype"}, }, } diff --git a/internal/datastore/mysql/datastore.go b/internal/datastore/mysql/datastore.go index da7780a48b..e8c5b6c903 100644 --- a/internal/datastore/mysql/datastore.go +++ b/internal/datastore/mysql/datastore.go @@ -435,12 +435,12 @@ type querier interface { QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) } -type wrappedTX struct { +type asQueryableTx struct { tx querier } -func (wtx wrappedTX) QueryFunc(ctx context.Context, f func(context.Context, common.Rows) error, sql string, args ...any) error { - rows, err := wtx.tx.QueryContext(ctx, sql, args...) +func (aqt asQueryableTx) QueryFunc(ctx context.Context, f func(context.Context, common.Rows) error, sql string, args ...any) error { + rows, err := aqt.tx.QueryContext(ctx, sql, args...) if err != nil { return err } @@ -470,7 +470,7 @@ func newMySQLExecutor(tx querier) common.ExecuteReadRelsQueryFunc { // short lifetime (e.g. to gracefully handle load-balancer connection drain) return func(ctx context.Context, queryInfo common.QueryInfo, sqlQuery string, args []interface{}) (datastore.RelationshipIterator, error) { span := trace.SpanFromContext(ctx) - return common.QueryRelationships[common.Rows, structpbWrapper](ctx, queryInfo, sqlQuery, args, span, wrappedTX{tx}, false) + return common.QueryRelationships[common.Rows, structpbWrapper](ctx, queryInfo, sqlQuery, args, span, asQueryableTx{tx}, false) } } From ae90b4d889d96d322f45ddae4c3b92488d9e6b79 Mon Sep 17 00:00:00 2001 From: Joseph Schorr Date: Wed, 11 Dec 2024 16:21:58 -0500 Subject: [PATCH 08/15] Implement a combined builder pattern for relationship SQL construction This moves the behavior out of Spanner datastore and into a common lib where possible --- internal/datastore/common/relationships.go | 55 +--- internal/datastore/common/schema.go | 122 +++++++++ internal/datastore/common/schema_options.go | 219 +++++++++++++++ internal/datastore/common/sql.go | 254 ++++++++++-------- internal/datastore/common/sql_test.go | 94 +++---- internal/datastore/crdb/crdb.go | 51 ++-- internal/datastore/crdb/reader.go | 2 +- .../datastore/dsfortesting/dsfortesting.go | 46 ++-- internal/datastore/mysql/datastore.go | 40 +-- internal/datastore/mysql/reader.go | 2 +- internal/datastore/postgres/common/pgx.go | 15 +- internal/datastore/postgres/postgres.go | 40 +-- internal/datastore/postgres/reader.go | 2 +- internal/datastore/spanner/reader.go | 51 ++-- internal/datastore/spanner/spanner.go | 38 +-- 15 files changed, 685 insertions(+), 346 deletions(-) create mode 100644 internal/datastore/common/schema.go create mode 100644 internal/datastore/common/schema_options.go diff --git a/internal/datastore/common/relationships.go b/internal/datastore/common/relationships.go index 7860b18650..3b6f8a5156 100644 --- a/internal/datastore/common/relationships.go +++ b/internal/datastore/common/relationships.go @@ -17,26 +17,6 @@ import ( const errUnableToQueryRels = "unable to query relationships: %w" -// StaticValueOrAddColumnForSelect adds a column to the list of columns to select if the value -// is not static, otherwise it sets the value to the static value. -func StaticValueOrAddColumnForSelect(colsToSelect []any, queryInfo QueryInfo, colName string, field *string) []any { - if queryInfo.Schema.ColumnOptimization == ColumnOptimizationOptionNone { - // If column optimization is disabled, always add the column to the list of columns to select. - colsToSelect = append(colsToSelect, field) - return colsToSelect - } - - // If the value is static, set the field to it and return. - if found, ok := queryInfo.FilteringValues[colName]; ok && found.SingleValue != nil { - *field = *found.SingleValue - return colsToSelect - } - - // Otherwise, add the column to the list of columns to select, as the value is not static. - colsToSelect = append(colsToSelect, field) - return colsToSelect -} - // Querier is an interface for querying the database. type Querier[R Rows] interface { QueryFunc(ctx context.Context, f func(context.Context, R) error, sql string, args ...any) error @@ -60,10 +40,14 @@ type closeRows interface { } // QueryRelationships queries relationships for the given query and transaction. -func QueryRelationships[R Rows, C ~map[string]any](ctx context.Context, queryInfo QueryInfo, sqlStatement string, args []any, span trace.Span, tx Querier[R], withIntegrity bool) (datastore.RelationshipIterator, error) { +func QueryRelationships[R Rows, C ~map[string]any](ctx context.Context, builder RelationshipsQueryBuilder, span trace.Span, tx Querier[R]) (datastore.RelationshipIterator, error) { defer span.End() - colsToSelect := make([]any, 0, 8) + sqlString, args, err := builder.SelectSQL() + if err != nil { + return nil, fmt.Errorf(errUnableToQueryRels, err) + } + var resourceObjectType string var resourceObjectID string var resourceRelation string @@ -78,26 +62,9 @@ func QueryRelationships[R Rows, C ~map[string]any](ctx context.Context, queryInf var integrityHash []byte var timestamp time.Time - colsToSelect = StaticValueOrAddColumnForSelect(colsToSelect, queryInfo, queryInfo.Schema.ColNamespace, &resourceObjectType) - colsToSelect = StaticValueOrAddColumnForSelect(colsToSelect, queryInfo, queryInfo.Schema.ColObjectID, &resourceObjectID) - colsToSelect = StaticValueOrAddColumnForSelect(colsToSelect, queryInfo, queryInfo.Schema.ColRelation, &resourceRelation) - colsToSelect = StaticValueOrAddColumnForSelect(colsToSelect, queryInfo, queryInfo.Schema.ColUsersetNamespace, &subjectObjectType) - colsToSelect = StaticValueOrAddColumnForSelect(colsToSelect, queryInfo, queryInfo.Schema.ColUsersetObjectID, &subjectObjectID) - colsToSelect = StaticValueOrAddColumnForSelect(colsToSelect, queryInfo, queryInfo.Schema.ColUsersetRelation, &subjectRelation) - - if !queryInfo.SkipCaveats || queryInfo.Schema.ColumnOptimization == ColumnOptimizationOptionNone { - colsToSelect = append(colsToSelect, &caveatName, &caveatCtx) - } - - colsToSelect = append(colsToSelect, &expiration) - - if withIntegrity { - colsToSelect = append(colsToSelect, &integrityKeyID, &integrityHash, ×tamp) - } - - if len(colsToSelect) == 0 { - var unused int - colsToSelect = append(colsToSelect, &unused) + colsToSelect, err := ColumnsToSelect(builder, &resourceObjectType, &resourceObjectID, &resourceRelation, &subjectObjectType, &subjectObjectID, &subjectRelation, &caveatName, &caveatCtx, &expiration, &integrityKeyID, &integrityHash, ×tamp) + if err != nil { + return nil, fmt.Errorf(errUnableToQueryRels, err) } return func(yield func(tuple.Relationship, error) bool) { @@ -117,7 +84,7 @@ func QueryRelationships[R Rows, C ~map[string]any](ctx context.Context, queryInf } var caveat *corev1.ContextualizedCaveat - if !queryInfo.SkipCaveats || queryInfo.Schema.ColumnOptimization == ColumnOptimizationOptionNone { + if !builder.SkipCaveats || builder.Schema.ColumnOptimization == ColumnOptimizationOptionNone { if caveatName.Valid { var err error caveat, err = ContextualizedCaveatFrom(caveatName.String, caveatCtx) @@ -171,7 +138,7 @@ func QueryRelationships[R Rows, C ~map[string]any](ctx context.Context, queryInf span.AddEvent("Rels loaded", trace.WithAttributes(attribute.Int("relCount", relCount))) return nil - }, sqlStatement, args...) + }, sqlString, args...) if err != nil { if !yield(tuple.Relationship{}, err) { return diff --git a/internal/datastore/common/schema.go b/internal/datastore/common/schema.go new file mode 100644 index 0000000000..8f31f929a6 --- /dev/null +++ b/internal/datastore/common/schema.go @@ -0,0 +1,122 @@ +package common + +import ( + sq "github.com/Masterminds/squirrel" + + "github.com/authzed/spicedb/pkg/spiceerrors" +) + +// SchemaInformation holds the schema information from the SQL datastore implementation. +// +//go:generate go run github.com/ecordell/optgen -output schema_options.go . SchemaInformation +type SchemaInformation struct { + RelationshipTableName string `debugmap:"visible"` + + ColNamespace string `debugmap:"visible"` + ColObjectID string `debugmap:"visible"` + ColRelation string `debugmap:"visible"` + ColUsersetNamespace string `debugmap:"visible"` + ColUsersetObjectID string `debugmap:"visible"` + ColUsersetRelation string `debugmap:"visible"` + ColCaveatName string `debugmap:"visible"` + ColCaveatContext string `debugmap:"visible"` + ColExpiration string `debugmap:"visible"` + + ColIntegrityKeyID string `debugmap:"visible"` + ColIntegrityHash string `debugmap:"visible"` + ColIntegrityTimestamp string `debugmap:"visible"` + + // PaginationFilterType is the type of pagination filter to use for this schema. + PaginationFilterType PaginationFilterType `debugmap:"visible"` + + // PlaceholderFormat is the format of placeholders to use for this schema. + PlaceholderFormat sq.PlaceholderFormat `debugmap:"visible"` + + // NowFunction is the function to use to get the current time in the datastore. + NowFunction string `debugmap:"visible"` + + // ColumnOptimization is the optimization to use for columns in the schema, if any. + ColumnOptimization ColumnOptimizationOption `debugmap:"visible"` + + // WithIntegrityColumns is a flag to indicate if the schema has integrity columns. + WithIntegrityColumns bool `debugmap:"visible"` +} + +func (si SchemaInformation) debugValidate() { + spiceerrors.DebugAssert(func() bool { + si.mustValidate() + return true + }, "SchemaInformation failed to validate") +} + +func (si SchemaInformation) mustValidate() { + if si.RelationshipTableName == "" { + panic("RelationshipTableName is required") + } + + if si.ColNamespace == "" { + panic("ColNamespace is required") + } + + if si.ColObjectID == "" { + panic("ColObjectID is required") + } + + if si.ColRelation == "" { + panic("ColRelation is required") + } + + if si.ColUsersetNamespace == "" { + panic("ColUsersetNamespace is required") + } + + if si.ColUsersetObjectID == "" { + panic("ColUsersetObjectID is required") + } + + if si.ColUsersetRelation == "" { + panic("ColUsersetRelation is required") + } + + if si.ColCaveatName == "" { + panic("ColCaveatName is required") + } + + if si.ColCaveatContext == "" { + panic("ColCaveatContext is required") + } + + if si.ColExpiration == "" { + panic("ColExpiration is required") + } + + if si.WithIntegrityColumns { + if si.ColIntegrityKeyID == "" { + panic("ColIntegrityKeyID is required") + } + + if si.ColIntegrityHash == "" { + panic("ColIntegrityHash is required") + } + + if si.ColIntegrityTimestamp == "" { + panic("ColIntegrityTimestamp is required") + } + } + + if si.NowFunction == "" { + panic("NowFunction is required") + } + + if si.ColumnOptimization == ColumnOptimizationOptionUnknown { + panic("ColumnOptimization is required") + } + + if si.PaginationFilterType == PaginationFilterTypeUnknown { + panic("PaginationFilterType is required") + } + + if si.PlaceholderFormat == nil { + panic("PlaceholderFormat is required") + } +} diff --git a/internal/datastore/common/schema_options.go b/internal/datastore/common/schema_options.go new file mode 100644 index 0000000000..3aed7f64e8 --- /dev/null +++ b/internal/datastore/common/schema_options.go @@ -0,0 +1,219 @@ +// Code generated by github.com/ecordell/optgen. DO NOT EDIT. +package common + +import ( + squirrel "github.com/Masterminds/squirrel" + defaults "github.com/creasty/defaults" + helpers "github.com/ecordell/optgen/helpers" +) + +type SchemaInformationOption func(s *SchemaInformation) + +// NewSchemaInformationWithOptions creates a new SchemaInformation with the passed in options set +func NewSchemaInformationWithOptions(opts ...SchemaInformationOption) *SchemaInformation { + s := &SchemaInformation{} + for _, o := range opts { + o(s) + } + return s +} + +// NewSchemaInformationWithOptionsAndDefaults creates a new SchemaInformation with the passed in options set starting from the defaults +func NewSchemaInformationWithOptionsAndDefaults(opts ...SchemaInformationOption) *SchemaInformation { + s := &SchemaInformation{} + defaults.MustSet(s) + for _, o := range opts { + o(s) + } + return s +} + +// ToOption returns a new SchemaInformationOption that sets the values from the passed in SchemaInformation +func (s *SchemaInformation) ToOption() SchemaInformationOption { + return func(to *SchemaInformation) { + to.RelationshipTableName = s.RelationshipTableName + to.ColNamespace = s.ColNamespace + to.ColObjectID = s.ColObjectID + to.ColRelation = s.ColRelation + to.ColUsersetNamespace = s.ColUsersetNamespace + to.ColUsersetObjectID = s.ColUsersetObjectID + to.ColUsersetRelation = s.ColUsersetRelation + to.ColCaveatName = s.ColCaveatName + to.ColCaveatContext = s.ColCaveatContext + to.ColExpiration = s.ColExpiration + to.ColIntegrityKeyID = s.ColIntegrityKeyID + to.ColIntegrityHash = s.ColIntegrityHash + to.ColIntegrityTimestamp = s.ColIntegrityTimestamp + to.PaginationFilterType = s.PaginationFilterType + to.PlaceholderFormat = s.PlaceholderFormat + to.NowFunction = s.NowFunction + to.ColumnOptimization = s.ColumnOptimization + to.WithIntegrityColumns = s.WithIntegrityColumns + } +} + +// DebugMap returns a map form of SchemaInformation for debugging +func (s SchemaInformation) DebugMap() map[string]any { + debugMap := map[string]any{} + debugMap["RelationshipTableName"] = helpers.DebugValue(s.RelationshipTableName, false) + debugMap["ColNamespace"] = helpers.DebugValue(s.ColNamespace, false) + debugMap["ColObjectID"] = helpers.DebugValue(s.ColObjectID, false) + debugMap["ColRelation"] = helpers.DebugValue(s.ColRelation, false) + debugMap["ColUsersetNamespace"] = helpers.DebugValue(s.ColUsersetNamespace, false) + debugMap["ColUsersetObjectID"] = helpers.DebugValue(s.ColUsersetObjectID, false) + debugMap["ColUsersetRelation"] = helpers.DebugValue(s.ColUsersetRelation, false) + debugMap["ColCaveatName"] = helpers.DebugValue(s.ColCaveatName, false) + debugMap["ColCaveatContext"] = helpers.DebugValue(s.ColCaveatContext, false) + debugMap["ColExpiration"] = helpers.DebugValue(s.ColExpiration, false) + debugMap["ColIntegrityKeyID"] = helpers.DebugValue(s.ColIntegrityKeyID, false) + debugMap["ColIntegrityHash"] = helpers.DebugValue(s.ColIntegrityHash, false) + debugMap["ColIntegrityTimestamp"] = helpers.DebugValue(s.ColIntegrityTimestamp, false) + debugMap["PaginationFilterType"] = helpers.DebugValue(s.PaginationFilterType, false) + debugMap["PlaceholderFormat"] = helpers.DebugValue(s.PlaceholderFormat, false) + debugMap["NowFunction"] = helpers.DebugValue(s.NowFunction, false) + debugMap["ColumnOptimization"] = helpers.DebugValue(s.ColumnOptimization, false) + debugMap["WithIntegrityColumns"] = helpers.DebugValue(s.WithIntegrityColumns, false) + return debugMap +} + +// SchemaInformationWithOptions configures an existing SchemaInformation with the passed in options set +func SchemaInformationWithOptions(s *SchemaInformation, opts ...SchemaInformationOption) *SchemaInformation { + for _, o := range opts { + o(s) + } + return s +} + +// WithOptions configures the receiver SchemaInformation with the passed in options set +func (s *SchemaInformation) WithOptions(opts ...SchemaInformationOption) *SchemaInformation { + for _, o := range opts { + o(s) + } + return s +} + +// WithRelationshipTableName returns an option that can set RelationshipTableName on a SchemaInformation +func WithRelationshipTableName(relationshipTableName string) SchemaInformationOption { + return func(s *SchemaInformation) { + s.RelationshipTableName = relationshipTableName + } +} + +// WithColNamespace returns an option that can set ColNamespace on a SchemaInformation +func WithColNamespace(colNamespace string) SchemaInformationOption { + return func(s *SchemaInformation) { + s.ColNamespace = colNamespace + } +} + +// WithColObjectID returns an option that can set ColObjectID on a SchemaInformation +func WithColObjectID(colObjectID string) SchemaInformationOption { + return func(s *SchemaInformation) { + s.ColObjectID = colObjectID + } +} + +// WithColRelation returns an option that can set ColRelation on a SchemaInformation +func WithColRelation(colRelation string) SchemaInformationOption { + return func(s *SchemaInformation) { + s.ColRelation = colRelation + } +} + +// WithColUsersetNamespace returns an option that can set ColUsersetNamespace on a SchemaInformation +func WithColUsersetNamespace(colUsersetNamespace string) SchemaInformationOption { + return func(s *SchemaInformation) { + s.ColUsersetNamespace = colUsersetNamespace + } +} + +// WithColUsersetObjectID returns an option that can set ColUsersetObjectID on a SchemaInformation +func WithColUsersetObjectID(colUsersetObjectID string) SchemaInformationOption { + return func(s *SchemaInformation) { + s.ColUsersetObjectID = colUsersetObjectID + } +} + +// WithColUsersetRelation returns an option that can set ColUsersetRelation on a SchemaInformation +func WithColUsersetRelation(colUsersetRelation string) SchemaInformationOption { + return func(s *SchemaInformation) { + s.ColUsersetRelation = colUsersetRelation + } +} + +// WithColCaveatName returns an option that can set ColCaveatName on a SchemaInformation +func WithColCaveatName(colCaveatName string) SchemaInformationOption { + return func(s *SchemaInformation) { + s.ColCaveatName = colCaveatName + } +} + +// WithColCaveatContext returns an option that can set ColCaveatContext on a SchemaInformation +func WithColCaveatContext(colCaveatContext string) SchemaInformationOption { + return func(s *SchemaInformation) { + s.ColCaveatContext = colCaveatContext + } +} + +// WithColExpiration returns an option that can set ColExpiration on a SchemaInformation +func WithColExpiration(colExpiration string) SchemaInformationOption { + return func(s *SchemaInformation) { + s.ColExpiration = colExpiration + } +} + +// WithColIntegrityKeyID returns an option that can set ColIntegrityKeyID on a SchemaInformation +func WithColIntegrityKeyID(colIntegrityKeyID string) SchemaInformationOption { + return func(s *SchemaInformation) { + s.ColIntegrityKeyID = colIntegrityKeyID + } +} + +// WithColIntegrityHash returns an option that can set ColIntegrityHash on a SchemaInformation +func WithColIntegrityHash(colIntegrityHash string) SchemaInformationOption { + return func(s *SchemaInformation) { + s.ColIntegrityHash = colIntegrityHash + } +} + +// WithColIntegrityTimestamp returns an option that can set ColIntegrityTimestamp on a SchemaInformation +func WithColIntegrityTimestamp(colIntegrityTimestamp string) SchemaInformationOption { + return func(s *SchemaInformation) { + s.ColIntegrityTimestamp = colIntegrityTimestamp + } +} + +// WithPaginationFilterType returns an option that can set PaginationFilterType on a SchemaInformation +func WithPaginationFilterType(paginationFilterType PaginationFilterType) SchemaInformationOption { + return func(s *SchemaInformation) { + s.PaginationFilterType = paginationFilterType + } +} + +// WithPlaceholderFormat returns an option that can set PlaceholderFormat on a SchemaInformation +func WithPlaceholderFormat(placeholderFormat squirrel.PlaceholderFormat) SchemaInformationOption { + return func(s *SchemaInformation) { + s.PlaceholderFormat = placeholderFormat + } +} + +// WithNowFunction returns an option that can set NowFunction on a SchemaInformation +func WithNowFunction(nowFunction string) SchemaInformationOption { + return func(s *SchemaInformation) { + s.NowFunction = nowFunction + } +} + +// WithColumnOptimization returns an option that can set ColumnOptimization on a SchemaInformation +func WithColumnOptimization(columnOptimization ColumnOptimizationOption) SchemaInformationOption { + return func(s *SchemaInformation) { + s.ColumnOptimization = columnOptimization + } +} + +// WithWithIntegrityColumns returns an option that can set WithIntegrityColumns on a SchemaInformation +func WithWithIntegrityColumns(withIntegrityColumns bool) SchemaInformationOption { + return func(s *SchemaInformation) { + s.WithIntegrityColumns = withIntegrityColumns + } +} diff --git a/internal/datastore/common/sql.go b/internal/datastore/common/sql.go index 83fc2cbd8e..a0a92f65fd 100644 --- a/internal/datastore/common/sql.go +++ b/internal/datastore/common/sql.go @@ -2,8 +2,10 @@ package common import ( "context" + "maps" "math" "strings" + "time" sq "github.com/Masterminds/squirrel" v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" @@ -51,10 +53,12 @@ var ( type PaginationFilterType uint8 const ( + PaginationFilterTypeUnknown PaginationFilterType = iota + // TupleComparison uses a comparison with a compound key, // e.g. (namespace, object_id, relation) > ('ns', '123', 'viewer') // which is not compatible with all datastores. - TupleComparison PaginationFilterType = iota + TupleComparison // ExpandedLogicComparison comparison uses a nested tree of ANDs and ORs to properly // filter out already received relationships. Useful for databases that do not support @@ -66,80 +70,15 @@ const ( type ColumnOptimizationOption int const ( + ColumnOptimizationOptionUnknown ColumnOptimizationOption = iota + // ColumnOptimizationOptionNone is the default option, which does not optimize the static columns. - ColumnOptimizationOptionNone ColumnOptimizationOption = iota + ColumnOptimizationOptionNone // ColumnOptimizationOptionStaticValue is an option that optimizes the column for a static value. ColumnOptimizationOptionStaticValues ) -// SchemaInformation holds the schema information from the SQL datastore implementation. -type SchemaInformation struct { - RelationshipTableName string - ColNamespace string - ColObjectID string - ColRelation string - ColUsersetNamespace string - ColUsersetObjectID string - ColUsersetRelation string - ColCaveatName string - ColCaveatContext string - ColExpiration string - - // PaginationFilterType is the type of pagination filter to use for this schema. - PaginationFilterType PaginationFilterType - - // PlaceholderFormat is the format of placeholders to use for this schema. - PlaceholderFormat sq.PlaceholderFormat - - // NowFunction is the function to use to get the current time in the datastore. - NowFunction string - - // ColumnOptimization is the optimization to use for columns in the schema, if any. - ColumnOptimization ColumnOptimizationOption - - // ExtaFields are additional fields that are not part of the core schema, but are - // requested by the caller for this query. - ExtraFields []string -} - -// NewSchemaInformation creates a new SchemaInformation object for a query. -func NewSchemaInformation( - relationshipTableName, - colNamespace, - colObjectID, - colRelation, - colUsersetNamespace, - colUsersetObjectID, - colUsersetRelation, - colCaveatName string, - colCaveatContext string, - colExpiration string, - paginationFilterType PaginationFilterType, - placeholderFormat sq.PlaceholderFormat, - nowFunction string, - columnOptionizationOption ColumnOptimizationOption, - extraFields ...string, -) SchemaInformation { - return SchemaInformation{ - relationshipTableName, - colNamespace, - colObjectID, - colRelation, - colUsersetNamespace, - colUsersetObjectID, - colUsersetRelation, - colCaveatName, - colCaveatContext, - colExpiration, - paginationFilterType, - placeholderFormat, - nowFunction, - columnOptionizationOption, - extraFields, - } -} - type ColumnTracker struct { SingleValue *string } @@ -160,6 +99,8 @@ type SchemaQueryFilterer struct { // relationships. This method will automatically filter the columns retrieved from the database, only // selecting the columns that are not already specified with a single static value in the query. func NewSchemaQueryFiltererForRelationshipsSelect(schema SchemaInformation, filterMaximumIDCount uint16, extraFields ...string) SchemaQueryFilterer { + schema.debugValidate() + if filterMaximumIDCount == 0 { filterMaximumIDCount = 100 log.Warn().Msg("SchemaQueryFilterer: filterMaximumIDCount not set, defaulting to 100") @@ -186,6 +127,8 @@ func NewSchemaQueryFiltererForRelationshipsSelect(schema SchemaInformation, filt // relationships, with a custom starting query. Unlike NewSchemaQueryFiltererForRelationshipsSelect, // this method will not auto-filter the columns retrieved from the database. func NewSchemaQueryFiltererWithStartingQuery(schema SchemaInformation, startingQuery sq.SelectBuilder, filterMaximumIDCount uint16) SchemaQueryFilterer { + schema.debugValidate() + if filterMaximumIDCount == 0 { filterMaximumIDCount = 100 log.Warn().Msg("SchemaQueryFilterer: filterMaximumIDCount not set, defaulting to 100") @@ -665,13 +608,16 @@ func (sqf SchemaQueryFilterer) limit(limit uint64) SchemaQueryFilterer { return sqf } -// QueryExecutor is a tuple query runner shared by SQL implementations of the datastore. -type QueryExecutor struct { +// QueryRelationshipsExecutor is a relationships query runner shared by SQL implementations of the datastore. +type QueryRelationshipsExecutor struct { Executor ExecuteReadRelsQueryFunc } +// ExecuteReadRelsQueryFunc is a function that can be used to execute a single rendered SQL query. +type ExecuteReadRelsQueryFunc func(ctx context.Context, builder RelationshipsQueryBuilder) (datastore.RelationshipIterator, error) + // ExecuteQuery executes the query. -func (tqs QueryExecutor) ExecuteQuery( +func (exc QueryRelationshipsExecutor) ExecuteQuery( ctx context.Context, query SchemaQueryFilterer, opts ...options.QueryOptionsOption, @@ -682,8 +628,10 @@ func (tqs QueryExecutor) ExecuteQuery( queryOpts := options.NewQueryOptionsWithOptions(opts...) + // Add sort order. query = query.TupleOrder(queryOpts.Sort) + // Add cursor. if queryOpts.After != nil { if queryOpts.Sort == options.Unsorted { return nil, datastore.ErrCursorsWithoutSorting @@ -692,6 +640,7 @@ func (tqs QueryExecutor) ExecuteQuery( query = query.After(queryOpts.After, queryOpts.Sort) } + // Add limit. var limit uint64 // NOTE: we use a uint here because it lines up with the // assignments in this function, but we set it to MaxInt64 @@ -706,70 +655,149 @@ func (tqs QueryExecutor) ExecuteQuery( query = query.limit(limit) } - toExecute := query - - // Set the column names to select. - columnNamesToSelect := make([]string, 0, 8+len(query.extraFields)) + // Add FROM clause. + from := query.schema.RelationshipTableName + if query.fromSuffix != "" { + from += " " + query.fromSuffix + } - columnNamesToSelect = checkColumn(columnNamesToSelect, query.schema.ColumnOptimization, query.filteringColumnTracker, query.schema.ColNamespace) - columnNamesToSelect = checkColumn(columnNamesToSelect, query.schema.ColumnOptimization, query.filteringColumnTracker, query.schema.ColObjectID) - columnNamesToSelect = checkColumn(columnNamesToSelect, query.schema.ColumnOptimization, query.filteringColumnTracker, query.schema.ColRelation) - columnNamesToSelect = checkColumn(columnNamesToSelect, query.schema.ColumnOptimization, query.filteringColumnTracker, query.schema.ColUsersetNamespace) - columnNamesToSelect = checkColumn(columnNamesToSelect, query.schema.ColumnOptimization, query.filteringColumnTracker, query.schema.ColUsersetObjectID) - columnNamesToSelect = checkColumn(columnNamesToSelect, query.schema.ColumnOptimization, query.filteringColumnTracker, query.schema.ColUsersetRelation) + query.queryBuilder = query.queryBuilder.From(from) - if !queryOpts.SkipCaveats || query.schema.ColumnOptimization == ColumnOptimizationOptionNone { - columnNamesToSelect = append(columnNamesToSelect, query.schema.ColCaveatName, query.schema.ColCaveatContext) + builder := RelationshipsQueryBuilder{ + Schema: query.schema, + SkipCaveats: queryOpts.SkipCaveats, + filteringValues: query.filteringColumnTracker, + baseQueryBuilder: query, } - columnNamesToSelect = append(columnNamesToSelect, query.schema.ColExpiration) + return exc.Executor(ctx, builder) +} + +// RelationshipsQueryBuilder is a builder for producing the SQL and arguments necessary for reading +// relationships. +type RelationshipsQueryBuilder struct { + Schema SchemaInformation + SkipCaveats bool - selectingNoColumns := false - columnNamesToSelect = append(columnNamesToSelect, query.schema.ExtraFields...) - if len(columnNamesToSelect) == 0 { - columnNamesToSelect = append(columnNamesToSelect, "1") - selectingNoColumns = true + filteringValues map[string]ColumnTracker + baseQueryBuilder SchemaQueryFilterer +} + +// SelectSQL returns the SQL and arguments necessary for reading relationships. +func (b RelationshipsQueryBuilder) SelectSQL() (string, []any, error) { + // Set the column names to select. + columnCount := 9 + if b.Schema.WithIntegrityColumns { + columnCount += 3 } + columnNamesToSelect := make([]string, 0, columnCount) - toExecute.queryBuilder = toExecute.queryBuilder.Columns(columnNamesToSelect...) + columnNamesToSelect = b.checkColumn(columnNamesToSelect, b.Schema.ColNamespace) + columnNamesToSelect = b.checkColumn(columnNamesToSelect, b.Schema.ColObjectID) + columnNamesToSelect = b.checkColumn(columnNamesToSelect, b.Schema.ColRelation) + columnNamesToSelect = b.checkColumn(columnNamesToSelect, b.Schema.ColUsersetNamespace) + columnNamesToSelect = b.checkColumn(columnNamesToSelect, b.Schema.ColUsersetObjectID) + columnNamesToSelect = b.checkColumn(columnNamesToSelect, b.Schema.ColUsersetRelation) - from := query.schema.RelationshipTableName - if query.fromSuffix != "" { - from += " " + query.fromSuffix + if !b.SkipCaveats || b.Schema.ColumnOptimization == ColumnOptimizationOptionNone { + columnNamesToSelect = append(columnNamesToSelect, b.Schema.ColCaveatName, b.Schema.ColCaveatContext) } - toExecute.queryBuilder = toExecute.queryBuilder.From(from) + columnNamesToSelect = append(columnNamesToSelect, b.Schema.ColExpiration) - sql, args, err := toExecute.queryBuilder.ToSql() - if err != nil { - return nil, err + if b.Schema.WithIntegrityColumns { + columnNamesToSelect = append(columnNamesToSelect, b.Schema.ColIntegrityKeyID, b.Schema.ColIntegrityHash, b.Schema.ColIntegrityTimestamp) + } + + if len(columnNamesToSelect) == 0 { + columnNamesToSelect = append(columnNamesToSelect, "1") } - return tqs.Executor(ctx, QueryInfo{query.schema, query.filteringColumnTracker, queryOpts.SkipCaveats, selectingNoColumns}, sql, args) + sqlBuilder := b.baseQueryBuilder.queryBuilder + sqlBuilder = sqlBuilder.Columns(columnNamesToSelect...) + + return sqlBuilder.ToSql() } -func checkColumn(columns []string, option ColumnOptimizationOption, tracker map[string]ColumnTracker, colName string) []string { - if option == ColumnOptimizationOptionNone { +// FilteringValuesForTesting returns the filtering values. For test use only. +func (b RelationshipsQueryBuilder) FilteringValuesForTesting() map[string]ColumnTracker { + return maps.Clone(b.filteringValues) +} + +func (b RelationshipsQueryBuilder) checkColumn(columns []string, colName string) []string { + if b.Schema.ColumnOptimization == ColumnOptimizationOptionNone { return append(columns, colName) } - if r, ok := tracker[colName]; !ok || r.SingleValue == nil { + if r, ok := b.filteringValues[colName]; !ok || r.SingleValue == nil { return append(columns, colName) } return columns } -// QueryInfo holds the schema information and filtering values for a query. -type QueryInfo struct { - Schema SchemaInformation - FilteringValues map[string]ColumnTracker - SkipCaveats bool - SelectingNoColumns bool +func (b RelationshipsQueryBuilder) staticValueOrAddColumnForSelect(colsToSelect []any, colName string, field *string) []any { + if b.Schema.ColumnOptimization == ColumnOptimizationOptionNone { + // If column optimization is disabled, always add the column to the list of columns to select. + colsToSelect = append(colsToSelect, field) + return colsToSelect + } + + // If the value is static, set the field to it and return. + if found, ok := b.filteringValues[colName]; ok && found.SingleValue != nil { + *field = *found.SingleValue + return colsToSelect + } + + // Otherwise, add the column to the list of columns to select, as the value is not static. + colsToSelect = append(colsToSelect, field) + return colsToSelect } -// ExecuteReadRelsQueryFunc is a function that can be used to execute a single rendered SQL query. -type ExecuteReadRelsQueryFunc func(ctx context.Context, queryInfo QueryInfo, sql string, args []any) (datastore.RelationshipIterator, error) +// ColumnsToSelect returns the columns to select for a given query. The columns provided are +// the references to the slots in which the values for each relationship will be placed. +func ColumnsToSelect[CN any, CC any, EC any]( + b RelationshipsQueryBuilder, + resourceObjectType *string, + resourceObjectID *string, + resourceRelation *string, + subjectObjectType *string, + subjectObjectID *string, + subjectRelation *string, + caveatName *CN, + caveatCtx *CC, + expiration EC, + + integrityKeyID *string, + integrityHash *[]byte, + timestamp *time.Time, +) ([]any, error) { + columnCount := 9 + if b.Schema.WithIntegrityColumns { + columnCount += 3 + } + colsToSelect := make([]any, 0, columnCount) + + colsToSelect = b.staticValueOrAddColumnForSelect(colsToSelect, b.Schema.ColNamespace, resourceObjectType) + colsToSelect = b.staticValueOrAddColumnForSelect(colsToSelect, b.Schema.ColObjectID, resourceObjectID) + colsToSelect = b.staticValueOrAddColumnForSelect(colsToSelect, b.Schema.ColRelation, resourceRelation) + colsToSelect = b.staticValueOrAddColumnForSelect(colsToSelect, b.Schema.ColUsersetNamespace, subjectObjectType) + colsToSelect = b.staticValueOrAddColumnForSelect(colsToSelect, b.Schema.ColUsersetObjectID, subjectObjectID) + colsToSelect = b.staticValueOrAddColumnForSelect(colsToSelect, b.Schema.ColUsersetRelation, subjectRelation) + + if !b.SkipCaveats || b.Schema.ColumnOptimization == ColumnOptimizationOptionNone { + colsToSelect = append(colsToSelect, caveatName, caveatCtx) + } -// TxCleanupFunc is a function that should be executed when the caller of -// TransactionFactory is done with the transaction. -type TxCleanupFunc func(context.Context) + colsToSelect = append(colsToSelect, expiration) + + if b.Schema.WithIntegrityColumns { + colsToSelect = append(colsToSelect, integrityKeyID, integrityHash, timestamp) + } + + if len(colsToSelect) == 0 { + var unused int + colsToSelect = append(colsToSelect, &unused) + } + + return colsToSelect, nil +} diff --git a/internal/datastore/common/sql_test.go b/internal/datastore/common/sql_test.go index 35cd4cc833..6b2e2c40a4 100644 --- a/internal/datastore/common/sql_test.go +++ b/internal/datastore/common/sql_test.go @@ -558,23 +558,23 @@ func TestSchemaQueryFilterer(t *testing.T) { for _, test := range tests { test := test t.Run(test.name, func(t *testing.T) { - schema := NewSchemaInformation( - "relationtuples", - "ns", - "object_id", - "relation", - "subject_ns", - "subject_object_id", - "subject_relation", - "caveat", - "caveat_context", - "expiration", - TupleComparison, - sq.Question, - "NOW", - ColumnOptimizationOptionStaticValues, + schema := NewSchemaInformationWithOptions( + WithRelationshipTableName("relationtuples"), + WithColNamespace("ns"), + WithColObjectID("object_id"), + WithColRelation("relation"), + WithColUsersetNamespace("subject_ns"), + WithColUsersetObjectID("subject_object_id"), + WithColUsersetRelation("subject_relation"), + WithColCaveatName("caveat"), + WithColCaveatContext("caveat_context"), + WithColExpiration("expiration"), + WithPlaceholderFormat(sq.Question), + WithPaginationFilterType(TupleComparison), + WithColumnOptimization(ColumnOptimizationOptionStaticValues), + WithNowFunction("NOW"), ) - filterer := NewSchemaQueryFiltererForRelationshipsSelect(schema, 100) + filterer := NewSchemaQueryFiltererForRelationshipsSelect(*schema, 100) ran := test.run(filterer) foundStaticColumns := []string{} @@ -598,13 +598,12 @@ func TestSchemaQueryFilterer(t *testing.T) { func TestExecuteQuery(t *testing.T) { tcs := []struct { - name string - run func(filterer SchemaQueryFilterer) SchemaQueryFilterer - options []options.QueryOptionsOption - expectedSQL string - expectedArgs []any - expectedSelectingNoColumns bool - expectedSkipCaveats bool + name string + run func(filterer SchemaQueryFilterer) SchemaQueryFilterer + options []options.QueryOptionsOption + expectedSQL string + expectedArgs []any + expectedSkipCaveats bool }{ { name: "filter by static resource type", @@ -695,10 +694,9 @@ func TestExecuteQuery(t *testing.T) { options: []options.QueryOptionsOption{ options.WithSkipCaveats(true), }, - expectedSkipCaveats: true, - expectedSelectingNoColumns: false, - expectedSQL: "SELECT expiration FROM relationtuples WHERE (expiration IS NULL OR expiration > NOW()) AND ns = ? AND object_id = ? AND relation = ? AND subject_ns = ? AND subject_object_id = ? AND subject_relation = ?", - expectedArgs: []any{"sometype", "someobj", "somerel", "subns", "subid", "subrel"}, + expectedSkipCaveats: true, + expectedSQL: "SELECT expiration FROM relationtuples WHERE (expiration IS NULL OR expiration > NOW()) AND ns = ? AND object_id = ? AND relation = ? AND subject_ns = ? AND subject_object_id = ? AND subject_relation = ?", + expectedArgs: []any{"sometype", "someobj", "somerel", "subns", "subid", "subrel"}, }, { name: "filter by static everything (except one field) without caveats", @@ -820,33 +818,35 @@ func TestExecuteQuery(t *testing.T) { for _, tc := range tcs { t.Run(tc.name, func(t *testing.T) { - schema := NewSchemaInformation( - "relationtuples", - "ns", - "object_id", - "relation", - "subject_ns", - "subject_object_id", - "subject_relation", - "caveat", - "caveat_context", - "expiration", - TupleComparison, - sq.Question, - "NOW", - ColumnOptimizationOptionStaticValues, + schema := NewSchemaInformationWithOptions( + WithRelationshipTableName("relationtuples"), + WithColNamespace("ns"), + WithColObjectID("object_id"), + WithColRelation("relation"), + WithColUsersetNamespace("subject_ns"), + WithColUsersetObjectID("subject_object_id"), + WithColUsersetRelation("subject_relation"), + WithColCaveatName("caveat"), + WithColCaveatContext("caveat_context"), + WithColExpiration("expiration"), + WithPlaceholderFormat(sq.Question), + WithPaginationFilterType(TupleComparison), + WithColumnOptimization(ColumnOptimizationOptionStaticValues), + WithNowFunction("NOW"), ) - filterer := NewSchemaQueryFiltererForRelationshipsSelect(schema, 100) + filterer := NewSchemaQueryFiltererForRelationshipsSelect(*schema, 100) ran := tc.run(filterer) var wasRun bool - fake := QueryExecutor{ - Executor: func(ctx context.Context, queryInfo QueryInfo, sql string, args []any) (datastore.RelationshipIterator, error) { + fake := QueryRelationshipsExecutor{ + Executor: func(ctx context.Context, builder RelationshipsQueryBuilder) (datastore.RelationshipIterator, error) { + sql, args, err := builder.SelectSQL() + require.NoError(t, err) + wasRun = true require.Equal(t, tc.expectedSQL, sql) require.Equal(t, tc.expectedArgs, args) - require.Equal(t, tc.expectedSelectingNoColumns, queryInfo.SelectingNoColumns) - require.Equal(t, tc.expectedSkipCaveats, queryInfo.SkipCaveats) + require.Equal(t, tc.expectedSkipCaveats, builder.SkipCaveats) return nil, nil }, } diff --git a/internal/datastore/crdb/crdb.go b/internal/datastore/crdb/crdb.go index 5ee9109bb3..affe7264f6 100644 --- a/internal/datastore/crdb/crdb.go +++ b/internal/datastore/crdb/crdb.go @@ -200,33 +200,30 @@ func newCRDBDatastore(ctx context.Context, url string, options ...Option) (datas return nil, fmt.Errorf("invalid head migration found for cockroach: %w", err) } - var extraFields []string relTableName := tableTuple if config.withIntegrity { relTableName = tableTupleWithIntegrity - extraFields = []string{ - colIntegrityKeyID, - colIntegrityHash, - colTimestamp, - } } - schema := common.NewSchemaInformation( - relTableName, - colNamespace, - colObjectID, - colRelation, - colUsersetNamespace, - colUsersetObjectID, - colUsersetRelation, - colCaveatContextName, - colCaveatContext, - colExpiration, - common.ExpandedLogicComparison, - sq.Dollar, - "NOW", - config.columnOptimizationOption, - extraFields..., + schema := common.NewSchemaInformationWithOptions( + common.WithRelationshipTableName(relTableName), + common.WithColNamespace(colNamespace), + common.WithColObjectID(colObjectID), + common.WithColRelation(colRelation), + common.WithColUsersetNamespace(colUsersetNamespace), + common.WithColUsersetObjectID(colUsersetObjectID), + common.WithColUsersetRelation(colUsersetRelation), + common.WithColCaveatName(colCaveatContextName), + common.WithColCaveatContext(colCaveatContext), + common.WithColExpiration(colExpiration), + common.WithColIntegrityKeyID(colIntegrityKeyID), + common.WithColIntegrityHash(colIntegrityHash), + common.WithColIntegrityTimestamp(colTimestamp), + common.WithPaginationFilterType(common.ExpandedLogicComparison), + common.WithPlaceholderFormat(sq.Dollar), + common.WithNowFunction("NOW"), + common.WithColumnOptimization(config.columnOptimizationOption), + common.WithWithIntegrityColumns(config.withIntegrity), ) ds := &crdbDatastore{ @@ -250,7 +247,7 @@ func newCRDBDatastore(ctx context.Context, url string, options ...Option) (datas filterMaximumIDCount: config.filterMaximumIDCount, supportsIntegrity: config.withIntegrity, gcWindow: config.gcWindow, - schema: schema, + schema: *schema, } ds.RemoteClockRevisions.SetNowFunc(ds.headRevisionInternal) @@ -350,8 +347,8 @@ type crdbDatastore struct { } func (cds *crdbDatastore) SnapshotReader(rev datastore.Revision) datastore.Reader { - executor := common.QueryExecutor{ - Executor: pgxcommon.NewPGXExecutorWithIntegrityOption(cds.readPool, cds.supportsIntegrity), + executor := common.QueryRelationshipsExecutor{ + Executor: pgxcommon.NewPGXQueryRelationshipsExecutor(cds.readPool), } withAsOfSystemTime := func(query sq.SelectBuilder, tableName string) sq.SelectBuilder { @@ -376,8 +373,8 @@ func (cds *crdbDatastore) ReadWriteTx( err := cds.writePool.BeginFunc(ctx, func(tx pgx.Tx) error { querier := pgxcommon.QuerierFuncsFor(tx) - executor := common.QueryExecutor{ - Executor: pgxcommon.NewPGXExecutorWithIntegrityOption(querier, cds.supportsIntegrity), + executor := common.QueryRelationshipsExecutor{ + Executor: pgxcommon.NewPGXQueryRelationshipsExecutor(querier), } // Write metadata onto the transaction. diff --git a/internal/datastore/crdb/reader.go b/internal/datastore/crdb/reader.go index ce9a950b05..e252a16c92 100644 --- a/internal/datastore/crdb/reader.go +++ b/internal/datastore/crdb/reader.go @@ -39,7 +39,7 @@ var ( type crdbReader struct { query pgxcommon.DBFuncQuerier - executor common.QueryExecutor + executor common.QueryRelationshipsExecutor keyer overlapKeyer overlapKeySet keySet fromWithAsOfSystemTime func(query sq.SelectBuilder, tableName string) sq.SelectBuilder diff --git a/internal/datastore/dsfortesting/dsfortesting.go b/internal/datastore/dsfortesting/dsfortesting.go index 5e83d9ce44..b6fc7f5d9c 100644 --- a/internal/datastore/dsfortesting/dsfortesting.go +++ b/internal/datastore/dsfortesting/dsfortesting.go @@ -49,24 +49,24 @@ func (vr validatingReader) QueryRelationships( filter datastore.RelationshipsFilter, options ...options.QueryOptionsOption, ) (datastore.RelationshipIterator, error) { - schema := common.NewSchemaInformation( - "relationtuples", - "ns", - "object_id", - "relation", - "subject_ns", - "subject_object_id", - "subject_relation", - "caveat", - "caveat_context", - "expiration", - common.TupleComparison, - sq.Question, - "NOW", - common.ColumnOptimizationOptionStaticValues, + schema := common.NewSchemaInformationWithOptions( + common.WithRelationshipTableName("relationtuples"), + common.WithColNamespace("ns"), + common.WithColObjectID("object_id"), + common.WithColRelation("relation"), + common.WithColUsersetNamespace("subject_ns"), + common.WithColUsersetObjectID("subject_object_id"), + common.WithColUsersetRelation("subject_relation"), + common.WithColCaveatName("caveat"), + common.WithColCaveatContext("caveat_context"), + common.WithColExpiration("expiration"), + common.WithPlaceholderFormat(sq.Question), + common.WithPaginationFilterType(common.TupleComparison), + common.WithColumnOptimization(common.ColumnOptimizationOptionStaticValues), + common.WithNowFunction("NOW"), ) - qBuilder, err := common.NewSchemaQueryFiltererForRelationshipsSelect(schema, 100). + qBuilder, err := common.NewSchemaQueryFiltererForRelationshipsSelect(*schema, 100). FilterWithRelationshipsFilter(filter) if err != nil { return nil, err @@ -74,21 +74,21 @@ func (vr validatingReader) QueryRelationships( // Run the filter through the common SQL ellison system and ensure that any // relationships return have values matching the static fields, if applicable. - var queryInfo *common.QueryInfo - executor := common.QueryExecutor{ - Executor: func(ctx context.Context, qi common.QueryInfo, sql string, args []any) (datastore.RelationshipIterator, error) { - queryInfo = &qi + var builder *common.RelationshipsQueryBuilder + executor := common.QueryRelationshipsExecutor{ + Executor: func(ctx context.Context, b common.RelationshipsQueryBuilder) (datastore.RelationshipIterator, error) { + builder = &b return nil, nil }, } _, _ = executor.ExecuteQuery(ctx, qBuilder, options...) - if queryInfo == nil { - return nil, fmt.Errorf("no query info returned") + if builder == nil { + return nil, fmt.Errorf("no builder returned") } checkStaticField := func(returnedValue string, fieldName string) error { - if found, ok := queryInfo.FilteringValues[fieldName]; ok && found.SingleValue != nil { + if found, ok := builder.FilteringValuesForTesting()[fieldName]; ok && found.SingleValue != nil { if returnedValue != *found.SingleValue { return fmt.Errorf("static field `%s` does not match expected value `%s`: `%s", fieldName, returnedValue, *found.SingleValue) } diff --git a/internal/datastore/mysql/datastore.go b/internal/datastore/mysql/datastore.go index e8c5b6c903..ea39afd629 100644 --- a/internal/datastore/mysql/datastore.go +++ b/internal/datastore/mysql/datastore.go @@ -244,21 +244,21 @@ func newMySQLDatastore(ctx context.Context, uri string, replicaIndex int, option -1*config.gcWindow.Seconds(), ) - schema := common.NewSchemaInformation( - driver.RelationTuple(), - colNamespace, - colObjectID, - colRelation, - colUsersetNamespace, - colUsersetObjectID, - colUsersetRelation, - colCaveatName, - colCaveatContext, - colExpiration, - common.ExpandedLogicComparison, - sq.Question, - "NOW", - config.columnOptimizationOption, + schema := common.NewSchemaInformationWithOptions( + common.WithRelationshipTableName(driver.RelationTuple()), + common.WithColNamespace(colNamespace), + common.WithColObjectID(colObjectID), + common.WithColRelation(colRelation), + common.WithColUsersetNamespace(colUsersetNamespace), + common.WithColUsersetObjectID(colUsersetObjectID), + common.WithColUsersetRelation(colUsersetRelation), + common.WithColCaveatName(colCaveatName), + common.WithColCaveatContext(colCaveatContext), + common.WithColExpiration(colExpiration), + common.WithPaginationFilterType(common.ExpandedLogicComparison), + common.WithPlaceholderFormat(sq.Question), + common.WithNowFunction("NOW"), + common.WithColumnOptimization(config.columnOptimizationOption), ) store := &Datastore{ @@ -282,7 +282,7 @@ func newMySQLDatastore(ctx context.Context, uri string, replicaIndex int, option readTxOptions: &sql.TxOptions{Isolation: sql.LevelSerializable, ReadOnly: true}, maxRetries: config.maxRetries, analyzeBeforeStats: config.analyzeBeforeStats, - schema: schema, + schema: *schema, CachedOptimizedRevisions: revisions.NewCachedOptimizedRevisions( maxRevisionStaleness, ), @@ -332,7 +332,7 @@ func (mds *Datastore) SnapshotReader(rev datastore.Revision) datastore.Reader { return tx, tx.Rollback, nil } - executor := common.QueryExecutor{ + executor := common.QueryRelationshipsExecutor{ Executor: newMySQLExecutor(mds.db), } @@ -375,7 +375,7 @@ func (mds *Datastore) ReadWriteTx( return tx, noCleanup, nil } - executor := common.QueryExecutor{ + executor := common.QueryRelationshipsExecutor{ Executor: newMySQLExecutor(tx), } @@ -468,9 +468,9 @@ func newMySQLExecutor(tx querier) common.ExecuteReadRelsQueryFunc { // // Prepared statements are also not used given they perform poorly on environments where connections have // short lifetime (e.g. to gracefully handle load-balancer connection drain) - return func(ctx context.Context, queryInfo common.QueryInfo, sqlQuery string, args []interface{}) (datastore.RelationshipIterator, error) { + return func(ctx context.Context, builder common.RelationshipsQueryBuilder) (datastore.RelationshipIterator, error) { span := trace.SpanFromContext(ctx) - return common.QueryRelationships[common.Rows, structpbWrapper](ctx, queryInfo, sqlQuery, args, span, asQueryableTx{tx}, false) + return common.QueryRelationships[common.Rows, structpbWrapper](ctx, builder, span, asQueryableTx{tx}) } } diff --git a/internal/datastore/mysql/reader.go b/internal/datastore/mysql/reader.go index 592b844575..c963fb808f 100644 --- a/internal/datastore/mysql/reader.go +++ b/internal/datastore/mysql/reader.go @@ -23,7 +23,7 @@ type mysqlReader struct { *QueryBuilder txSource txFactory - executor common.QueryExecutor + executor common.QueryRelationshipsExecutor aliveFilter queryFilterer filterMaximumIDCount uint16 schema common.SchemaInformation diff --git a/internal/datastore/postgres/common/pgx.go b/internal/datastore/postgres/common/pgx.go index 3a0de6803a..012e908bcc 100644 --- a/internal/datastore/postgres/common/pgx.go +++ b/internal/datastore/postgres/common/pgx.go @@ -21,18 +21,11 @@ import ( "github.com/authzed/spicedb/pkg/datastore" ) -// NewPGXExecutor creates an executor that uses the pgx library to make the specified queries. -func NewPGXExecutor(querier DBFuncQuerier) common.ExecuteReadRelsQueryFunc { - return func(ctx context.Context, queryInfo common.QueryInfo, sql string, args []any) (datastore.RelationshipIterator, error) { +// NewPGXQueryRelationshipsExecutor creates an executor that uses the pgx library to make the specified queries. +func NewPGXQueryRelationshipsExecutor(querier DBFuncQuerier) common.ExecuteReadRelsQueryFunc { + return func(ctx context.Context, builder common.RelationshipsQueryBuilder) (datastore.RelationshipIterator, error) { span := trace.SpanFromContext(ctx) - return common.QueryRelationships[pgx.Rows, map[string]any](ctx, queryInfo, sql, args, span, querier, false) - } -} - -func NewPGXExecutorWithIntegrityOption(querier DBFuncQuerier, withIntegrity bool) common.ExecuteReadRelsQueryFunc { - return func(ctx context.Context, queryInfo common.QueryInfo, sql string, args []any) (datastore.RelationshipIterator, error) { - span := trace.SpanFromContext(ctx) - return common.QueryRelationships[pgx.Rows, map[string]any](ctx, queryInfo, sql, args, span, querier, withIntegrity) + return common.QueryRelationships[pgx.Rows, map[string]any](ctx, builder, span, querier) } } diff --git a/internal/datastore/postgres/postgres.go b/internal/datastore/postgres/postgres.go index 0cef2e43ac..e62c9f9dce 100644 --- a/internal/datastore/postgres/postgres.go +++ b/internal/datastore/postgres/postgres.go @@ -316,21 +316,21 @@ func newPostgresDatastore( maxRevisionStaleness := time.Duration(float64(config.revisionQuantization.Nanoseconds())* config.maxRevisionStalenessPercent) * time.Nanosecond - schema := common.NewSchemaInformation( - tableTuple, - colNamespace, - colObjectID, - colRelation, - colUsersetNamespace, - colUsersetObjectID, - colUsersetRelation, - colCaveatContextName, - colCaveatContext, - colExpiration, - common.TupleComparison, - sq.Dollar, - "NOW", - config.columnOptimizationOption, + schema := common.NewSchemaInformationWithOptions( + common.WithRelationshipTableName(tableTuple), + common.WithColNamespace(colNamespace), + common.WithColObjectID(colObjectID), + common.WithColRelation(colRelation), + common.WithColUsersetNamespace(colUsersetNamespace), + common.WithColUsersetObjectID(colUsersetObjectID), + common.WithColUsersetRelation(colUsersetRelation), + common.WithColCaveatName(colCaveatContextName), + common.WithColCaveatContext(colCaveatContext), + common.WithColExpiration(colExpiration), + common.WithPaginationFilterType(common.TupleComparison), + common.WithPlaceholderFormat(sq.Dollar), + common.WithNowFunction("NOW"), + common.WithColumnOptimization(config.columnOptimizationOption), ) datastore := &pgDatastore{ @@ -358,7 +358,7 @@ func newPostgresDatastore( isPrimary: isPrimary, inStrictReadMode: config.readStrictMode, filterMaximumIDCount: config.filterMaximumIDCount, - schema: schema, + schema: *schema, } if isPrimary && config.readStrictMode { @@ -435,8 +435,8 @@ func (pgd *pgDatastore) SnapshotReader(revRaw datastore.Revision) datastore.Read queryFuncs = strictReaderQueryFuncs{wrapped: queryFuncs, revision: rev} } - executor := common.QueryExecutor{ - Executor: pgxcommon.NewPGXExecutor(queryFuncs), + executor := common.QueryRelationshipsExecutor{ + Executor: pgxcommon.NewPGXQueryRelationshipsExecutor(queryFuncs), } return &pgReader{ @@ -478,8 +478,8 @@ func (pgd *pgDatastore) ReadWriteTx( } queryFuncs := pgxcommon.QuerierFuncsFor(pgd.readPool) - executor := common.QueryExecutor{ - Executor: pgxcommon.NewPGXExecutor(queryFuncs), + executor := common.QueryRelationshipsExecutor{ + Executor: pgxcommon.NewPGXQueryRelationshipsExecutor(queryFuncs), } rwt := &pgReadWriteTXN{ diff --git a/internal/datastore/postgres/reader.go b/internal/datastore/postgres/reader.go index 8ab3097e60..ed0ad792d1 100644 --- a/internal/datastore/postgres/reader.go +++ b/internal/datastore/postgres/reader.go @@ -17,7 +17,7 @@ import ( type pgReader struct { query pgxcommon.DBFuncQuerier - executor common.QueryExecutor + executor common.QueryRelationshipsExecutor aliveFilter queryFilterer filterMaximumIDCount uint16 schema common.SchemaInformation diff --git a/internal/datastore/spanner/reader.go b/internal/datastore/spanner/reader.go index fb23c0e06e..2dc8ee5022 100644 --- a/internal/datastore/spanner/reader.go +++ b/internal/datastore/spanner/reader.go @@ -31,7 +31,7 @@ type readTX interface { type txFactory func() readTX type spannerReader struct { - executor common.QueryExecutor + executor common.QueryRelationshipsExecutor txSource txFactory filterMaximumIDCount uint16 schema common.SchemaInformation @@ -173,10 +173,17 @@ func (sr spannerReader) ReverseQueryRelationships( var errStopIterator = fmt.Errorf("stop iteration") func queryExecutor(txSource txFactory) common.ExecuteReadRelsQueryFunc { - return func(ctx context.Context, queryInfo common.QueryInfo, sql string, args []any) (datastore.RelationshipIterator, error) { + return func(ctx context.Context, builder common.RelationshipsQueryBuilder) (datastore.RelationshipIterator, error) { return func(yield func(tuple.Relationship, error) bool) { span := trace.SpanFromContext(ctx) span.AddEvent("Query issued to database") + + sql, args, err := builder.SelectSQL() + if err != nil { + yield(tuple.Relationship{}, err) + return + } + iter := txSource().Query(ctx, statementFromSQL(sql, args)) defer iter.Stop() @@ -196,24 +203,28 @@ func queryExecutor(txSource txFactory) common.ExecuteReadRelsQueryFunc { var caveatCtx spanner.NullJSON var expirationOrNull spanner.NullTime - colsToSelect := make([]any, 0, 8) - - colsToSelect = common.StaticValueOrAddColumnForSelect(colsToSelect, queryInfo, queryInfo.Schema.ColNamespace, &resourceObjectType) - colsToSelect = common.StaticValueOrAddColumnForSelect(colsToSelect, queryInfo, queryInfo.Schema.ColObjectID, &resourceObjectID) - colsToSelect = common.StaticValueOrAddColumnForSelect(colsToSelect, queryInfo, queryInfo.Schema.ColRelation, &relation) - colsToSelect = common.StaticValueOrAddColumnForSelect(colsToSelect, queryInfo, queryInfo.Schema.ColUsersetNamespace, &subjectObjectType) - colsToSelect = common.StaticValueOrAddColumnForSelect(colsToSelect, queryInfo, queryInfo.Schema.ColUsersetObjectID, &subjectObjectID) - colsToSelect = common.StaticValueOrAddColumnForSelect(colsToSelect, queryInfo, queryInfo.Schema.ColUsersetRelation, &subjectRelation) - - if !queryInfo.SkipCaveats || queryInfo.Schema.ColumnOptimization == common.ColumnOptimizationOptionNone { - colsToSelect = append(colsToSelect, &caveatName, &caveatCtx) - } - - colsToSelect = append(colsToSelect, &expirationOrNull) - - if len(colsToSelect) == 0 { - var unused int64 - colsToSelect = append(colsToSelect, &unused) + // NOTE: these are unused in Spanner, but necessary for the ColumnsToSelect call. + var integrityKeyID string + var integrityHash []byte + var timestamp time.Time + + colsToSelect, err := common.ColumnsToSelect(builder, + &resourceObjectType, + &resourceObjectID, + &relation, + &subjectObjectType, + &subjectObjectID, + &subjectRelation, + &caveatName, + &caveatCtx, + &expirationOrNull, + &integrityKeyID, + &integrityHash, + ×tamp, + ) + if err != nil { + yield(tuple.Relationship{}, err) + return } if err := iter.Do(func(row *spanner.Row) error { diff --git a/internal/datastore/spanner/spanner.go b/internal/datastore/spanner/spanner.go index 38d4190576..e894de1cee 100644 --- a/internal/datastore/spanner/spanner.go +++ b/internal/datastore/spanner/spanner.go @@ -181,6 +181,23 @@ func NewSpannerDatastore(ctx context.Context, database string, opts ...Option) ( return nil, fmt.Errorf("invalid head migration found for spanner: %w", err) } + schema := common.NewSchemaInformationWithOptions( + common.WithRelationshipTableName(tableRelationship), + common.WithColNamespace(colNamespace), + common.WithColObjectID(colObjectID), + common.WithColRelation(colRelation), + common.WithColUsersetNamespace(colUsersetNamespace), + common.WithColUsersetObjectID(colUsersetObjectID), + common.WithColUsersetRelation(colUsersetRelation), + common.WithColCaveatName(colCaveatName), + common.WithColCaveatContext(colCaveatContext), + common.WithColExpiration(colExpiration), + common.WithPaginationFilterType(common.ExpandedLogicComparison), + common.WithPlaceholderFormat(sq.AtP), + common.WithNowFunction("CURRENT_TIMESTAMP"), + common.WithColumnOptimization(config.columnOptimizationOption), + ) + ds := &spannerDatastore{ RemoteClockRevisions: revisions.NewRemoteClockRevisions( defaultChangeStreamRetention, @@ -201,22 +218,7 @@ func NewSpannerDatastore(ctx context.Context, database string, opts ...Option) ( cachedEstimatedBytesPerRelationshipLock: sync.RWMutex{}, tableSizesStatsTable: tableSizesStatsTable, filterMaximumIDCount: config.filterMaximumIDCount, - schema: common.NewSchemaInformation( - tableRelationship, - colNamespace, - colObjectID, - colRelation, - colUsersetNamespace, - colUsersetObjectID, - colUsersetRelation, - colCaveatName, - colCaveatContext, - colExpiration, - common.ExpandedLogicComparison, - sq.AtP, - "CURRENT_TIMESTAMP", - config.columnOptimizationOption, - ), + schema: *schema, } // Optimized revision and revision checking use a stale read for the // current timestamp. @@ -264,7 +266,7 @@ func (sd *spannerDatastore) SnapshotReader(revisionRaw datastore.Revision) datas txSource := func() readTX { return &traceableRTX{delegate: sd.client.Single().WithTimestampBound(spanner.ReadTimestamp(r.Time()))} } - executor := common.QueryExecutor{Executor: queryExecutor(txSource)} + executor := common.QueryRelationshipsExecutor{Executor: queryExecutor(txSource)} return spannerReader{executor, txSource, sd.filterMaximumIDCount, sd.schema} } @@ -312,7 +314,7 @@ func (sd *spannerDatastore) ReadWriteTx(ctx context.Context, fn datastore.TxUser } } - executor := common.QueryExecutor{Executor: queryExecutor(txSource)} + executor := common.QueryRelationshipsExecutor{Executor: queryExecutor(txSource)} rwt := spannerReadWriteTXN{ spannerReader{executor, txSource, sd.filterMaximumIDCount, sd.schema}, spannerRWT, From 70ce48217517369fa904b87748bd935c1577e8bc Mon Sep 17 00:00:00 2001 From: Joseph Schorr Date: Fri, 13 Dec 2024 17:55:48 -0500 Subject: [PATCH 09/15] Elide the expiration column and checking when expiration is either disabled or the relationship cannot be marked as expiring --- internal/datastore/common/schema.go | 3 + internal/datastore/common/schema_options.go | 9 + internal/datastore/common/sql.go | 48 +- internal/datastore/common/sql_test.go | 563 ++++++++++-------- internal/datastore/crdb/crdb.go | 1 + internal/datastore/crdb/options.go | 12 +- internal/datastore/memdb/readonly.go | 29 +- internal/datastore/mysql/datastore.go | 1 + internal/datastore/mysql/gc.go | 4 + internal/datastore/mysql/options.go | 10 + internal/datastore/postgres/gc.go | 4 + internal/datastore/postgres/options.go | 10 +- internal/datastore/postgres/postgres.go | 1 + internal/datastore/spanner/options.go | 10 + internal/graph/check.go | 86 ++- .../steelthreadtesting/steelthread_test.go | 6 - internal/services/v1/experimental_test.go | 48 +- internal/services/v1/permissions_test.go | 38 +- internal/testfixtures/datastore.go | 26 +- .../testserver/datastore/config/config.go | 1 + pkg/cmd/datastore/datastore.go | 94 +-- pkg/cmd/datastore/zz_generated.options.go | 9 + pkg/cmd/server/server.go | 4 +- pkg/datastore/options/options.go | 9 +- .../options/zz_generated.query_options.go | 9 + pkg/datastore/test/relationships.go | 69 ++- 26 files changed, 651 insertions(+), 453 deletions(-) diff --git a/internal/datastore/common/schema.go b/internal/datastore/common/schema.go index 8f31f929a6..7cff99d578 100644 --- a/internal/datastore/common/schema.go +++ b/internal/datastore/common/schema.go @@ -40,6 +40,9 @@ type SchemaInformation struct { // WithIntegrityColumns is a flag to indicate if the schema has integrity columns. WithIntegrityColumns bool `debugmap:"visible"` + + // ExpirationDisabled is a flag to indicate whether expiration support is disabled. + ExpirationDisabled bool `debugmap:"visible"` } func (si SchemaInformation) debugValidate() { diff --git a/internal/datastore/common/schema_options.go b/internal/datastore/common/schema_options.go index 3aed7f64e8..fa7639776e 100644 --- a/internal/datastore/common/schema_options.go +++ b/internal/datastore/common/schema_options.go @@ -49,6 +49,7 @@ func (s *SchemaInformation) ToOption() SchemaInformationOption { to.NowFunction = s.NowFunction to.ColumnOptimization = s.ColumnOptimization to.WithIntegrityColumns = s.WithIntegrityColumns + to.ExpirationDisabled = s.ExpirationDisabled } } @@ -73,6 +74,7 @@ func (s SchemaInformation) DebugMap() map[string]any { debugMap["NowFunction"] = helpers.DebugValue(s.NowFunction, false) debugMap["ColumnOptimization"] = helpers.DebugValue(s.ColumnOptimization, false) debugMap["WithIntegrityColumns"] = helpers.DebugValue(s.WithIntegrityColumns, false) + debugMap["ExpirationDisabled"] = helpers.DebugValue(s.ExpirationDisabled, false) return debugMap } @@ -217,3 +219,10 @@ func WithWithIntegrityColumns(withIntegrityColumns bool) SchemaInformationOption s.WithIntegrityColumns = withIntegrityColumns } } + +// WithExpirationDisabled returns an option that can set ExpirationDisabled on a SchemaInformation +func WithExpirationDisabled(expirationDisabled bool) SchemaInformationOption { + return func(s *SchemaInformation) { + s.ExpirationDisabled = expirationDisabled + } +} diff --git a/internal/datastore/common/sql.go b/internal/datastore/common/sql.go index a0a92f65fd..f86131bcaa 100644 --- a/internal/datastore/common/sql.go +++ b/internal/datastore/common/sql.go @@ -106,13 +106,7 @@ func NewSchemaQueryFiltererForRelationshipsSelect(schema SchemaInformation, filt log.Warn().Msg("SchemaQueryFilterer: filterMaximumIDCount not set, defaulting to 100") } - // Filter out any expired relationships. - // TODO(jschorr): Make this depend on whether expiration is necessary. - queryBuilder := sq.StatementBuilder.PlaceholderFormat(schema.PlaceholderFormat).Select().Where(sq.Or{ - sq.Eq{schema.ColExpiration: nil}, - sq.Expr(schema.ColExpiration + " > " + schema.NowFunction + "()"), - }) - + queryBuilder := sq.StatementBuilder.PlaceholderFormat(schema.PlaceholderFormat).Select() return SchemaQueryFilterer{ schema: schema, queryBuilder: queryBuilder, @@ -134,13 +128,6 @@ func NewSchemaQueryFiltererWithStartingQuery(schema SchemaInformation, startingQ log.Warn().Msg("SchemaQueryFilterer: filterMaximumIDCount not set, defaulting to 100") } - // Filter out any expired relationships. - // TODO(jschorr): Make this depend on whether expiration is necessary. - startingQuery = startingQuery.Where(sq.Or{ - sq.Eq{schema.ColExpiration: nil}, - sq.Expr(schema.ColExpiration + " > " + schema.NowFunction + "()"), - }) - return SchemaQueryFilterer{ schema: schema, queryBuilder: startingQuery, @@ -179,7 +166,21 @@ func (sqf SchemaQueryFilterer) UnderlyingQueryBuilder() sq.SelectBuilder { spiceerrors.DebugAssert(func() bool { return sqf.isCustomQuery }, "UnderlyingQueryBuilder should only be called on custom queries") - return sqf.queryBuilder + return sqf.queryBuilderWithExpirationFilter(false) +} + +// queryBuilderWithExpirationFilter returns the query builder with the expiration filter applied, when necessary. +// Note that this adds the clause to the existing builder. +func (sqf SchemaQueryFilterer) queryBuilderWithExpirationFilter(skipExpiration bool) sq.SelectBuilder { + if sqf.schema.ExpirationDisabled || skipExpiration { + return sqf.queryBuilder + } + + // Filter out any expired relationships. + return sqf.queryBuilder.Where(sq.Or{ + sq.Eq{sqf.schema.ColExpiration: nil}, + sq.Expr(sqf.schema.ColExpiration + " > " + sqf.schema.NowFunction + "()"), + }) } func (sqf SchemaQueryFilterer) TupleOrder(order options.SortOrder) SchemaQueryFilterer { @@ -470,6 +471,7 @@ func (sqf SchemaQueryFilterer) FilterWithRelationshipsFilter(filter datastore.Re if filter.OptionalExpirationOption == datastore.ExpirationFilterOptionHasExpiration { csqf.queryBuilder = csqf.queryBuilder.Where(sq.NotEq{csqf.schema.ColExpiration: nil}) + spiceerrors.DebugAssert(func() bool { return !sqf.schema.ExpirationDisabled }, "expiration filter requested but schema does not support expiration") } else if filter.OptionalExpirationOption == datastore.ExpirationFilterOptionNoExpiration { csqf.queryBuilder = csqf.queryBuilder.Where(sq.Eq{csqf.schema.ColExpiration: nil}) } @@ -666,6 +668,7 @@ func (exc QueryRelationshipsExecutor) ExecuteQuery( builder := RelationshipsQueryBuilder{ Schema: query.schema, SkipCaveats: queryOpts.SkipCaveats, + SkipExpiration: queryOpts.SkipExpiration, filteringValues: query.filteringColumnTracker, baseQueryBuilder: query, } @@ -676,8 +679,9 @@ func (exc QueryRelationshipsExecutor) ExecuteQuery( // RelationshipsQueryBuilder is a builder for producing the SQL and arguments necessary for reading // relationships. type RelationshipsQueryBuilder struct { - Schema SchemaInformation - SkipCaveats bool + Schema SchemaInformation + SkipCaveats bool + SkipExpiration bool filteringValues map[string]ColumnTracker baseQueryBuilder SchemaQueryFilterer @@ -703,7 +707,9 @@ func (b RelationshipsQueryBuilder) SelectSQL() (string, []any, error) { columnNamesToSelect = append(columnNamesToSelect, b.Schema.ColCaveatName, b.Schema.ColCaveatContext) } - columnNamesToSelect = append(columnNamesToSelect, b.Schema.ColExpiration) + if !b.SkipExpiration && !b.Schema.ExpirationDisabled { + columnNamesToSelect = append(columnNamesToSelect, b.Schema.ColExpiration) + } if b.Schema.WithIntegrityColumns { columnNamesToSelect = append(columnNamesToSelect, b.Schema.ColIntegrityKeyID, b.Schema.ColIntegrityHash, b.Schema.ColIntegrityTimestamp) @@ -713,7 +719,7 @@ func (b RelationshipsQueryBuilder) SelectSQL() (string, []any, error) { columnNamesToSelect = append(columnNamesToSelect, "1") } - sqlBuilder := b.baseQueryBuilder.queryBuilder + sqlBuilder := b.baseQueryBuilder.queryBuilderWithExpirationFilter(b.SkipExpiration) sqlBuilder = sqlBuilder.Columns(columnNamesToSelect...) return sqlBuilder.ToSql() @@ -788,7 +794,9 @@ func ColumnsToSelect[CN any, CC any, EC any]( colsToSelect = append(colsToSelect, caveatName, caveatCtx) } - colsToSelect = append(colsToSelect, expiration) + if !b.SkipExpiration && !b.Schema.ExpirationDisabled { + colsToSelect = append(colsToSelect, expiration) + } if b.Schema.WithIntegrityColumns { colsToSelect = append(colsToSelect, integrityKeyID, integrityHash, timestamp) diff --git a/internal/datastore/common/sql_test.go b/internal/datastore/common/sql_test.go index 6b2e2c40a4..cee75c6ce4 100644 --- a/internal/datastore/common/sql_test.go +++ b/internal/datastore/common/sql_test.go @@ -19,245 +19,256 @@ var toCursor = options.ToCursor func TestSchemaQueryFilterer(t *testing.T) { tests := []struct { - name string - run func(filterer SchemaQueryFilterer) SchemaQueryFilterer - expectedSQL string - expectedArgs []any - expectedStaticColumns []string + name string + run func(filterer SchemaQueryFilterer) SchemaQueryFilterer + expectedSQL string + expectedArgs []any + expectedStaticColumns []string + withExpirationDisabled bool }{ { - "relation filter", - func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + name: "relation filter", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + return filterer.FilterToRelation("somerelation") + }, + expectedSQL: "SELECT * WHERE relation = ? AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"somerelation"}, + expectedStaticColumns: []string{"relation"}, + }, + { + name: "relation filter without expiration", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.FilterToRelation("somerelation") }, - "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND relation = ?", - []any{"somerelation"}, - []string{"relation"}, + expectedSQL: "SELECT * WHERE relation = ?", + expectedArgs: []any{"somerelation"}, + expectedStaticColumns: []string{"relation"}, + withExpirationDisabled: true, }, { - "resource ID filter", - func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + name: "resource ID filter", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.FilterToResourceID("someresourceid") }, - "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND object_id = ?", - []any{"someresourceid"}, - []string{"object_id"}, + expectedSQL: "SELECT * WHERE object_id = ? AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"someresourceid"}, + expectedStaticColumns: []string{"object_id"}, }, { - "resource IDs filter", - func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + name: "resource IDs filter", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.MustFilterWithResourceIDPrefix("someprefix") }, - "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND object_id LIKE ?", - []any{"someprefix%"}, - []string{}, + expectedSQL: "SELECT * WHERE object_id LIKE ? AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"someprefix%"}, + expectedStaticColumns: []string{}, }, { - "resource IDs prefix filter", - func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + name: "resource IDs prefix filter", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.MustFilterToResourceIDs([]string{"someresourceid", "anotherresourceid"}) }, - "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND object_id IN (?,?)", - []any{"someresourceid", "anotherresourceid"}, - []string{}, + expectedSQL: "SELECT * WHERE object_id IN (?,?) AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"someresourceid", "anotherresourceid"}, + expectedStaticColumns: []string{}, }, { - "resource type filter", - func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + name: "resource type filter", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.FilterToResourceType("sometype") }, - "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND ns = ?", - []any{"sometype"}, - []string{"ns"}, + expectedSQL: "SELECT * WHERE ns = ? AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"sometype"}, + expectedStaticColumns: []string{"ns"}, }, { - "resource filter", - func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + name: "resource filter", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.FilterToResourceType("sometype").FilterToResourceID("someobj").FilterToRelation("somerel") }, - "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND ns = ? AND object_id = ? AND relation = ?", - []any{"sometype", "someobj", "somerel"}, - []string{"ns", "object_id", "relation"}, + expectedSQL: "SELECT * WHERE ns = ? AND object_id = ? AND relation = ? AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"sometype", "someobj", "somerel"}, + expectedStaticColumns: []string{"ns", "object_id", "relation"}, }, { - "relationships filter with no IDs or relations", - func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + name: "relationships filter with no IDs or relations", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.MustFilterWithRelationshipsFilter(datastore.RelationshipsFilter{ OptionalResourceType: "sometype", }) }, - "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND ns = ?", - []any{"sometype"}, - []string{"ns"}, + expectedSQL: "SELECT * WHERE ns = ? AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"sometype"}, + expectedStaticColumns: []string{"ns"}, }, { - "relationships filter with single ID", - func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + name: "relationships filter with single ID", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.MustFilterWithRelationshipsFilter(datastore.RelationshipsFilter{ OptionalResourceType: "sometype", OptionalResourceIds: []string{"someid"}, }) }, - "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND ns = ? AND object_id IN (?)", - []any{"sometype", "someid"}, - []string{"ns", "object_id"}, + expectedSQL: "SELECT * WHERE ns = ? AND object_id IN (?) AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"sometype", "someid"}, + expectedStaticColumns: []string{"ns", "object_id"}, }, { - "relationships filter with no IDs", - func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + name: "relationships filter with no IDs", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.MustFilterWithRelationshipsFilter(datastore.RelationshipsFilter{ OptionalResourceType: "sometype", OptionalResourceIds: []string{}, }) }, - "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND ns = ?", - []any{"sometype"}, - []string{"ns"}, + expectedSQL: "SELECT * WHERE ns = ? AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"sometype"}, + expectedStaticColumns: []string{"ns"}, }, { - "relationships filter with multiple IDs", - func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + name: "relationships filter with multiple IDs", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.MustFilterWithRelationshipsFilter(datastore.RelationshipsFilter{ OptionalResourceType: "sometype", OptionalResourceIds: []string{"someid", "anotherid"}, }) }, - "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND ns = ? AND object_id IN (?,?)", - []any{"sometype", "someid", "anotherid"}, - []string{"ns"}, + expectedSQL: "SELECT * WHERE ns = ? AND object_id IN (?,?) AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"sometype", "someid", "anotherid"}, + expectedStaticColumns: []string{"ns"}, }, { - "subjects filter with no IDs or relations", - func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + name: "subjects filter with no IDs or relations", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.MustFilterWithSubjectsSelectors(datastore.SubjectsSelector{ OptionalSubjectType: "somesubjectype", }) }, - "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND ((subject_ns = ?))", - []any{"somesubjectype"}, - []string{"subject_ns"}, + expectedSQL: "SELECT * WHERE ((subject_ns = ?)) AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"somesubjectype"}, + expectedStaticColumns: []string{"subject_ns"}, }, { - "multiple subjects filters with just types", - func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + name: "multiple subjects filters with just types", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.MustFilterWithSubjectsSelectors(datastore.SubjectsSelector{ OptionalSubjectType: "somesubjectype", }, datastore.SubjectsSelector{ OptionalSubjectType: "anothersubjectype", }) }, - "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND ((subject_ns = ?) OR (subject_ns = ?))", - []any{"somesubjectype", "anothersubjectype"}, - []string{}, + expectedSQL: "SELECT * WHERE ((subject_ns = ?) OR (subject_ns = ?)) AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"somesubjectype", "anothersubjectype"}, + expectedStaticColumns: []string{}, }, { - "subjects filter with single ID", - func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + name: "subjects filter with single ID", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.MustFilterWithSubjectsSelectors(datastore.SubjectsSelector{ OptionalSubjectType: "somesubjectype", OptionalSubjectIds: []string{"somesubjectid"}, }) }, - "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND ((subject_ns = ? AND subject_object_id IN (?)))", - []any{"somesubjectype", "somesubjectid"}, - []string{"subject_ns", "subject_object_id"}, + expectedSQL: "SELECT * WHERE ((subject_ns = ? AND subject_object_id IN (?))) AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"somesubjectype", "somesubjectid"}, + expectedStaticColumns: []string{"subject_ns", "subject_object_id"}, }, { - "subjects filter with single ID and no type", - func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + name: "subjects filter with single ID and no type", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.MustFilterWithSubjectsSelectors(datastore.SubjectsSelector{ OptionalSubjectIds: []string{"somesubjectid"}, }) }, - "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND ((subject_object_id IN (?)))", - []any{"somesubjectid"}, - []string{"subject_object_id"}, + expectedSQL: "SELECT * WHERE ((subject_object_id IN (?))) AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"somesubjectid"}, + expectedStaticColumns: []string{"subject_object_id"}, }, { - "empty subjects filter", - func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + name: "empty subjects filter", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.MustFilterWithSubjectsSelectors(datastore.SubjectsSelector{}) }, - "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND ((1=1))", - nil, - []string{}, + expectedSQL: "SELECT * WHERE ((1=1)) AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: nil, + expectedStaticColumns: []string{}, }, { - "subjects filter with multiple IDs", - func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + name: "subjects filter with multiple IDs", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.MustFilterWithSubjectsSelectors(datastore.SubjectsSelector{ OptionalSubjectType: "somesubjectype", OptionalSubjectIds: []string{"somesubjectid", "anothersubjectid"}, }) }, - "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND ((subject_ns = ? AND subject_object_id IN (?,?)))", - []any{"somesubjectype", "somesubjectid", "anothersubjectid"}, - []string{"subject_ns"}, + expectedSQL: "SELECT * WHERE ((subject_ns = ? AND subject_object_id IN (?,?))) AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"somesubjectype", "somesubjectid", "anothersubjectid"}, + expectedStaticColumns: []string{"subject_ns"}, }, { - "subjects filter with single ellipsis relation", - func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + name: "subjects filter with single ellipsis relation", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.MustFilterWithSubjectsSelectors(datastore.SubjectsSelector{ OptionalSubjectType: "somesubjectype", RelationFilter: datastore.SubjectRelationFilter{}.WithEllipsisRelation(), }) }, - "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND ((subject_ns = ? AND subject_relation = ?))", - []any{"somesubjectype", "..."}, - []string{"subject_ns", "subject_relation"}, + expectedSQL: "SELECT * WHERE ((subject_ns = ? AND subject_relation = ?)) AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"somesubjectype", "..."}, + expectedStaticColumns: []string{"subject_ns", "subject_relation"}, }, { - "subjects filter with single defined relation", - func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + name: "subjects filter with single defined relation", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.MustFilterWithSubjectsSelectors(datastore.SubjectsSelector{ OptionalSubjectType: "somesubjectype", RelationFilter: datastore.SubjectRelationFilter{}.WithNonEllipsisRelation("somesubrel"), }) }, - "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND ((subject_ns = ? AND subject_relation = ?))", - []any{"somesubjectype", "somesubrel"}, - []string{"subject_ns", "subject_relation"}, + expectedSQL: "SELECT * WHERE ((subject_ns = ? AND subject_relation = ?)) AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"somesubjectype", "somesubrel"}, + expectedStaticColumns: []string{"subject_ns", "subject_relation"}, }, { - "subjects filter with only non-ellipsis", - func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + name: "subjects filter with only non-ellipsis", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.MustFilterWithSubjectsSelectors(datastore.SubjectsSelector{ OptionalSubjectType: "somesubjectype", RelationFilter: datastore.SubjectRelationFilter{}.WithOnlyNonEllipsisRelations(), }) }, - "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND ((subject_ns = ? AND subject_relation <> ?))", - []any{"somesubjectype", "..."}, - []string{"subject_ns"}, + expectedSQL: "SELECT * WHERE ((subject_ns = ? AND subject_relation <> ?)) AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"somesubjectype", "..."}, + expectedStaticColumns: []string{"subject_ns"}, }, { - "subjects filter with defined relation and ellipsis", - func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + name: "subjects filter with defined relation and ellipsis", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.MustFilterWithSubjectsSelectors(datastore.SubjectsSelector{ OptionalSubjectType: "somesubjectype", RelationFilter: datastore.SubjectRelationFilter{}.WithNonEllipsisRelation("somesubrel").WithEllipsisRelation(), }) }, - "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND ((subject_ns = ? AND (subject_relation = ? OR subject_relation = ?)))", - []any{"somesubjectype", "...", "somesubrel"}, - []string{"subject_ns"}, + expectedSQL: "SELECT * WHERE ((subject_ns = ? AND (subject_relation = ? OR subject_relation = ?))) AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"somesubjectype", "...", "somesubrel"}, + expectedStaticColumns: []string{"subject_ns"}, }, { - "subjects filter", - func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + name: "subjects filter", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.MustFilterWithSubjectsSelectors(datastore.SubjectsSelector{ OptionalSubjectType: "somesubjectype", OptionalSubjectIds: []string{"somesubjectid", "anothersubjectid"}, RelationFilter: datastore.SubjectRelationFilter{}.WithNonEllipsisRelation("somesubrel").WithEllipsisRelation(), }) }, - "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND ((subject_ns = ? AND subject_object_id IN (?,?) AND (subject_relation = ? OR subject_relation = ?)))", - []any{"somesubjectype", "somesubjectid", "anothersubjectid", "...", "somesubrel"}, - []string{"subject_ns"}, + expectedSQL: "SELECT * WHERE ((subject_ns = ? AND subject_object_id IN (?,?) AND (subject_relation = ? OR subject_relation = ?))) AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"somesubjectype", "somesubjectid", "anothersubjectid", "...", "somesubrel"}, + expectedStaticColumns: []string{"subject_ns"}, }, { - "multiple subjects filter", - func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + name: "multiple subjects filter", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.MustFilterWithSubjectsSelectors( datastore.SubjectsSelector{ OptionalSubjectType: "somesubjectype", @@ -275,36 +286,36 @@ func TestSchemaQueryFilterer(t *testing.T) { }, ) }, - "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND ((subject_ns = ? AND subject_object_id IN (?,?) AND (subject_relation = ? OR subject_relation = ?)) OR (subject_ns = ? AND subject_object_id IN (?,?) AND (subject_relation = ? OR subject_relation = ?)) OR (subject_ns = ? AND subject_relation <> ?))", - []any{"somesubjectype", "a", "b", "...", "somesubrel", "anothersubjecttype", "b", "c", "...", "anotherrel", "thirdsubjectype", "..."}, - []string{}, + expectedSQL: "SELECT * WHERE ((subject_ns = ? AND subject_object_id IN (?,?) AND (subject_relation = ? OR subject_relation = ?)) OR (subject_ns = ? AND subject_object_id IN (?,?) AND (subject_relation = ? OR subject_relation = ?)) OR (subject_ns = ? AND subject_relation <> ?)) AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"somesubjectype", "a", "b", "...", "somesubrel", "anothersubjecttype", "b", "c", "...", "anotherrel", "thirdsubjectype", "..."}, + expectedStaticColumns: []string{}, }, { - "v1 subject filter with namespace", - func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + name: "v1 subject filter with namespace", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.FilterToSubjectFilter(&v1.SubjectFilter{ SubjectType: "subns", }) }, - "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND subject_ns = ?", - []any{"subns"}, - []string{"subject_ns"}, + expectedSQL: "SELECT * WHERE subject_ns = ? AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"subns"}, + expectedStaticColumns: []string{"subject_ns"}, }, { - "v1 subject filter with subject id", - func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + name: "v1 subject filter with subject id", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.FilterToSubjectFilter(&v1.SubjectFilter{ SubjectType: "subns", OptionalSubjectId: "subid", }) }, - "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND subject_ns = ? AND subject_object_id = ?", - []any{"subns", "subid"}, - []string{"subject_ns", "subject_object_id"}, + expectedSQL: "SELECT * WHERE subject_ns = ? AND subject_object_id = ? AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"subns", "subid"}, + expectedStaticColumns: []string{"subject_ns", "subject_object_id"}, }, { - "v1 subject filter with relation", - func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + name: "v1 subject filter with relation", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.FilterToSubjectFilter(&v1.SubjectFilter{ SubjectType: "subns", OptionalRelation: &v1.SubjectFilter_RelationFilter{ @@ -312,13 +323,13 @@ func TestSchemaQueryFilterer(t *testing.T) { }, }) }, - "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND subject_ns = ? AND subject_relation = ?", - []any{"subns", "subrel"}, - []string{"subject_ns", "subject_relation"}, + expectedSQL: "SELECT * WHERE subject_ns = ? AND subject_relation = ? AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"subns", "subrel"}, + expectedStaticColumns: []string{"subject_ns", "subject_relation"}, }, { - "v1 subject filter with empty relation", - func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + name: "v1 subject filter with empty relation", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.FilterToSubjectFilter(&v1.SubjectFilter{ SubjectType: "subns", OptionalRelation: &v1.SubjectFilter_RelationFilter{ @@ -326,13 +337,13 @@ func TestSchemaQueryFilterer(t *testing.T) { }, }) }, - "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND subject_ns = ? AND subject_relation = ?", - []any{"subns", "..."}, - []string{"subject_ns", "subject_relation"}, + expectedSQL: "SELECT * WHERE subject_ns = ? AND subject_relation = ? AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"subns", "..."}, + expectedStaticColumns: []string{"subject_ns", "subject_relation"}, }, { - "v1 subject filter", - func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + name: "v1 subject filter", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.FilterToSubjectFilter(&v1.SubjectFilter{ SubjectType: "subns", OptionalSubjectId: "subid", @@ -341,22 +352,44 @@ func TestSchemaQueryFilterer(t *testing.T) { }, }) }, - "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND subject_ns = ? AND subject_object_id = ? AND subject_relation = ?", - []any{"subns", "subid", "somerel"}, - []string{"subject_ns", "subject_object_id", "subject_relation"}, + expectedSQL: "SELECT * WHERE subject_ns = ? AND subject_object_id = ? AND subject_relation = ? AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"subns", "subid", "somerel"}, + expectedStaticColumns: []string{"subject_ns", "subject_object_id", "subject_relation"}, }, { - "limit", - func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + name: "limit", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.limit(100) }, - "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) LIMIT 100", - nil, - []string{}, + expectedSQL: "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) LIMIT 100", + expectedArgs: nil, + expectedStaticColumns: []string{}, + }, + { + name: "full resources filter", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + return filterer.MustFilterWithRelationshipsFilter( + datastore.RelationshipsFilter{ + OptionalResourceType: "someresourcetype", + OptionalResourceIds: []string{"someid", "anotherid"}, + OptionalResourceRelation: "somerelation", + OptionalSubjectsSelectors: []datastore.SubjectsSelector{ + { + OptionalSubjectType: "somesubjectype", + OptionalSubjectIds: []string{"somesubjectid", "anothersubjectid"}, + RelationFilter: datastore.SubjectRelationFilter{}.WithNonEllipsisRelation("somesubrel").WithEllipsisRelation(), + }, + }, + }, + ) + }, + expectedSQL: "SELECT * WHERE ns = ? AND relation = ? AND object_id IN (?,?) AND ((subject_ns = ? AND subject_object_id IN (?,?) AND (subject_relation = ? OR subject_relation = ?))) AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"someresourcetype", "somerelation", "someid", "anotherid", "somesubjectype", "somesubjectid", "anothersubjectid", "...", "somesubrel"}, + expectedStaticColumns: []string{"ns", "relation", "subject_ns"}, }, { - "full resources filter", - func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + name: "full resources filter without expiration", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.MustFilterWithRelationshipsFilter( datastore.RelationshipsFilter{ OptionalResourceType: "someresourcetype", @@ -372,52 +405,53 @@ func TestSchemaQueryFilterer(t *testing.T) { }, ) }, - "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND ns = ? AND relation = ? AND object_id IN (?,?) AND ((subject_ns = ? AND subject_object_id IN (?,?) AND (subject_relation = ? OR subject_relation = ?)))", - []any{"someresourcetype", "somerelation", "someid", "anotherid", "somesubjectype", "somesubjectid", "anothersubjectid", "...", "somesubrel"}, - []string{"ns", "relation", "subject_ns"}, + expectedSQL: "SELECT * WHERE ns = ? AND relation = ? AND object_id IN (?,?) AND ((subject_ns = ? AND subject_object_id IN (?,?) AND (subject_relation = ? OR subject_relation = ?)))", + expectedArgs: []any{"someresourcetype", "somerelation", "someid", "anotherid", "somesubjectype", "somesubjectid", "anothersubjectid", "...", "somesubrel"}, + expectedStaticColumns: []string{"ns", "relation", "subject_ns"}, + withExpirationDisabled: true, }, { - "order by", - func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + name: "order by", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.MustFilterWithRelationshipsFilter( datastore.RelationshipsFilter{ OptionalResourceType: "someresourcetype", }, ).TupleOrder(options.ByResource) }, - "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND ns = ? ORDER BY ns, object_id, relation, subject_ns, subject_object_id, subject_relation", - []any{"someresourcetype"}, - []string{"ns"}, + expectedSQL: "SELECT * WHERE ns = ? AND (expiration IS NULL OR expiration > NOW()) ORDER BY ns, object_id, relation, subject_ns, subject_object_id, subject_relation", + expectedArgs: []any{"someresourcetype"}, + expectedStaticColumns: []string{"ns"}, }, { - "after with just namespace", - func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + name: "after with just namespace", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.MustFilterWithRelationshipsFilter( datastore.RelationshipsFilter{ OptionalResourceType: "someresourcetype", }, ).After(toCursor(tuple.MustParse("someresourcetype:foo#viewer@user:bar")), options.ByResource) }, - "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND ns = ? AND (object_id,relation,subject_ns,subject_object_id,subject_relation) > (?,?,?,?,?)", - []any{"someresourcetype", "foo", "viewer", "user", "bar", "..."}, - []string{"ns"}, + expectedSQL: "SELECT * WHERE ns = ? AND (object_id,relation,subject_ns,subject_object_id,subject_relation) > (?,?,?,?,?) AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"someresourcetype", "foo", "viewer", "user", "bar", "..."}, + expectedStaticColumns: []string{"ns"}, }, { - "after with just relation", - func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + name: "after with just relation", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.MustFilterWithRelationshipsFilter( datastore.RelationshipsFilter{ OptionalResourceRelation: "somerelation", }, ).After(toCursor(tuple.MustParse("someresourcetype:foo#viewer@user:bar")), options.ByResource) }, - "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND relation = ? AND (ns,object_id,subject_ns,subject_object_id,subject_relation) > (?,?,?,?,?)", - []any{"somerelation", "someresourcetype", "foo", "user", "bar", "..."}, - []string{"relation"}, + expectedSQL: "SELECT * WHERE relation = ? AND (ns,object_id,subject_ns,subject_object_id,subject_relation) > (?,?,?,?,?) AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"somerelation", "someresourcetype", "foo", "user", "bar", "..."}, + expectedStaticColumns: []string{"relation"}, }, { - "after with namespace and single resource id", - func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + name: "after with namespace and single resource id", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.MustFilterWithRelationshipsFilter( datastore.RelationshipsFilter{ OptionalResourceType: "someresourcetype", @@ -425,26 +459,26 @@ func TestSchemaQueryFilterer(t *testing.T) { }, ).After(toCursor(tuple.MustParse("someresourcetype:foo#viewer@user:bar")), options.ByResource) }, - "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND ns = ? AND object_id IN (?) AND (relation,subject_ns,subject_object_id,subject_relation) > (?,?,?,?)", - []any{"someresourcetype", "one", "viewer", "user", "bar", "..."}, - []string{"ns", "object_id"}, + expectedSQL: "SELECT * WHERE ns = ? AND object_id IN (?) AND (relation,subject_ns,subject_object_id,subject_relation) > (?,?,?,?) AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"someresourcetype", "one", "viewer", "user", "bar", "..."}, + expectedStaticColumns: []string{"ns", "object_id"}, }, { - "after with single resource id", - func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + name: "after with single resource id", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.MustFilterWithRelationshipsFilter( datastore.RelationshipsFilter{ OptionalResourceIds: []string{"one"}, }, ).After(toCursor(tuple.MustParse("someresourcetype:foo#viewer@user:bar")), options.ByResource) }, - "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND object_id IN (?) AND (ns,relation,subject_ns,subject_object_id,subject_relation) > (?,?,?,?,?)", - []any{"one", "someresourcetype", "viewer", "user", "bar", "..."}, - []string{"object_id"}, + expectedSQL: "SELECT * WHERE object_id IN (?) AND (ns,relation,subject_ns,subject_object_id,subject_relation) > (?,?,?,?,?) AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"one", "someresourcetype", "viewer", "user", "bar", "..."}, + expectedStaticColumns: []string{"object_id"}, }, { - "after with namespace and resource ids", - func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + name: "after with namespace and resource ids", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.MustFilterWithRelationshipsFilter( datastore.RelationshipsFilter{ OptionalResourceType: "someresourcetype", @@ -452,13 +486,13 @@ func TestSchemaQueryFilterer(t *testing.T) { }, ).After(toCursor(tuple.MustParse("someresourcetype:foo#viewer@user:bar")), options.ByResource) }, - "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND ns = ? AND object_id IN (?,?) AND (object_id,relation,subject_ns,subject_object_id,subject_relation) > (?,?,?,?,?)", - []any{"someresourcetype", "one", "two", "foo", "viewer", "user", "bar", "..."}, - []string{"ns"}, + expectedSQL: "SELECT * WHERE ns = ? AND object_id IN (?,?) AND (object_id,relation,subject_ns,subject_object_id,subject_relation) > (?,?,?,?,?) AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"someresourcetype", "one", "two", "foo", "viewer", "user", "bar", "..."}, + expectedStaticColumns: []string{"ns"}, }, { - "after with namespace and relation", - func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + name: "after with namespace and relation", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.MustFilterWithRelationshipsFilter( datastore.RelationshipsFilter{ OptionalResourceType: "someresourcetype", @@ -466,24 +500,24 @@ func TestSchemaQueryFilterer(t *testing.T) { }, ).After(toCursor(tuple.MustParse("someresourcetype:foo#viewer@user:bar")), options.ByResource) }, - "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND ns = ? AND relation = ? AND (object_id,subject_ns,subject_object_id,subject_relation) > (?,?,?,?)", - []any{"someresourcetype", "somerelation", "foo", "user", "bar", "..."}, - []string{"ns", "relation"}, + expectedSQL: "SELECT * WHERE ns = ? AND relation = ? AND (object_id,subject_ns,subject_object_id,subject_relation) > (?,?,?,?) AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"someresourcetype", "somerelation", "foo", "user", "bar", "..."}, + expectedStaticColumns: []string{"ns", "relation"}, }, { - "after with subject namespace", - func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + name: "after with subject namespace", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.MustFilterWithSubjectsSelectors(datastore.SubjectsSelector{ OptionalSubjectType: "somesubjectype", }).After(toCursor(tuple.MustParse("someresourcetype:foo#viewer@user:bar")), options.ByResource) }, - "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND ((subject_ns = ?)) AND (ns,object_id,relation,subject_object_id,subject_relation) > (?,?,?,?,?)", - []any{"somesubjectype", "someresourcetype", "foo", "viewer", "bar", "..."}, - []string{"subject_ns"}, + expectedSQL: "SELECT * WHERE ((subject_ns = ?)) AND (ns,object_id,relation,subject_object_id,subject_relation) > (?,?,?,?,?) AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"somesubjectype", "someresourcetype", "foo", "viewer", "bar", "..."}, + expectedStaticColumns: []string{"subject_ns"}, }, { - "after with subject namespaces", - func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + name: "after with subject namespaces", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { // NOTE: this isn't really valid (it'll return no results), but is a good test to ensure // the duplicate subject type results in the subject type being in the ORDER BY. return filterer.MustFilterWithSubjectsSelectors(datastore.SubjectsSelector{ @@ -492,66 +526,66 @@ func TestSchemaQueryFilterer(t *testing.T) { OptionalSubjectType: "anothersubjectype", }).After(toCursor(tuple.MustParse("someresourcetype:foo#viewer@user:bar")), options.ByResource) }, - "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND ((subject_ns = ?)) AND ((subject_ns = ?)) AND (ns,object_id,relation,subject_ns,subject_object_id,subject_relation) > (?,?,?,?,?,?)", - []any{"somesubjectype", "anothersubjectype", "someresourcetype", "foo", "viewer", "user", "bar", "..."}, - []string{}, + expectedSQL: "SELECT * WHERE ((subject_ns = ?)) AND ((subject_ns = ?)) AND (ns,object_id,relation,subject_ns,subject_object_id,subject_relation) > (?,?,?,?,?,?) AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"somesubjectype", "anothersubjectype", "someresourcetype", "foo", "viewer", "user", "bar", "..."}, + expectedStaticColumns: []string{}, }, { - "after with resource ID prefix", - func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + name: "after with resource ID prefix", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.MustFilterWithResourceIDPrefix("someprefix").After(toCursor(tuple.MustParse("someresourcetype:foo#viewer@user:bar")), options.ByResource) }, - "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND object_id LIKE ? AND (ns,object_id,relation,subject_ns,subject_object_id,subject_relation) > (?,?,?,?,?,?)", - []any{"someprefix%", "someresourcetype", "foo", "viewer", "user", "bar", "..."}, - []string{}, + expectedSQL: "SELECT * WHERE object_id LIKE ? AND (ns,object_id,relation,subject_ns,subject_object_id,subject_relation) > (?,?,?,?,?,?) AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"someprefix%", "someresourcetype", "foo", "viewer", "user", "bar", "..."}, + expectedStaticColumns: []string{}, }, { - "order by subject", - func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + name: "order by subject", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.MustFilterWithRelationshipsFilter( datastore.RelationshipsFilter{ OptionalResourceType: "someresourcetype", }, ).TupleOrder(options.BySubject) }, - "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND ns = ? ORDER BY subject_ns, subject_object_id, subject_relation, ns, object_id, relation", - []any{"someresourcetype"}, - []string{"ns"}, + expectedSQL: "SELECT * WHERE ns = ? AND (expiration IS NULL OR expiration > NOW()) ORDER BY subject_ns, subject_object_id, subject_relation, ns, object_id, relation", + expectedArgs: []any{"someresourcetype"}, + expectedStaticColumns: []string{"ns"}, }, { - "order by subject, after with subject namespace", - func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + name: "order by subject, after with subject namespace", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.MustFilterWithSubjectsSelectors(datastore.SubjectsSelector{ OptionalSubjectType: "somesubjectype", }).After(toCursor(tuple.MustParse("someresourcetype:foo#viewer@user:bar")), options.BySubject) }, - "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND ((subject_ns = ?)) AND (subject_object_id,ns,object_id,relation,subject_relation) > (?,?,?,?,?)", - []any{"somesubjectype", "bar", "someresourcetype", "foo", "viewer", "..."}, - []string{"subject_ns"}, + expectedSQL: "SELECT * WHERE ((subject_ns = ?)) AND (subject_object_id,ns,object_id,relation,subject_relation) > (?,?,?,?,?) AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"somesubjectype", "bar", "someresourcetype", "foo", "viewer", "..."}, + expectedStaticColumns: []string{"subject_ns"}, }, { - "order by subject, after with subject namespace and subject object id", - func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + name: "order by subject, after with subject namespace and subject object id", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.MustFilterWithSubjectsSelectors(datastore.SubjectsSelector{ OptionalSubjectType: "somesubjectype", OptionalSubjectIds: []string{"foo"}, }).After(toCursor(tuple.MustParse("someresourcetype:someresource#viewer@user:bar")), options.BySubject) }, - "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND ((subject_ns = ? AND subject_object_id IN (?))) AND (ns,object_id,relation,subject_relation) > (?,?,?,?)", - []any{"somesubjectype", "foo", "someresourcetype", "someresource", "viewer", "..."}, - []string{"subject_ns", "subject_object_id"}, + expectedSQL: "SELECT * WHERE ((subject_ns = ? AND subject_object_id IN (?))) AND (ns,object_id,relation,subject_relation) > (?,?,?,?) AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"somesubjectype", "foo", "someresourcetype", "someresource", "viewer", "..."}, + expectedStaticColumns: []string{"subject_ns", "subject_object_id"}, }, { - "order by subject, after with subject namespace and multiple subject object IDs", - func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + name: "order by subject, after with subject namespace and multiple subject object IDs", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.MustFilterWithSubjectsSelectors(datastore.SubjectsSelector{ OptionalSubjectType: "somesubjectype", OptionalSubjectIds: []string{"foo", "bar"}, }).After(toCursor(tuple.MustParse("someresourcetype:someresource#viewer@user:next")), options.BySubject) }, - "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) AND ((subject_ns = ? AND subject_object_id IN (?,?))) AND (subject_object_id,ns,object_id,relation,subject_relation) > (?,?,?,?,?)", - []any{"somesubjectype", "foo", "bar", "next", "someresourcetype", "someresource", "viewer", "..."}, - []string{"subject_ns"}, + expectedSQL: "SELECT * WHERE ((subject_ns = ? AND subject_object_id IN (?,?))) AND (subject_object_id,ns,object_id,relation,subject_relation) > (?,?,?,?,?) AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"somesubjectype", "foo", "bar", "next", "someresourcetype", "someresource", "viewer", "..."}, + expectedStaticColumns: []string{"subject_ns"}, }, } @@ -586,7 +620,7 @@ func TestSchemaQueryFilterer(t *testing.T) { require.ElementsMatch(t, test.expectedStaticColumns, foundStaticColumns) - ran.queryBuilder = ran.queryBuilder.Columns("*") + ran.queryBuilder = ran.queryBuilderWithExpirationFilter(test.withExpirationDisabled).Columns("*") sql, args, err := ran.queryBuilder.ToSql() require.NoError(t, err) @@ -598,19 +632,21 @@ func TestSchemaQueryFilterer(t *testing.T) { func TestExecuteQuery(t *testing.T) { tcs := []struct { - name string - run func(filterer SchemaQueryFilterer) SchemaQueryFilterer - options []options.QueryOptionsOption - expectedSQL string - expectedArgs []any - expectedSkipCaveats bool + name string + run func(filterer SchemaQueryFilterer) SchemaQueryFilterer + options []options.QueryOptionsOption + expectedSQL string + expectedArgs []any + expectedSkipCaveats bool + expectedSkipExpiration bool + withExpirationDisabled bool }{ { name: "filter by static resource type", run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.FilterToResourceType("sometype") }, - expectedSQL: "SELECT object_id, relation, subject_ns, subject_object_id, subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE (expiration IS NULL OR expiration > NOW()) AND ns = ?", + expectedSQL: "SELECT object_id, relation, subject_ns, subject_object_id, subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE ns = ? AND (expiration IS NULL OR expiration > NOW())", expectedArgs: []any{"sometype"}, }, { @@ -618,7 +654,7 @@ func TestExecuteQuery(t *testing.T) { run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.FilterToResourceType("sometype").FilterToResourceID("someobj") }, - expectedSQL: "SELECT relation, subject_ns, subject_object_id, subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE (expiration IS NULL OR expiration > NOW()) AND ns = ? AND object_id = ?", + expectedSQL: "SELECT relation, subject_ns, subject_object_id, subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE ns = ? AND object_id = ? AND (expiration IS NULL OR expiration > NOW())", expectedArgs: []any{"sometype", "someobj"}, }, { @@ -626,7 +662,7 @@ func TestExecuteQuery(t *testing.T) { run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.FilterToResourceType("sometype").MustFilterWithResourceIDPrefix("someprefix") }, - expectedSQL: "SELECT object_id, relation, subject_ns, subject_object_id, subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE (expiration IS NULL OR expiration > NOW()) AND ns = ? AND object_id LIKE ?", + expectedSQL: "SELECT object_id, relation, subject_ns, subject_object_id, subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE ns = ? AND object_id LIKE ? AND (expiration IS NULL OR expiration > NOW())", expectedArgs: []any{"sometype", "someprefix%"}, }, { @@ -634,7 +670,7 @@ func TestExecuteQuery(t *testing.T) { run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.FilterToResourceType("sometype").MustFilterToResourceIDs([]string{"someobj", "anotherobj"}) }, - expectedSQL: "SELECT object_id, relation, subject_ns, subject_object_id, subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE (expiration IS NULL OR expiration > NOW()) AND ns = ? AND object_id IN (?,?)", + expectedSQL: "SELECT object_id, relation, subject_ns, subject_object_id, subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE ns = ? AND object_id IN (?,?) AND (expiration IS NULL OR expiration > NOW())", expectedArgs: []any{"sometype", "someobj", "anotherobj"}, }, { @@ -642,7 +678,7 @@ func TestExecuteQuery(t *testing.T) { run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.FilterToResourceType("sometype").FilterToResourceID("someobj").FilterToRelation("somerel") }, - expectedSQL: "SELECT subject_ns, subject_object_id, subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE (expiration IS NULL OR expiration > NOW()) AND ns = ? AND object_id = ? AND relation = ?", + expectedSQL: "SELECT subject_ns, subject_object_id, subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE ns = ? AND object_id = ? AND relation = ? AND (expiration IS NULL OR expiration > NOW())", expectedArgs: []any{"sometype", "someobj", "somerel"}, }, { @@ -652,7 +688,7 @@ func TestExecuteQuery(t *testing.T) { SubjectType: "subns", }) }, - expectedSQL: "SELECT subject_object_id, subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE (expiration IS NULL OR expiration > NOW()) AND ns = ? AND object_id = ? AND relation = ? AND subject_ns = ?", + expectedSQL: "SELECT subject_object_id, subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE ns = ? AND object_id = ? AND relation = ? AND subject_ns = ? AND (expiration IS NULL OR expiration > NOW())", expectedArgs: []any{"sometype", "someobj", "somerel", "subns"}, }, { @@ -663,7 +699,7 @@ func TestExecuteQuery(t *testing.T) { OptionalSubjectId: "subid", }) }, - expectedSQL: "SELECT subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE (expiration IS NULL OR expiration > NOW()) AND ns = ? AND object_id = ? AND relation = ? AND subject_ns = ? AND subject_object_id = ?", + expectedSQL: "SELECT subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE ns = ? AND object_id = ? AND relation = ? AND subject_ns = ? AND subject_object_id = ? AND (expiration IS NULL OR expiration > NOW())", expectedArgs: []any{"sometype", "someobj", "somerel", "subns", "subid"}, }, { @@ -677,7 +713,7 @@ func TestExecuteQuery(t *testing.T) { }, }) }, - expectedSQL: "SELECT caveat, caveat_context, expiration FROM relationtuples WHERE (expiration IS NULL OR expiration > NOW()) AND ns = ? AND object_id = ? AND relation = ? AND subject_ns = ? AND subject_object_id = ? AND subject_relation = ?", + expectedSQL: "SELECT caveat, caveat_context, expiration FROM relationtuples WHERE ns = ? AND object_id = ? AND relation = ? AND subject_ns = ? AND subject_object_id = ? AND subject_relation = ? AND (expiration IS NULL OR expiration > NOW())", expectedArgs: []any{"sometype", "someobj", "somerel", "subns", "subid", "subrel"}, }, { @@ -695,7 +731,7 @@ func TestExecuteQuery(t *testing.T) { options.WithSkipCaveats(true), }, expectedSkipCaveats: true, - expectedSQL: "SELECT expiration FROM relationtuples WHERE (expiration IS NULL OR expiration > NOW()) AND ns = ? AND object_id = ? AND relation = ? AND subject_ns = ? AND subject_object_id = ? AND subject_relation = ?", + expectedSQL: "SELECT expiration FROM relationtuples WHERE ns = ? AND object_id = ? AND relation = ? AND subject_ns = ? AND subject_object_id = ? AND subject_relation = ? AND (expiration IS NULL OR expiration > NOW())", expectedArgs: []any{"sometype", "someobj", "somerel", "subns", "subid", "subrel"}, }, { @@ -713,7 +749,7 @@ func TestExecuteQuery(t *testing.T) { options.WithSkipCaveats(true), }, expectedSkipCaveats: true, - expectedSQL: "SELECT object_id, expiration FROM relationtuples WHERE (expiration IS NULL OR expiration > NOW()) AND ns = ? AND object_id IN (?,?) AND relation = ? AND subject_ns = ? AND subject_object_id = ? AND subject_relation = ?", + expectedSQL: "SELECT object_id, expiration FROM relationtuples WHERE ns = ? AND object_id IN (?,?) AND relation = ? AND subject_ns = ? AND subject_object_id = ? AND subject_relation = ? AND (expiration IS NULL OR expiration > NOW())", expectedArgs: []any{"sometype", "someobj", "anotherobj", "somerel", "subns", "subid", "subrel"}, }, { @@ -725,7 +761,7 @@ func TestExecuteQuery(t *testing.T) { options.WithSkipCaveats(true), }, expectedSkipCaveats: true, - expectedSQL: "SELECT object_id, relation, subject_ns, subject_object_id, subject_relation, expiration FROM relationtuples WHERE (expiration IS NULL OR expiration > NOW()) AND ns = ?", + expectedSQL: "SELECT object_id, relation, subject_ns, subject_object_id, subject_relation, expiration FROM relationtuples WHERE ns = ? AND (expiration IS NULL OR expiration > NOW())", expectedArgs: []any{"sometype"}, }, { @@ -735,7 +771,7 @@ func TestExecuteQuery(t *testing.T) { SubjectType: "subns", }) }, - expectedSQL: "SELECT ns, object_id, relation, subject_object_id, subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE (expiration IS NULL OR expiration > NOW()) AND subject_ns = ?", + expectedSQL: "SELECT ns, object_id, relation, subject_object_id, subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE subject_ns = ? AND (expiration IS NULL OR expiration > NOW())", expectedArgs: []any{"subns"}, }, { @@ -746,7 +782,7 @@ func TestExecuteQuery(t *testing.T) { OptionalSubjectId: "subid", }) }, - expectedSQL: "SELECT ns, object_id, relation, subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE (expiration IS NULL OR expiration > NOW()) AND subject_ns = ? AND subject_object_id = ?", + expectedSQL: "SELECT ns, object_id, relation, subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE subject_ns = ? AND subject_object_id = ? AND (expiration IS NULL OR expiration > NOW())", expectedArgs: []any{"subns", "subid"}, }, { @@ -759,7 +795,7 @@ func TestExecuteQuery(t *testing.T) { }, }) }, - expectedSQL: "SELECT ns, object_id, relation, subject_object_id, caveat, caveat_context, expiration FROM relationtuples WHERE (expiration IS NULL OR expiration > NOW()) AND subject_ns = ? AND subject_relation = ?", + expectedSQL: "SELECT ns, object_id, relation, subject_object_id, caveat, caveat_context, expiration FROM relationtuples WHERE subject_ns = ? AND subject_relation = ? AND (expiration IS NULL OR expiration > NOW())", expectedArgs: []any{"subns", "subrel"}, }, { @@ -773,7 +809,7 @@ func TestExecuteQuery(t *testing.T) { }, }) }, - expectedSQL: "SELECT ns, object_id, relation, caveat, caveat_context, expiration FROM relationtuples WHERE (expiration IS NULL OR expiration > NOW()) AND subject_ns = ? AND subject_object_id = ? AND subject_relation = ?", + expectedSQL: "SELECT ns, object_id, relation, caveat, caveat_context, expiration FROM relationtuples WHERE subject_ns = ? AND subject_object_id = ? AND subject_relation = ? AND (expiration IS NULL OR expiration > NOW())", expectedArgs: []any{"subns", "subid", "subrel"}, }, { @@ -787,7 +823,7 @@ func TestExecuteQuery(t *testing.T) { OptionalSubjectId: "subid", }) }, - expectedSQL: "SELECT ns, object_id, relation, subject_ns, subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE (expiration IS NULL OR expiration > NOW()) AND subject_ns = ? AND subject_object_id = ? AND subject_ns = ? AND subject_object_id = ?", + expectedSQL: "SELECT ns, object_id, relation, subject_ns, subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE subject_ns = ? AND subject_object_id = ? AND subject_ns = ? AND subject_object_id = ? AND (expiration IS NULL OR expiration > NOW())", expectedArgs: []any{"subns", "subid", "anothersubns", "subid"}, }, { @@ -799,7 +835,7 @@ func TestExecuteQuery(t *testing.T) { OptionalSubjectType: "anothersubjectype", }) }, - expectedSQL: "SELECT ns, object_id, relation, subject_ns, subject_object_id, subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE (expiration IS NULL OR expiration > NOW()) AND ((subject_ns = ?) OR (subject_ns = ?))", + expectedSQL: "SELECT ns, object_id, relation, subject_ns, subject_object_id, subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE ((subject_ns = ?) OR (subject_ns = ?)) AND (expiration IS NULL OR expiration > NOW())", expectedArgs: []any{"somesubjectype", "anothersubjectype"}, }, { @@ -811,9 +847,44 @@ func TestExecuteQuery(t *testing.T) { OptionalSubjectType: "anothersubjectype", }).FilterToResourceType("sometype") }, - expectedSQL: "SELECT object_id, relation, subject_ns, subject_object_id, subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE (expiration IS NULL OR expiration > NOW()) AND ((subject_ns = ?) OR (subject_ns = ?)) AND ns = ?", + expectedSQL: "SELECT object_id, relation, subject_ns, subject_object_id, subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE ((subject_ns = ?) OR (subject_ns = ?)) AND ns = ? AND (expiration IS NULL OR expiration > NOW())", expectedArgs: []any{"somesubjectype", "anothersubjectype", "sometype"}, }, + { + name: "filter by static resource type with expiration disabled", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + return filterer.FilterToResourceType("sometype") + }, + expectedSQL: "SELECT object_id, relation, subject_ns, subject_object_id, subject_relation, caveat, caveat_context FROM relationtuples WHERE ns = ?", + expectedArgs: []any{"sometype"}, + withExpirationDisabled: true, + }, + { + name: "filter by static resource type with expiration skipped", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + return filterer.FilterToResourceType("sometype") + }, + expectedSQL: "SELECT object_id, relation, subject_ns, subject_object_id, subject_relation, caveat, caveat_context FROM relationtuples WHERE ns = ?", + expectedArgs: []any{"sometype"}, + withExpirationDisabled: false, + expectedSkipExpiration: true, + options: []options.QueryOptionsOption{ + options.WithSkipExpiration(true), + }, + }, + { + name: "filter by static resource type with expiration skipped and disabled", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + return filterer.FilterToResourceType("sometype") + }, + expectedSQL: "SELECT object_id, relation, subject_ns, subject_object_id, subject_relation, caveat, caveat_context FROM relationtuples WHERE ns = ?", + expectedArgs: []any{"sometype"}, + withExpirationDisabled: true, + expectedSkipExpiration: true, + options: []options.QueryOptionsOption{ + options.WithSkipExpiration(true), + }, + }, } for _, tc := range tcs { @@ -833,6 +904,7 @@ func TestExecuteQuery(t *testing.T) { WithPaginationFilterType(TupleComparison), WithColumnOptimization(ColumnOptimizationOptionStaticValues), WithNowFunction("NOW"), + WithExpirationDisabled(tc.withExpirationDisabled), ) filterer := NewSchemaQueryFiltererForRelationshipsSelect(*schema, 100) ran := tc.run(filterer) @@ -847,6 +919,7 @@ func TestExecuteQuery(t *testing.T) { require.Equal(t, tc.expectedSQL, sql) require.Equal(t, tc.expectedArgs, args) require.Equal(t, tc.expectedSkipCaveats, builder.SkipCaveats) + require.Equal(t, tc.expectedSkipExpiration, builder.SkipExpiration) return nil, nil }, } diff --git a/internal/datastore/crdb/crdb.go b/internal/datastore/crdb/crdb.go index affe7264f6..3f27d68d12 100644 --- a/internal/datastore/crdb/crdb.go +++ b/internal/datastore/crdb/crdb.go @@ -224,6 +224,7 @@ func newCRDBDatastore(ctx context.Context, url string, options ...Option) (datas common.WithNowFunction("NOW"), common.WithColumnOptimization(config.columnOptimizationOption), common.WithWithIntegrityColumns(config.withIntegrity), + common.WithExpirationDisabled(config.expirationDisabled), ) ds := &crdbDatastore{ diff --git a/internal/datastore/crdb/options.go b/internal/datastore/crdb/options.go index 80ead5a4f4..7c62227085 100644 --- a/internal/datastore/crdb/options.go +++ b/internal/datastore/crdb/options.go @@ -28,9 +28,10 @@ type crdbOptions struct { filterMaximumIDCount uint16 enablePrometheusStats bool withIntegrity bool - includeQueryParametersInTraces bool - columnOptimizationOption common.ColumnOptimizationOption allowedMigrations []string + columnOptimizationOption common.ColumnOptimizationOption + includeQueryParametersInTraces bool + expirationDisabled bool } const ( @@ -60,6 +61,7 @@ const ( defaultWithIntegrity = false defaultColumnOptimizationOption = common.ColumnOptimizationOptionNone defaultIncludeQueryParametersInTraces = false + defaultExpirationDisabled = false ) // Option provides the facility to configure how clients within the CRDB @@ -85,6 +87,7 @@ func generateConfig(options []Option) (crdbOptions, error) { withIntegrity: defaultWithIntegrity, columnOptimizationOption: defaultColumnOptimizationOption, includeQueryParametersInTraces: defaultIncludeQueryParametersInTraces, + expirationDisabled: defaultExpirationDisabled, } for _, option := range options { @@ -368,3 +371,8 @@ func WithColumnOptimization(isEnabled bool) Option { } } } + +// WithExpirationDisabled configures the datastore to disable relationship expiration. +func WithExpirationDisabled(isDisabled bool) Option { + return func(po *crdbOptions) { po.expirationDisabled = isDisabled } +} diff --git a/internal/datastore/memdb/readonly.go b/internal/datastore/memdb/readonly.go index 7878b3c276..e348e2d773 100644 --- a/internal/datastore/memdb/readonly.go +++ b/internal/datastore/memdb/readonly.go @@ -151,11 +151,11 @@ func (r *memdbReader) QueryRelationships( fallthrough case options.ByResource: - iter := newMemdbTupleIterator(r.now, filteredIterator, queryOpts.Limit, queryOpts.SkipCaveats) + iter := newMemdbTupleIterator(r.now, filteredIterator, queryOpts.Limit, queryOpts.SkipCaveats, queryOpts.SkipExpiration) return iter, nil case options.BySubject: - return newSubjectSortedIterator(r.now, filteredIterator, queryOpts.Limit, queryOpts.SkipCaveats) + return newSubjectSortedIterator(r.now, filteredIterator, queryOpts.Limit, queryOpts.SkipCaveats, queryOpts.SkipExpiration) default: return nil, spiceerrors.MustBugf("unsupported sort order: %v", queryOpts.Sort) @@ -214,11 +214,11 @@ func (r *memdbReader) ReverseQueryRelationships( fallthrough case options.ByResource: - iter := newMemdbTupleIterator(r.now, filteredIterator, queryOpts.LimitForReverse, false) + iter := newMemdbTupleIterator(r.now, filteredIterator, queryOpts.LimitForReverse, false, false) return iter, nil case options.BySubject: - return newSubjectSortedIterator(r.now, filteredIterator, queryOpts.LimitForReverse, false) + return newSubjectSortedIterator(r.now, filteredIterator, queryOpts.LimitForReverse, false, false) default: return nil, spiceerrors.MustBugf("unsupported sort order: %v", queryOpts.SortForReverse) @@ -476,7 +476,7 @@ func makeCursorFilterFn(after options.Cursor, order options.SortOrder) func(tpl return noopCursorFilter } -func newSubjectSortedIterator(now time.Time, it memdb.ResultIterator, limit *uint64, skipCaveats bool) (datastore.RelationshipIterator, error) { +func newSubjectSortedIterator(now time.Time, it memdb.ResultIterator, limit *uint64, skipCaveats bool, skipExpiration bool) (datastore.RelationshipIterator, error) { results := make([]tuple.Relationship, 0) // Coalesce all of the results into memory @@ -494,6 +494,10 @@ func newSubjectSortedIterator(now time.Time, it memdb.ResultIterator, limit *uin return nil, spiceerrors.MustBugf("unexpected caveat in result for relationship: %v", rt) } + if skipExpiration && rt.OptionalExpiration != nil { + return nil, spiceerrors.MustBugf("unexpected expiration in result for relationship: %v", rt) + } + results = append(results, rt) } @@ -530,7 +534,7 @@ func eq(lhsNamespace, lhsObjectID, lhsRelation string, rhs tuple.ObjectAndRelati return lhsNamespace == rhs.ObjectType && lhsObjectID == rhs.ObjectID && lhsRelation == rhs.Relation } -func newMemdbTupleIterator(now time.Time, it memdb.ResultIterator, limit *uint64, skipCaveats bool) datastore.RelationshipIterator { +func newMemdbTupleIterator(now time.Time, it memdb.ResultIterator, limit *uint64, skipCaveats bool, skipExpiration bool) datastore.RelationshipIterator { var count uint64 return func(yield func(tuple.Relationship, error) bool) { for { @@ -551,15 +555,20 @@ func newMemdbTupleIterator(now time.Time, it memdb.ResultIterator, limit *uint64 continue } - if rt.OptionalExpiration != nil && rt.OptionalExpiration.Before(now) { - continue - } - if skipCaveats && rt.OptionalCaveat != nil { yield(rt, fmt.Errorf("unexpected caveat in result for relationship: %v", rt)) return } + if skipExpiration && rt.OptionalExpiration != nil { + yield(rt, fmt.Errorf("unexpected expiration in result for relationship: %v", rt)) + return + } + + if rt.OptionalExpiration != nil && rt.OptionalExpiration.Before(now) { + continue + } + if !yield(rt, err) { return } diff --git a/internal/datastore/mysql/datastore.go b/internal/datastore/mysql/datastore.go index ea39afd629..1462e6fbf6 100644 --- a/internal/datastore/mysql/datastore.go +++ b/internal/datastore/mysql/datastore.go @@ -259,6 +259,7 @@ func newMySQLDatastore(ctx context.Context, uri string, replicaIndex int, option common.WithPlaceholderFormat(sq.Question), common.WithNowFunction("NOW"), common.WithColumnOptimization(config.columnOptimizationOption), + common.WithExpirationDisabled(config.expirationDisabled), ) store := &Datastore{ diff --git a/internal/datastore/mysql/gc.go b/internal/datastore/mysql/gc.go index d4375eddbd..9558a93a34 100644 --- a/internal/datastore/mysql/gc.go +++ b/internal/datastore/mysql/gc.go @@ -108,6 +108,10 @@ func (mds *Datastore) DeleteBeforeTx( } func (mds *Datastore) DeleteExpiredRels(ctx context.Context) (int64, error) { + if mds.schema.ExpirationDisabled { + return 0, nil + } + now, err := mds.Now(ctx) if err != nil { return 0, err diff --git a/internal/datastore/mysql/options.go b/internal/datastore/mysql/options.go index 713e97e671..4405dbcaba 100644 --- a/internal/datastore/mysql/options.go +++ b/internal/datastore/mysql/options.go @@ -27,6 +27,7 @@ const ( defaultCredentialsProviderName = "" defaultFilterMaximumIDCount = 100 defaultColumnOptimizationOption = common.ColumnOptimizationOptionNone + defaultExpirationDisabled = false ) type mysqlOptions struct { @@ -50,6 +51,7 @@ type mysqlOptions struct { filterMaximumIDCount uint16 allowedMigrations []string columnOptimizationOption common.ColumnOptimizationOption + expirationDisabled bool } // Option provides the facility to configure how clients within the @@ -74,6 +76,7 @@ func generateConfig(options []Option) (mysqlOptions, error) { credentialsProviderName: defaultCredentialsProviderName, filterMaximumIDCount: defaultFilterMaximumIDCount, columnOptimizationOption: defaultColumnOptimizationOption, + expirationDisabled: defaultExpirationDisabled, } for _, option := range options { @@ -284,3 +287,10 @@ func WithColumnOptimization(isEnabled bool) Option { } } } + +// WithExpirationDisabled disables the expiration of relationships in the MySQL datastore. +func WithExpirationDisabled(isDisabled bool) Option { + return func(mo *mysqlOptions) { + mo.expirationDisabled = isDisabled + } +} diff --git a/internal/datastore/postgres/gc.go b/internal/datastore/postgres/gc.go index 674082171d..a7419f70d8 100644 --- a/internal/datastore/postgres/gc.go +++ b/internal/datastore/postgres/gc.go @@ -77,6 +77,10 @@ func (pgd *pgDatastore) TxIDBefore(ctx context.Context, before time.Time) (datas } func (pgd *pgDatastore) DeleteExpiredRels(ctx context.Context) (int64, error) { + if pgd.schema.ExpirationDisabled { + return 0, nil + } + now, err := pgd.Now(ctx) if err != nil { return 0, err diff --git a/internal/datastore/postgres/options.go b/internal/datastore/postgres/options.go index 0997c5be16..82a42de6af 100644 --- a/internal/datastore/postgres/options.go +++ b/internal/datastore/postgres/options.go @@ -29,6 +29,7 @@ type postgresOptions struct { analyzeBeforeStatistics bool gcEnabled bool readStrictMode bool + expirationDisabled bool columnOptimizationOption common.ColumnOptimizationOption includeQueryParametersInTraces bool @@ -72,6 +73,7 @@ const ( defaultFilterMaximumIDCount = 100 defaultColumnOptimizationOption = common.ColumnOptimizationOptionNone defaultIncludeQueryParametersInTraces = false + defaultExpirationDisabled = false ) // Option provides the facility to configure how clients within the @@ -94,8 +96,9 @@ func generateConfig(options []Option) (postgresOptions, error) { readStrictMode: defaultReadStrictMode, queryInterceptor: nil, filterMaximumIDCount: defaultFilterMaximumIDCount, - includeQueryParametersInTraces: defaultIncludeQueryParametersInTraces, columnOptimizationOption: defaultColumnOptimizationOption, + includeQueryParametersInTraces: defaultIncludeQueryParametersInTraces, + expirationDisabled: defaultExpirationDisabled, } for _, option := range options { @@ -400,3 +403,8 @@ func WithColumnOptimization(isEnabled bool) Option { } } } + +// WithExpirationDisabled disables support for relationship expiration. +func WithExpirationDisabled(isDisabled bool) Option { + return func(po *postgresOptions) { po.expirationDisabled = isDisabled } +} diff --git a/internal/datastore/postgres/postgres.go b/internal/datastore/postgres/postgres.go index e62c9f9dce..8f5d66b068 100644 --- a/internal/datastore/postgres/postgres.go +++ b/internal/datastore/postgres/postgres.go @@ -331,6 +331,7 @@ func newPostgresDatastore( common.WithPlaceholderFormat(sq.Dollar), common.WithNowFunction("NOW"), common.WithColumnOptimization(config.columnOptimizationOption), + common.WithExpirationDisabled(config.expirationDisabled), ) datastore := &pgDatastore{ diff --git a/internal/datastore/spanner/options.go b/internal/datastore/spanner/options.go index a0ae438a23..75f84cf4b9 100644 --- a/internal/datastore/spanner/options.go +++ b/internal/datastore/spanner/options.go @@ -28,6 +28,7 @@ type spannerOptions struct { allowedMigrations []string filterMaximumIDCount uint16 columnOptimizationOption common.ColumnOptimizationOption + expirationDisabled bool } type migrationPhase uint8 @@ -52,6 +53,7 @@ const ( maxRevisionQuantization = 24 * time.Hour defaultFilterMaximumIDCount = 100 defaultColumnOptimizationOption = common.ColumnOptimizationOptionNone + defaultExpirationDisabled = false ) // Option provides the facility to configure how clients within the Spanner @@ -76,6 +78,7 @@ func generateConfig(options []Option) (spannerOptions, error) { migrationPhase: "", // no migration filterMaximumIDCount: defaultFilterMaximumIDCount, columnOptimizationOption: defaultColumnOptimizationOption, + expirationDisabled: defaultExpirationDisabled, } for _, option := range options { @@ -240,3 +243,10 @@ func WithColumnOptimization(isEnabled bool) Option { } } } + +// WithExpirationDisabled disables relationship expiration support in the Spanner. +func WithExpirationDisabled(isDisabled bool) Option { + return func(po *spannerOptions) { + po.expirationDisabled = isDisabled + } +} diff --git a/internal/graph/check.go b/internal/graph/check.go index 63df50d0bf..f5001275db 100644 --- a/internal/graph/check.go +++ b/internal/graph/check.go @@ -325,8 +325,12 @@ func (cc *ConcurrentChecker) checkDirect(ctx context.Context, crc currentRequest hasNonTerminals := false hasDirectSubject := false hasWildcardSubject := false + directSubjectOrWildcardCanHaveCaveats := false + directSubjectOrWildcardCanHaveExpiration := false + nonTerminalsCanHaveCaveats := false + nonTerminalsCanHaveExpiration := false defer func() { if hasNonTerminals { @@ -355,6 +359,10 @@ func (cc *ConcurrentChecker) checkDirect(ctx context.Context, crc currentRequest if allowedDirectRelation.RequiredCaveat != nil { directSubjectOrWildcardCanHaveCaveats = true } + + if allowedDirectRelation.RequiredExpiration != nil { + directSubjectOrWildcardCanHaveExpiration = true + } } // If the relation found is not an ellipsis, then this is a nested relation that @@ -367,6 +375,9 @@ func (cc *ConcurrentChecker) checkDirect(ctx context.Context, crc currentRequest if allowedDirectRelation.RequiredCaveat != nil { nonTerminalsCanHaveCaveats = true } + if allowedDirectRelation.RequiredExpiration != nil { + nonTerminalsCanHaveExpiration = true + } } } @@ -405,7 +416,10 @@ func (cc *ConcurrentChecker) checkDirect(ctx context.Context, crc currentRequest OptionalSubjectsSelectors: subjectSelectors, } - it, err := ds.QueryRelationships(ctx, filter, options.WithSkipCaveats(!directSubjectOrWildcardCanHaveCaveats)) + it, err := ds.QueryRelationships(ctx, filter, + options.WithSkipCaveats(!directSubjectOrWildcardCanHaveCaveats), + options.WithSkipExpiration(!directSubjectOrWildcardCanHaveExpiration), + ) if err != nil { return checkResultError(NewCheckFailureErr(err), emptyMetadata) } @@ -454,7 +468,10 @@ func (cc *ConcurrentChecker) checkDirect(ctx context.Context, crc currentRequest }, } - it, err := ds.QueryRelationships(ctx, filter, options.WithSkipCaveats(!nonTerminalsCanHaveCaveats)) + it, err := ds.QueryRelationships(ctx, filter, + options.WithSkipCaveats(!nonTerminalsCanHaveCaveats), + options.WithSkipExpiration(!nonTerminalsCanHaveExpiration), + ) if err != nil { return checkResultError(NewCheckFailureErr(err), emptyMetadata) } @@ -648,6 +665,56 @@ func (cc *ConcurrentChecker) checkComputedUserset(ctx context.Context, crc curre return combineResultWithFoundResources(result, membershipSet) } +// queryOptionsForArrowRelation returns query options such as SkipCaveats and SkipExpiration if *none* of the subject +// types of the given relation support caveats or expiration. +func (cc *ConcurrentChecker) queryOptionsForArrowRelation(ctx context.Context, reader datastore.Reader, namespaceName string, relationName string) ([]options.QueryOptionsOption, error) { + // TODO(jschorr): Change to use the type system once we wire it through Check dispatch. + nsDefs, err := reader.LookupNamespacesWithNames(ctx, []string{namespaceName}) + if err != nil { + return nil, err + } + + if len(nsDefs) != 1 { + return nil, nil + } + + var relation *core.Relation + for _, rel := range nsDefs[0].Definition.Relation { + if rel.Name == relationName { + relation = rel + break + } + } + + if relation == nil || relation.TypeInformation == nil { + return nil, nil + } + + hasCaveats := false + hasExpiration := false + + for _, allowedDirectRelation := range relation.TypeInformation.GetAllowedDirectRelations() { + if allowedDirectRelation.RequiredCaveat != nil { + hasCaveats = true + } + + if allowedDirectRelation.RequiredExpiration != nil { + hasExpiration = true + } + } + + opts := make([]options.QueryOptionsOption, 0, 2) + if !hasCaveats { + opts = append(opts, options.WithSkipCaveats(true)) + } + + if !hasExpiration { + opts = append(opts, options.WithSkipExpiration(true)) + } + + return opts, nil +} + func filterForFoundMemberResource(resourceRelation *core.RelationReference, resourceIds []string, subject *core.ObjectAndRelation) (*MembershipSet, []string) { if resourceRelation.Namespace != subject.Namespace || resourceRelation.Relation != subject.Relation { return nil, resourceIds @@ -698,11 +765,16 @@ func checkIntersectionTupleToUserset( // Query for the subjects over which to walk the TTU. log.Ctx(ctx).Trace().Object("intersectionttu", crc.parentReq).Send() ds := datastoremw.MustFromContext(ctx).SnapshotReader(crc.parentReq.Revision) + queryOpts, err := cc.queryOptionsForArrowRelation(ctx, ds, crc.parentReq.ResourceRelation.Namespace, ttu.GetTupleset().GetRelation()) + if err != nil { + return checkResultError(NewCheckFailureErr(err), emptyMetadata) + } + it, err := ds.QueryRelationships(ctx, datastore.RelationshipsFilter{ OptionalResourceType: crc.parentReq.ResourceRelation.Namespace, OptionalResourceIds: crc.filteredResourceIDs, OptionalResourceRelation: ttu.GetTupleset().GetRelation(), - }) + }, queryOpts...) if err != nil { return checkResultError(NewCheckFailureErr(err), emptyMetadata) } @@ -859,11 +931,17 @@ func checkTupleToUserset[T relation]( log.Ctx(ctx).Trace().Object("ttu", crc.parentReq).Send() ds := datastoremw.MustFromContext(ctx).SnapshotReader(crc.parentReq.Revision) + + queryOpts, err := cc.queryOptionsForArrowRelation(ctx, ds, crc.parentReq.ResourceRelation.Namespace, ttu.GetTupleset().GetRelation()) + if err != nil { + return checkResultError(NewCheckFailureErr(err), emptyMetadata) + } + it, err := ds.QueryRelationships(ctx, datastore.RelationshipsFilter{ OptionalResourceType: crc.parentReq.ResourceRelation.Namespace, OptionalResourceIds: filteredResourceIDs, OptionalResourceRelation: ttu.GetTupleset().GetRelation(), - }) + }, queryOpts...) if err != nil { return checkResultError(NewCheckFailureErr(err), emptyMetadata) } diff --git a/internal/services/steelthreadtesting/steelthread_test.go b/internal/services/steelthreadtesting/steelthread_test.go index b34485c489..39733522a1 100644 --- a/internal/services/steelthreadtesting/steelthread_test.go +++ b/internal/services/steelthreadtesting/steelthread_test.go @@ -31,13 +31,7 @@ const defaultConnBufferSize = humanize.MiByte func TestMemdbSteelThreads(t *testing.T) { for _, tc := range steelThreadTestCases { t.Run(tc.name, func(t *testing.T) { -<<<<<<< HEAD - emptyDS, err := memdb.NewMemdbDatastore(0, 5*time.Second, 2*time.Hour) -======= - t.Parallel() - emptyDS, err := dsfortesting.NewMemDBDatastoreForTesting(0, 5*time.Second, 2*time.Hour) ->>>>>>> 963e4a60 (Change tests to use a new entrypoint for creating memdb for testing) require.NoError(t, err) runSteelThreadTest(t, tc, emptyDS) diff --git a/internal/services/v1/experimental_test.go b/internal/services/v1/experimental_test.go index 595878c3ea..dd84ebda37 100644 --- a/internal/services/v1/experimental_test.go +++ b/internal/services/v1/experimental_test.go @@ -10,7 +10,6 @@ import ( "strconv" "testing" - "github.com/authzed/authzed-go/pkg/responsemeta" v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" "github.com/authzed/grpcutil" "github.com/ccoveille/go-safecast" @@ -435,10 +434,9 @@ func TestBulkCheckPermission(t *testing.T) { defer cleanup() testCases := []struct { - name string - requests []string - response []bulkCheckTest - expectedDispatchCount int + name string + requests []string + response []bulkCheckTest }{ { name: "same resource and permission, different subjects", @@ -461,7 +459,6 @@ func TestBulkCheckPermission(t *testing.T) { resp: v1.CheckPermissionResponse_PERMISSIONSHIP_NO_PERMISSION, }, }, - expectedDispatchCount: 49, }, { name: "different resources, same permission and subject", @@ -484,7 +481,6 @@ func TestBulkCheckPermission(t *testing.T) { resp: v1.CheckPermissionResponse_PERMISSIONSHIP_NO_PERMISSION, }, }, - expectedDispatchCount: 18, }, { name: "some items fail", @@ -507,36 +503,6 @@ func TestBulkCheckPermission(t *testing.T) { err: namespace.NewNamespaceNotFoundErr("superfake"), }, }, - expectedDispatchCount: 17, - }, - { - name: "different caveat context is not clustered", - requests: []string{ - `document:masterplan#view@user:eng_lead[test:{"secret": "1234"}]`, - `document:companyplan#view@user:eng_lead[test:{"secret": "1234"}]`, - `document:masterplan#view@user:eng_lead[test:{"secret": "4321"}]`, - `document:masterplan#view@user:eng_lead`, - }, - response: []bulkCheckTest{ - { - req: `document:masterplan#view@user:eng_lead[test:{"secret": "1234"}]`, - resp: v1.CheckPermissionResponse_PERMISSIONSHIP_HAS_PERMISSION, - }, - { - req: `document:companyplan#view@user:eng_lead[test:{"secret": "1234"}]`, - resp: v1.CheckPermissionResponse_PERMISSIONSHIP_NO_PERMISSION, - }, - { - req: `document:masterplan#view@user:eng_lead[test:{"secret": "4321"}]`, - resp: v1.CheckPermissionResponse_PERMISSIONSHIP_NO_PERMISSION, - }, - { - req: `document:masterplan#view@user:eng_lead`, - resp: v1.CheckPermissionResponse_PERMISSIONSHIP_CONDITIONAL_PERMISSION, - partial: []string{"secret"}, - }, - }, - expectedDispatchCount: 50, }, { name: "namespace validation", @@ -554,7 +520,6 @@ func TestBulkCheckPermission(t *testing.T) { err: namespace.NewNamespaceNotFoundErr("fake"), }, }, - expectedDispatchCount: 1, }, { name: "chunking test", @@ -577,7 +542,6 @@ func TestBulkCheckPermission(t *testing.T) { return toReturn })(), - expectedDispatchCount: 11, }, { name: "chunking test with errors", @@ -607,7 +571,6 @@ func TestBulkCheckPermission(t *testing.T) { return toReturn })(), - expectedDispatchCount: 11, }, { name: "same resource and permission with same subject, repeated", @@ -625,7 +588,6 @@ func TestBulkCheckPermission(t *testing.T) { resp: v1.CheckPermissionResponse_PERMISSIONSHIP_HAS_PERMISSION, }, }, - expectedDispatchCount: 17, }, } @@ -694,10 +656,6 @@ func TestBulkCheckPermission(t *testing.T) { actual, err := client.BulkCheckPermission(context.Background(), &req, grpc.Trailer(&trailer)) require.NoError(t, err) - dispatchCount, err := responsemeta.GetIntResponseTrailerMetadata(trailer, responsemeta.DispatchedOperationsCount) - require.NoError(t, err) - require.Equal(t, tt.expectedDispatchCount, dispatchCount) - testutil.RequireProtoSlicesEqual(t, expected, actual.Pairs, nil, "response bulk check pairs did not match") }) } diff --git a/internal/services/v1/permissions_test.go b/internal/services/v1/permissions_test.go index 86b3a70168..7bb0c87b72 100644 --- a/internal/services/v1/permissions_test.go +++ b/internal/services/v1/permissions_test.go @@ -1027,9 +1027,9 @@ func TestCheckWithCaveats(t *testing.T) { AtLeastAsFresh: zedtoken.MustNewFromRevision(revision), }, }, - Resource: obj("document", "companyplan"), - Permission: "view", - Subject: sub("user", "owner", ""), + Resource: obj("document", "caveatedplan"), + Permission: "caveated_viewer", + Subject: sub("user", "caveatedguy", ""), } // caveat evaluated and returned false @@ -1774,10 +1774,9 @@ func TestCheckBulkPermissions(t *testing.T) { defer cleanup() testCases := []struct { - name string - requests []string - response []bulkCheckTest - expectedDispatchCount int + name string + requests []string + response []bulkCheckTest }{ { name: "same resource and permission, different subjects", @@ -1800,7 +1799,6 @@ func TestCheckBulkPermissions(t *testing.T) { resp: v1.CheckPermissionResponse_PERMISSIONSHIP_NO_PERMISSION, }, }, - expectedDispatchCount: 49, }, { name: "different resources, same permission and subject", @@ -1823,7 +1821,6 @@ func TestCheckBulkPermissions(t *testing.T) { resp: v1.CheckPermissionResponse_PERMISSIONSHIP_NO_PERMISSION, }, }, - expectedDispatchCount: 18, }, { name: "some items fail", @@ -1846,36 +1843,29 @@ func TestCheckBulkPermissions(t *testing.T) { err: namespace.NewNamespaceNotFoundErr("superfake"), }, }, - expectedDispatchCount: 17, }, { name: "different caveat context is not clustered", requests: []string{ - `document:masterplan#view@user:eng_lead[test:{"secret": "1234"}]`, - `document:companyplan#view@user:eng_lead[test:{"secret": "1234"}]`, - `document:masterplan#view@user:eng_lead[test:{"secret": "4321"}]`, - `document:masterplan#view@user:eng_lead`, + `document:caveatedplan#caveated_viewer@user:caveatedguy[test:{"secret": "1234"}]`, + `document:caveatedplan#caveated_viewer@user:caveatedguy[test:{"secret": "4321"}]`, + `document:caveatedplan#caveated_viewer@user:caveatedguy`, }, response: []bulkCheckTest{ { - req: `document:masterplan#view@user:eng_lead[test:{"secret": "1234"}]`, + req: `document:caveatedplan#caveated_viewer@user:caveatedguy[test:{"secret": "1234"}]`, resp: v1.CheckPermissionResponse_PERMISSIONSHIP_HAS_PERMISSION, }, { - req: `document:companyplan#view@user:eng_lead[test:{"secret": "1234"}]`, - resp: v1.CheckPermissionResponse_PERMISSIONSHIP_NO_PERMISSION, - }, - { - req: `document:masterplan#view@user:eng_lead[test:{"secret": "4321"}]`, + req: `document:caveatedplan#caveated_viewer@user:caveatedguy[test:{"secret": "4321"}]`, resp: v1.CheckPermissionResponse_PERMISSIONSHIP_NO_PERMISSION, }, { - req: `document:masterplan#view@user:eng_lead`, + req: `document:caveatedplan#caveated_viewer@user:caveatedguy`, resp: v1.CheckPermissionResponse_PERMISSIONSHIP_CONDITIONAL_PERMISSION, partial: []string{"secret"}, }, }, - expectedDispatchCount: 50, }, { name: "namespace validation", @@ -1893,7 +1883,6 @@ func TestCheckBulkPermissions(t *testing.T) { err: namespace.NewNamespaceNotFoundErr("fake"), }, }, - expectedDispatchCount: 1, }, { name: "chunking test", @@ -1916,7 +1905,6 @@ func TestCheckBulkPermissions(t *testing.T) { return toReturn })(), - expectedDispatchCount: 11, }, { name: "chunking test with errors", @@ -1946,7 +1934,6 @@ func TestCheckBulkPermissions(t *testing.T) { return toReturn })(), - expectedDispatchCount: 11, }, { name: "same resource and permission with same subject, repeated", @@ -1964,7 +1951,6 @@ func TestCheckBulkPermissions(t *testing.T) { resp: v1.CheckPermissionResponse_PERMISSIONSHIP_HAS_PERMISSION, }, }, - expectedDispatchCount: 17, }, } diff --git a/internal/testfixtures/datastore.go b/internal/testfixtures/datastore.go index b73d9235dc..cc448ed5e3 100644 --- a/internal/testfixtures/datastore.go +++ b/internal/testfixtures/datastore.go @@ -4,7 +4,6 @@ import ( "context" "github.com/stretchr/testify/require" - "google.golang.org/protobuf/types/known/structpb" "github.com/authzed/spicedb/internal/datastore/common" "github.com/authzed/spicedb/internal/namespace" @@ -141,6 +140,10 @@ var StandardRelationships = []string{ "document:ownerplan#viewer@user:owner#...", } +var StandardCaveatedRelationships = []string{ + "document:caveatedplan#caveated_viewer@user:caveatedguy#...[test:{\"expectedSecret\":\"1234\"}]", +} + // EmptyDatastore returns an empty datastore for testing. func EmptyDatastore(ds datastore.Datastore, require *require.Assertions) (datastore.Datastore, datastore.Revision) { rev, err := ds.HeadRevision(context.Background()) @@ -185,16 +188,17 @@ func StandardDatastoreWithCaveatedData(ds datastore.Datastore, require *require. }) require.NoError(err) - rels := make([]tuple.Relationship, 0, len(StandardRelationships)) + rels := make([]tuple.Relationship, 0, len(StandardRelationships)+len(StandardCaveatedRelationships)) for _, tupleStr := range StandardRelationships { rel, err := tuple.Parse(tupleStr) require.NoError(err) require.NotNil(rel) - - rel.OptionalCaveat = &core.ContextualizedCaveat{ - CaveatName: "test", - Context: mustProtoStruct(map[string]any{"expectedSecret": "1234"}), - } + rels = append(rels, rel) + } + for _, tupleStr := range StandardCaveatedRelationships { + rel, err := tuple.Parse(tupleStr) + require.NoError(err) + require.NotNil(rel) rels = append(rels, rel) } @@ -360,11 +364,3 @@ func (tc RelationshipChecker) NoRelationshipExists(ctx context.Context, rel tupl iter := tc.ExactRelationshipIterator(ctx, rel, rev) tc.VerifyIteratorResults(iter) } - -func mustProtoStruct(in map[string]any) *structpb.Struct { - out, err := structpb.NewStruct(in) - if err != nil { - panic(err) - } - return out -} diff --git a/internal/testserver/datastore/config/config.go b/internal/testserver/datastore/config/config.go index 960a22cadf..f745a73cdd 100644 --- a/internal/testserver/datastore/config/config.go +++ b/internal/testserver/datastore/config/config.go @@ -24,6 +24,7 @@ func DatastoreConfigInitFunc(t testing.TB, options ...dsconfig.ConfigOption) tes append(options, dsconfig.WithEngine(engine), dsconfig.WithEnableDatastoreMetrics(false), + dsconfig.WithEnableExperimentalRelationshipExpiration(true), dsconfig.WithURI(uri), )...) require.NoError(t, err) diff --git a/pkg/cmd/datastore/datastore.go b/pkg/cmd/datastore/datastore.go index 2248697e51..08ed140fc9 100644 --- a/pkg/cmd/datastore/datastore.go +++ b/pkg/cmd/datastore/datastore.go @@ -168,7 +168,8 @@ type Config struct { AllowedMigrations []string `debugmap:"visible"` // Expermimental - ExperimentalColumnOptimization bool `debugmap:"visible"` + ExperimentalColumnOptimization bool `debugmap:"visible"` + EnableExperimentalRelationshipExpiration bool `debugmap:"visible"` } //go:generate go run github.com/ecordell/optgen -sensitive-field-name-matches uri,secure -output zz_generated.relintegritykey.options.go . RelIntegrityKey @@ -281,49 +282,50 @@ func RegisterDatastoreFlagsWithPrefix(flagSet *pflag.FlagSet, prefix string, opt func DefaultDatastoreConfig() *Config { return &Config{ - Engine: MemoryEngine, - GCWindow: 24 * time.Hour, - LegacyFuzzing: -1, - RevisionQuantization: 5 * time.Second, - MaxRevisionStalenessPercent: .1, // 10% - ReadConnPool: *DefaultReadConnPool(), - WriteConnPool: *DefaultWriteConnPool(), - ReadReplicaConnPool: *DefaultReadConnPool(), - ReadReplicaURIs: []string{}, - ReadOnly: false, - MaxRetries: 10, - OverlapKey: "key", - OverlapStrategy: "static", - ConnectRate: 100 * time.Millisecond, - EnableConnectionBalancing: true, - GCInterval: 3 * time.Minute, - GCMaxOperationTime: 1 * time.Minute, - WatchBufferLength: 1024, - WatchBufferWriteTimeout: 1 * time.Second, - WatchConnectTimeout: 1 * time.Second, - EnableDatastoreMetrics: true, - DisableStats: false, - BootstrapFiles: []string{}, - BootstrapTimeout: 10 * time.Second, - BootstrapOverwrite: false, - RequestHedgingEnabled: false, - RequestHedgingInitialSlowValue: 10000000, - RequestHedgingMaxRequests: 1_000_000, - RequestHedgingQuantile: 0.95, - SpannerCredentialsFile: "", - SpannerEmulatorHost: "", - TablePrefix: "", - MigrationPhase: "", - FollowerReadDelay: 4_800 * time.Millisecond, - SpannerMinSessions: 100, - SpannerMaxSessions: 400, - FilterMaximumIDCount: 100, - RelationshipIntegrityEnabled: false, - RelationshipIntegrityCurrentKey: RelIntegrityKey{}, - RelationshipIntegrityExpiredKeys: []string{}, - AllowedMigrations: []string{}, - ExperimentalColumnOptimization: false, - IncludeQueryParametersInTraces: false, + Engine: MemoryEngine, + GCWindow: 24 * time.Hour, + LegacyFuzzing: -1, + RevisionQuantization: 5 * time.Second, + MaxRevisionStalenessPercent: .1, // 10% + ReadConnPool: *DefaultReadConnPool(), + WriteConnPool: *DefaultWriteConnPool(), + ReadReplicaConnPool: *DefaultReadConnPool(), + ReadReplicaURIs: []string{}, + ReadOnly: false, + MaxRetries: 10, + OverlapKey: "key", + OverlapStrategy: "static", + ConnectRate: 100 * time.Millisecond, + EnableConnectionBalancing: true, + GCInterval: 3 * time.Minute, + GCMaxOperationTime: 1 * time.Minute, + WatchBufferLength: 1024, + WatchBufferWriteTimeout: 1 * time.Second, + WatchConnectTimeout: 1 * time.Second, + EnableDatastoreMetrics: true, + DisableStats: false, + BootstrapFiles: []string{}, + BootstrapTimeout: 10 * time.Second, + BootstrapOverwrite: false, + RequestHedgingEnabled: false, + RequestHedgingInitialSlowValue: 10000000, + RequestHedgingMaxRequests: 1_000_000, + RequestHedgingQuantile: 0.95, + SpannerCredentialsFile: "", + SpannerEmulatorHost: "", + TablePrefix: "", + MigrationPhase: "", + FollowerReadDelay: 4_800 * time.Millisecond, + SpannerMinSessions: 100, + SpannerMaxSessions: 400, + FilterMaximumIDCount: 100, + RelationshipIntegrityEnabled: false, + RelationshipIntegrityCurrentKey: RelIntegrityKey{}, + RelationshipIntegrityExpiredKeys: []string{}, + AllowedMigrations: []string{}, + ExperimentalColumnOptimization: false, + IncludeQueryParametersInTraces: false, + EnableExperimentalRelationshipExpiration: false, } } @@ -520,6 +522,7 @@ func newCRDBDatastore(ctx context.Context, opts Config) (datastore.Datastore, er crdb.AllowedMigrations(opts.AllowedMigrations), crdb.WithColumnOptimization(opts.ExperimentalColumnOptimization), crdb.IncludeQueryParametersInTraces(opts.IncludeQueryParametersInTraces), + crdb.WithExpirationDisabled(!opts.EnableExperimentalRelationshipExpiration), ) } @@ -562,6 +565,7 @@ func commonPostgresDatastoreOptions(opts Config) ([]postgres.Option, error) { postgres.FilterMaximumIDCount(opts.FilterMaximumIDCount), postgres.WithColumnOptimization(opts.ExperimentalColumnOptimization), postgres.IncludeQueryParametersInTraces(opts.IncludeQueryParametersInTraces), + postgres.WithExpirationDisabled(!opts.EnableExperimentalRelationshipExpiration), }, nil } @@ -645,6 +649,7 @@ func newSpannerDatastore(ctx context.Context, opts Config) (datastore.Datastore, spanner.AllowedMigrations(opts.AllowedMigrations), spanner.FilterMaximumIDCount(opts.FilterMaximumIDCount), spanner.WithColumnOptimization(opts.ExperimentalColumnOptimization), + spanner.WithExpirationDisabled(!opts.EnableExperimentalRelationshipExpiration), ) } @@ -690,6 +695,7 @@ func commonMySQLDatastoreOptions(opts Config) ([]mysql.Option, error) { mysql.FilterMaximumIDCount(opts.FilterMaximumIDCount), mysql.AllowedMigrations(opts.AllowedMigrations), mysql.WithColumnOptimization(opts.ExperimentalColumnOptimization), + mysql.WithExpirationDisabled(!opts.EnableExperimentalRelationshipExpiration), }, nil } diff --git a/pkg/cmd/datastore/zz_generated.options.go b/pkg/cmd/datastore/zz_generated.options.go index ab9d39c92e..4e0c537307 100644 --- a/pkg/cmd/datastore/zz_generated.options.go +++ b/pkg/cmd/datastore/zz_generated.options.go @@ -79,6 +79,7 @@ func (c *Config) ToOption() ConfigOption { to.MigrationPhase = c.MigrationPhase to.AllowedMigrations = c.AllowedMigrations to.ExperimentalColumnOptimization = c.ExperimentalColumnOptimization + to.EnableExperimentalRelationshipExpiration = c.EnableExperimentalRelationshipExpiration } } @@ -132,6 +133,7 @@ func (c Config) DebugMap() map[string]any { debugMap["MigrationPhase"] = helpers.DebugValue(c.MigrationPhase, false) debugMap["AllowedMigrations"] = helpers.DebugValue(c.AllowedMigrations, false) debugMap["ExperimentalColumnOptimization"] = helpers.DebugValue(c.ExperimentalColumnOptimization, false) + debugMap["EnableExperimentalRelationshipExpiration"] = helpers.DebugValue(c.EnableExperimentalRelationshipExpiration, false) return debugMap } @@ -528,3 +530,10 @@ func WithExperimentalColumnOptimization(experimentalColumnOptimization bool) Con c.ExperimentalColumnOptimization = experimentalColumnOptimization } } + +// WithEnableExperimentalRelationshipExpiration returns an option that can set EnableExperimentalRelationshipExpiration on a Config +func WithEnableExperimentalRelationshipExpiration(enableExperimentalRelationshipExpiration bool) ConfigOption { + return func(c *Config) { + c.EnableExperimentalRelationshipExpiration = enableExperimentalRelationshipExpiration + } +} diff --git a/pkg/cmd/server/server.go b/pkg/cmd/server/server.go index 5f2ca4a37b..3ded6efb7a 100644 --- a/pkg/cmd/server/server.go +++ b/pkg/cmd/server/server.go @@ -226,7 +226,9 @@ func (c *Config) Complete(ctx context.Context) (RunnableServer, error) { ds, err = datastorecfg.NewDatastore(context.Background(), c.DatastoreConfig.ToOption(), // Datastore's filter maximum ID count is set to the max size, since the number of elements to be dispatched // are at most the number of elements returned from a datastore query - datastorecfg.WithFilterMaximumIDCount(c.DispatchChunkSize)) + datastorecfg.WithFilterMaximumIDCount(c.DispatchChunkSize), + datastorecfg.WithEnableExperimentalRelationshipExpiration(c.EnableExperimentalRelationshipExpiration), + ) if err != nil { return nil, spiceerrors.NewTerminationErrorBuilder(fmt.Errorf("failed to create datastore: %w", err)). Component("datastore"). diff --git a/pkg/datastore/options/options.go b/pkg/datastore/options/options.go index 6a55c0582d..749005e136 100644 --- a/pkg/datastore/options/options.go +++ b/pkg/datastore/options/options.go @@ -43,10 +43,11 @@ func ToRelationship(c Cursor) *tuple.Relationship { // QueryOptions are the options that can affect the results of a normal forward query. type QueryOptions struct { - Limit *uint64 `debugmap:"visible"` - Sort SortOrder `debugmap:"visible"` - After Cursor `debugmap:"visible"` - SkipCaveats bool `debugmap:"visible"` + Limit *uint64 `debugmap:"visible"` + Sort SortOrder `debugmap:"visible"` + After Cursor `debugmap:"visible"` + SkipCaveats bool `debugmap:"visible"` + SkipExpiration bool `debugmap:"visible"` } // ReverseQueryOptions are the options that can affect the results of a reverse query. diff --git a/pkg/datastore/options/zz_generated.query_options.go b/pkg/datastore/options/zz_generated.query_options.go index 79db45999e..493e1c73d2 100644 --- a/pkg/datastore/options/zz_generated.query_options.go +++ b/pkg/datastore/options/zz_generated.query_options.go @@ -35,6 +35,7 @@ func (q *QueryOptions) ToOption() QueryOptionsOption { to.Sort = q.Sort to.After = q.After to.SkipCaveats = q.SkipCaveats + to.SkipExpiration = q.SkipExpiration } } @@ -45,6 +46,7 @@ func (q QueryOptions) DebugMap() map[string]any { debugMap["Sort"] = helpers.DebugValue(q.Sort, false) debugMap["After"] = helpers.DebugValue(q.After, false) debugMap["SkipCaveats"] = helpers.DebugValue(q.SkipCaveats, false) + debugMap["SkipExpiration"] = helpers.DebugValue(q.SkipExpiration, false) return debugMap } @@ -92,6 +94,13 @@ func WithSkipCaveats(skipCaveats bool) QueryOptionsOption { } } +// WithSkipExpiration returns an option that can set SkipExpiration on a QueryOptions +func WithSkipExpiration(skipExpiration bool) QueryOptionsOption { + return func(q *QueryOptions) { + q.SkipExpiration = skipExpiration + } +} + type ReverseQueryOptionsOption func(r *ReverseQueryOptions) // NewReverseQueryOptionsWithOptions creates a new ReverseQueryOptions with the passed in options set diff --git a/pkg/datastore/test/relationships.go b/pkg/datastore/test/relationships.go index 19f534d3dd..95049c15e6 100644 --- a/pkg/datastore/test/relationships.go +++ b/pkg/datastore/test/relationships.go @@ -1113,11 +1113,12 @@ func RecreateRelationshipsAfterDeleteWithFilter(t *testing.T, tester DatastoreTe // QueryRelationshipsWithVariousFiltersTest tests various relationship filters for query relationships. func QueryRelationshipsWithVariousFiltersTest(t *testing.T, tester DatastoreTester) { tcs := []struct { - name string - filter datastore.RelationshipsFilter - withoutCaveats bool - relationships []string - expected []string + name string + filter datastore.RelationshipsFilter + withoutCaveats bool + withoutExpiration bool + relationships []string + expected []string }{ { name: "resource type", @@ -1507,7 +1508,6 @@ func QueryRelationshipsWithVariousFiltersTest(t *testing.T, tester DatastoreTest "document:third#viewer@user:tom[secondcaveat:{\"bar\":\"baz\"}]", }, }, - { name: "relationship expiration", filter: datastore.RelationshipsFilter{ @@ -1553,6 +1553,24 @@ func QueryRelationshipsWithVariousFiltersTest(t *testing.T, tester DatastoreTest "document:first#viewer@user:tom", }, }, + { + name: "no caveats and no expiration", + filter: datastore.RelationshipsFilter{ + OptionalResourceType: "document", + }, + relationships: []string{ + "document:first#viewer@user:tom", + "document:first#viewer@user:fred", + "document:first#viewer@user:sarah", + }, + expected: []string{ + "document:first#viewer@user:tom", + "document:first#viewer@user:fred", + "document:first#viewer@user:sarah", + }, + withoutCaveats: true, + withoutExpiration: true, + }, { name: "multiple subject IDs with subject type", filter: datastore.RelationshipsFilter{ @@ -1577,7 +1595,20 @@ func QueryRelationshipsWithVariousFiltersTest(t *testing.T, tester DatastoreTest }, }, { - name: "multiple subject filters", + name: "relationships with expiration", + filter: datastore.RelationshipsFilter{ + OptionalResourceType: "document", + }, + relationships: []string{ + "document:first#expiring_viewer@user:tom[expiration:2020-01-01T00:00:00Z]", + "document:first#expiring_viewer@user:fred[expiration:2321-01-01T00:00:00Z]", + }, + expected: []string{ + "document:first#expiring_viewer@user:fred[expiration:2321-01-01T00:00:00Z]", + }, + }, + { + name: "multiple subject filters with multiple ids", filter: datastore.RelationshipsFilter{ OptionalSubjectsSelectors: []datastore.SubjectsSelector{ { @@ -1607,19 +1638,6 @@ func QueryRelationshipsWithVariousFiltersTest(t *testing.T, tester DatastoreTest "folder:secondfolder#viewer@anotheruser:jerry", }, }, - { - name: "relationships with expiration", - filter: datastore.RelationshipsFilter{ - OptionalResourceType: "document", - }, - relationships: []string{ - "document:first#expiring_viewer@user:tom[expiration:2020-01-01T00:00:00Z]", - "document:first#expiring_viewer@user:fred[expiration:2321-01-01T00:00:00Z]", - }, - expected: []string{ - "document:first#expiring_viewer@user:fred[expiration:2321-01-01T00:00:00Z]", - }, - }, } for _, tc := range tcs { @@ -1643,7 +1661,7 @@ func QueryRelationshipsWithVariousFiltersTest(t *testing.T, tester DatastoreTest require.NoError(err) reader := ds.SnapshotReader(headRev) - iter, err := reader.QueryRelationships(ctx, tc.filter, options.WithSkipCaveats(tc.withoutCaveats)) + iter, err := reader.QueryRelationships(ctx, tc.filter, options.WithSkipCaveats(tc.withoutCaveats), options.WithSkipExpiration(tc.withoutExpiration)) require.NoError(err) var results []string @@ -1657,14 +1675,6 @@ func QueryRelationshipsWithVariousFiltersTest(t *testing.T, tester DatastoreTest } } -// TypedTouchAlreadyExistingTest tests touching a relationship twice, when valid type information is provided. -func TypedTouchAlreadyExistingTest(t *testing.T, tester DatastoreTester) { - require := require.New(t) - - rawDS, err := tester.New(0, veryLargeGCInterval, veryLargeGCWindow, 1) - require.NoError(err) - - ds, _ := testfixtures.StandardDatastoreWithData(rawDS, require) ctx := context.Background() tpl1, err := tuple.Parse("document:foo#viewer@user:tom") @@ -1676,7 +1686,6 @@ func TypedTouchAlreadyExistingTest(t *testing.T, tester DatastoreTester) { _, err = common.WriteRelationships(ctx, ds, tuple.UpdateOperationTouch, tpl1) require.NoError(err) - ensureRelationships(ctx, require, ds, tpl1) } // RelationshipExpirationTest tests expiration on relationships. From 56d513b084133a146c7e550c8a0c537991c99676 Mon Sep 17 00:00:00 2001 From: Joseph Schorr Date: Mon, 16 Dec 2024 14:17:11 -0500 Subject: [PATCH 10/15] Add testing for expanded comparison logic in SQL builder Also adds a datastore test to ensure the constructed cursor operates as expected --- internal/datastore/common/relationships.go | 8 +- internal/datastore/common/sql.go | 26 +- internal/datastore/common/sql_test.go | 593 ++++++++++++++------- pkg/datastore/test/datastore.go | 1 + pkg/datastore/test/pagination.go | 59 ++ pkg/datastore/test/relationships.go | 35 +- 6 files changed, 497 insertions(+), 225 deletions(-) diff --git a/internal/datastore/common/relationships.go b/internal/datastore/common/relationships.go index 3b6f8a5156..f780dca876 100644 --- a/internal/datastore/common/relationships.go +++ b/internal/datastore/common/relationships.go @@ -69,6 +69,7 @@ func QueryRelationships[R Rows, C ~map[string]any](ctx context.Context, builder return func(yield func(tuple.Relationship, error) bool) { err := tx.QueryFunc(ctx, func(ctx context.Context, rows R) error { + span.AddEvent("Query issued to database") var r Rows = rows if crwe, ok := r.(closeRowsWithError); ok { defer LogOnError(ctx, crwe.Close) @@ -76,9 +77,12 @@ func QueryRelationships[R Rows, C ~map[string]any](ctx context.Context, builder defer cr.Close() } - span.AddEvent("Query issued to database") relCount := 0 for rows.Next() { + if relCount == 0 { + span.AddEvent("First row returned") + } + if err := rows.Scan(colsToSelect...); err != nil { return fmt.Errorf(errUnableToQueryRels, fmt.Errorf("scan err: %w", err)) } @@ -132,11 +136,11 @@ func QueryRelationships[R Rows, C ~map[string]any](ctx context.Context, builder } } + span.AddEvent("Rels loaded", trace.WithAttributes(attribute.Int("relCount", relCount))) if err := rows.Err(); err != nil { return fmt.Errorf(errUnableToQueryRels, fmt.Errorf("rows err: %w", err)) } - span.AddEvent("Rels loaded", trace.WithAttributes(attribute.Int("relCount", relCount))) return nil }, sqlString, args...) if err != nil { diff --git a/internal/datastore/common/sql.go b/internal/datastore/common/sql.go index f86131bcaa..7e4bacb5d5 100644 --- a/internal/datastore/common/sql.go +++ b/internal/datastore/common/sql.go @@ -58,12 +58,12 @@ const ( // TupleComparison uses a comparison with a compound key, // e.g. (namespace, object_id, relation) > ('ns', '123', 'viewer') // which is not compatible with all datastores. - TupleComparison + TupleComparison = 1 // ExpandedLogicComparison comparison uses a nested tree of ANDs and ORs to properly // filter out already received relationships. Useful for databases that do not support // tuple comparison, or do not execute it efficiently - ExpandedLogicComparison + ExpandedLogicComparison = 2 ) // ColumnOptimizationOption is an enumerator for column optimization options. @@ -83,12 +83,21 @@ type ColumnTracker struct { SingleValue *string } +type columnTrackerMap map[string]ColumnTracker + +func (ctm columnTrackerMap) hasStaticValue(columnName string) bool { + if r, ok := ctm[columnName]; ok && r.SingleValue != nil { + return true + } + return false +} + // SchemaQueryFilterer wraps a SchemaInformation and SelectBuilder to give an opinionated // way to build query objects. type SchemaQueryFilterer struct { schema SchemaInformation queryBuilder sq.SelectBuilder - filteringColumnTracker map[string]ColumnTracker + filteringColumnTracker columnTrackerMap filterMaximumIDCount uint16 isCustomQuery bool extraFields []string @@ -269,7 +278,7 @@ func (sqf SchemaQueryFilterer) After(cursor options.Cursor, order options.SortOr comparisonSlotCount := 0 for _, cav := range columnsAndValues { - if r, ok := sqf.filteringColumnTracker[cav.name]; !ok || r.SingleValue == nil { + if !sqf.filteringColumnTracker.hasStaticValue(cav.name) { columnNames = append(columnNames, cav.name) valueSlots = append(valueSlots, cav.value) comparisonSlotCount++ @@ -289,10 +298,10 @@ func (sqf SchemaQueryFilterer) After(cursor options.Cursor, order options.SortOr orClause := sq.Or{} for index, cav := range columnsAndValues { - if r, ok := sqf.filteringColumnTracker[cav.name]; !ok || r.SingleValue != nil { + if !sqf.filteringColumnTracker.hasStaticValue(cav.name) { andClause := sq.And{} for _, previous := range columnsAndValues[0:index] { - if r, ok := sqf.filteringColumnTracker[previous.name]; !ok || r.SingleValue != nil { + if !sqf.filteringColumnTracker.hasStaticValue(previous.name) { andClause = append(andClause, sq.Eq{previous.name: previous.value}) } } @@ -683,7 +692,7 @@ type RelationshipsQueryBuilder struct { SkipCaveats bool SkipExpiration bool - filteringValues map[string]ColumnTracker + filteringValues columnTrackerMap baseQueryBuilder SchemaQueryFilterer } @@ -735,9 +744,10 @@ func (b RelationshipsQueryBuilder) checkColumn(columns []string, colName string) return append(columns, colName) } - if r, ok := b.filteringValues[colName]; !ok || r.SingleValue == nil { + if !b.filteringValues.hasStaticValue(colName) { return append(columns, colName) } + return columns } diff --git a/internal/datastore/common/sql_test.go b/internal/datastore/common/sql_test.go index cee75c6ce4..2e19772245 100644 --- a/internal/datastore/common/sql_test.go +++ b/internal/datastore/common/sql_test.go @@ -2,6 +2,7 @@ package common import ( "context" + "fmt" "testing" "github.com/authzed/spicedb/pkg/datastore/options" @@ -17,78 +18,97 @@ import ( var toCursor = options.ToCursor +type expected struct { + sql string + args []any + staticCols []string +} + func TestSchemaQueryFilterer(t *testing.T) { tests := []struct { name string run func(filterer SchemaQueryFilterer) SchemaQueryFilterer - expectedSQL string - expectedArgs []any - expectedStaticColumns []string withExpirationDisabled bool + expectedForTuple expected + expectedForExpanded expected }{ { name: "relation filter", run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.FilterToRelation("somerelation") }, - expectedSQL: "SELECT * WHERE relation = ? AND (expiration IS NULL OR expiration > NOW())", - expectedArgs: []any{"somerelation"}, - expectedStaticColumns: []string{"relation"}, + expectedForTuple: expected{ + sql: "SELECT * WHERE relation = ? AND (expiration IS NULL OR expiration > NOW())", + args: []any{"somerelation"}, + staticCols: []string{"relation"}, + }, }, { name: "relation filter without expiration", run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.FilterToRelation("somerelation") }, - expectedSQL: "SELECT * WHERE relation = ?", - expectedArgs: []any{"somerelation"}, - expectedStaticColumns: []string{"relation"}, withExpirationDisabled: true, + expectedForTuple: expected{ + sql: "SELECT * WHERE relation = ?", + args: []any{"somerelation"}, + staticCols: []string{"relation"}, + }, }, { name: "resource ID filter", run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.FilterToResourceID("someresourceid") }, - expectedSQL: "SELECT * WHERE object_id = ? AND (expiration IS NULL OR expiration > NOW())", - expectedArgs: []any{"someresourceid"}, - expectedStaticColumns: []string{"object_id"}, + expectedForTuple: expected{ + sql: "SELECT * WHERE object_id = ? AND (expiration IS NULL OR expiration > NOW())", + args: []any{"someresourceid"}, + staticCols: []string{"object_id"}, + }, }, { name: "resource IDs filter", run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.MustFilterWithResourceIDPrefix("someprefix") }, - expectedSQL: "SELECT * WHERE object_id LIKE ? AND (expiration IS NULL OR expiration > NOW())", - expectedArgs: []any{"someprefix%"}, - expectedStaticColumns: []string{}, + expectedForTuple: expected{ + sql: "SELECT * WHERE object_id LIKE ? AND (expiration IS NULL OR expiration > NOW())", + args: []any{"someprefix%"}, + staticCols: []string{}, + }, }, { name: "resource IDs prefix filter", run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.MustFilterToResourceIDs([]string{"someresourceid", "anotherresourceid"}) }, - expectedSQL: "SELECT * WHERE object_id IN (?,?) AND (expiration IS NULL OR expiration > NOW())", - expectedArgs: []any{"someresourceid", "anotherresourceid"}, - expectedStaticColumns: []string{}, + expectedForTuple: expected{ + sql: "SELECT * WHERE object_id IN (?,?) AND (expiration IS NULL OR expiration > NOW())", + args: []any{"someresourceid", "anotherresourceid"}, + staticCols: []string{}, + }, }, { name: "resource type filter", run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.FilterToResourceType("sometype") }, - expectedSQL: "SELECT * WHERE ns = ? AND (expiration IS NULL OR expiration > NOW())", - expectedArgs: []any{"sometype"}, - expectedStaticColumns: []string{"ns"}, + expectedForTuple: expected{ + sql: "SELECT * WHERE ns = ? AND (expiration IS NULL OR expiration > NOW())", + args: []any{"sometype"}, + staticCols: []string{"ns"}, + }, }, { name: "resource filter", run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.FilterToResourceType("sometype").FilterToResourceID("someobj").FilterToRelation("somerel") }, - expectedSQL: "SELECT * WHERE ns = ? AND object_id = ? AND relation = ? AND (expiration IS NULL OR expiration > NOW())", - expectedArgs: []any{"sometype", "someobj", "somerel"}, - expectedStaticColumns: []string{"ns", "object_id", "relation"}, + expectedForTuple: expected{ + sql: "SELECT * WHERE ns = ? AND object_id = ? AND relation = ? AND (expiration IS NULL OR expiration > NOW())", + args: []any{"sometype", "someobj", "somerel"}, + staticCols: []string{"ns", "object_id", "relation"}, + }, }, { name: "relationships filter with no IDs or relations", @@ -97,9 +117,11 @@ func TestSchemaQueryFilterer(t *testing.T) { OptionalResourceType: "sometype", }) }, - expectedSQL: "SELECT * WHERE ns = ? AND (expiration IS NULL OR expiration > NOW())", - expectedArgs: []any{"sometype"}, - expectedStaticColumns: []string{"ns"}, + expectedForTuple: expected{ + sql: "SELECT * WHERE ns = ? AND (expiration IS NULL OR expiration > NOW())", + args: []any{"sometype"}, + staticCols: []string{"ns"}, + }, }, { name: "relationships filter with single ID", @@ -109,9 +131,11 @@ func TestSchemaQueryFilterer(t *testing.T) { OptionalResourceIds: []string{"someid"}, }) }, - expectedSQL: "SELECT * WHERE ns = ? AND object_id IN (?) AND (expiration IS NULL OR expiration > NOW())", - expectedArgs: []any{"sometype", "someid"}, - expectedStaticColumns: []string{"ns", "object_id"}, + expectedForTuple: expected{ + sql: "SELECT * WHERE ns = ? AND object_id IN (?) AND (expiration IS NULL OR expiration > NOW())", + args: []any{"sometype", "someid"}, + staticCols: []string{"ns", "object_id"}, + }, }, { name: "relationships filter with no IDs", @@ -121,9 +145,11 @@ func TestSchemaQueryFilterer(t *testing.T) { OptionalResourceIds: []string{}, }) }, - expectedSQL: "SELECT * WHERE ns = ? AND (expiration IS NULL OR expiration > NOW())", - expectedArgs: []any{"sometype"}, - expectedStaticColumns: []string{"ns"}, + expectedForTuple: expected{ + sql: "SELECT * WHERE ns = ? AND (expiration IS NULL OR expiration > NOW())", + args: []any{"sometype"}, + staticCols: []string{"ns"}, + }, }, { name: "relationships filter with multiple IDs", @@ -133,9 +159,11 @@ func TestSchemaQueryFilterer(t *testing.T) { OptionalResourceIds: []string{"someid", "anotherid"}, }) }, - expectedSQL: "SELECT * WHERE ns = ? AND object_id IN (?,?) AND (expiration IS NULL OR expiration > NOW())", - expectedArgs: []any{"sometype", "someid", "anotherid"}, - expectedStaticColumns: []string{"ns"}, + expectedForTuple: expected{ + sql: "SELECT * WHERE ns = ? AND object_id IN (?,?) AND (expiration IS NULL OR expiration > NOW())", + args: []any{"sometype", "someid", "anotherid"}, + staticCols: []string{"ns"}, + }, }, { name: "subjects filter with no IDs or relations", @@ -144,9 +172,11 @@ func TestSchemaQueryFilterer(t *testing.T) { OptionalSubjectType: "somesubjectype", }) }, - expectedSQL: "SELECT * WHERE ((subject_ns = ?)) AND (expiration IS NULL OR expiration > NOW())", - expectedArgs: []any{"somesubjectype"}, - expectedStaticColumns: []string{"subject_ns"}, + expectedForTuple: expected{ + sql: "SELECT * WHERE ((subject_ns = ?)) AND (expiration IS NULL OR expiration > NOW())", + args: []any{"somesubjectype"}, + staticCols: []string{"subject_ns"}, + }, }, { name: "multiple subjects filters with just types", @@ -157,9 +187,11 @@ func TestSchemaQueryFilterer(t *testing.T) { OptionalSubjectType: "anothersubjectype", }) }, - expectedSQL: "SELECT * WHERE ((subject_ns = ?) OR (subject_ns = ?)) AND (expiration IS NULL OR expiration > NOW())", - expectedArgs: []any{"somesubjectype", "anothersubjectype"}, - expectedStaticColumns: []string{}, + expectedForTuple: expected{ + sql: "SELECT * WHERE ((subject_ns = ?) OR (subject_ns = ?)) AND (expiration IS NULL OR expiration > NOW())", + args: []any{"somesubjectype", "anothersubjectype"}, + staticCols: []string{}, + }, }, { name: "subjects filter with single ID", @@ -169,9 +201,11 @@ func TestSchemaQueryFilterer(t *testing.T) { OptionalSubjectIds: []string{"somesubjectid"}, }) }, - expectedSQL: "SELECT * WHERE ((subject_ns = ? AND subject_object_id IN (?))) AND (expiration IS NULL OR expiration > NOW())", - expectedArgs: []any{"somesubjectype", "somesubjectid"}, - expectedStaticColumns: []string{"subject_ns", "subject_object_id"}, + expectedForTuple: expected{ + sql: "SELECT * WHERE ((subject_ns = ? AND subject_object_id IN (?))) AND (expiration IS NULL OR expiration > NOW())", + args: []any{"somesubjectype", "somesubjectid"}, + staticCols: []string{"subject_ns", "subject_object_id"}, + }, }, { name: "subjects filter with single ID and no type", @@ -180,18 +214,22 @@ func TestSchemaQueryFilterer(t *testing.T) { OptionalSubjectIds: []string{"somesubjectid"}, }) }, - expectedSQL: "SELECT * WHERE ((subject_object_id IN (?))) AND (expiration IS NULL OR expiration > NOW())", - expectedArgs: []any{"somesubjectid"}, - expectedStaticColumns: []string{"subject_object_id"}, + expectedForTuple: expected{ + sql: "SELECT * WHERE ((subject_object_id IN (?))) AND (expiration IS NULL OR expiration > NOW())", + args: []any{"somesubjectid"}, + staticCols: []string{"subject_object_id"}, + }, }, { name: "empty subjects filter", run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.MustFilterWithSubjectsSelectors(datastore.SubjectsSelector{}) }, - expectedSQL: "SELECT * WHERE ((1=1)) AND (expiration IS NULL OR expiration > NOW())", - expectedArgs: nil, - expectedStaticColumns: []string{}, + expectedForTuple: expected{ + sql: "SELECT * WHERE ((1=1)) AND (expiration IS NULL OR expiration > NOW())", + args: nil, + staticCols: []string{}, + }, }, { name: "subjects filter with multiple IDs", @@ -201,9 +239,11 @@ func TestSchemaQueryFilterer(t *testing.T) { OptionalSubjectIds: []string{"somesubjectid", "anothersubjectid"}, }) }, - expectedSQL: "SELECT * WHERE ((subject_ns = ? AND subject_object_id IN (?,?))) AND (expiration IS NULL OR expiration > NOW())", - expectedArgs: []any{"somesubjectype", "somesubjectid", "anothersubjectid"}, - expectedStaticColumns: []string{"subject_ns"}, + expectedForTuple: expected{ + sql: "SELECT * WHERE ((subject_ns = ? AND subject_object_id IN (?,?))) AND (expiration IS NULL OR expiration > NOW())", + args: []any{"somesubjectype", "somesubjectid", "anothersubjectid"}, + staticCols: []string{"subject_ns"}, + }, }, { name: "subjects filter with single ellipsis relation", @@ -213,9 +253,11 @@ func TestSchemaQueryFilterer(t *testing.T) { RelationFilter: datastore.SubjectRelationFilter{}.WithEllipsisRelation(), }) }, - expectedSQL: "SELECT * WHERE ((subject_ns = ? AND subject_relation = ?)) AND (expiration IS NULL OR expiration > NOW())", - expectedArgs: []any{"somesubjectype", "..."}, - expectedStaticColumns: []string{"subject_ns", "subject_relation"}, + expectedForTuple: expected{ + sql: "SELECT * WHERE ((subject_ns = ? AND subject_relation = ?)) AND (expiration IS NULL OR expiration > NOW())", + args: []any{"somesubjectype", "..."}, + staticCols: []string{"subject_ns", "subject_relation"}, + }, }, { name: "subjects filter with single defined relation", @@ -225,9 +267,11 @@ func TestSchemaQueryFilterer(t *testing.T) { RelationFilter: datastore.SubjectRelationFilter{}.WithNonEllipsisRelation("somesubrel"), }) }, - expectedSQL: "SELECT * WHERE ((subject_ns = ? AND subject_relation = ?)) AND (expiration IS NULL OR expiration > NOW())", - expectedArgs: []any{"somesubjectype", "somesubrel"}, - expectedStaticColumns: []string{"subject_ns", "subject_relation"}, + expectedForTuple: expected{ + sql: "SELECT * WHERE ((subject_ns = ? AND subject_relation = ?)) AND (expiration IS NULL OR expiration > NOW())", + args: []any{"somesubjectype", "somesubrel"}, + staticCols: []string{"subject_ns", "subject_relation"}, + }, }, { name: "subjects filter with only non-ellipsis", @@ -237,9 +281,11 @@ func TestSchemaQueryFilterer(t *testing.T) { RelationFilter: datastore.SubjectRelationFilter{}.WithOnlyNonEllipsisRelations(), }) }, - expectedSQL: "SELECT * WHERE ((subject_ns = ? AND subject_relation <> ?)) AND (expiration IS NULL OR expiration > NOW())", - expectedArgs: []any{"somesubjectype", "..."}, - expectedStaticColumns: []string{"subject_ns"}, + expectedForTuple: expected{ + sql: "SELECT * WHERE ((subject_ns = ? AND subject_relation <> ?)) AND (expiration IS NULL OR expiration > NOW())", + args: []any{"somesubjectype", "..."}, + staticCols: []string{"subject_ns"}, + }, }, { name: "subjects filter with defined relation and ellipsis", @@ -249,9 +295,11 @@ func TestSchemaQueryFilterer(t *testing.T) { RelationFilter: datastore.SubjectRelationFilter{}.WithNonEllipsisRelation("somesubrel").WithEllipsisRelation(), }) }, - expectedSQL: "SELECT * WHERE ((subject_ns = ? AND (subject_relation = ? OR subject_relation = ?))) AND (expiration IS NULL OR expiration > NOW())", - expectedArgs: []any{"somesubjectype", "...", "somesubrel"}, - expectedStaticColumns: []string{"subject_ns"}, + expectedForTuple: expected{ + sql: "SELECT * WHERE ((subject_ns = ? AND (subject_relation = ? OR subject_relation = ?))) AND (expiration IS NULL OR expiration > NOW())", + args: []any{"somesubjectype", "...", "somesubrel"}, + staticCols: []string{"subject_ns"}, + }, }, { name: "subjects filter", @@ -262,9 +310,11 @@ func TestSchemaQueryFilterer(t *testing.T) { RelationFilter: datastore.SubjectRelationFilter{}.WithNonEllipsisRelation("somesubrel").WithEllipsisRelation(), }) }, - expectedSQL: "SELECT * WHERE ((subject_ns = ? AND subject_object_id IN (?,?) AND (subject_relation = ? OR subject_relation = ?))) AND (expiration IS NULL OR expiration > NOW())", - expectedArgs: []any{"somesubjectype", "somesubjectid", "anothersubjectid", "...", "somesubrel"}, - expectedStaticColumns: []string{"subject_ns"}, + expectedForTuple: expected{ + sql: "SELECT * WHERE ((subject_ns = ? AND subject_object_id IN (?,?) AND (subject_relation = ? OR subject_relation = ?))) AND (expiration IS NULL OR expiration > NOW())", + args: []any{"somesubjectype", "somesubjectid", "anothersubjectid", "...", "somesubrel"}, + staticCols: []string{"subject_ns"}, + }, }, { name: "multiple subjects filter", @@ -286,9 +336,11 @@ func TestSchemaQueryFilterer(t *testing.T) { }, ) }, - expectedSQL: "SELECT * WHERE ((subject_ns = ? AND subject_object_id IN (?,?) AND (subject_relation = ? OR subject_relation = ?)) OR (subject_ns = ? AND subject_object_id IN (?,?) AND (subject_relation = ? OR subject_relation = ?)) OR (subject_ns = ? AND subject_relation <> ?)) AND (expiration IS NULL OR expiration > NOW())", - expectedArgs: []any{"somesubjectype", "a", "b", "...", "somesubrel", "anothersubjecttype", "b", "c", "...", "anotherrel", "thirdsubjectype", "..."}, - expectedStaticColumns: []string{}, + expectedForTuple: expected{ + sql: "SELECT * WHERE ((subject_ns = ? AND subject_object_id IN (?,?) AND (subject_relation = ? OR subject_relation = ?)) OR (subject_ns = ? AND subject_object_id IN (?,?) AND (subject_relation = ? OR subject_relation = ?)) OR (subject_ns = ? AND subject_relation <> ?)) AND (expiration IS NULL OR expiration > NOW())", + args: []any{"somesubjectype", "a", "b", "...", "somesubrel", "anothersubjecttype", "b", "c", "...", "anotherrel", "thirdsubjectype", "..."}, + staticCols: []string{}, + }, }, { name: "v1 subject filter with namespace", @@ -297,9 +349,11 @@ func TestSchemaQueryFilterer(t *testing.T) { SubjectType: "subns", }) }, - expectedSQL: "SELECT * WHERE subject_ns = ? AND (expiration IS NULL OR expiration > NOW())", - expectedArgs: []any{"subns"}, - expectedStaticColumns: []string{"subject_ns"}, + expectedForTuple: expected{ + sql: "SELECT * WHERE subject_ns = ? AND (expiration IS NULL OR expiration > NOW())", + args: []any{"subns"}, + staticCols: []string{"subject_ns"}, + }, }, { name: "v1 subject filter with subject id", @@ -309,9 +363,11 @@ func TestSchemaQueryFilterer(t *testing.T) { OptionalSubjectId: "subid", }) }, - expectedSQL: "SELECT * WHERE subject_ns = ? AND subject_object_id = ? AND (expiration IS NULL OR expiration > NOW())", - expectedArgs: []any{"subns", "subid"}, - expectedStaticColumns: []string{"subject_ns", "subject_object_id"}, + expectedForTuple: expected{ + sql: "SELECT * WHERE subject_ns = ? AND subject_object_id = ? AND (expiration IS NULL OR expiration > NOW())", + args: []any{"subns", "subid"}, + staticCols: []string{"subject_ns", "subject_object_id"}, + }, }, { name: "v1 subject filter with relation", @@ -323,9 +379,11 @@ func TestSchemaQueryFilterer(t *testing.T) { }, }) }, - expectedSQL: "SELECT * WHERE subject_ns = ? AND subject_relation = ? AND (expiration IS NULL OR expiration > NOW())", - expectedArgs: []any{"subns", "subrel"}, - expectedStaticColumns: []string{"subject_ns", "subject_relation"}, + expectedForTuple: expected{ + sql: "SELECT * WHERE subject_ns = ? AND subject_relation = ? AND (expiration IS NULL OR expiration > NOW())", + args: []any{"subns", "subrel"}, + staticCols: []string{"subject_ns", "subject_relation"}, + }, }, { name: "v1 subject filter with empty relation", @@ -337,9 +395,11 @@ func TestSchemaQueryFilterer(t *testing.T) { }, }) }, - expectedSQL: "SELECT * WHERE subject_ns = ? AND subject_relation = ? AND (expiration IS NULL OR expiration > NOW())", - expectedArgs: []any{"subns", "..."}, - expectedStaticColumns: []string{"subject_ns", "subject_relation"}, + expectedForTuple: expected{ + sql: "SELECT * WHERE subject_ns = ? AND subject_relation = ? AND (expiration IS NULL OR expiration > NOW())", + args: []any{"subns", "..."}, + staticCols: []string{"subject_ns", "subject_relation"}, + }, }, { name: "v1 subject filter", @@ -352,18 +412,22 @@ func TestSchemaQueryFilterer(t *testing.T) { }, }) }, - expectedSQL: "SELECT * WHERE subject_ns = ? AND subject_object_id = ? AND subject_relation = ? AND (expiration IS NULL OR expiration > NOW())", - expectedArgs: []any{"subns", "subid", "somerel"}, - expectedStaticColumns: []string{"subject_ns", "subject_object_id", "subject_relation"}, + expectedForTuple: expected{ + sql: "SELECT * WHERE subject_ns = ? AND subject_object_id = ? AND subject_relation = ? AND (expiration IS NULL OR expiration > NOW())", + args: []any{"subns", "subid", "somerel"}, + staticCols: []string{"subject_ns", "subject_object_id", "subject_relation"}, + }, }, { name: "limit", run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.limit(100) }, - expectedSQL: "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) LIMIT 100", - expectedArgs: nil, - expectedStaticColumns: []string{}, + expectedForTuple: expected{ + sql: "SELECT * WHERE (expiration IS NULL OR expiration > NOW()) LIMIT 100", + args: nil, + staticCols: []string{}, + }, }, { name: "full resources filter", @@ -383,9 +447,11 @@ func TestSchemaQueryFilterer(t *testing.T) { }, ) }, - expectedSQL: "SELECT * WHERE ns = ? AND relation = ? AND object_id IN (?,?) AND ((subject_ns = ? AND subject_object_id IN (?,?) AND (subject_relation = ? OR subject_relation = ?))) AND (expiration IS NULL OR expiration > NOW())", - expectedArgs: []any{"someresourcetype", "somerelation", "someid", "anotherid", "somesubjectype", "somesubjectid", "anothersubjectid", "...", "somesubrel"}, - expectedStaticColumns: []string{"ns", "relation", "subject_ns"}, + expectedForTuple: expected{ + sql: "SELECT * WHERE ns = ? AND relation = ? AND object_id IN (?,?) AND ((subject_ns = ? AND subject_object_id IN (?,?) AND (subject_relation = ? OR subject_relation = ?))) AND (expiration IS NULL OR expiration > NOW())", + args: []any{"someresourcetype", "somerelation", "someid", "anotherid", "somesubjectype", "somesubjectid", "anothersubjectid", "...", "somesubrel"}, + staticCols: []string{"ns", "relation", "subject_ns"}, + }, }, { name: "full resources filter without expiration", @@ -405,10 +471,12 @@ func TestSchemaQueryFilterer(t *testing.T) { }, ) }, - expectedSQL: "SELECT * WHERE ns = ? AND relation = ? AND object_id IN (?,?) AND ((subject_ns = ? AND subject_object_id IN (?,?) AND (subject_relation = ? OR subject_relation = ?)))", - expectedArgs: []any{"someresourcetype", "somerelation", "someid", "anotherid", "somesubjectype", "somesubjectid", "anothersubjectid", "...", "somesubrel"}, - expectedStaticColumns: []string{"ns", "relation", "subject_ns"}, withExpirationDisabled: true, + expectedForTuple: expected{ + sql: "SELECT * WHERE ns = ? AND relation = ? AND object_id IN (?,?) AND ((subject_ns = ? AND subject_object_id IN (?,?) AND (subject_relation = ? OR subject_relation = ?)))", + args: []any{"someresourcetype", "somerelation", "someid", "anotherid", "somesubjectype", "somesubjectid", "anothersubjectid", "...", "somesubrel"}, + staticCols: []string{"ns", "relation", "subject_ns"}, + }, }, { name: "order by", @@ -419,9 +487,11 @@ func TestSchemaQueryFilterer(t *testing.T) { }, ).TupleOrder(options.ByResource) }, - expectedSQL: "SELECT * WHERE ns = ? AND (expiration IS NULL OR expiration > NOW()) ORDER BY ns, object_id, relation, subject_ns, subject_object_id, subject_relation", - expectedArgs: []any{"someresourcetype"}, - expectedStaticColumns: []string{"ns"}, + expectedForTuple: expected{ + sql: "SELECT * WHERE ns = ? AND (expiration IS NULL OR expiration > NOW()) ORDER BY ns, object_id, relation, subject_ns, subject_object_id, subject_relation", + args: []any{"someresourcetype"}, + staticCols: []string{"ns"}, + }, }, { name: "after with just namespace", @@ -432,9 +502,16 @@ func TestSchemaQueryFilterer(t *testing.T) { }, ).After(toCursor(tuple.MustParse("someresourcetype:foo#viewer@user:bar")), options.ByResource) }, - expectedSQL: "SELECT * WHERE ns = ? AND (object_id,relation,subject_ns,subject_object_id,subject_relation) > (?,?,?,?,?) AND (expiration IS NULL OR expiration > NOW())", - expectedArgs: []any{"someresourcetype", "foo", "viewer", "user", "bar", "..."}, - expectedStaticColumns: []string{"ns"}, + expectedForTuple: expected{ + sql: "SELECT * WHERE ns = ? AND (object_id,relation,subject_ns,subject_object_id,subject_relation) > (?,?,?,?,?) AND (expiration IS NULL OR expiration > NOW())", + args: []any{"someresourcetype", "foo", "viewer", "user", "bar", "..."}, + staticCols: []string{"ns"}, + }, + expectedForExpanded: expected{ + sql: "SELECT * WHERE ns = ? AND ((object_id > ?) OR (object_id = ? AND relation > ?) OR (object_id = ? AND relation = ? AND subject_ns > ?) OR (object_id = ? AND relation = ? AND subject_ns = ? AND subject_object_id > ?) OR (object_id = ? AND relation = ? AND subject_ns = ? AND subject_object_id = ? AND subject_relation > ?)) AND (expiration IS NULL OR expiration > NOW())", + args: []any{"someresourcetype", "foo", "foo", "viewer", "foo", "viewer", "user", "foo", "viewer", "user", "bar", "foo", "viewer", "user", "bar", "..."}, + staticCols: []string{"ns"}, + }, }, { name: "after with just relation", @@ -445,9 +522,16 @@ func TestSchemaQueryFilterer(t *testing.T) { }, ).After(toCursor(tuple.MustParse("someresourcetype:foo#viewer@user:bar")), options.ByResource) }, - expectedSQL: "SELECT * WHERE relation = ? AND (ns,object_id,subject_ns,subject_object_id,subject_relation) > (?,?,?,?,?) AND (expiration IS NULL OR expiration > NOW())", - expectedArgs: []any{"somerelation", "someresourcetype", "foo", "user", "bar", "..."}, - expectedStaticColumns: []string{"relation"}, + expectedForTuple: expected{ + sql: "SELECT * WHERE relation = ? AND (ns,object_id,subject_ns,subject_object_id,subject_relation) > (?,?,?,?,?) AND (expiration IS NULL OR expiration > NOW())", + args: []any{"somerelation", "someresourcetype", "foo", "user", "bar", "..."}, + staticCols: []string{"relation"}, + }, + expectedForExpanded: expected{ + sql: "SELECT * WHERE relation = ? AND ((ns > ?) OR (ns = ? AND object_id > ?) OR (ns = ? AND object_id = ? AND subject_ns > ?) OR (ns = ? AND object_id = ? AND subject_ns = ? AND subject_object_id > ?) OR (ns = ? AND object_id = ? AND subject_ns = ? AND subject_object_id = ? AND subject_relation > ?)) AND (expiration IS NULL OR expiration > NOW())", + args: []any{"somerelation", "someresourcetype", "someresourcetype", "foo", "someresourcetype", "foo", "user", "someresourcetype", "foo", "user", "bar", "someresourcetype", "foo", "user", "bar", "..."}, + staticCols: []string{"relation"}, + }, }, { name: "after with namespace and single resource id", @@ -459,9 +543,16 @@ func TestSchemaQueryFilterer(t *testing.T) { }, ).After(toCursor(tuple.MustParse("someresourcetype:foo#viewer@user:bar")), options.ByResource) }, - expectedSQL: "SELECT * WHERE ns = ? AND object_id IN (?) AND (relation,subject_ns,subject_object_id,subject_relation) > (?,?,?,?) AND (expiration IS NULL OR expiration > NOW())", - expectedArgs: []any{"someresourcetype", "one", "viewer", "user", "bar", "..."}, - expectedStaticColumns: []string{"ns", "object_id"}, + expectedForTuple: expected{ + sql: "SELECT * WHERE ns = ? AND object_id IN (?) AND (relation,subject_ns,subject_object_id,subject_relation) > (?,?,?,?) AND (expiration IS NULL OR expiration > NOW())", + args: []any{"someresourcetype", "one", "viewer", "user", "bar", "..."}, + staticCols: []string{"ns", "object_id"}, + }, + expectedForExpanded: expected{ + sql: "SELECT * WHERE ns = ? AND object_id IN (?) AND ((relation > ?) OR (relation = ? AND subject_ns > ?) OR (relation = ? AND subject_ns = ? AND subject_object_id > ?) OR (relation = ? AND subject_ns = ? AND subject_object_id = ? AND subject_relation > ?)) AND (expiration IS NULL OR expiration > NOW())", + args: []any{"someresourcetype", "one", "viewer", "viewer", "user", "viewer", "user", "bar", "viewer", "user", "bar", "..."}, + staticCols: []string{"ns", "object_id"}, + }, }, { name: "after with single resource id", @@ -472,9 +563,16 @@ func TestSchemaQueryFilterer(t *testing.T) { }, ).After(toCursor(tuple.MustParse("someresourcetype:foo#viewer@user:bar")), options.ByResource) }, - expectedSQL: "SELECT * WHERE object_id IN (?) AND (ns,relation,subject_ns,subject_object_id,subject_relation) > (?,?,?,?,?) AND (expiration IS NULL OR expiration > NOW())", - expectedArgs: []any{"one", "someresourcetype", "viewer", "user", "bar", "..."}, - expectedStaticColumns: []string{"object_id"}, + expectedForTuple: expected{ + sql: "SELECT * WHERE object_id IN (?) AND (ns,relation,subject_ns,subject_object_id,subject_relation) > (?,?,?,?,?) AND (expiration IS NULL OR expiration > NOW())", + args: []any{"one", "someresourcetype", "viewer", "user", "bar", "..."}, + staticCols: []string{"object_id"}, + }, + expectedForExpanded: expected{ + sql: "SELECT * WHERE object_id IN (?) AND ((ns > ?) OR (ns = ? AND relation > ?) OR (ns = ? AND relation = ? AND subject_ns > ?) OR (ns = ? AND relation = ? AND subject_ns = ? AND subject_object_id > ?) OR (ns = ? AND relation = ? AND subject_ns = ? AND subject_object_id = ? AND subject_relation > ?)) AND (expiration IS NULL OR expiration > NOW())", + args: []any{"one", "someresourcetype", "someresourcetype", "viewer", "someresourcetype", "viewer", "user", "someresourcetype", "viewer", "user", "bar", "someresourcetype", "viewer", "user", "bar", "..."}, + staticCols: []string{"object_id"}, + }, }, { name: "after with namespace and resource ids", @@ -486,9 +584,16 @@ func TestSchemaQueryFilterer(t *testing.T) { }, ).After(toCursor(tuple.MustParse("someresourcetype:foo#viewer@user:bar")), options.ByResource) }, - expectedSQL: "SELECT * WHERE ns = ? AND object_id IN (?,?) AND (object_id,relation,subject_ns,subject_object_id,subject_relation) > (?,?,?,?,?) AND (expiration IS NULL OR expiration > NOW())", - expectedArgs: []any{"someresourcetype", "one", "two", "foo", "viewer", "user", "bar", "..."}, - expectedStaticColumns: []string{"ns"}, + expectedForTuple: expected{ + sql: "SELECT * WHERE ns = ? AND object_id IN (?,?) AND (object_id,relation,subject_ns,subject_object_id,subject_relation) > (?,?,?,?,?) AND (expiration IS NULL OR expiration > NOW())", + args: []any{"someresourcetype", "one", "two", "foo", "viewer", "user", "bar", "..."}, + staticCols: []string{"ns"}, + }, + expectedForExpanded: expected{ + sql: "SELECT * WHERE ns = ? AND object_id IN (?,?) AND ((object_id > ?) OR (object_id = ? AND relation > ?) OR (object_id = ? AND relation = ? AND subject_ns > ?) OR (object_id = ? AND relation = ? AND subject_ns = ? AND subject_object_id > ?) OR (object_id = ? AND relation = ? AND subject_ns = ? AND subject_object_id = ? AND subject_relation > ?)) AND (expiration IS NULL OR expiration > NOW())", + args: []any{"someresourcetype", "one", "two", "foo", "foo", "viewer", "foo", "viewer", "user", "foo", "viewer", "user", "bar", "foo", "viewer", "user", "bar", "..."}, + staticCols: []string{"ns"}, + }, }, { name: "after with namespace and relation", @@ -500,9 +605,16 @@ func TestSchemaQueryFilterer(t *testing.T) { }, ).After(toCursor(tuple.MustParse("someresourcetype:foo#viewer@user:bar")), options.ByResource) }, - expectedSQL: "SELECT * WHERE ns = ? AND relation = ? AND (object_id,subject_ns,subject_object_id,subject_relation) > (?,?,?,?) AND (expiration IS NULL OR expiration > NOW())", - expectedArgs: []any{"someresourcetype", "somerelation", "foo", "user", "bar", "..."}, - expectedStaticColumns: []string{"ns", "relation"}, + expectedForTuple: expected{ + sql: "SELECT * WHERE ns = ? AND relation = ? AND (object_id,subject_ns,subject_object_id,subject_relation) > (?,?,?,?) AND (expiration IS NULL OR expiration > NOW())", + args: []any{"someresourcetype", "somerelation", "foo", "user", "bar", "..."}, + staticCols: []string{"ns", "relation"}, + }, + expectedForExpanded: expected{ + sql: "SELECT * WHERE ns = ? AND relation = ? AND ((object_id > ?) OR (object_id = ? AND subject_ns > ?) OR (object_id = ? AND subject_ns = ? AND subject_object_id > ?) OR (object_id = ? AND subject_ns = ? AND subject_object_id = ? AND subject_relation > ?)) AND (expiration IS NULL OR expiration > NOW())", + args: []any{"someresourcetype", "somerelation", "foo", "foo", "user", "foo", "user", "bar", "foo", "user", "bar", "..."}, + staticCols: []string{"ns", "relation"}, + }, }, { name: "after with subject namespace", @@ -511,9 +623,16 @@ func TestSchemaQueryFilterer(t *testing.T) { OptionalSubjectType: "somesubjectype", }).After(toCursor(tuple.MustParse("someresourcetype:foo#viewer@user:bar")), options.ByResource) }, - expectedSQL: "SELECT * WHERE ((subject_ns = ?)) AND (ns,object_id,relation,subject_object_id,subject_relation) > (?,?,?,?,?) AND (expiration IS NULL OR expiration > NOW())", - expectedArgs: []any{"somesubjectype", "someresourcetype", "foo", "viewer", "bar", "..."}, - expectedStaticColumns: []string{"subject_ns"}, + expectedForTuple: expected{ + sql: "SELECT * WHERE ((subject_ns = ?)) AND (ns,object_id,relation,subject_object_id,subject_relation) > (?,?,?,?,?) AND (expiration IS NULL OR expiration > NOW())", + args: []any{"somesubjectype", "someresourcetype", "foo", "viewer", "bar", "..."}, + staticCols: []string{"subject_ns"}, + }, + expectedForExpanded: expected{ + sql: "SELECT * WHERE ((subject_ns = ?)) AND ((ns > ?) OR (ns = ? AND object_id > ?) OR (ns = ? AND object_id = ? AND relation > ?) OR (ns = ? AND object_id = ? AND relation = ? AND subject_object_id > ?) OR (ns = ? AND object_id = ? AND relation = ? AND subject_object_id = ? AND subject_relation > ?)) AND (expiration IS NULL OR expiration > NOW())", + args: []any{"somesubjectype", "someresourcetype", "someresourcetype", "foo", "someresourcetype", "foo", "viewer", "someresourcetype", "foo", "viewer", "bar", "someresourcetype", "foo", "viewer", "bar", "..."}, + staticCols: []string{"subject_ns"}, + }, }, { name: "after with subject namespaces", @@ -526,18 +645,32 @@ func TestSchemaQueryFilterer(t *testing.T) { OptionalSubjectType: "anothersubjectype", }).After(toCursor(tuple.MustParse("someresourcetype:foo#viewer@user:bar")), options.ByResource) }, - expectedSQL: "SELECT * WHERE ((subject_ns = ?)) AND ((subject_ns = ?)) AND (ns,object_id,relation,subject_ns,subject_object_id,subject_relation) > (?,?,?,?,?,?) AND (expiration IS NULL OR expiration > NOW())", - expectedArgs: []any{"somesubjectype", "anothersubjectype", "someresourcetype", "foo", "viewer", "user", "bar", "..."}, - expectedStaticColumns: []string{}, + expectedForTuple: expected{ + sql: "SELECT * WHERE ((subject_ns = ?)) AND ((subject_ns = ?)) AND (ns,object_id,relation,subject_ns,subject_object_id,subject_relation) > (?,?,?,?,?,?) AND (expiration IS NULL OR expiration > NOW())", + args: []any{"somesubjectype", "anothersubjectype", "someresourcetype", "foo", "viewer", "user", "bar", "..."}, + staticCols: []string{}, + }, + expectedForExpanded: expected{ + sql: "SELECT * WHERE ((subject_ns = ?)) AND ((subject_ns = ?)) AND ((ns > ?) OR (ns = ? AND object_id > ?) OR (ns = ? AND object_id = ? AND relation > ?) OR (ns = ? AND object_id = ? AND relation = ? AND subject_ns > ?) OR (ns = ? AND object_id = ? AND relation = ? AND subject_ns = ? AND subject_object_id > ?) OR (ns = ? AND object_id = ? AND relation = ? AND subject_ns = ? AND subject_object_id = ? AND subject_relation > ?)) AND (expiration IS NULL OR expiration > NOW())", + args: []any{"somesubjectype", "anothersubjectype", "someresourcetype", "someresourcetype", "foo", "someresourcetype", "foo", "viewer", "someresourcetype", "foo", "viewer", "user", "someresourcetype", "foo", "viewer", "user", "bar", "someresourcetype", "foo", "viewer", "user", "bar", "..."}, + staticCols: []string{}, + }, }, { name: "after with resource ID prefix", run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.MustFilterWithResourceIDPrefix("someprefix").After(toCursor(tuple.MustParse("someresourcetype:foo#viewer@user:bar")), options.ByResource) }, - expectedSQL: "SELECT * WHERE object_id LIKE ? AND (ns,object_id,relation,subject_ns,subject_object_id,subject_relation) > (?,?,?,?,?,?) AND (expiration IS NULL OR expiration > NOW())", - expectedArgs: []any{"someprefix%", "someresourcetype", "foo", "viewer", "user", "bar", "..."}, - expectedStaticColumns: []string{}, + expectedForTuple: expected{ + sql: "SELECT * WHERE object_id LIKE ? AND (ns,object_id,relation,subject_ns,subject_object_id,subject_relation) > (?,?,?,?,?,?) AND (expiration IS NULL OR expiration > NOW())", + args: []any{"someprefix%", "someresourcetype", "foo", "viewer", "user", "bar", "..."}, + staticCols: []string{}, + }, + expectedForExpanded: expected{ + sql: "SELECT * WHERE object_id LIKE ? AND ((ns > ?) OR (ns = ? AND object_id > ?) OR (ns = ? AND object_id = ? AND relation > ?) OR (ns = ? AND object_id = ? AND relation = ? AND subject_ns > ?) OR (ns = ? AND object_id = ? AND relation = ? AND subject_ns = ? AND subject_object_id > ?) OR (ns = ? AND object_id = ? AND relation = ? AND subject_ns = ? AND subject_object_id = ? AND subject_relation > ?)) AND (expiration IS NULL OR expiration > NOW())", + args: []any{"someprefix%", "someresourcetype", "someresourcetype", "foo", "someresourcetype", "foo", "viewer", "someresourcetype", "foo", "viewer", "user", "someresourcetype", "foo", "viewer", "user", "bar", "someresourcetype", "foo", "viewer", "user", "bar", "..."}, + staticCols: []string{}, + }, }, { name: "order by subject", @@ -548,9 +681,11 @@ func TestSchemaQueryFilterer(t *testing.T) { }, ).TupleOrder(options.BySubject) }, - expectedSQL: "SELECT * WHERE ns = ? AND (expiration IS NULL OR expiration > NOW()) ORDER BY subject_ns, subject_object_id, subject_relation, ns, object_id, relation", - expectedArgs: []any{"someresourcetype"}, - expectedStaticColumns: []string{"ns"}, + expectedForTuple: expected{ + sql: "SELECT * WHERE ns = ? AND (expiration IS NULL OR expiration > NOW()) ORDER BY subject_ns, subject_object_id, subject_relation, ns, object_id, relation", + args: []any{"someresourcetype"}, + staticCols: []string{"ns"}, + }, }, { name: "order by subject, after with subject namespace", @@ -559,9 +694,16 @@ func TestSchemaQueryFilterer(t *testing.T) { OptionalSubjectType: "somesubjectype", }).After(toCursor(tuple.MustParse("someresourcetype:foo#viewer@user:bar")), options.BySubject) }, - expectedSQL: "SELECT * WHERE ((subject_ns = ?)) AND (subject_object_id,ns,object_id,relation,subject_relation) > (?,?,?,?,?) AND (expiration IS NULL OR expiration > NOW())", - expectedArgs: []any{"somesubjectype", "bar", "someresourcetype", "foo", "viewer", "..."}, - expectedStaticColumns: []string{"subject_ns"}, + expectedForTuple: expected{ + sql: "SELECT * WHERE ((subject_ns = ?)) AND (subject_object_id,ns,object_id,relation,subject_relation) > (?,?,?,?,?) AND (expiration IS NULL OR expiration > NOW())", + args: []any{"somesubjectype", "bar", "someresourcetype", "foo", "viewer", "..."}, + staticCols: []string{"subject_ns"}, + }, + expectedForExpanded: expected{ + sql: "SELECT * WHERE ((subject_ns = ?)) AND ((subject_object_id > ?) OR (subject_object_id = ? AND ns > ?) OR (subject_object_id = ? AND ns = ? AND object_id > ?) OR (subject_object_id = ? AND ns = ? AND object_id = ? AND relation > ?) OR (subject_object_id = ? AND ns = ? AND object_id = ? AND relation = ? AND subject_relation > ?)) AND (expiration IS NULL OR expiration > NOW())", + args: []any{"somesubjectype", "bar", "bar", "someresourcetype", "bar", "someresourcetype", "foo", "bar", "someresourcetype", "foo", "viewer", "bar", "someresourcetype", "foo", "viewer", "..."}, + staticCols: []string{"subject_ns"}, + }, }, { name: "order by subject, after with subject namespace and subject object id", @@ -571,9 +713,16 @@ func TestSchemaQueryFilterer(t *testing.T) { OptionalSubjectIds: []string{"foo"}, }).After(toCursor(tuple.MustParse("someresourcetype:someresource#viewer@user:bar")), options.BySubject) }, - expectedSQL: "SELECT * WHERE ((subject_ns = ? AND subject_object_id IN (?))) AND (ns,object_id,relation,subject_relation) > (?,?,?,?) AND (expiration IS NULL OR expiration > NOW())", - expectedArgs: []any{"somesubjectype", "foo", "someresourcetype", "someresource", "viewer", "..."}, - expectedStaticColumns: []string{"subject_ns", "subject_object_id"}, + expectedForTuple: expected{ + sql: "SELECT * WHERE ((subject_ns = ? AND subject_object_id IN (?))) AND (ns,object_id,relation,subject_relation) > (?,?,?,?) AND (expiration IS NULL OR expiration > NOW())", + args: []any{"somesubjectype", "foo", "someresourcetype", "someresource", "viewer", "..."}, + staticCols: []string{"subject_ns", "subject_object_id"}, + }, + expectedForExpanded: expected{ + sql: "SELECT * WHERE ((subject_ns = ? AND subject_object_id IN (?))) AND ((ns > ?) OR (ns = ? AND object_id > ?) OR (ns = ? AND object_id = ? AND relation > ?) OR (ns = ? AND object_id = ? AND relation = ? AND subject_relation > ?)) AND (expiration IS NULL OR expiration > NOW())", + args: []any{"somesubjectype", "foo", "someresourcetype", "someresourcetype", "someresource", "someresourcetype", "someresource", "viewer", "someresourcetype", "someresource", "viewer", "..."}, + staticCols: []string{"subject_ns", "subject_object_id"}, + }, }, { name: "order by subject, after with subject namespace and multiple subject object IDs", @@ -583,49 +732,85 @@ func TestSchemaQueryFilterer(t *testing.T) { OptionalSubjectIds: []string{"foo", "bar"}, }).After(toCursor(tuple.MustParse("someresourcetype:someresource#viewer@user:next")), options.BySubject) }, - expectedSQL: "SELECT * WHERE ((subject_ns = ? AND subject_object_id IN (?,?))) AND (subject_object_id,ns,object_id,relation,subject_relation) > (?,?,?,?,?) AND (expiration IS NULL OR expiration > NOW())", - expectedArgs: []any{"somesubjectype", "foo", "bar", "next", "someresourcetype", "someresource", "viewer", "..."}, - expectedStaticColumns: []string{"subject_ns"}, + expectedForTuple: expected{ + sql: "SELECT * WHERE ((subject_ns = ? AND subject_object_id IN (?,?))) AND (subject_object_id,ns,object_id,relation,subject_relation) > (?,?,?,?,?) AND (expiration IS NULL OR expiration > NOW())", + args: []any{"somesubjectype", "foo", "bar", "next", "someresourcetype", "someresource", "viewer", "..."}, + staticCols: []string{"subject_ns"}, + }, + expectedForExpanded: expected{ + sql: "SELECT * WHERE ((subject_ns = ? AND subject_object_id IN (?,?))) AND ((subject_object_id > ?) OR (subject_object_id = ? AND ns > ?) OR (subject_object_id = ? AND ns = ? AND object_id > ?) OR (subject_object_id = ? AND ns = ? AND object_id = ? AND relation > ?) OR (subject_object_id = ? AND ns = ? AND object_id = ? AND relation = ? AND subject_relation > ?)) AND (expiration IS NULL OR expiration > NOW())", + args: []any{"somesubjectype", "foo", "bar", "next", "next", "someresourcetype", "next", "someresourcetype", "someresource", "next", "someresourcetype", "someresource", "viewer", "next", "someresourcetype", "someresource", "viewer", "..."}, + staticCols: []string{"subject_ns"}, + }, + }, + { + name: "order by subject, after with subject namespace and multiple subject object IDs and no expiration", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + return filterer.MustFilterWithSubjectsSelectors(datastore.SubjectsSelector{ + OptionalSubjectType: "somesubjectype", + OptionalSubjectIds: []string{"foo", "bar"}, + }).After(toCursor(tuple.MustParse("someresourcetype:someresource#viewer@user:next")), options.BySubject) + }, + withExpirationDisabled: true, + expectedForTuple: expected{ + sql: "SELECT * WHERE ((subject_ns = ? AND subject_object_id IN (?,?))) AND (subject_object_id,ns,object_id,relation,subject_relation) > (?,?,?,?,?)", + args: []any{"somesubjectype", "foo", "bar", "next", "someresourcetype", "someresource", "viewer", "..."}, + staticCols: []string{"subject_ns"}, + }, + expectedForExpanded: expected{ + sql: "SELECT * WHERE ((subject_ns = ? AND subject_object_id IN (?,?))) AND ((subject_object_id > ?) OR (subject_object_id = ? AND ns > ?) OR (subject_object_id = ? AND ns = ? AND object_id > ?) OR (subject_object_id = ? AND ns = ? AND object_id = ? AND relation > ?) OR (subject_object_id = ? AND ns = ? AND object_id = ? AND relation = ? AND subject_relation > ?))", + args: []any{"somesubjectype", "foo", "bar", "next", "next", "someresourcetype", "next", "someresourcetype", "someresource", "next", "someresourcetype", "someresource", "viewer", "next", "someresourcetype", "someresource", "viewer", "..."}, + staticCols: []string{"subject_ns"}, + }, }, } for _, test := range tests { test := test t.Run(test.name, func(t *testing.T) { - schema := NewSchemaInformationWithOptions( - WithRelationshipTableName("relationtuples"), - WithColNamespace("ns"), - WithColObjectID("object_id"), - WithColRelation("relation"), - WithColUsersetNamespace("subject_ns"), - WithColUsersetObjectID("subject_object_id"), - WithColUsersetRelation("subject_relation"), - WithColCaveatName("caveat"), - WithColCaveatContext("caveat_context"), - WithColExpiration("expiration"), - WithPlaceholderFormat(sq.Question), - WithPaginationFilterType(TupleComparison), - WithColumnOptimization(ColumnOptimizationOptionStaticValues), - WithNowFunction("NOW"), - ) - filterer := NewSchemaQueryFiltererForRelationshipsSelect(*schema, 100) + for _, filterType := range []PaginationFilterType{TupleComparison, ExpandedLogicComparison} { + t.Run(fmt.Sprintf("filter type: %v", filterType), func(t *testing.T) { + schema := NewSchemaInformationWithOptions( + WithRelationshipTableName("relationtuples"), + WithColNamespace("ns"), + WithColObjectID("object_id"), + WithColRelation("relation"), + WithColUsersetNamespace("subject_ns"), + WithColUsersetObjectID("subject_object_id"), + WithColUsersetRelation("subject_relation"), + WithColCaveatName("caveat"), + WithColCaveatContext("caveat_context"), + WithColExpiration("expiration"), + WithPlaceholderFormat(sq.Question), + WithPaginationFilterType(filterType), + WithColumnOptimization(ColumnOptimizationOptionStaticValues), + WithNowFunction("NOW"), + ) + filterer := NewSchemaQueryFiltererForRelationshipsSelect(*schema, 100) - ran := test.run(filterer) - foundStaticColumns := []string{} - for col, tracker := range ran.filteringColumnTracker { - if tracker.SingleValue != nil { - foundStaticColumns = append(foundStaticColumns, col) - } - } + ran := test.run(filterer) + foundStaticColumns := []string{} + for col, tracker := range ran.filteringColumnTracker { + if tracker.SingleValue != nil { + foundStaticColumns = append(foundStaticColumns, col) + } + } + + expected := test.expectedForTuple + if filterType == ExpandedLogicComparison && test.expectedForExpanded.sql != "" { + expected = test.expectedForExpanded + } - require.ElementsMatch(t, test.expectedStaticColumns, foundStaticColumns) + require.ElementsMatch(t, expected.staticCols, foundStaticColumns) - ran.queryBuilder = ran.queryBuilderWithExpirationFilter(test.withExpirationDisabled).Columns("*") + ran.queryBuilder = ran.queryBuilderWithExpirationFilter(test.withExpirationDisabled).Columns("*") - sql, args, err := ran.queryBuilder.ToSql() - require.NoError(t, err) - require.Equal(t, test.expectedSQL, sql) - require.Equal(t, test.expectedArgs, args) + sql, args, err := ran.queryBuilder.ToSql() + require.NoError(t, err) + require.Equal(t, expected.sql, sql) + require.Equal(t, expected.args, args) + }) + } }) } } @@ -889,43 +1074,47 @@ func TestExecuteQuery(t *testing.T) { for _, tc := range tcs { t.Run(tc.name, func(t *testing.T) { - schema := NewSchemaInformationWithOptions( - WithRelationshipTableName("relationtuples"), - WithColNamespace("ns"), - WithColObjectID("object_id"), - WithColRelation("relation"), - WithColUsersetNamespace("subject_ns"), - WithColUsersetObjectID("subject_object_id"), - WithColUsersetRelation("subject_relation"), - WithColCaveatName("caveat"), - WithColCaveatContext("caveat_context"), - WithColExpiration("expiration"), - WithPlaceholderFormat(sq.Question), - WithPaginationFilterType(TupleComparison), - WithColumnOptimization(ColumnOptimizationOptionStaticValues), - WithNowFunction("NOW"), - WithExpirationDisabled(tc.withExpirationDisabled), - ) - filterer := NewSchemaQueryFiltererForRelationshipsSelect(*schema, 100) - ran := tc.run(filterer) + for _, filterType := range []PaginationFilterType{TupleComparison, ExpandedLogicComparison} { + t.Run(fmt.Sprintf("filter type: %v", filterType), func(t *testing.T) { + schema := NewSchemaInformationWithOptions( + WithRelationshipTableName("relationtuples"), + WithColNamespace("ns"), + WithColObjectID("object_id"), + WithColRelation("relation"), + WithColUsersetNamespace("subject_ns"), + WithColUsersetObjectID("subject_object_id"), + WithColUsersetRelation("subject_relation"), + WithColCaveatName("caveat"), + WithColCaveatContext("caveat_context"), + WithColExpiration("expiration"), + WithPlaceholderFormat(sq.Question), + WithPaginationFilterType(filterType), + WithColumnOptimization(ColumnOptimizationOptionStaticValues), + WithNowFunction("NOW"), + WithExpirationDisabled(tc.withExpirationDisabled), + ) + filterer := NewSchemaQueryFiltererForRelationshipsSelect(*schema, 100) + ran := tc.run(filterer) - var wasRun bool - fake := QueryRelationshipsExecutor{ - Executor: func(ctx context.Context, builder RelationshipsQueryBuilder) (datastore.RelationshipIterator, error) { - sql, args, err := builder.SelectSQL() - require.NoError(t, err) + var wasRun bool + fake := QueryRelationshipsExecutor{ + Executor: func(ctx context.Context, builder RelationshipsQueryBuilder) (datastore.RelationshipIterator, error) { + sql, args, err := builder.SelectSQL() + require.NoError(t, err) - wasRun = true - require.Equal(t, tc.expectedSQL, sql) - require.Equal(t, tc.expectedArgs, args) - require.Equal(t, tc.expectedSkipCaveats, builder.SkipCaveats) - require.Equal(t, tc.expectedSkipExpiration, builder.SkipExpiration) - return nil, nil - }, + wasRun = true + require.Equal(t, tc.expectedSQL, sql) + require.Equal(t, tc.expectedArgs, args) + require.Equal(t, tc.expectedSkipCaveats, builder.SkipCaveats) + require.Equal(t, tc.expectedSkipExpiration, builder.SkipExpiration) + return nil, nil + }, + } + _, err := fake.ExecuteQuery(context.Background(), ran, tc.options...) + require.NoError(t, err) + require.True(t, wasRun) + }) } - _, err := fake.ExecuteQuery(context.Background(), ran, tc.options...) - require.NoError(t, err) - require.True(t, wasRun) }) } } diff --git a/pkg/datastore/test/datastore.go b/pkg/datastore/test/datastore.go index 8617c12248..034e265a43 100644 --- a/pkg/datastore/test/datastore.go +++ b/pkg/datastore/test/datastore.go @@ -147,6 +147,7 @@ func AllWithExceptions(t *testing.T, tester DatastoreTester, except Categories, t.Run("TestOrderedLimit", runner(tester, OrderedLimitTest)) t.Run("TestResume", runner(tester, ResumeTest)) t.Run("TestReverseQueryCursor", runner(tester, ReverseQueryCursorTest)) + t.Run("TestReverseQueryFilteredCursor", runner(tester, ReverseQueryFilteredOverMultipleValuesCursorTest)) t.Run("TestRevisionQuantization", runner(tester, RevisionQuantizationTest)) t.Run("TestRevisionSerialization", runner(tester, RevisionSerializationTest)) diff --git a/pkg/datastore/test/pagination.go b/pkg/datastore/test/pagination.go index fd7b09edbd..bfe8b9399a 100644 --- a/pkg/datastore/test/pagination.go +++ b/pkg/datastore/test/pagination.go @@ -246,6 +246,65 @@ func ResumeTest(t *testing.T, tester DatastoreTester) { } } +func ReverseQueryFilteredOverMultipleValuesCursorTest(t *testing.T, tester DatastoreTester) { + rawDS, err := tester.New(0, veryLargeGCInterval, veryLargeGCWindow, 1) + require.NoError(t, err) + + // Create a datastore with the standard schema but no data. + ds, _ := testfixtures.StandardDatastoreWithSchema(rawDS, require.New(t)) + + // Add test relationships. + rev, err := ds.ReadWriteTx(context.Background(), func(ctx context.Context, rwt datastore.ReadWriteTransaction) error { + return rwt.WriteRelationships(ctx, []tuple.RelationshipUpdate{ + tuple.Create(tuple.MustParse("document:firstdoc#viewer@user:alice")), + tuple.Create(tuple.MustParse("document:firstdoc#viewer@user:tom")), + tuple.Create(tuple.MustParse("document:firstdoc#viewer@user:fred")), + tuple.Create(tuple.MustParse("document:seconddoc#viewer@user:alice")), + tuple.Create(tuple.MustParse("document:seconddoc#viewer@user:*")), + tuple.Create(tuple.MustParse("document:thirddoc#viewer@user:*")), + }) + }) + require.NoError(t, err) + + // Issue a reverse query call with a limit. + for _, sortBy := range []options.SortOrder{options.ByResource, options.BySubject} { + t.Run(fmt.Sprintf("SortBy-%d", sortBy), func(t *testing.T) { + reader := ds.SnapshotReader(rev) + + var limit uint64 = 2 + var cursor options.Cursor + + foundTuples := mapz.NewSet[string]() + + for i := 0; i < 5; i++ { + iter, err := reader.ReverseQueryRelationships(context.Background(), datastore.SubjectsFilter{ + SubjectType: testfixtures.UserNS.Name, + OptionalSubjectIds: []string{"alice", "tom", "fred", "*"}, + }, options.WithResRelation(&options.ResourceRelation{ + Namespace: "document", + Relation: "viewer", + }), options.WithSortForReverse(sortBy), options.WithLimitForReverse(&limit), options.WithAfterForReverse(cursor)) + require.NoError(t, err) + + encounteredTuples := mapz.NewSet[string]() + for rel, err := range iter { + require.NoError(t, err) + require.True(t, encounteredTuples.Add(tuple.MustString(rel))) + cursor = options.ToCursor(rel) + } + + require.LessOrEqual(t, encounteredTuples.Len(), 2) + foundTuples = foundTuples.Union(encounteredTuples) + if encounteredTuples.IsEmpty() { + break + } + } + + require.Equal(t, 6, foundTuples.Len()) + }) + } +} + func ReverseQueryCursorTest(t *testing.T, tester DatastoreTester) { rawDS, err := tester.New(0, veryLargeGCInterval, veryLargeGCWindow, 1) require.NoError(t, err) diff --git a/pkg/datastore/test/relationships.go b/pkg/datastore/test/relationships.go index 95049c15e6..d80cca3797 100644 --- a/pkg/datastore/test/relationships.go +++ b/pkg/datastore/test/relationships.go @@ -1675,19 +1675,6 @@ func QueryRelationshipsWithVariousFiltersTest(t *testing.T, tester DatastoreTest } } - ctx := context.Background() - - tpl1, err := tuple.Parse("document:foo#viewer@user:tom") - require.NoError(err) - - _, err = common.WriteRelationships(ctx, ds, tuple.UpdateOperationTouch, tpl1) - require.NoError(err) - ensureRelationships(ctx, require, ds, tpl1) - - _, err = common.WriteRelationships(ctx, ds, tuple.UpdateOperationTouch, tpl1) - require.NoError(err) -} - // RelationshipExpirationTest tests expiration on relationships. func RelationshipExpirationTest(t *testing.T, tester DatastoreTester) { require := require.New(t) @@ -1743,6 +1730,28 @@ func RelationshipExpirationTest(t *testing.T, tester DatastoreTester) { ensureReverseRelationships(ctx, require, ds, rel4) } +// TypedTouchAlreadyExistingTest tests touching a relationship twice, when valid type information is provided. +func TypedTouchAlreadyExistingTest(t *testing.T, tester DatastoreTester) { + require := require.New(t) + + rawDS, err := tester.New(0, veryLargeGCInterval, veryLargeGCWindow, 1) + require.NoError(err) + + ds, _ := testfixtures.StandardDatastoreWithData(rawDS, require) + ctx := context.Background() + + tpl1, err := tuple.Parse("document:foo#viewer@user:tom") + require.NoError(err) + + _, err = common.WriteRelationships(ctx, ds, tuple.UpdateOperationTouch, tpl1) + require.NoError(err) + ensureRelationships(ctx, require, ds, tpl1) + + _, err = common.WriteRelationships(ctx, ds, tuple.UpdateOperationTouch, tpl1) + require.NoError(err) + ensureRelationships(ctx, require, ds, tpl1) +} + // TypedTouchAlreadyExistingWithCaveatTest tests touching a relationship twice, when valid type information is provided. func TypedTouchAlreadyExistingWithCaveatTest(t *testing.T, tester DatastoreTester) { require := require.New(t) From a6584099a37287fd56b9f599190a5baa749902d2 Mon Sep 17 00:00:00 2001 From: Joseph Schorr Date: Fri, 20 Dec 2024 13:26:33 -0500 Subject: [PATCH 11/15] Add additional tracing to relationship querying --- internal/datastore/common/relationships.go | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/internal/datastore/common/relationships.go b/internal/datastore/common/relationships.go index f780dca876..a8f09569a9 100644 --- a/internal/datastore/common/relationships.go +++ b/internal/datastore/common/relationships.go @@ -62,14 +62,18 @@ func QueryRelationships[R Rows, C ~map[string]any](ctx context.Context, builder var integrityHash []byte var timestamp time.Time + span.AddEvent("Selecting columns") colsToSelect, err := ColumnsToSelect(builder, &resourceObjectType, &resourceObjectID, &resourceRelation, &subjectObjectType, &subjectObjectID, &subjectRelation, &caveatName, &caveatCtx, &expiration, &integrityKeyID, &integrityHash, ×tamp) if err != nil { return nil, fmt.Errorf(errUnableToQueryRels, err) } + span.AddEvent("Returning iterator", trace.WithAttributes(attribute.Int("column-count", len(colsToSelect)))) return func(yield func(tuple.Relationship, error) bool) { + span.AddEvent("Issuing query to database") err := tx.QueryFunc(ctx, func(ctx context.Context, rows R) error { span.AddEvent("Query issued to database") + var r Rows = rows if crwe, ok := r.(closeRowsWithError); ok { defer LogOnError(ctx, crwe.Close) @@ -87,6 +91,10 @@ func QueryRelationships[R Rows, C ~map[string]any](ctx context.Context, builder return fmt.Errorf(errUnableToQueryRels, fmt.Errorf("scan err: %w", err)) } + if relCount == 0 { + span.AddEvent("First row scanned") + } + var caveat *corev1.ContextualizedCaveat if !builder.SkipCaveats || builder.Schema.ColumnOptimization == ColumnOptimizationOptionNone { if caveatName.Valid { @@ -136,7 +144,7 @@ func QueryRelationships[R Rows, C ~map[string]any](ctx context.Context, builder } } - span.AddEvent("Rels loaded", trace.WithAttributes(attribute.Int("relCount", relCount))) + span.AddEvent("Relationships loaded", trace.WithAttributes(attribute.Int("relCount", relCount))) if err := rows.Err(); err != nil { return fmt.Errorf(errUnableToQueryRels, fmt.Errorf("rows err: %w", err)) } From c4b5867d8b56376674f60525b240a17df332dc85 Mon Sep 17 00:00:00 2001 From: Joseph Schorr Date: Fri, 3 Jan 2025 20:49:03 -0500 Subject: [PATCH 12/15] Address review feedback --- internal/datastore/common/sql.go | 34 +++++++++++--------------------- 1 file changed, 11 insertions(+), 23 deletions(-) diff --git a/internal/datastore/common/sql.go b/internal/datastore/common/sql.go index 7e4bacb5d5..95dc68af8c 100644 --- a/internal/datastore/common/sql.go +++ b/internal/datastore/common/sql.go @@ -147,28 +147,16 @@ func NewSchemaQueryFiltererWithStartingQuery(schema SchemaInformation, startingQ } } -// WithAdditionalFilter returns a new SchemaQueryFilterer with an additional filter applied to the query. +// WithAdditionalFilter returns the SchemaQueryFilterer with an additional filter applied to the query. func (sqf SchemaQueryFilterer) WithAdditionalFilter(filter func(original sq.SelectBuilder) sq.SelectBuilder) SchemaQueryFilterer { - return SchemaQueryFilterer{ - schema: sqf.schema, - queryBuilder: filter(sqf.queryBuilder), - filteringColumnTracker: sqf.filteringColumnTracker, - filterMaximumIDCount: sqf.filterMaximumIDCount, - isCustomQuery: sqf.isCustomQuery, - extraFields: sqf.extraFields, - } + sqf.queryBuilder = filter(sqf.queryBuilder) + return sqf } +// WithFromSuffix returns the SchemaQueryFilterer with a suffix added to the FROM clause. func (sqf SchemaQueryFilterer) WithFromSuffix(fromSuffix string) SchemaQueryFilterer { - return SchemaQueryFilterer{ - schema: sqf.schema, - queryBuilder: sqf.queryBuilder, - filteringColumnTracker: sqf.filteringColumnTracker, - filterMaximumIDCount: sqf.filterMaximumIDCount, - isCustomQuery: sqf.isCustomQuery, - extraFields: sqf.extraFields, - fromSuffix: fromSuffix, - } + sqf.fromSuffix = fromSuffix + return sqf } func (sqf SchemaQueryFilterer) UnderlyingQueryBuilder() sq.SelectBuilder { @@ -338,7 +326,7 @@ func (sqf SchemaQueryFilterer) recordColumnValue(colName string, colValue string } } -func (sqf SchemaQueryFilterer) recordMutableColumnValue(colName string) { +func (sqf SchemaQueryFilterer) recordVaryingColumnValue(colName string) { sqf.filteringColumnTracker[colName] = ColumnTracker{SingleValue: nil} } @@ -507,9 +495,9 @@ func (sqf SchemaQueryFilterer) FilterWithSubjectsSelectors(selectors ...datastor // can differ for each branch. // TODO(jschorr): Optimize this further where applicable. if len(selectors) > 1 { - sqf.recordMutableColumnValue(sqf.schema.ColUsersetNamespace) - sqf.recordMutableColumnValue(sqf.schema.ColUsersetObjectID) - sqf.recordMutableColumnValue(sqf.schema.ColUsersetRelation) + sqf.recordVaryingColumnValue(sqf.schema.ColUsersetNamespace) + sqf.recordVaryingColumnValue(sqf.schema.ColUsersetObjectID) + sqf.recordVaryingColumnValue(sqf.schema.ColUsersetRelation) } for _, selector := range selectors { @@ -551,7 +539,7 @@ func (sqf SchemaQueryFilterer) FilterWithSubjectsSelectors(selectors ...datastor if !selector.RelationFilter.IsEmpty() { if selector.RelationFilter.OnlyNonEllipsisRelations { selectorClause = append(selectorClause, sq.NotEq{sqf.schema.ColUsersetRelation: datastore.Ellipsis}) - sqf.recordMutableColumnValue(sqf.schema.ColUsersetRelation) + sqf.recordVaryingColumnValue(sqf.schema.ColUsersetRelation) } else { relations := make([]string, 0, 2) if selector.RelationFilter.IncludeEllipsisRelation { From 47387bc7a539ae603d96b2c6574a88a1f3fec654 Mon Sep 17 00:00:00 2001 From: Joseph Schorr Date: Mon, 6 Jan 2025 10:55:33 -0500 Subject: [PATCH 13/15] Move column count logic into helpers --- internal/datastore/common/schema.go | 15 +++++++-- internal/datastore/common/sql.go | 50 +++++++++++++++++++++-------- 2 files changed, 48 insertions(+), 17 deletions(-) diff --git a/internal/datastore/common/schema.go b/internal/datastore/common/schema.go index 7cff99d578..008c2857d2 100644 --- a/internal/datastore/common/schema.go +++ b/internal/datastore/common/schema.go @@ -6,6 +6,13 @@ import ( "github.com/authzed/spicedb/pkg/spiceerrors" ) +const ( + relationshipStandardColumnCount = 6 // ColNamespace, ColObjectID, ColRelation, ColUsersetNamespace, ColUsersetObjectID, ColUsersetRelation + relationshipCaveatColumnCount = 2 // ColCaveatName, ColCaveatContext + relationshipExpirationColumnCount = 1 // ColExpiration + relationshipIntegrityColumnCount = 3 // ColIntegrityKeyID, ColIntegrityHash, ColIntegrityTimestamp +) + // SchemaInformation holds the schema information from the SQL datastore implementation. // //go:generate go run github.com/ecordell/optgen -output schema_options.go . SchemaInformation @@ -18,9 +25,11 @@ type SchemaInformation struct { ColUsersetNamespace string `debugmap:"visible"` ColUsersetObjectID string `debugmap:"visible"` ColUsersetRelation string `debugmap:"visible"` - ColCaveatName string `debugmap:"visible"` - ColCaveatContext string `debugmap:"visible"` - ColExpiration string `debugmap:"visible"` + + ColCaveatName string `debugmap:"visible"` + ColCaveatContext string `debugmap:"visible"` + + ColExpiration string `debugmap:"visible"` ColIntegrityKeyID string `debugmap:"visible"` ColIntegrityHash string `debugmap:"visible"` diff --git a/internal/datastore/common/sql.go b/internal/datastore/common/sql.go index 95dc68af8c..8bdc0f2653 100644 --- a/internal/datastore/common/sql.go +++ b/internal/datastore/common/sql.go @@ -684,14 +684,40 @@ type RelationshipsQueryBuilder struct { baseQueryBuilder SchemaQueryFilterer } +// withCaveats returns true if caveats should be included in the query. +func (b RelationshipsQueryBuilder) withCaveats() bool { + return !b.SkipCaveats || b.Schema.ColumnOptimization == ColumnOptimizationOptionNone +} + +// withExpiration returns true if expiration should be included in the query. +func (b RelationshipsQueryBuilder) withExpiration() bool { + return !b.SkipExpiration && !b.Schema.ExpirationDisabled +} + +// withIntegrityColumns returns true if integrity columns should be included in the query. +func (b RelationshipsQueryBuilder) withIntegrityColumns() bool { + return b.Schema.WithIntegrityColumns +} + +// columnCount returns the number of columns that will be selected in the query. +func (b RelationshipsQueryBuilder) columnCount() int { + columnCount := relationshipStandardColumnCount + if b.withCaveats() { + columnCount += relationshipCaveatColumnCount + } + if b.withExpiration() { + columnCount += relationshipExpirationColumnCount + } + if b.withIntegrityColumns() { + columnCount += relationshipIntegrityColumnCount + } + return columnCount +} + // SelectSQL returns the SQL and arguments necessary for reading relationships. func (b RelationshipsQueryBuilder) SelectSQL() (string, []any, error) { // Set the column names to select. - columnCount := 9 - if b.Schema.WithIntegrityColumns { - columnCount += 3 - } - columnNamesToSelect := make([]string, 0, columnCount) + columnNamesToSelect := make([]string, 0, b.columnCount()) columnNamesToSelect = b.checkColumn(columnNamesToSelect, b.Schema.ColNamespace) columnNamesToSelect = b.checkColumn(columnNamesToSelect, b.Schema.ColObjectID) @@ -700,11 +726,11 @@ func (b RelationshipsQueryBuilder) SelectSQL() (string, []any, error) { columnNamesToSelect = b.checkColumn(columnNamesToSelect, b.Schema.ColUsersetObjectID) columnNamesToSelect = b.checkColumn(columnNamesToSelect, b.Schema.ColUsersetRelation) - if !b.SkipCaveats || b.Schema.ColumnOptimization == ColumnOptimizationOptionNone { + if b.withCaveats() { columnNamesToSelect = append(columnNamesToSelect, b.Schema.ColCaveatName, b.Schema.ColCaveatContext) } - if !b.SkipExpiration && !b.Schema.ExpirationDisabled { + if b.withExpiration() { columnNamesToSelect = append(columnNamesToSelect, b.Schema.ColExpiration) } @@ -775,11 +801,7 @@ func ColumnsToSelect[CN any, CC any, EC any]( integrityHash *[]byte, timestamp *time.Time, ) ([]any, error) { - columnCount := 9 - if b.Schema.WithIntegrityColumns { - columnCount += 3 - } - colsToSelect := make([]any, 0, columnCount) + colsToSelect := make([]any, 0, b.columnCount()) colsToSelect = b.staticValueOrAddColumnForSelect(colsToSelect, b.Schema.ColNamespace, resourceObjectType) colsToSelect = b.staticValueOrAddColumnForSelect(colsToSelect, b.Schema.ColObjectID, resourceObjectID) @@ -788,11 +810,11 @@ func ColumnsToSelect[CN any, CC any, EC any]( colsToSelect = b.staticValueOrAddColumnForSelect(colsToSelect, b.Schema.ColUsersetObjectID, subjectObjectID) colsToSelect = b.staticValueOrAddColumnForSelect(colsToSelect, b.Schema.ColUsersetRelation, subjectRelation) - if !b.SkipCaveats || b.Schema.ColumnOptimization == ColumnOptimizationOptionNone { + if b.withCaveats() { colsToSelect = append(colsToSelect, caveatName, caveatCtx) } - if !b.SkipExpiration && !b.Schema.ExpirationDisabled { + if b.withExpiration() { colsToSelect = append(colsToSelect, expiration) } From 52991f7af789b0d8febbeb91c3f9cb3aa811794e Mon Sep 17 00:00:00 2001 From: Joseph Schorr Date: Tue, 7 Jan 2025 17:39:13 -0500 Subject: [PATCH 14/15] Address review feedback on SQL generator --- internal/datastore/common/schema.go | 8 +- internal/datastore/common/sql.go | 40 +- internal/datastore/common/sql_test.go | 366 ++++++++++++++++-- ...ions.go => zz_generated.schema_options.go} | 10 +- internal/datastore/crdb/crdb.go | 2 +- 5 files changed, 356 insertions(+), 70 deletions(-) rename internal/datastore/common/{schema_options.go => zz_generated.schema_options.go} (95%) diff --git a/internal/datastore/common/schema.go b/internal/datastore/common/schema.go index 008c2857d2..542dc5dbf3 100644 --- a/internal/datastore/common/schema.go +++ b/internal/datastore/common/schema.go @@ -15,7 +15,7 @@ const ( // SchemaInformation holds the schema information from the SQL datastore implementation. // -//go:generate go run github.com/ecordell/optgen -output schema_options.go . SchemaInformation +//go:generate go run github.com/ecordell/optgen -output zz_generated.schema_options.go . SchemaInformation type SchemaInformation struct { RelationshipTableName string `debugmap:"visible"` @@ -47,8 +47,8 @@ type SchemaInformation struct { // ColumnOptimization is the optimization to use for columns in the schema, if any. ColumnOptimization ColumnOptimizationOption `debugmap:"visible"` - // WithIntegrityColumns is a flag to indicate if the schema has integrity columns. - WithIntegrityColumns bool `debugmap:"visible"` + // IntegrityEnabled is a flag to indicate if the schema has integrity columns. + IntegrityEnabled bool `debugmap:"visible"` // ExpirationDisabled is a flag to indicate whether expiration support is disabled. ExpirationDisabled bool `debugmap:"visible"` @@ -102,7 +102,7 @@ func (si SchemaInformation) mustValidate() { panic("ColExpiration is required") } - if si.WithIntegrityColumns { + if si.IntegrityEnabled { if si.ColIntegrityKeyID == "" { panic("ColIntegrityKeyID is required") } diff --git a/internal/datastore/common/sql.go b/internal/datastore/common/sql.go index 8bdc0f2653..db6b277df2 100644 --- a/internal/datastore/common/sql.go +++ b/internal/datastore/common/sql.go @@ -75,15 +75,15 @@ const ( // ColumnOptimizationOptionNone is the default option, which does not optimize the static columns. ColumnOptimizationOptionNone - // ColumnOptimizationOptionStaticValue is an option that optimizes the column for a static value. + // ColumnOptimizationOptionStaticValues is an option that optimizes columns for static values. ColumnOptimizationOptionStaticValues ) -type ColumnTracker struct { +type columnTracker struct { SingleValue *string } -type columnTrackerMap map[string]ColumnTracker +type columnTrackerMap map[string]columnTracker func (ctm columnTrackerMap) hasStaticValue(columnName string) bool { if r, ok := ctm[columnName]; ok && r.SingleValue != nil { @@ -119,7 +119,7 @@ func NewSchemaQueryFiltererForRelationshipsSelect(schema SchemaInformation, filt return SchemaQueryFilterer{ schema: schema, queryBuilder: queryBuilder, - filteringColumnTracker: map[string]ColumnTracker{}, + filteringColumnTracker: map[string]columnTracker{}, filterMaximumIDCount: filterMaximumIDCount, isCustomQuery: false, extraFields: extraFields, @@ -140,7 +140,7 @@ func NewSchemaQueryFiltererWithStartingQuery(schema SchemaInformation, startingQ return SchemaQueryFilterer{ schema: schema, queryBuilder: startingQuery, - filteringColumnTracker: map[string]ColumnTracker{}, + filteringColumnTracker: map[string]columnTracker{}, filterMaximumIDCount: filterMaximumIDCount, isCustomQuery: true, extraFields: nil, @@ -163,12 +163,12 @@ func (sqf SchemaQueryFilterer) UnderlyingQueryBuilder() sq.SelectBuilder { spiceerrors.DebugAssert(func() bool { return sqf.isCustomQuery }, "UnderlyingQueryBuilder should only be called on custom queries") - return sqf.queryBuilderWithExpirationFilter(false) + return sqf.queryBuilderWithMaybeExpirationFilter(false) } -// queryBuilderWithExpirationFilter returns the query builder with the expiration filter applied, when necessary. +// queryBuilderWithMaybeExpirationFilter returns the query builder with the expiration filter applied, when necessary. // Note that this adds the clause to the existing builder. -func (sqf SchemaQueryFilterer) queryBuilderWithExpirationFilter(skipExpiration bool) sq.SelectBuilder { +func (sqf SchemaQueryFilterer) queryBuilderWithMaybeExpirationFilter(skipExpiration bool) sq.SelectBuilder { if sqf.schema.ExpirationDisabled || skipExpiration { return sqf.queryBuilder } @@ -319,15 +319,15 @@ func (sqf SchemaQueryFilterer) recordColumnValue(colName string, colValue string existing, ok := sqf.filteringColumnTracker[colName] if ok { if existing.SingleValue != nil && *existing.SingleValue != colValue { - sqf.filteringColumnTracker[colName] = ColumnTracker{SingleValue: nil} + sqf.filteringColumnTracker[colName] = columnTracker{SingleValue: nil} } } else { - sqf.filteringColumnTracker[colName] = ColumnTracker{SingleValue: &colValue} + sqf.filteringColumnTracker[colName] = columnTracker{SingleValue: &colValue} } } func (sqf SchemaQueryFilterer) recordVaryingColumnValue(colName string) { - sqf.filteringColumnTracker[colName] = ColumnTracker{SingleValue: nil} + sqf.filteringColumnTracker[colName] = columnTracker{SingleValue: nil} } // FilterToResourceID returns a new SchemaQueryFilterer that is limited to resources with the @@ -491,7 +491,7 @@ func (sqf SchemaQueryFilterer) MustFilterWithSubjectsSelectors(selectors ...data func (sqf SchemaQueryFilterer) FilterWithSubjectsSelectors(selectors ...datastore.SubjectsSelector) (SchemaQueryFilterer, error) { selectorsOrClause := sq.Or{} - // If there is more than a single filter, record all the subjects as mutable, as the subjects returned + // If there is more than a single filter, record all the subjects as varying, as the subjects returned // can differ for each branch. // TODO(jschorr): Optimize this further where applicable. if len(selectors) > 1 { @@ -694,9 +694,9 @@ func (b RelationshipsQueryBuilder) withExpiration() bool { return !b.SkipExpiration && !b.Schema.ExpirationDisabled } -// withIntegrityColumns returns true if integrity columns should be included in the query. -func (b RelationshipsQueryBuilder) withIntegrityColumns() bool { - return b.Schema.WithIntegrityColumns +// integrityEnabled returns true if integrity columns should be included in the query. +func (b RelationshipsQueryBuilder) integrityEnabled() bool { + return b.Schema.IntegrityEnabled } // columnCount returns the number of columns that will be selected in the query. @@ -708,7 +708,7 @@ func (b RelationshipsQueryBuilder) columnCount() int { if b.withExpiration() { columnCount += relationshipExpirationColumnCount } - if b.withIntegrityColumns() { + if b.integrityEnabled() { columnCount += relationshipIntegrityColumnCount } return columnCount @@ -734,7 +734,7 @@ func (b RelationshipsQueryBuilder) SelectSQL() (string, []any, error) { columnNamesToSelect = append(columnNamesToSelect, b.Schema.ColExpiration) } - if b.Schema.WithIntegrityColumns { + if b.integrityEnabled() { columnNamesToSelect = append(columnNamesToSelect, b.Schema.ColIntegrityKeyID, b.Schema.ColIntegrityHash, b.Schema.ColIntegrityTimestamp) } @@ -742,14 +742,14 @@ func (b RelationshipsQueryBuilder) SelectSQL() (string, []any, error) { columnNamesToSelect = append(columnNamesToSelect, "1") } - sqlBuilder := b.baseQueryBuilder.queryBuilderWithExpirationFilter(b.SkipExpiration) + sqlBuilder := b.baseQueryBuilder.queryBuilderWithMaybeExpirationFilter(b.SkipExpiration) sqlBuilder = sqlBuilder.Columns(columnNamesToSelect...) return sqlBuilder.ToSql() } // FilteringValuesForTesting returns the filtering values. For test use only. -func (b RelationshipsQueryBuilder) FilteringValuesForTesting() map[string]ColumnTracker { +func (b RelationshipsQueryBuilder) FilteringValuesForTesting() map[string]columnTracker { return maps.Clone(b.filteringValues) } @@ -818,7 +818,7 @@ func ColumnsToSelect[CN any, CC any, EC any]( colsToSelect = append(colsToSelect, expiration) } - if b.Schema.WithIntegrityColumns { + if b.Schema.IntegrityEnabled { colsToSelect = append(colsToSelect, integrityKeyID, integrityHash, timestamp) } diff --git a/internal/datastore/common/sql_test.go b/internal/datastore/common/sql_test.go index 2e19772245..62d6fbeea1 100644 --- a/internal/datastore/common/sql_test.go +++ b/internal/datastore/common/sql_test.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "testing" + "time" "github.com/authzed/spicedb/pkg/datastore/options" @@ -803,7 +804,7 @@ func TestSchemaQueryFilterer(t *testing.T) { require.ElementsMatch(t, expected.staticCols, foundStaticColumns) - ran.queryBuilder = ran.queryBuilderWithExpirationFilter(test.withExpirationDisabled).Columns("*") + ran.queryBuilder = ran.queryBuilderWithMaybeExpirationFilter(test.withExpirationDisabled).Columns("*") sql, args, err := ran.queryBuilder.ToSql() require.NoError(t, err) @@ -822,49 +823,58 @@ func TestExecuteQuery(t *testing.T) { options []options.QueryOptionsOption expectedSQL string expectedArgs []any + expectedStaticColCount int expectedSkipCaveats bool expectedSkipExpiration bool withExpirationDisabled bool + withIntegrityEnabled bool + fromSuffix string + limit uint64 }{ { name: "filter by static resource type", run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.FilterToResourceType("sometype") }, - expectedSQL: "SELECT object_id, relation, subject_ns, subject_object_id, subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE ns = ? AND (expiration IS NULL OR expiration > NOW())", - expectedArgs: []any{"sometype"}, + expectedSQL: "SELECT object_id, relation, subject_ns, subject_object_id, subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE ns = ? AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"sometype"}, + expectedStaticColCount: 1, }, { name: "filter by static resource type and resource ID", run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.FilterToResourceType("sometype").FilterToResourceID("someobj") }, - expectedSQL: "SELECT relation, subject_ns, subject_object_id, subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE ns = ? AND object_id = ? AND (expiration IS NULL OR expiration > NOW())", - expectedArgs: []any{"sometype", "someobj"}, + expectedSQL: "SELECT relation, subject_ns, subject_object_id, subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE ns = ? AND object_id = ? AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"sometype", "someobj"}, + expectedStaticColCount: 2, }, { name: "filter by static resource type and resource ID prefix", run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.FilterToResourceType("sometype").MustFilterWithResourceIDPrefix("someprefix") }, - expectedSQL: "SELECT object_id, relation, subject_ns, subject_object_id, subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE ns = ? AND object_id LIKE ? AND (expiration IS NULL OR expiration > NOW())", - expectedArgs: []any{"sometype", "someprefix%"}, + expectedSQL: "SELECT object_id, relation, subject_ns, subject_object_id, subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE ns = ? AND object_id LIKE ? AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"sometype", "someprefix%"}, + expectedStaticColCount: 1, }, { name: "filter by static resource type and resource IDs", run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.FilterToResourceType("sometype").MustFilterToResourceIDs([]string{"someobj", "anotherobj"}) }, - expectedSQL: "SELECT object_id, relation, subject_ns, subject_object_id, subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE ns = ? AND object_id IN (?,?) AND (expiration IS NULL OR expiration > NOW())", - expectedArgs: []any{"sometype", "someobj", "anotherobj"}, + expectedSQL: "SELECT object_id, relation, subject_ns, subject_object_id, subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE ns = ? AND object_id IN (?,?) AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"sometype", "someobj", "anotherobj"}, + expectedStaticColCount: 1, }, { name: "filter by static resource type, resource ID and relation", run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { return filterer.FilterToResourceType("sometype").FilterToResourceID("someobj").FilterToRelation("somerel") }, - expectedSQL: "SELECT subject_ns, subject_object_id, subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE ns = ? AND object_id = ? AND relation = ? AND (expiration IS NULL OR expiration > NOW())", - expectedArgs: []any{"sometype", "someobj", "somerel"}, + expectedSQL: "SELECT subject_ns, subject_object_id, subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE ns = ? AND object_id = ? AND relation = ? AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"sometype", "someobj", "somerel"}, + expectedStaticColCount: 3, }, { name: "filter by static resource type, resource ID, relation and subject type", @@ -873,8 +883,9 @@ func TestExecuteQuery(t *testing.T) { SubjectType: "subns", }) }, - expectedSQL: "SELECT subject_object_id, subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE ns = ? AND object_id = ? AND relation = ? AND subject_ns = ? AND (expiration IS NULL OR expiration > NOW())", - expectedArgs: []any{"sometype", "someobj", "somerel", "subns"}, + expectedSQL: "SELECT subject_object_id, subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE ns = ? AND object_id = ? AND relation = ? AND subject_ns = ? AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"sometype", "someobj", "somerel", "subns"}, + expectedStaticColCount: 4, }, { name: "filter by static resource type, resource ID, relation, subject type and subject ID", @@ -884,8 +895,9 @@ func TestExecuteQuery(t *testing.T) { OptionalSubjectId: "subid", }) }, - expectedSQL: "SELECT subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE ns = ? AND object_id = ? AND relation = ? AND subject_ns = ? AND subject_object_id = ? AND (expiration IS NULL OR expiration > NOW())", - expectedArgs: []any{"sometype", "someobj", "somerel", "subns", "subid"}, + expectedSQL: "SELECT subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE ns = ? AND object_id = ? AND relation = ? AND subject_ns = ? AND subject_object_id = ? AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"sometype", "someobj", "somerel", "subns", "subid"}, + expectedStaticColCount: 5, }, { name: "filter by static resource type, resource ID, relation, subject type, subject ID and subject relation", @@ -898,8 +910,9 @@ func TestExecuteQuery(t *testing.T) { }, }) }, - expectedSQL: "SELECT caveat, caveat_context, expiration FROM relationtuples WHERE ns = ? AND object_id = ? AND relation = ? AND subject_ns = ? AND subject_object_id = ? AND subject_relation = ? AND (expiration IS NULL OR expiration > NOW())", - expectedArgs: []any{"sometype", "someobj", "somerel", "subns", "subid", "subrel"}, + expectedSQL: "SELECT caveat, caveat_context, expiration FROM relationtuples WHERE ns = ? AND object_id = ? AND relation = ? AND subject_ns = ? AND subject_object_id = ? AND subject_relation = ? AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"sometype", "someobj", "somerel", "subns", "subid", "subrel"}, + expectedStaticColCount: 6, }, { name: "filter by static everything without caveats", @@ -915,9 +928,10 @@ func TestExecuteQuery(t *testing.T) { options: []options.QueryOptionsOption{ options.WithSkipCaveats(true), }, - expectedSkipCaveats: true, - expectedSQL: "SELECT expiration FROM relationtuples WHERE ns = ? AND object_id = ? AND relation = ? AND subject_ns = ? AND subject_object_id = ? AND subject_relation = ? AND (expiration IS NULL OR expiration > NOW())", - expectedArgs: []any{"sometype", "someobj", "somerel", "subns", "subid", "subrel"}, + expectedSkipCaveats: true, + expectedSQL: "SELECT expiration FROM relationtuples WHERE ns = ? AND object_id = ? AND relation = ? AND subject_ns = ? AND subject_object_id = ? AND subject_relation = ? AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"sometype", "someobj", "somerel", "subns", "subid", "subrel"}, + expectedStaticColCount: 6, }, { name: "filter by static everything (except one field) without caveats", @@ -933,9 +947,10 @@ func TestExecuteQuery(t *testing.T) { options: []options.QueryOptionsOption{ options.WithSkipCaveats(true), }, - expectedSkipCaveats: true, - expectedSQL: "SELECT object_id, expiration FROM relationtuples WHERE ns = ? AND object_id IN (?,?) AND relation = ? AND subject_ns = ? AND subject_object_id = ? AND subject_relation = ? AND (expiration IS NULL OR expiration > NOW())", - expectedArgs: []any{"sometype", "someobj", "anotherobj", "somerel", "subns", "subid", "subrel"}, + expectedSkipCaveats: true, + expectedSQL: "SELECT object_id, expiration FROM relationtuples WHERE ns = ? AND object_id IN (?,?) AND relation = ? AND subject_ns = ? AND subject_object_id = ? AND subject_relation = ? AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"sometype", "someobj", "anotherobj", "somerel", "subns", "subid", "subrel"}, + expectedStaticColCount: 5, }, { name: "filter by static resource type with no caveats", @@ -945,9 +960,10 @@ func TestExecuteQuery(t *testing.T) { options: []options.QueryOptionsOption{ options.WithSkipCaveats(true), }, - expectedSkipCaveats: true, - expectedSQL: "SELECT object_id, relation, subject_ns, subject_object_id, subject_relation, expiration FROM relationtuples WHERE ns = ? AND (expiration IS NULL OR expiration > NOW())", - expectedArgs: []any{"sometype"}, + expectedSkipCaveats: true, + expectedSQL: "SELECT object_id, relation, subject_ns, subject_object_id, subject_relation, expiration FROM relationtuples WHERE ns = ? AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"sometype"}, + expectedStaticColCount: 1, }, { name: "filter by just subject type", @@ -956,8 +972,9 @@ func TestExecuteQuery(t *testing.T) { SubjectType: "subns", }) }, - expectedSQL: "SELECT ns, object_id, relation, subject_object_id, subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE subject_ns = ? AND (expiration IS NULL OR expiration > NOW())", - expectedArgs: []any{"subns"}, + expectedSQL: "SELECT ns, object_id, relation, subject_object_id, subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE subject_ns = ? AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"subns"}, + expectedStaticColCount: 1, }, { name: "filter by just subject type and subject ID", @@ -967,8 +984,9 @@ func TestExecuteQuery(t *testing.T) { OptionalSubjectId: "subid", }) }, - expectedSQL: "SELECT ns, object_id, relation, subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE subject_ns = ? AND subject_object_id = ? AND (expiration IS NULL OR expiration > NOW())", - expectedArgs: []any{"subns", "subid"}, + expectedSQL: "SELECT ns, object_id, relation, subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE subject_ns = ? AND subject_object_id = ? AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"subns", "subid"}, + expectedStaticColCount: 2, }, { name: "filter by just subject type and subject relation", @@ -980,8 +998,9 @@ func TestExecuteQuery(t *testing.T) { }, }) }, - expectedSQL: "SELECT ns, object_id, relation, subject_object_id, caveat, caveat_context, expiration FROM relationtuples WHERE subject_ns = ? AND subject_relation = ? AND (expiration IS NULL OR expiration > NOW())", - expectedArgs: []any{"subns", "subrel"}, + expectedSQL: "SELECT ns, object_id, relation, subject_object_id, caveat, caveat_context, expiration FROM relationtuples WHERE subject_ns = ? AND subject_relation = ? AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"subns", "subrel"}, + expectedStaticColCount: 2, }, { name: "filter by just subject type and subject ID and relation", @@ -994,8 +1013,9 @@ func TestExecuteQuery(t *testing.T) { }, }) }, - expectedSQL: "SELECT ns, object_id, relation, caveat, caveat_context, expiration FROM relationtuples WHERE subject_ns = ? AND subject_object_id = ? AND subject_relation = ? AND (expiration IS NULL OR expiration > NOW())", - expectedArgs: []any{"subns", "subid", "subrel"}, + expectedSQL: "SELECT ns, object_id, relation, caveat, caveat_context, expiration FROM relationtuples WHERE subject_ns = ? AND subject_object_id = ? AND subject_relation = ? AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"subns", "subid", "subrel"}, + expectedStaticColCount: 3, }, { name: "filter by multiple subject types, but static subject ID", @@ -1008,8 +1028,9 @@ func TestExecuteQuery(t *testing.T) { OptionalSubjectId: "subid", }) }, - expectedSQL: "SELECT ns, object_id, relation, subject_ns, subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE subject_ns = ? AND subject_object_id = ? AND subject_ns = ? AND subject_object_id = ? AND (expiration IS NULL OR expiration > NOW())", - expectedArgs: []any{"subns", "subid", "anothersubns", "subid"}, + expectedSQL: "SELECT ns, object_id, relation, subject_ns, subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE subject_ns = ? AND subject_object_id = ? AND subject_ns = ? AND subject_object_id = ? AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"subns", "subid", "anothersubns", "subid"}, + expectedStaticColCount: 1, }, { name: "multiple subjects filters with just types", @@ -1020,8 +1041,9 @@ func TestExecuteQuery(t *testing.T) { OptionalSubjectType: "anothersubjectype", }) }, - expectedSQL: "SELECT ns, object_id, relation, subject_ns, subject_object_id, subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE ((subject_ns = ?) OR (subject_ns = ?)) AND (expiration IS NULL OR expiration > NOW())", - expectedArgs: []any{"somesubjectype", "anothersubjectype"}, + expectedSQL: "SELECT ns, object_id, relation, subject_ns, subject_object_id, subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE ((subject_ns = ?) OR (subject_ns = ?)) AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"somesubjectype", "anothersubjectype"}, + expectedStaticColCount: 0, }, { name: "multiple subjects filters with just types and static resource type", @@ -1032,8 +1054,9 @@ func TestExecuteQuery(t *testing.T) { OptionalSubjectType: "anothersubjectype", }).FilterToResourceType("sometype") }, - expectedSQL: "SELECT object_id, relation, subject_ns, subject_object_id, subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE ((subject_ns = ?) OR (subject_ns = ?)) AND ns = ? AND (expiration IS NULL OR expiration > NOW())", - expectedArgs: []any{"somesubjectype", "anothersubjectype", "sometype"}, + expectedSQL: "SELECT object_id, relation, subject_ns, subject_object_id, subject_relation, caveat, caveat_context, expiration FROM relationtuples WHERE ((subject_ns = ?) OR (subject_ns = ?)) AND ns = ? AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"somesubjectype", "anothersubjectype", "sometype"}, + expectedStaticColCount: 1, }, { name: "filter by static resource type with expiration disabled", @@ -1043,6 +1066,7 @@ func TestExecuteQuery(t *testing.T) { expectedSQL: "SELECT object_id, relation, subject_ns, subject_object_id, subject_relation, caveat, caveat_context FROM relationtuples WHERE ns = ?", expectedArgs: []any{"sometype"}, withExpirationDisabled: true, + expectedStaticColCount: 1, }, { name: "filter by static resource type with expiration skipped", @@ -1056,6 +1080,7 @@ func TestExecuteQuery(t *testing.T) { options: []options.QueryOptionsOption{ options.WithSkipExpiration(true), }, + expectedStaticColCount: 1, }, { name: "filter by static resource type with expiration skipped and disabled", @@ -1069,6 +1094,182 @@ func TestExecuteQuery(t *testing.T) { options: []options.QueryOptionsOption{ options.WithSkipExpiration(true), }, + expectedStaticColCount: 1, + }, + { + name: "with from suffix", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + return filterer.FilterToResourceType("sometype") + }, + expectedSQL: "SELECT object_id, relation, subject_ns, subject_object_id, subject_relation, caveat, caveat_context FROM relationtuples as of tomorrow WHERE ns = ?", + expectedArgs: []any{"sometype"}, + withExpirationDisabled: true, + fromSuffix: "as of tomorrow", + expectedStaticColCount: 1, + }, + { + name: "with limit", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + return filterer.FilterToResourceType("sometype") + }, + expectedSQL: "SELECT object_id, relation, subject_ns, subject_object_id, subject_relation, caveat, caveat_context FROM relationtuples WHERE ns = ? LIMIT 65", + expectedArgs: []any{"sometype"}, + withExpirationDisabled: true, + limit: 65, + expectedStaticColCount: 1, + }, + { + name: "with integrity", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + return filterer.FilterToResourceType("sometype") + }, + expectedSQL: "SELECT object_id, relation, subject_ns, subject_object_id, subject_relation, caveat, caveat_context, integrity_key_id, integrity_hash, integrity_timestamp FROM relationtuples WHERE ns = ?", + expectedArgs: []any{"sometype"}, + withExpirationDisabled: true, + withIntegrityEnabled: true, + expectedStaticColCount: 1, + }, + { + name: "all columns static with caveats", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + return filterer.FilterToResourceType("sometype"). + FilterToResourceID("someobj"). + FilterToRelation("somerel"). + FilterToSubjectFilter(&v1.SubjectFilter{ + SubjectType: "subns", + OptionalSubjectId: "subid", + OptionalRelation: &v1.SubjectFilter_RelationFilter{ + Relation: "subrel", + }, + }) + }, + expectedSQL: "SELECT caveat, caveat_context FROM relationtuples WHERE ns = ? AND object_id = ? AND relation = ? AND subject_ns = ? AND subject_object_id = ? AND subject_relation = ?", + expectedArgs: []any{"sometype", "someobj", "somerel", "subns", "subid", "subrel"}, + withExpirationDisabled: true, + expectedSkipExpiration: true, + options: []options.QueryOptionsOption{ + options.WithSkipExpiration(true), + }, + expectedStaticColCount: 6, + }, + { + name: "all columns static with expiration", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + return filterer.FilterToResourceType("sometype"). + FilterToResourceID("someobj"). + FilterToRelation("somerel"). + FilterToSubjectFilter(&v1.SubjectFilter{ + SubjectType: "subns", + OptionalSubjectId: "subid", + OptionalRelation: &v1.SubjectFilter_RelationFilter{ + Relation: "subrel", + }, + }) + }, + expectedSQL: "SELECT expiration FROM relationtuples WHERE ns = ? AND object_id = ? AND relation = ? AND subject_ns = ? AND subject_object_id = ? AND subject_relation = ? AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"sometype", "someobj", "somerel", "subns", "subid", "subrel"}, + expectedSkipCaveats: true, + options: []options.QueryOptionsOption{ + options.WithSkipCaveats(true), + }, + expectedStaticColCount: 6, + }, + { + name: "all columns static with caveats and expiration", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + return filterer.FilterToResourceType("sometype"). + FilterToResourceID("someobj"). + FilterToRelation("somerel"). + FilterToSubjectFilter(&v1.SubjectFilter{ + SubjectType: "subns", + OptionalSubjectId: "subid", + OptionalRelation: &v1.SubjectFilter_RelationFilter{ + Relation: "subrel", + }, + }) + }, + expectedSQL: "SELECT caveat, caveat_context, expiration FROM relationtuples WHERE ns = ? AND object_id = ? AND relation = ? AND subject_ns = ? AND subject_object_id = ? AND subject_relation = ? AND (expiration IS NULL OR expiration > NOW())", + expectedArgs: []any{"sometype", "someobj", "somerel", "subns", "subid", "subrel"}, + expectedStaticColCount: 6, + }, + { + name: "all columns static without caveats", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + return filterer.FilterToResourceType("sometype"). + FilterToResourceID("someobj"). + FilterToRelation("somerel"). + FilterToSubjectFilter(&v1.SubjectFilter{ + SubjectType: "subns", + OptionalSubjectId: "subid", + OptionalRelation: &v1.SubjectFilter_RelationFilter{ + Relation: "subrel", + }, + }) + }, + expectedSQL: "SELECT 1 FROM relationtuples WHERE ns = ? AND object_id = ? AND relation = ? AND subject_ns = ? AND subject_object_id = ? AND subject_relation = ?", + expectedArgs: []any{"sometype", "someobj", "somerel", "subns", "subid", "subrel"}, + withExpirationDisabled: true, + expectedSkipExpiration: true, + expectedSkipCaveats: true, + options: []options.QueryOptionsOption{ + options.WithSkipCaveats(true), + options.WithSkipExpiration(true), + }, + expectedStaticColCount: -1, + }, + { + name: "one column not static", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + f := filterer.FilterToResourceType("sometype"). + FilterToRelation("somerel"). + FilterToSubjectFilter(&v1.SubjectFilter{ + SubjectType: "subns", + OptionalSubjectId: "subid", + OptionalRelation: &v1.SubjectFilter_RelationFilter{ + Relation: "subrel", + }, + }) + + f2, _ := f.FilterToResourceIDs([]string{"foo", "bar"}) + return f2 + }, + expectedSQL: "SELECT object_id FROM relationtuples WHERE ns = ? AND relation = ? AND subject_ns = ? AND subject_object_id = ? AND subject_relation = ? AND object_id IN (?,?)", + expectedArgs: []any{"sometype", "somerel", "subns", "subid", "subrel", "foo", "bar"}, + withExpirationDisabled: true, + expectedSkipExpiration: true, + expectedSkipCaveats: true, + options: []options.QueryOptionsOption{ + options.WithSkipCaveats(true), + options.WithSkipExpiration(true), + }, + expectedStaticColCount: 5, + }, + { + name: "resource ID prefix", + run: func(filterer SchemaQueryFilterer) SchemaQueryFilterer { + f := filterer.FilterToResourceType("sometype"). + FilterToRelation("somerel"). + FilterToSubjectFilter(&v1.SubjectFilter{ + SubjectType: "subns", + OptionalSubjectId: "subid", + OptionalRelation: &v1.SubjectFilter_RelationFilter{ + Relation: "subrel", + }, + }) + + f2, _ := f.FilterWithResourceIDPrefix("foo") + return f2 + }, + expectedSQL: "SELECT object_id FROM relationtuples WHERE ns = ? AND relation = ? AND subject_ns = ? AND subject_object_id = ? AND subject_relation = ? AND object_id LIKE ?", + expectedArgs: []any{"sometype", "somerel", "subns", "subid", "subrel", "foo%"}, + withExpirationDisabled: true, + expectedSkipExpiration: true, + expectedSkipCaveats: true, + options: []options.QueryOptionsOption{ + options.WithSkipCaveats(true), + options.WithSkipExpiration(true), + }, + expectedStaticColCount: 5, }, } @@ -1087,13 +1288,22 @@ func TestExecuteQuery(t *testing.T) { WithColCaveatName("caveat"), WithColCaveatContext("caveat_context"), WithColExpiration("expiration"), + WithColIntegrityHash("integrity_hash"), + WithColIntegrityKeyID("integrity_key_id"), + WithColIntegrityTimestamp("integrity_timestamp"), WithPlaceholderFormat(sq.Question), WithPaginationFilterType(filterType), WithColumnOptimization(ColumnOptimizationOptionStaticValues), WithNowFunction("NOW"), + WithIntegrityEnabled(tc.withIntegrityEnabled), WithExpirationDisabled(tc.withExpirationDisabled), ) filterer := NewSchemaQueryFiltererForRelationshipsSelect(*schema, 100) + filterer = filterer.WithFromSuffix(tc.fromSuffix) + if tc.limit > 0 { + filterer = filterer.limit(tc.limit) + } + ran := tc.run(filterer) var wasRun bool @@ -1107,6 +1317,46 @@ func TestExecuteQuery(t *testing.T) { require.Equal(t, tc.expectedArgs, args) require.Equal(t, tc.expectedSkipCaveats, builder.SkipCaveats) require.Equal(t, tc.expectedSkipExpiration, builder.SkipExpiration) + + // 6 standard columns for relationships: + // ns, object_id, relation, subject_ns, subject_object_id, subject_relation + expectedColCount := 6 - tc.expectedStaticColCount + if !tc.expectedSkipCaveats { + // caveat, caveat_context + expectedColCount += 2 + } + if !tc.expectedSkipExpiration && !tc.withExpirationDisabled { + // expiration + expectedColCount++ + } + if tc.withIntegrityEnabled { + // integrity_key_id, integrity_hash, integrity_timestamp + expectedColCount += 3 + } + + if tc.expectedStaticColCount == -1 { + // SELECT 1 + expectedColCount = 1 + } + + var resourceObjectType string + var resourceObjectID string + var resourceRelation string + var subjectObjectType string + var subjectObjectID string + var subjectRelation string + var caveatName *string + var caveatCtx map[string]any + var expiration *time.Time + + var integrityKeyID string + var integrityHash []byte + var timestamp time.Time + + colsToSelect, err := ColumnsToSelect(builder, &resourceObjectType, &resourceObjectID, &resourceRelation, &subjectObjectType, &subjectObjectID, &subjectRelation, &caveatName, &caveatCtx, &expiration, &integrityKeyID, &integrityHash, ×tamp) + require.NoError(t, err) + require.Equal(t, expectedColCount, len(colsToSelect)) + return nil, nil }, } @@ -1118,3 +1368,39 @@ func TestExecuteQuery(t *testing.T) { }) } } + +func TestNewSchemaQueryFiltererWithStartingQuery(t *testing.T) { + schema := NewSchemaInformationWithOptions( + WithRelationshipTableName("relationtuples"), + WithColNamespace("ns"), + WithColObjectID("object_id"), + WithColRelation("relation"), + WithColUsersetNamespace("subject_ns"), + WithColUsersetObjectID("subject_object_id"), + WithColUsersetRelation("subject_relation"), + WithColCaveatName("caveat"), + WithColCaveatContext("caveat_context"), + WithColExpiration("expiration"), + WithPlaceholderFormat(sq.Question), + WithPaginationFilterType(TupleComparison), + WithColumnOptimization(ColumnOptimizationOptionStaticValues), + WithNowFunction("NOW"), + WithExpirationDisabled(true), + ) + + sql := sq.StatementBuilder.PlaceholderFormat(sq.AtP) + query := sql.Select("COUNT(*)").From("sometable") + filterer := NewSchemaQueryFiltererWithStartingQuery(*schema, query, 50) + filterer = filterer.MustFilterToResourceIDs([]string{"someid"}) + filterer = filterer.WithAdditionalFilter(func(original sq.SelectBuilder) sq.SelectBuilder { + return original.Where("somecoolclause") + }) + + sqlQuery, args, err := filterer.UnderlyingQueryBuilder().ToSql() + require.NoError(t, err) + + expectedSQL := "SELECT COUNT(*) FROM sometable WHERE object_id IN (@p1) AND somecoolclause" + expectedArgs := []any{"someid"} + require.Equal(t, expectedSQL, sqlQuery) + require.Equal(t, expectedArgs, args) +} diff --git a/internal/datastore/common/schema_options.go b/internal/datastore/common/zz_generated.schema_options.go similarity index 95% rename from internal/datastore/common/schema_options.go rename to internal/datastore/common/zz_generated.schema_options.go index fa7639776e..04b6088a36 100644 --- a/internal/datastore/common/schema_options.go +++ b/internal/datastore/common/zz_generated.schema_options.go @@ -48,7 +48,7 @@ func (s *SchemaInformation) ToOption() SchemaInformationOption { to.PlaceholderFormat = s.PlaceholderFormat to.NowFunction = s.NowFunction to.ColumnOptimization = s.ColumnOptimization - to.WithIntegrityColumns = s.WithIntegrityColumns + to.IntegrityEnabled = s.IntegrityEnabled to.ExpirationDisabled = s.ExpirationDisabled } } @@ -73,7 +73,7 @@ func (s SchemaInformation) DebugMap() map[string]any { debugMap["PlaceholderFormat"] = helpers.DebugValue(s.PlaceholderFormat, false) debugMap["NowFunction"] = helpers.DebugValue(s.NowFunction, false) debugMap["ColumnOptimization"] = helpers.DebugValue(s.ColumnOptimization, false) - debugMap["WithIntegrityColumns"] = helpers.DebugValue(s.WithIntegrityColumns, false) + debugMap["IntegrityEnabled"] = helpers.DebugValue(s.IntegrityEnabled, false) debugMap["ExpirationDisabled"] = helpers.DebugValue(s.ExpirationDisabled, false) return debugMap } @@ -213,10 +213,10 @@ func WithColumnOptimization(columnOptimization ColumnOptimizationOption) SchemaI } } -// WithWithIntegrityColumns returns an option that can set WithIntegrityColumns on a SchemaInformation -func WithWithIntegrityColumns(withIntegrityColumns bool) SchemaInformationOption { +// WithIntegrityEnabled returns an option that can set IntegrityEnabled on a SchemaInformation +func WithIntegrityEnabled(integrityEnabled bool) SchemaInformationOption { return func(s *SchemaInformation) { - s.WithIntegrityColumns = withIntegrityColumns + s.IntegrityEnabled = integrityEnabled } } diff --git a/internal/datastore/crdb/crdb.go b/internal/datastore/crdb/crdb.go index 3f27d68d12..f4929967ed 100644 --- a/internal/datastore/crdb/crdb.go +++ b/internal/datastore/crdb/crdb.go @@ -223,7 +223,7 @@ func newCRDBDatastore(ctx context.Context, url string, options ...Option) (datas common.WithPlaceholderFormat(sq.Dollar), common.WithNowFunction("NOW"), common.WithColumnOptimization(config.columnOptimizationOption), - common.WithWithIntegrityColumns(config.withIntegrity), + common.WithIntegrityEnabled(config.withIntegrity), common.WithExpirationDisabled(config.expirationDisabled), ) From f71404b70848e8e3d4ecc1601d672908bf29a8bf Mon Sep 17 00:00:00 2001 From: Joseph Schorr Date: Wed, 8 Jan 2025 13:40:12 -0500 Subject: [PATCH 15/15] Address further feedback: - Have QueryRelationships add events to the parent span directly - Cleanup CRDB system time handling and add debug-time assertions for "as of system time" - Additional testing --- internal/datastore/common/relationships.go | 5 +- internal/datastore/common/sql.go | 13 +- internal/datastore/crdb/caveat.go | 6 +- internal/datastore/crdb/crdb.go | 40 ++--- internal/datastore/crdb/reader.go | 94 ++++++++--- internal/datastore/crdb/stats.go | 2 +- internal/datastore/mysql/datastore.go | 8 +- internal/datastore/postgres/common/pgx.go | 4 +- internal/graph/check.go | 90 +++++++---- internal/graph/check_isolated_test.go | 148 ++++++++++++++++++ internal/services/v1/permissions_test.go | 4 - pkg/datastore/options/options.go | 3 + .../options/zz_generated.query_options.go | 9 ++ pkg/spiceerrors/assert_off.go | 2 + pkg/spiceerrors/assert_on.go | 2 + 15 files changed, 336 insertions(+), 94 deletions(-) create mode 100644 internal/graph/check_isolated_test.go diff --git a/internal/datastore/common/relationships.go b/internal/datastore/common/relationships.go index a8f09569a9..4543a46013 100644 --- a/internal/datastore/common/relationships.go +++ b/internal/datastore/common/relationships.go @@ -40,9 +40,8 @@ type closeRows interface { } // QueryRelationships queries relationships for the given query and transaction. -func QueryRelationships[R Rows, C ~map[string]any](ctx context.Context, builder RelationshipsQueryBuilder, span trace.Span, tx Querier[R]) (datastore.RelationshipIterator, error) { - defer span.End() - +func QueryRelationships[R Rows, C ~map[string]any](ctx context.Context, builder RelationshipsQueryBuilder, tx Querier[R]) (datastore.RelationshipIterator, error) { + span := trace.SpanFromContext(ctx) sqlString, args, err := builder.SelectSQL() if err != nil { return nil, fmt.Errorf(errUnableToQueryRels, err) diff --git a/internal/datastore/common/sql.go b/internal/datastore/common/sql.go index db6b277df2..733ca30bba 100644 --- a/internal/datastore/common/sql.go +++ b/internal/datastore/common/sql.go @@ -666,6 +666,7 @@ func (exc QueryRelationshipsExecutor) ExecuteQuery( Schema: query.schema, SkipCaveats: queryOpts.SkipCaveats, SkipExpiration: queryOpts.SkipExpiration, + sqlAssertion: queryOpts.SQLAssertion, filteringValues: query.filteringColumnTracker, baseQueryBuilder: query, } @@ -682,6 +683,7 @@ type RelationshipsQueryBuilder struct { filteringValues columnTrackerMap baseQueryBuilder SchemaQueryFilterer + sqlAssertion options.Assertion } // withCaveats returns true if caveats should be included in the query. @@ -745,7 +747,16 @@ func (b RelationshipsQueryBuilder) SelectSQL() (string, []any, error) { sqlBuilder := b.baseQueryBuilder.queryBuilderWithMaybeExpirationFilter(b.SkipExpiration) sqlBuilder = sqlBuilder.Columns(columnNamesToSelect...) - return sqlBuilder.ToSql() + sql, args, err := sqlBuilder.ToSql() + if err != nil { + return "", nil, err + } + + if b.sqlAssertion != nil { + b.sqlAssertion(sql) + } + + return sql, args, nil } // FilteringValuesForTesting returns the filtering values. For test use only. diff --git a/internal/datastore/crdb/caveat.go b/internal/datastore/crdb/caveat.go index ebaae37301..94c74f6553 100644 --- a/internal/datastore/crdb/caveat.go +++ b/internal/datastore/crdb/caveat.go @@ -35,11 +35,12 @@ const ( ) func (cr *crdbReader) ReadCaveatByName(ctx context.Context, name string) (*core.CaveatDefinition, datastore.Revision, error) { - query := cr.fromWithAsOfSystemTime(readCaveat.Where(sq.Eq{colCaveatName: name}), tableCaveat) + query := cr.addFromToQuery(readCaveat.Where(sq.Eq{colCaveatName: name}), tableCaveat) sql, args, err := query.ToSql() if err != nil { return nil, datastore.NoRevision, fmt.Errorf(errReadCaveat, name, err) } + cr.assertHasExpectedAsOfSystemTime(sql) var definitionBytes []byte var timestamp time.Time @@ -79,7 +80,7 @@ type bytesAndTimestamp struct { } func (cr *crdbReader) lookupCaveats(ctx context.Context, caveatNames []string) ([]datastore.RevisionedCaveat, error) { - caveatsWithNames := cr.fromWithAsOfSystemTime(listCaveat, tableCaveat) + caveatsWithNames := cr.addFromToQuery(listCaveat, tableCaveat) if len(caveatNames) > 0 { caveatsWithNames = caveatsWithNames.Where(sq.Eq{colCaveatName: caveatNames}) } @@ -88,6 +89,7 @@ func (cr *crdbReader) lookupCaveats(ctx context.Context, caveatNames []string) ( if err != nil { return nil, fmt.Errorf(errListCaveats, err) } + cr.assertHasExpectedAsOfSystemTime(sql) var allDefinitionBytes []bytesAndTimestamp diff --git a/internal/datastore/crdb/crdb.go b/internal/datastore/crdb/crdb.go index f4929967ed..517afe0c2c 100644 --- a/internal/datastore/crdb/crdb.go +++ b/internal/datastore/crdb/crdb.go @@ -351,13 +351,16 @@ func (cds *crdbDatastore) SnapshotReader(rev datastore.Revision) datastore.Reade executor := common.QueryRelationshipsExecutor{ Executor: pgxcommon.NewPGXQueryRelationshipsExecutor(cds.readPool), } - - withAsOfSystemTime := func(query sq.SelectBuilder, tableName string) sq.SelectBuilder { - return query.From(tableName + " AS OF SYSTEM TIME " + rev.String()) + return &crdbReader{ + schema: cds.schema, + query: cds.readPool, + executor: executor, + keyer: noOverlapKeyer, + overlapKeySet: nil, + filterMaximumIDCount: cds.filterMaximumIDCount, + withIntegrity: cds.supportsIntegrity, + atSpecificRevision: rev.String(), } - - asOfSystemTimeSuffix := "AS OF SYSTEM TIME " + rev.String() - return &crdbReader{cds.readPool, executor, noOverlapKeyer, nil, withAsOfSystemTime, asOfSystemTimeSuffix, cds.filterMaximumIDCount, cds.schema, cds.supportsIntegrity} } func (cds *crdbDatastore) ReadWriteTx( @@ -399,20 +402,19 @@ func (cds *crdbDatastore) ReadWriteTx( return fmt.Errorf("error writing metadata: %w", err) } + reader := &crdbReader{ + schema: cds.schema, + query: querier, + executor: executor, + keyer: cds.writeOverlapKeyer, + overlapKeySet: cds.overlapKeyInit(ctx), + filterMaximumIDCount: cds.filterMaximumIDCount, + withIntegrity: cds.supportsIntegrity, + atSpecificRevision: "", // No AS OF SYSTEM TIME for writes + } + rwt := &crdbReadWriteTXN{ - &crdbReader{ - querier, - executor, - cds.writeOverlapKeyer, - cds.overlapKeyInit(ctx), - func(query sq.SelectBuilder, tableName string) sq.SelectBuilder { - return query.From(tableName) - }, - "", // No AS OF SYSTEM TIME for writes - cds.filterMaximumIDCount, - cds.schema, - cds.supportsIntegrity, - }, + reader, tx, 0, } diff --git a/internal/datastore/crdb/reader.go b/internal/datastore/crdb/reader.go index e252a16c92..db44aaadd6 100644 --- a/internal/datastore/crdb/reader.go +++ b/internal/datastore/crdb/reader.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "strings" "time" sq "github.com/Masterminds/squirrel" @@ -16,6 +17,7 @@ import ( "github.com/authzed/spicedb/pkg/datastore" "github.com/authzed/spicedb/pkg/datastore/options" core "github.com/authzed/spicedb/pkg/proto/core/v1" + "github.com/authzed/spicedb/pkg/spiceerrors" ) const ( @@ -38,15 +40,42 @@ var ( ) type crdbReader struct { - query pgxcommon.DBFuncQuerier - executor common.QueryRelationshipsExecutor - keyer overlapKeyer - overlapKeySet keySet - fromWithAsOfSystemTime func(query sq.SelectBuilder, tableName string) sq.SelectBuilder - asOfSystemTimeSuffix string - filterMaximumIDCount uint16 - schema common.SchemaInformation - withIntegrity bool + schema common.SchemaInformation + query pgxcommon.DBFuncQuerier + executor common.QueryRelationshipsExecutor + keyer overlapKeyer + overlapKeySet keySet + filterMaximumIDCount uint16 + withIntegrity bool + atSpecificRevision string +} + +const asOfSystemTime = "AS OF SYSTEM TIME" + +func (cr *crdbReader) addFromToQuery(query sq.SelectBuilder, tableName string) sq.SelectBuilder { + if cr.atSpecificRevision == "" { + return query.From(tableName) + } + + return query.From(tableName + " " + asOfSystemTime + " " + cr.atSpecificRevision) +} + +func (cr *crdbReader) fromSuffix() string { + if cr.atSpecificRevision == "" { + return "" + } + + return " " + asOfSystemTime + " " + cr.atSpecificRevision +} + +func (cr *crdbReader) assertHasExpectedAsOfSystemTime(sql string) { + spiceerrors.DebugAssert(func() bool { + if cr.atSpecificRevision == "" { + return !strings.Contains(sql, "AS OF SYSTEM TIME") + } else { + return strings.Contains(sql, "AS OF SYSTEM TIME") + } + }, "mismatch in AS OF SYSTEM TIME in query: %s", sql) } func (cr *crdbReader) CountRelationships(ctx context.Context, name string) (int, error) { @@ -64,7 +93,7 @@ func (cr *crdbReader) CountRelationships(ctx context.Context, name string) (int, return 0, err } - query := cr.fromWithAsOfSystemTime(countRels, cr.schema.RelationshipTableName) + query := cr.addFromToQuery(countRels, cr.schema.RelationshipTableName) builder, err := common.NewSchemaQueryFiltererWithStartingQuery(cr.schema, query, cr.filterMaximumIDCount).FilterWithRelationshipsFilter(relFilter) if err != nil { return 0, err @@ -74,6 +103,7 @@ func (cr *crdbReader) CountRelationships(ctx context.Context, name string) (int, if err != nil { return 0, err } + cr.assertHasExpectedAsOfSystemTime(sql) var count int err = cr.query.QueryRowFunc(ctx, func(ctx context.Context, row pgx.Row) error { @@ -93,7 +123,7 @@ func (cr *crdbReader) LookupCounters(ctx context.Context) ([]datastore.Relations } func (cr *crdbReader) lookupCounters(ctx context.Context, optionalFilterName string) ([]datastore.RelationshipCounter, error) { - query := cr.fromWithAsOfSystemTime(queryCounters, tableRelationshipCounter) + query := cr.addFromToQuery(queryCounters, tableRelationshipCounter) if optionalFilterName != noFilterOnCounterName { query = query.Where(sq.Eq{colCounterName: optionalFilterName}) } @@ -102,6 +132,7 @@ func (cr *crdbReader) lookupCounters(ctx context.Context, optionalFilterName str if err != nil { return nil, err } + cr.assertHasExpectedAsOfSystemTime(sql) var counters []datastore.RelationshipCounter err = cr.query.QueryFunc(ctx, func(ctx context.Context, rows pgx.Rows) error { @@ -165,10 +196,11 @@ func (cr *crdbReader) ReadNamespaceByName( } func (cr *crdbReader) ListAllNamespaces(ctx context.Context) ([]datastore.RevisionedNamespace, error) { - nsDefs, err := loadAllNamespaces(ctx, cr.query, cr.fromWithAsOfSystemTime) + nsDefs, sql, err := loadAllNamespaces(ctx, cr.query, cr.addFromToQuery) if err != nil { return nil, fmt.Errorf(errUnableToListNamespaces, err) } + cr.assertHasExpectedAsOfSystemTime(sql) return nsDefs, nil } @@ -188,11 +220,15 @@ func (cr *crdbReader) QueryRelationships( filter datastore.RelationshipsFilter, opts ...options.QueryOptionsOption, ) (iter datastore.RelationshipIterator, err error) { - qBuilder, err := common.NewSchemaQueryFiltererForRelationshipsSelect(cr.schema, cr.filterMaximumIDCount).WithFromSuffix(cr.asOfSystemTimeSuffix).FilterWithRelationshipsFilter(filter) + qBuilder, err := common.NewSchemaQueryFiltererForRelationshipsSelect(cr.schema, cr.filterMaximumIDCount).WithFromSuffix(cr.fromSuffix()).FilterWithRelationshipsFilter(filter) if err != nil { return nil, err } + if spiceerrors.DebugAssertionsEnabled { + opts = append(opts, options.WithSQLAssertion(cr.assertHasExpectedAsOfSystemTime)) + } + return cr.executor.ExecuteQuery(ctx, qBuilder, opts...) } @@ -202,34 +238,43 @@ func (cr *crdbReader) ReverseQueryRelationships( opts ...options.ReverseQueryOptionsOption, ) (iter datastore.RelationshipIterator, err error) { qBuilder, err := common.NewSchemaQueryFiltererForRelationshipsSelect(cr.schema, cr.filterMaximumIDCount). - WithFromSuffix(cr.asOfSystemTimeSuffix). + WithFromSuffix(cr.fromSuffix()). FilterWithSubjectsSelectors(subjectsFilter.AsSelector()) if err != nil { return nil, err } queryOpts := options.NewReverseQueryOptionsWithOptions(opts...) - if queryOpts.ResRelation != nil { qBuilder = qBuilder. FilterToResourceType(queryOpts.ResRelation.Namespace). FilterToRelation(queryOpts.ResRelation.Relation) } + eopts := []options.QueryOptionsOption{ + options.WithLimit(queryOpts.LimitForReverse), + options.WithAfter(queryOpts.AfterForReverse), + options.WithSort(queryOpts.SortForReverse), + } + + if spiceerrors.DebugAssertionsEnabled { + eopts = append(eopts, options.WithSQLAssertion(cr.assertHasExpectedAsOfSystemTime)) + } + return cr.executor.ExecuteQuery( ctx, qBuilder, - options.WithLimit(queryOpts.LimitForReverse), - options.WithAfter(queryOpts.AfterForReverse), - options.WithSort(queryOpts.SortForReverse)) + eopts..., + ) } func (cr crdbReader) loadNamespace(ctx context.Context, tx pgxcommon.DBFuncQuerier, nsName string) (*core.NamespaceDefinition, time.Time, error) { - query := cr.fromWithAsOfSystemTime(queryReadNamespace, tableNamespace).Where(sq.Eq{colNamespace: nsName}) + query := cr.addFromToQuery(queryReadNamespace, tableNamespace).Where(sq.Eq{colNamespace: nsName}) sql, args, err := query.ToSql() if err != nil { return nil, time.Time{}, err } + cr.assertHasExpectedAsOfSystemTime(sql) var config []byte var timestamp time.Time @@ -258,11 +303,12 @@ func (cr crdbReader) lookupNamespaces(ctx context.Context, tx pgxcommon.DBFuncQu clause = append(clause, sq.Eq{colNamespace: nsName}) } - query := cr.fromWithAsOfSystemTime(queryReadNamespace, tableNamespace).Where(clause) + query := cr.addFromToQuery(queryReadNamespace, tableNamespace).Where(clause) sql, args, err := query.ToSql() if err != nil { return nil, err } + cr.assertHasExpectedAsOfSystemTime(sql) var nsDefs []datastore.RevisionedNamespace @@ -297,11 +343,11 @@ func (cr crdbReader) lookupNamespaces(ctx context.Context, tx pgxcommon.DBFuncQu return nsDefs, nil } -func loadAllNamespaces(ctx context.Context, tx pgxcommon.DBFuncQuerier, fromBuilder func(sq.SelectBuilder, string) sq.SelectBuilder) ([]datastore.RevisionedNamespace, error) { +func loadAllNamespaces(ctx context.Context, tx pgxcommon.DBFuncQuerier, fromBuilder func(sq.SelectBuilder, string) sq.SelectBuilder) ([]datastore.RevisionedNamespace, string, error) { query := fromBuilder(queryReadNamespace, tableNamespace) sql, args, err := query.ToSql() if err != nil { - return nil, err + return nil, sql, err } var nsDefs []datastore.RevisionedNamespace @@ -331,10 +377,10 @@ func loadAllNamespaces(ctx context.Context, tx pgxcommon.DBFuncQuerier, fromBuil return nil }, sql, args...) if err != nil { - return nil, err + return nil, sql, err } - return nsDefs, nil + return nsDefs, sql, nil } func (cr *crdbReader) addOverlapKey(namespace string) { diff --git a/internal/datastore/crdb/stats.go b/internal/datastore/crdb/stats.go index b01a1f3722..c468747b1a 100644 --- a/internal/datastore/crdb/stats.go +++ b/internal/datastore/crdb/stats.go @@ -44,7 +44,7 @@ func (cds *crdbDatastore) Statistics(ctx context.Context) (datastore.Stats, erro if err != nil { return fmt.Errorf("unable to read namespaces: %w", err) } - nsDefs, err = loadAllNamespaces(ctx, pgxcommon.QuerierFuncsFor(tx), func(sb squirrel.SelectBuilder, tableName string) squirrel.SelectBuilder { + nsDefs, _, err = loadAllNamespaces(ctx, pgxcommon.QuerierFuncsFor(tx), func(sb squirrel.SelectBuilder, tableName string) squirrel.SelectBuilder { return sb.From(tableName) }) if err != nil { diff --git a/internal/datastore/mysql/datastore.go b/internal/datastore/mysql/datastore.go index 1462e6fbf6..3ccd1b8327 100644 --- a/internal/datastore/mysql/datastore.go +++ b/internal/datastore/mysql/datastore.go @@ -18,7 +18,6 @@ import ( "github.com/google/uuid" "github.com/prometheus/client_golang/prometheus" "go.opentelemetry.io/otel" - "go.opentelemetry.io/otel/trace" "golang.org/x/sync/errgroup" datastoreinternal "github.com/authzed/spicedb/internal/datastore" @@ -446,8 +445,8 @@ func (aqt asQueryableTx) QueryFunc(ctx context.Context, f func(context.Context, return err } - if rows.Err() != nil { - return rows.Err() + if err := rows.Err(); err != nil { + return err } return f(ctx, rows) @@ -470,8 +469,7 @@ func newMySQLExecutor(tx querier) common.ExecuteReadRelsQueryFunc { // Prepared statements are also not used given they perform poorly on environments where connections have // short lifetime (e.g. to gracefully handle load-balancer connection drain) return func(ctx context.Context, builder common.RelationshipsQueryBuilder) (datastore.RelationshipIterator, error) { - span := trace.SpanFromContext(ctx) - return common.QueryRelationships[common.Rows, structpbWrapper](ctx, builder, span, asQueryableTx{tx}) + return common.QueryRelationships[common.Rows, structpbWrapper](ctx, builder, asQueryableTx{tx}) } } diff --git a/internal/datastore/postgres/common/pgx.go b/internal/datastore/postgres/common/pgx.go index 012e908bcc..f0530cb754 100644 --- a/internal/datastore/postgres/common/pgx.go +++ b/internal/datastore/postgres/common/pgx.go @@ -14,7 +14,6 @@ import ( "github.com/jackc/pgx/v5/pgxpool" "github.com/jackc/pgx/v5/tracelog" "github.com/rs/zerolog" - "go.opentelemetry.io/otel/trace" "github.com/authzed/spicedb/internal/datastore/common" log "github.com/authzed/spicedb/internal/logging" @@ -24,8 +23,7 @@ import ( // NewPGXQueryRelationshipsExecutor creates an executor that uses the pgx library to make the specified queries. func NewPGXQueryRelationshipsExecutor(querier DBFuncQuerier) common.ExecuteReadRelsQueryFunc { return func(ctx context.Context, builder common.RelationshipsQueryBuilder) (datastore.RelationshipIterator, error) { - span := trace.SpanFromContext(ctx) - return common.QueryRelationships[pgx.Rows, map[string]any](ctx, builder, span, querier) + return common.QueryRelationships[pgx.Rows, map[string]any](ctx, builder, querier) } } diff --git a/internal/graph/check.go b/internal/graph/check.go index f5001275db..d6fff3198a 100644 --- a/internal/graph/check.go +++ b/internal/graph/check.go @@ -3,6 +3,7 @@ package graph import ( "context" "errors" + "fmt" "time" "github.com/google/uuid" @@ -322,20 +323,14 @@ func (cc *ConcurrentChecker) checkDirect(ctx context.Context, crc currentRequest // 2) the wildcard form of the target subject, if a wildcard is allowed on this relation // 3) Otherwise, any non-terminal (non-`...`) subjects, if allowed on this relation, to be // redispatched outward - hasNonTerminals := false - hasDirectSubject := false - hasWildcardSubject := false - - directSubjectOrWildcardCanHaveCaveats := false - directSubjectOrWildcardCanHaveExpiration := false - - nonTerminalsCanHaveCaveats := false - nonTerminalsCanHaveExpiration := false + totalNonTerminals := 0 + totalDirectSubjects := 0 + totalWildcardSubjects := 0 defer func() { - if hasNonTerminals { + if totalNonTerminals > 0 { span.SetName("non terminal") - } else if hasDirectSubject { + } else if totalDirectSubjects > 0 { span.SetName("terminal") } else { span.SetName("wildcard subject") @@ -344,6 +339,11 @@ func (cc *ConcurrentChecker) checkDirect(ctx context.Context, crc currentRequest log.Ctx(ctx).Trace().Object("direct", crc.parentReq).Send() ds := datastoremw.MustFromContext(ctx).SnapshotReader(crc.parentReq.Revision) + directSubjectsAndWildcardsWithoutCaveats := 0 + directSubjectsAndWildcardsWithoutExpiration := 0 + nonTerminalsWithoutCaveats := 0 + nonTerminalsWithoutExpiration := 0 + for _, allowedDirectRelation := range relation.GetTypeInformation().GetAllowedDirectRelations() { // If the namespace of the allowed direct relation matches the subject type, there are two // cases to optimize: @@ -351,17 +351,17 @@ func (cc *ConcurrentChecker) checkDirect(ctx context.Context, crc currentRequest // 2) Finding a wildcard for the subject type+relation if allowedDirectRelation.GetNamespace() == crc.parentReq.Subject.Namespace { if allowedDirectRelation.GetPublicWildcard() != nil { - hasWildcardSubject = true + totalWildcardSubjects++ } else if allowedDirectRelation.GetRelation() == crc.parentReq.Subject.Relation { - hasDirectSubject = true + totalDirectSubjects++ } - if allowedDirectRelation.RequiredCaveat != nil { - directSubjectOrWildcardCanHaveCaveats = true + if allowedDirectRelation.RequiredCaveat == nil { + directSubjectsAndWildcardsWithoutCaveats++ } - if allowedDirectRelation.RequiredExpiration != nil { - directSubjectOrWildcardCanHaveExpiration = true + if allowedDirectRelation.RequiredExpiration == nil { + directSubjectsAndWildcardsWithoutExpiration++ } } @@ -371,16 +371,20 @@ func (cc *ConcurrentChecker) checkDirect(ctx context.Context, crc currentRequest // TODO(jschorr): Use type information to *further* optimize this query around which nested // relations can reach the target subject type. if allowedDirectRelation.GetRelation() != tuple.Ellipsis { - hasNonTerminals = true - if allowedDirectRelation.RequiredCaveat != nil { - nonTerminalsCanHaveCaveats = true + totalNonTerminals++ + if allowedDirectRelation.RequiredCaveat == nil { + nonTerminalsWithoutCaveats++ } - if allowedDirectRelation.RequiredExpiration != nil { - nonTerminalsCanHaveExpiration = true + if allowedDirectRelation.RequiredExpiration == nil { + nonTerminalsWithoutExpiration++ } } } + nonTerminalsCanHaveCaveats := totalNonTerminals != nonTerminalsWithoutCaveats + nonTerminalsCanHaveExpiration := totalNonTerminals != nonTerminalsWithoutExpiration + hasNonTerminals := totalNonTerminals > 0 + foundResources := NewMembershipSet() // If the direct subject or a wildcard form can be found, issue a query for just that @@ -390,7 +394,12 @@ func (cc *ConcurrentChecker) checkDirect(ctx context.Context, crc currentRequest directDispatchQueryHistogram.Observe(queryCount) }() + hasDirectSubject := totalDirectSubjects > 0 + hasWildcardSubject := totalWildcardSubjects > 0 if hasDirectSubject || hasWildcardSubject { + directSubjectOrWildcardCanHaveCaveats := directSubjectsAndWildcardsWithoutCaveats != (totalDirectSubjects + totalWildcardSubjects) + directSubjectOrWildcardCanHaveExpiration := directSubjectsAndWildcardsWithoutExpiration != (totalDirectSubjects + totalWildcardSubjects) + subjectSelectors := []datastore.SubjectsSelector{} if hasDirectSubject { @@ -665,17 +674,22 @@ func (cc *ConcurrentChecker) checkComputedUserset(ctx context.Context, crc curre return combineResultWithFoundResources(result, membershipSet) } -// queryOptionsForArrowRelation returns query options such as SkipCaveats and SkipExpiration if *none* of the subject +type Traits struct { + HasCaveats bool + HasExpiration bool +} + +// TraitsForArrowRelation returns traits such as HasCaveats and HasExpiration if *any* of the subject // types of the given relation support caveats or expiration. -func (cc *ConcurrentChecker) queryOptionsForArrowRelation(ctx context.Context, reader datastore.Reader, namespaceName string, relationName string) ([]options.QueryOptionsOption, error) { +func TraitsForArrowRelation(ctx context.Context, reader datastore.Reader, namespaceName string, relationName string) (Traits, error) { // TODO(jschorr): Change to use the type system once we wire it through Check dispatch. nsDefs, err := reader.LookupNamespacesWithNames(ctx, []string{namespaceName}) if err != nil { - return nil, err + return Traits{}, err } if len(nsDefs) != 1 { - return nil, nil + return Traits{}, fmt.Errorf("namespace %q not found", namespaceName) } var relation *core.Relation @@ -687,7 +701,7 @@ func (cc *ConcurrentChecker) queryOptionsForArrowRelation(ctx context.Context, r } if relation == nil || relation.TypeInformation == nil { - return nil, nil + return Traits{}, fmt.Errorf("relation %q not found", relationName) } hasCaveats := false @@ -703,12 +717,24 @@ func (cc *ConcurrentChecker) queryOptionsForArrowRelation(ctx context.Context, r } } - opts := make([]options.QueryOptionsOption, 0, 2) - if !hasCaveats { + return Traits{ + HasCaveats: hasCaveats, + HasExpiration: hasExpiration, + }, nil +} + +func queryOptionsForArrowRelation(ctx context.Context, ds datastore.Reader, namespaceName string, relationName string) ([]options.QueryOptionsOption, error) { + traits, err := TraitsForArrowRelation(ctx, ds, namespaceName, relationName) + if err != nil { + return nil, err + } + + opts := []options.QueryOptionsOption{} + if !traits.HasCaveats { opts = append(opts, options.WithSkipCaveats(true)) } - if !hasExpiration { + if !traits.HasExpiration { opts = append(opts, options.WithSkipExpiration(true)) } @@ -765,7 +791,7 @@ func checkIntersectionTupleToUserset( // Query for the subjects over which to walk the TTU. log.Ctx(ctx).Trace().Object("intersectionttu", crc.parentReq).Send() ds := datastoremw.MustFromContext(ctx).SnapshotReader(crc.parentReq.Revision) - queryOpts, err := cc.queryOptionsForArrowRelation(ctx, ds, crc.parentReq.ResourceRelation.Namespace, ttu.GetTupleset().GetRelation()) + queryOpts, err := queryOptionsForArrowRelation(ctx, ds, crc.parentReq.ResourceRelation.Namespace, ttu.GetTupleset().GetRelation()) if err != nil { return checkResultError(NewCheckFailureErr(err), emptyMetadata) } @@ -932,7 +958,7 @@ func checkTupleToUserset[T relation]( log.Ctx(ctx).Trace().Object("ttu", crc.parentReq).Send() ds := datastoremw.MustFromContext(ctx).SnapshotReader(crc.parentReq.Revision) - queryOpts, err := cc.queryOptionsForArrowRelation(ctx, ds, crc.parentReq.ResourceRelation.Namespace, ttu.GetTupleset().GetRelation()) + queryOpts, err := queryOptionsForArrowRelation(ctx, ds, crc.parentReq.ResourceRelation.Namespace, ttu.GetTupleset().GetRelation()) if err != nil { return checkResultError(NewCheckFailureErr(err), emptyMetadata) } diff --git a/internal/graph/check_isolated_test.go b/internal/graph/check_isolated_test.go new file mode 100644 index 0000000000..0957026a20 --- /dev/null +++ b/internal/graph/check_isolated_test.go @@ -0,0 +1,148 @@ +package graph_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/authzed/spicedb/internal/datastore/dsfortesting" + "github.com/authzed/spicedb/internal/datastore/memdb" + "github.com/authzed/spicedb/internal/graph" + "github.com/authzed/spicedb/internal/testfixtures" +) + +func TestTraitsForArrowRelation(t *testing.T) { + tcs := []struct { + name string + schema string + namespaceName string + relationName string + expectedTraits graph.Traits + expectedError string + }{ + { + name: "unknown namespace", + schema: `definition user {}`, + namespaceName: "unknown", + relationName: "unknown", + expectedTraits: graph.Traits{}, + expectedError: "not found", + }, + { + name: "unknown relation", + schema: `definition resource {}`, + namespaceName: "resource", + relationName: "unknown", + expectedTraits: graph.Traits{}, + expectedError: "not found", + }, + { + name: "known relation with all optimizations", + schema: ` + definition folder {} + + definition resource { + relation folder: folder + }`, + namespaceName: "resource", + relationName: "folder", + expectedTraits: graph.Traits{}, + }, + { + name: "known relation with caveats", + schema: ` + definition folder {} + + caveat somecaveat(somecondition int) { + somecondition == 42 + } + + definition resource { + relation folder: folder with somecaveat + }`, + namespaceName: "resource", + relationName: "folder", + expectedTraits: graph.Traits{ + HasCaveats: true, + }, + }, + { + name: "known relation with expiration", + schema: ` + use expiration + + definition folder {} + + definition resource { + relation folder: folder with expiration + }`, + namespaceName: "resource", + relationName: "folder", + expectedTraits: graph.Traits{ + HasExpiration: true, + }, + }, + { + name: "known relation with caveats and expiration", + schema: ` + use expiration + + caveat somecaveat(somecondition int) { + somecondition == 42 + } + + definition folder {} + + definition resource { + relation folder: folder with somecaveat and expiration + }`, + namespaceName: "resource", + relationName: "folder", + expectedTraits: graph.Traits{ + HasCaveats: true, + HasExpiration: true, + }, + }, + { + name: "different relation with caveats and expiration", + schema: ` + use expiration + + caveat somecaveat(somecondition int) { + somecondition == 42 + } + + definition folder {} + + definition resource { + relation folder: folder + relation folder2: folder with somecaveat and expiration + }`, + namespaceName: "resource", + relationName: "folder", + expectedTraits: graph.Traits{}, + }, + } + + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + require := require.New(t) + + rawDS, err := dsfortesting.NewMemDBDatastoreForTesting(0, 0, memdb.DisableGC) + require.NoError(err) + + ds, revision := testfixtures.DatastoreFromSchemaAndTestRelationships(rawDS, tc.schema, nil, require) + reader := ds.SnapshotReader(revision) + + traits, err := graph.TraitsForArrowRelation(context.Background(), reader, tc.namespaceName, tc.relationName) + if tc.expectedError != "" { + require.ErrorContains(err, tc.expectedError) + return + } + + require.NoError(err) + require.Equal(tc.expectedTraits, traits) + }) + } +} diff --git a/internal/services/v1/permissions_test.go b/internal/services/v1/permissions_test.go index 7bb0c87b72..d68d3039da 100644 --- a/internal/services/v1/permissions_test.go +++ b/internal/services/v1/permissions_test.go @@ -2013,10 +2013,6 @@ func TestCheckBulkPermissions(t *testing.T) { actual, err := client.CheckBulkPermissions(context.Background(), &req, grpc.Trailer(&trailer)) require.NoError(t, err) - dispatchCount, err := responsemeta.GetIntResponseTrailerMetadata(trailer, responsemeta.DispatchedOperationsCount) - require.NoError(t, err) - require.Equal(t, tt.expectedDispatchCount, dispatchCount) - if withTracing { for index, pair := range actual.Pairs { if pair.GetItem() != nil { diff --git a/pkg/datastore/options/options.go b/pkg/datastore/options/options.go index 749005e136..5bf6bbd40e 100644 --- a/pkg/datastore/options/options.go +++ b/pkg/datastore/options/options.go @@ -41,6 +41,8 @@ func ToRelationship(c Cursor) *tuple.Relationship { return (*tuple.Relationship)(c) } +type Assertion func(sql string) + // QueryOptions are the options that can affect the results of a normal forward query. type QueryOptions struct { Limit *uint64 `debugmap:"visible"` @@ -48,6 +50,7 @@ type QueryOptions struct { After Cursor `debugmap:"visible"` SkipCaveats bool `debugmap:"visible"` SkipExpiration bool `debugmap:"visible"` + SQLAssertion Assertion `debugmap:"visible"` } // ReverseQueryOptions are the options that can affect the results of a reverse query. diff --git a/pkg/datastore/options/zz_generated.query_options.go b/pkg/datastore/options/zz_generated.query_options.go index 493e1c73d2..348c53639f 100644 --- a/pkg/datastore/options/zz_generated.query_options.go +++ b/pkg/datastore/options/zz_generated.query_options.go @@ -36,6 +36,7 @@ func (q *QueryOptions) ToOption() QueryOptionsOption { to.After = q.After to.SkipCaveats = q.SkipCaveats to.SkipExpiration = q.SkipExpiration + to.SQLAssertion = q.SQLAssertion } } @@ -47,6 +48,7 @@ func (q QueryOptions) DebugMap() map[string]any { debugMap["After"] = helpers.DebugValue(q.After, false) debugMap["SkipCaveats"] = helpers.DebugValue(q.SkipCaveats, false) debugMap["SkipExpiration"] = helpers.DebugValue(q.SkipExpiration, false) + debugMap["SQLAssertion"] = helpers.DebugValue(q.SQLAssertion, false) return debugMap } @@ -101,6 +103,13 @@ func WithSkipExpiration(skipExpiration bool) QueryOptionsOption { } } +// WithSQLAssertion returns an option that can set SQLAssertion on a QueryOptions +func WithSQLAssertion(sQLAssertion Assertion) QueryOptionsOption { + return func(q *QueryOptions) { + q.SQLAssertion = sQLAssertion + } +} + type ReverseQueryOptionsOption func(r *ReverseQueryOptions) // NewReverseQueryOptionsWithOptions creates a new ReverseQueryOptions with the passed in options set diff --git a/pkg/spiceerrors/assert_off.go b/pkg/spiceerrors/assert_off.go index fa0ac4731e..20ac7ef717 100644 --- a/pkg/spiceerrors/assert_off.go +++ b/pkg/spiceerrors/assert_off.go @@ -3,6 +3,8 @@ package spiceerrors +const DebugAssertionsEnabled = false + // DebugAssert is a no-op in non-CI builds func DebugAssert(condition func() bool, format string, args ...any) { // Do nothing on purpose diff --git a/pkg/spiceerrors/assert_on.go b/pkg/spiceerrors/assert_on.go index b71f8de614..a21414791e 100644 --- a/pkg/spiceerrors/assert_on.go +++ b/pkg/spiceerrors/assert_on.go @@ -8,6 +8,8 @@ import ( "runtime" ) +const DebugAssertionsEnabled = true + // DebugAssert panics if the condition is false in CI builds. func DebugAssert(condition func() bool, format string, args ...any) { if !condition() {