Skip to content

Commit

Permalink
[ENH] Implement GetCollectionWithSegments endpoint for SysDB (#3243)
Browse files Browse the repository at this point in the history
## Description of changes

*Summarize the changes made by this PR.*
 - Improvements & Bug fixes
   - N/A
 - New functionality
   - Implement the `GetCollectionWithSegments` interface for SysDB, which should return information about a consistent snapshot for a collection and its segments

## Test plan
*How are these changes tested?*

- [x] Tests pass locally with `pytest` for python, `yarn test` for js, `cargo test` for rust

## Documentation Changes
*Are all docstrings for user-facing APIs updated if required? Do we need to make documentation changes in the [docs repository](https://github.com/chroma-core/docs)?*
  • Loading branch information
Sicheng-Pan authored Dec 17, 2024
1 parent 6558172 commit 433cc37
Show file tree
Hide file tree
Showing 9 changed files with 258 additions and 43 deletions.
62 changes: 33 additions & 29 deletions chromadb/proto/coordinator_pb2.py

Large diffs are not rendered by default.

14 changes: 14 additions & 0 deletions chromadb/proto/coordinator_pb2.pyi

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

33 changes: 33 additions & 0 deletions chromadb/proto/coordinator_pb2_grpc.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions go/pkg/sysdb/coordinator/coordinator.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,10 @@ func (s *Coordinator) GetCollections(ctx context.Context, collectionID types.Uni
return s.catalog.GetCollections(ctx, collectionID, collectionName, tenantID, databaseName, limit, offset)
}

func (s *Coordinator) GetCollectionWithSegments(ctx context.Context, collectionID types.UniqueID) (*model.Collection, []*model.Segment, error) {
return s.catalog.GetCollectionWithSegments(ctx, collectionID)
}

func (s *Coordinator) CheckCollection(ctx context.Context, collectionID types.UniqueID) (bool, error) {
return s.catalog.CheckCollection(ctx, collectionID)
}
Expand Down
15 changes: 15 additions & 0 deletions go/pkg/sysdb/coordinator/coordinator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,21 @@ func (suite *APIsTestSuite) TestCreateCollectionAndSegments() {
suite.Equal(segment.ID, segmentResult[0].ID)
}

// The same information should be returned by the GetCollectionWithSegments endpoint
collection, collection_segments, error := suite.coordinator.GetCollectionWithSegments(ctx, newCollection.ID)
suite.NoError(error)
suite.Equal(newCollection.ID, collection.ID)
suite.Equal(newCollection.Name, collection.Name)
expected_ids, actual_ids := []types.UniqueID{}, []types.UniqueID{}
for _, segment := range segments {
expected_ids = append(expected_ids, segment.ID)
}
for _, segment := range collection_segments {
suite.Equal(collection.ID, segment.CollectionID)
actual_ids = append(actual_ids, segment.ID)
}
suite.ElementsMatch(expected_ids, actual_ids)

// Attempt to create a duplicate collection (should fail)
_, _, err = suite.coordinator.CreateCollectionAndSegments(ctx, newCollection, segments)
suite.Error(err)
Expand Down
37 changes: 37 additions & 0 deletions go/pkg/sysdb/coordinator/table_catalog.go
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,43 @@ func (tc *Catalog) GetCollections(ctx context.Context, collectionID types.Unique
return collections, nil
}

func (tc *Catalog) GetCollectionWithSegments(ctx context.Context, collectionID types.UniqueID) (*model.Collection, []*model.Segment, error) {
tracer := otel.Tracer
if tracer != nil {
_, span := tracer.Start(ctx, "Catalog.GetCollections")
defer span.End()
}

var collection *model.Collection
var segments []*model.Segment

err := tc.txImpl.Transaction(ctx, func(txCtx context.Context) error {
collections, e := tc.GetCollections(ctx, collectionID, nil, "", "", nil, nil)
if e != nil {
return e
}
if len(collections) == 0 {
return common.ErrCollectionNotFound
}
if len(collections) > 1 {
return common.ErrCollectionUniqueConstraintViolation
}
collection = collections[0]

segments, e = tc.GetSegments(ctx, types.NilUniqueID(), nil, nil, collectionID)
if e != nil {
return e
}

return nil
})
if err != nil {
return nil, nil, err
}

return collection, segments, nil
}

func (tc *Catalog) DeleteCollection(ctx context.Context, deleteCollection *model.DeleteCollection, softDelete bool) error {
if softDelete {
return tc.softDeleteCollection(ctx, deleteCollection)
Expand Down
47 changes: 47 additions & 0 deletions go/pkg/sysdb/grpc/collection_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package grpc
import (
"context"
"encoding/json"
"fmt"

"github.com/chroma-core/chroma/go/pkg/grpcutils"

Expand Down Expand Up @@ -169,6 +170,52 @@ func (s *Server) CheckCollections(ctx context.Context, req *coordinatorpb.CheckC
return res, nil
}

func (s *Server) GetCollectionWithSegments(ctx context.Context, req *coordinatorpb.GetCollectionWithSegmentsRequest) (*coordinatorpb.GetCollectionWithSegmentsResponse, error) {
collectionID := req.Id

res := &coordinatorpb.GetCollectionWithSegmentsResponse{}

parsedCollectionID, err := types.ToUniqueID(&collectionID)
if err != nil {
log.Error("GetCollectionWithSegments failed. collection id format error", zap.Error(err), zap.String("collection_id", collectionID))
return res, grpcutils.BuildInternalGrpcError(err.Error())
}

collection, segments, err := s.coordinator.GetCollectionWithSegments(ctx, parsedCollectionID)
if err != nil {
log.Error("GetCollectionWithSegments failed. ", zap.Error(err), zap.String("collection_id", collectionID))
return res, grpcutils.BuildInternalGrpcError(err.Error())
}

res.Collection = convertCollectionToProto(collection)
segmentpbList := make([]*coordinatorpb.Segment, 0, len(segments))
scopeToSegmentMap := map[coordinatorpb.SegmentScope]*coordinatorpb.Segment{}
for _, segment := range segments {
segmentpb := convertSegmentToProto(segment)
scopeToSegmentMap[segmentpb.GetScope()] = segmentpb
segmentpbList = append(segmentpbList, segmentpb)
}

if len(segmentpbList) != 3 {
log.Error("GetCollectionWithSegments failed. Unexpected number of collection segments", zap.String("collection_id", collectionID))
return res, grpcutils.BuildInternalGrpcError(fmt.Sprintf("Unexpected number of segments for collection %s: %d", collectionID, len(segmentpbList)))
}

scopes := []coordinatorpb.SegmentScope{coordinatorpb.SegmentScope_METADATA, coordinatorpb.SegmentScope_RECORD, coordinatorpb.SegmentScope_VECTOR}

for _, scope := range scopes {
if _, exists := scopeToSegmentMap[scope]; !exists {
log.Error("GetCollectionWithSegments failed. Collection segment scope not found", zap.String("collection_id", collectionID), zap.String("missing_scope", scope.String()))
return res, grpcutils.BuildInternalGrpcError(fmt.Sprintf("Missing segment scope for collection %s: %s", collectionID, scope.String()))
}
}

res.Segments = segmentpbList

log.Info("GetCollectionWithSegments succeeded", zap.String("request", req.String()), zap.String("response", res.String()))
return res, nil
}

func (s *Server) DeleteCollection(ctx context.Context, req *coordinatorpb.DeleteCollectionRequest) (*coordinatorpb.DeleteCollectionResponse, error) {
collectionID := req.GetId()
res := &coordinatorpb.DeleteCollectionResponse{}
Expand Down
79 changes: 65 additions & 14 deletions go/pkg/sysdb/grpc/collection_service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,16 +81,19 @@ func testCollection(t *rapid.T) {
var collectionsWithErrors []*coordinatorpb.Collection

t.Repeat(map[string]func(*rapid.T){
"create_collection": func(t *rapid.T) {
"create_get_collection": func(t *rapid.T) {
stringValue := generateStringMetadataValue(t)
intValue := generateInt64MetadataValue(t)
floatValue := generateFloat64MetadataValue(t)
getOrCreate := false

collectionId := rapid.StringMatching(`[0-9a-f]{8}-[0-9a-f]{4}-4[0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}`).Draw(t, "collection_id")
collectionName := rapid.String().Draw(t, "collection_name")

createCollectionRequest := rapid.Custom[*coordinatorpb.CreateCollectionRequest](func(t *rapid.T) *coordinatorpb.CreateCollectionRequest {
return &coordinatorpb.CreateCollectionRequest{
Id: rapid.StringMatching(`[0-9a-f]{8}-[0-9a-f]{4}-4[0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}`).Draw(t, "collection_id"),
Name: rapid.String().Draw(t, "collection_name"),
Id: collectionId,
Name: collectionName,
Metadata: &coordinatorpb.UpdateMetadata{
Metadata: map[string]*coordinatorpb.UpdateMetadataValue{
"string_value": stringValue,
Expand All @@ -99,6 +102,26 @@ func testCollection(t *rapid.T) {
},
},
GetOrCreate: &getOrCreate,
Segments: []*coordinatorpb.Segment{
{
Id: rapid.StringMatching(`[0-9a-f]{8}-[0-9a-f]{4}-4[0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}`).Draw(t, "metadata_segment_id"),
Type: "metadata_segment_type",
Scope: coordinatorpb.SegmentScope_METADATA,
Collection: collectionId,
},
{
Id: rapid.StringMatching(`[0-9a-f]{8}-[0-9a-f]{4}-4[0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}`).Draw(t, "record_segment_id"),
Type: "record_segment_type",
Scope: coordinatorpb.SegmentScope_RECORD,
Collection: collectionId,
},
{
Id: rapid.StringMatching(`[0-9a-f]{8}-[0-9a-f]{4}-4[0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}`).Draw(t, "vector_segment_id"),
Type: "vector_segment_type",
Scope: coordinatorpb.SegmentScope_VECTOR,
Collection: collectionId,
},
},
}
}).Draw(t, "create_collection_request")

Expand All @@ -114,29 +137,57 @@ func testCollection(t *rapid.T) {
}
}

getCollectionsRequest := coordinatorpb.GetCollectionsRequest{
Id: &createCollectionRequest.Id,
}
if err == nil {
getCollectionsRequest := coordinatorpb.GetCollectionsRequest{
Id: &createCollectionRequest.Id,
}
// verify the correctness
GetCollectionsResponse, err := s.GetCollections(ctx, &getCollectionsRequest)
getCollectionsResponse, err := s.GetCollections(ctx, &getCollectionsRequest)
if err != nil {
t.Fatalf("error getting collections: %v", err)
}
collectionList := GetCollectionsResponse.GetCollections()
collectionList := getCollectionsResponse.GetCollections()
if len(collectionList) != 1 {
t.Fatalf("More than 1 collection with the same collection id")
t.Fatalf("there should be exactly one matching collection given the collection id")
}
if collectionList[0].Id != createCollectionRequest.Id {
t.Fatalf("collection id mismatch")
}

getCollectionWithSegmentsRequest := coordinatorpb.GetCollectionWithSegmentsRequest{
Id: createCollectionRequest.Id,
}

getCollectionWithSegmentsResponse, err := s.GetCollectionWithSegments(ctx, &getCollectionWithSegmentsRequest)
if err != nil {
t.Fatalf("error getting collection with segments: %v", err)
}

if getCollectionWithSegmentsResponse.Collection.Id != res.Collection.Id {
t.Fatalf("collection id mismatch")
}

if len(getCollectionWithSegmentsResponse.Segments) != 3 {
t.Fatalf("unexpected number of segments in collection: %v", getCollectionWithSegmentsResponse.Segments)
}
for _, collection := range collectionList {
if collection.Id != createCollectionRequest.Id {
t.Fatalf("collection id is the right value")

scopeToSegmentMap := map[coordinatorpb.SegmentScope]*coordinatorpb.Segment{}
for _, segment := range getCollectionWithSegmentsResponse.Segments {
if segment.Collection != res.Collection.Id {
t.Fatalf("invalid collection id in segment")
}
scopeToSegmentMap[segment.GetScope()] = segment
}
scopes := []coordinatorpb.SegmentScope{coordinatorpb.SegmentScope_METADATA, coordinatorpb.SegmentScope_RECORD, coordinatorpb.SegmentScope_VECTOR}
for _, scope := range scopes {
if _, exists := scopeToSegmentMap[scope]; !exists {
t.Fatalf("collection segment scope not found: %s", scope.String())
}
}

state = append(state, res.Collection)
}
},
"get_collections": func(t *rapid.T) {
},
})
}

Expand Down
10 changes: 10 additions & 0 deletions idl/chromadb/proto/coordinator.proto
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,15 @@ message GetCollectionsResponse {
reserved "status";
}

message GetCollectionWithSegmentsRequest {
string id = 1;
}

message GetCollectionWithSegmentsResponse {
Collection collection = 1;
repeated Segment segments = 2;
}

message CheckCollectionsRequest {
repeated string collection_ids = 1;
}
Expand Down Expand Up @@ -219,6 +228,7 @@ service SysDB {
rpc CreateCollection(CreateCollectionRequest) returns (CreateCollectionResponse) {}
rpc DeleteCollection(DeleteCollectionRequest) returns (DeleteCollectionResponse) {}
rpc GetCollections(GetCollectionsRequest) returns (GetCollectionsResponse) {}
rpc GetCollectionWithSegments(GetCollectionWithSegmentsRequest) returns (GetCollectionWithSegmentsResponse) {}
rpc CheckCollections(CheckCollectionsRequest) returns (CheckCollectionsResponse) {}
rpc UpdateCollection(UpdateCollectionRequest) returns (UpdateCollectionResponse) {}
rpc ResetState(google.protobuf.Empty) returns (ResetStateResponse) {}
Expand Down

0 comments on commit 433cc37

Please sign in to comment.