diff --git a/.release-please-manifest-individual.json b/.release-please-manifest-individual.json index 1a5cdadf7c90..101f7c5be41f 100644 --- a/.release-please-manifest-individual.json +++ b/.release-please-manifest-individual.json @@ -1,7 +1,7 @@ { "ai": "0.8.2", "aiplatform": "1.68.0", - "auth": "0.9.0", + "auth": "0.9.1", "auth/oauth2adapt": "0.2.4", "bigquery": "1.62.0", "bigtable": "1.29.0", @@ -14,5 +14,5 @@ "pubsublite": "1.8.2", "spanner": "1.67.0", "storage": "1.43.0", - "vertexai": "0.12.0" + "vertexai": "0.13.0" } diff --git a/auth/CHANGES.md b/auth/CHANGES.md index 8042bdae809b..ea6df0cafa7a 100644 --- a/auth/CHANGES.md +++ b/auth/CHANGES.md @@ -1,5 +1,12 @@ # Changelog +## [0.9.1](https://github.com/googleapis/google-cloud-go/compare/auth/v0.9.0...auth/v0.9.1) (2024-08-22) + + +### Bug Fixes + +* **auth:** Setting expireEarly to default when the value is 0 ([#10732](https://github.com/googleapis/google-cloud-go/issues/10732)) ([5e67869](https://github.com/googleapis/google-cloud-go/commit/5e67869a31e9e8ecb4eeebd2cfa11a761c3b1948)) + ## [0.9.0](https://github.com/googleapis/google-cloud-go/compare/auth/v0.8.1...auth/v0.9.0) (2024-08-16) diff --git a/auth/auth.go b/auth/auth.go index 41e03f293546..2eb78d7b076a 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -258,7 +258,7 @@ func (ctpo *CachedTokenProviderOptions) autoRefresh() bool { } func (ctpo *CachedTokenProviderOptions) expireEarly() time.Duration { - if ctpo == nil { + if ctpo == nil || ctpo.ExpireEarly == 0 { return defaultExpiryDelta } return ctpo.ExpireEarly diff --git a/bigtable/bigtable.go b/bigtable/bigtable.go index 323ff9b13409..e12035db4e7a 100644 --- a/bigtable/bigtable.go +++ b/bigtable/bigtable.go @@ -108,9 +108,7 @@ func NewClientWithConfig(ctx context.Context, project, instance string, config C ) // Allow non-default service account in DirectPath. - o = append(o, - internaloption.AllowNonDefaultServiceAccount(true), - internaloption.EnableNewAuthLibrary()) + o = append(o, internaloption.AllowNonDefaultServiceAccount(true)) o = append(o, opts...) connPool, err := gtransport.DialPool(ctx, o...) if err != nil { diff --git a/bigtable/conformance_test.sh b/bigtable/conformance_test.sh index 35a126e2f94d..bf6f520a6b0c 100755 --- a/bigtable/conformance_test.sh +++ b/bigtable/conformance_test.sh @@ -50,10 +50,10 @@ trap cleanup EXIT # Run the conformance tests cd $conformanceTestsHome -# Tests in https://github.com/googleapis/cloud-bigtable-clients-test/tree/main/tests can only be run on go1.20.2 -go install golang.org/dl/go1.20.2@latest -go1.20.2 download -go1.20.2 test -v -proxy_addr=:$testProxyPort | tee -a $sponge_log +# Tests in https://github.com/googleapis/cloud-bigtable-clients-test/tree/main/tests can only be run on go1.22.5 +go install golang.org/dl/go1.22.5@latest +go1.22.5 download +go1.22.5 test -v -proxy_addr=:$testProxyPort | tee -a $sponge_log RETURN_CODE=$? echo "exiting with ${RETURN_CODE}" diff --git a/bigtable/type.go b/bigtable/type.go index 59f954f081f7..b99bb4a251d1 100644 --- a/bigtable/type.go +++ b/bigtable/type.go @@ -16,7 +16,11 @@ limitations under the License. package bigtable -import btapb "cloud.google.com/go/bigtable/admin/apiv2/adminpb" +import ( + btapb "cloud.google.com/go/bigtable/admin/apiv2/adminpb" + "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/proto" +) // Type wraps the protobuf representation of a type. See the protobuf definition // for more details on types. @@ -24,6 +28,28 @@ type Type interface { proto() *btapb.Type } +var marshalOptions = protojson.MarshalOptions{AllowPartial: true, UseEnumNumbers: true} +var unmarshalOptions = protojson.UnmarshalOptions{AllowPartial: true} + +// MarshalJSON returns the string representation of the Type protobuf. +func MarshalJSON(t Type) ([]byte, error) { + return marshalOptions.Marshal(t.proto()) +} + +// UnmarshalJSON returns a Type object from json bytes. +func UnmarshalJSON(data []byte) (Type, error) { + result := &btapb.Type{} + if err := unmarshalOptions.Unmarshal(data, result); err != nil { + return nil, err + } + return ProtoToType(result), nil +} + +// Equal compares Type objects. +func Equal(a, b Type) bool { + return proto.Equal(a.proto(), b.proto()) +} + type unknown[T interface{}] struct { wrapped *T } @@ -205,6 +231,8 @@ func ProtoToType(pb *btapb.Type) Type { return int64ProtoToType(t.Int64Type) case *btapb.Type_BytesType: return bytesProtoToType(t.BytesType) + case *btapb.Type_StringType: + return stringProtoToType(t.StringType) case *btapb.Type_AggregateType: return aggregateProtoToType(t.AggregateType) default: @@ -229,6 +257,23 @@ func bytesProtoToType(b *btapb.Type_Bytes) BytesType { return BytesType{Encoding: bytesEncodingProtoToType(b.Encoding)} } +func stringEncodingProtoToType(se *btapb.Type_String_Encoding) StringEncoding { + if se == nil { + return unknown[btapb.Type_String_Encoding]{wrapped: se} + } + + switch se.Encoding.(type) { + case *btapb.Type_String_Encoding_Utf8Raw_: + return StringUtf8Encoding{} + default: + return unknown[btapb.Type_String_Encoding]{wrapped: se} + } +} + +func stringProtoToType(s *btapb.Type_String) Type { + return StringType{Encoding: stringEncodingProtoToType(s.Encoding)} +} + func int64EncodingProtoToEncoding(ie *btapb.Type_Int64_Encoding) Int64Encoding { if ie == nil { return unknown[btapb.Type_Int64_Encoding]{wrapped: ie} @@ -246,7 +291,7 @@ func int64ProtoToType(i *btapb.Type_Int64) Type { return Int64Type{Encoding: int64EncodingProtoToEncoding(i.Encoding)} } -func aggregateProtoToType(agg *btapb.Type_Aggregate) Type { +func aggregateProtoToType(agg *btapb.Type_Aggregate) AggregateType { if agg == nil { return AggregateType{Input: nil, Aggregator: unknownAggregator{wrapped: agg}} } diff --git a/bigtable/type_test.go b/bigtable/type_test.go index e80525a698c4..9961ee410c7f 100644 --- a/bigtable/type_test.go +++ b/bigtable/type_test.go @@ -20,6 +20,8 @@ import ( "testing" btapb "cloud.google.com/go/bigtable/admin/apiv2/adminpb" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" "google.golang.org/protobuf/proto" ) @@ -37,12 +39,25 @@ func aggregateProto() *btapb.Type { } } +func TestUnknown(t *testing.T) { + unsupportedType := &btapb.Type{ + Kind: &btapb.Type_Float64Type{ + Float64Type: &btapb.Type_Float64{}, + }, + } + got, ok := ProtoToType(unsupportedType).(unknown[btapb.Type]) + if !ok { + t.Errorf("got: %T, wanted unknown[btapb.Type]", got) + } + + assertType(t, got, unsupportedType) +} + func TestInt64Proto(t *testing.T) { want := aggregateProto() - got := Int64Type{}.proto() - if !proto.Equal(got, want) { - t.Errorf("got type %v, want: %v", got, want) - } + it := Int64Type{Encoding: BigEndianBytesEncoding{}} + + assertType(t, it, want) } func TestStringProto(t *testing.T) { @@ -55,39 +70,9 @@ func TestStringProto(t *testing.T) { }, }, } + st := StringType{Encoding: StringUtf8Encoding{}} - got := StringType{}.proto() - if !proto.Equal(got, want) { - t.Errorf("got type %v, want: %v", got, want) - } -} - -func TestSumAggregateProto(t *testing.T) { - want := &btapb.Type{ - Kind: &btapb.Type_AggregateType{ - AggregateType: &btapb.Type_Aggregate{ - InputType: &btapb.Type{ - Kind: &btapb.Type_Int64Type{ - Int64Type: &btapb.Type_Int64{ - Encoding: &btapb.Type_Int64_Encoding{ - Encoding: &btapb.Type_Int64_Encoding_BigEndianBytes_{ - BigEndianBytes: &btapb.Type_Int64_Encoding_BigEndianBytes{}, - }, - }, - }, - }, - }, - Aggregator: &btapb.Type_Aggregate_Sum_{ - Sum: &btapb.Type_Aggregate_Sum{}, - }, - }, - }, - } - - got := AggregateType{Input: Int64Type{}, Aggregator: SumAggregator{}}.proto() - if !proto.Equal(got, want) { - t.Errorf("got type %v, want: %v", got, want) - } + assertType(t, st, want) } func TestProtoBijection(t *testing.T) { @@ -98,88 +83,100 @@ func TestProtoBijection(t *testing.T) { } } -func TestMinAggregateProto(t *testing.T) { - want := &btapb.Type{ - Kind: &btapb.Type_AggregateType{ - AggregateType: &btapb.Type_Aggregate{ - InputType: &btapb.Type{ - Kind: &btapb.Type_Int64Type{ - Int64Type: &btapb.Type_Int64{ - Encoding: &btapb.Type_Int64_Encoding{ - Encoding: &btapb.Type_Int64_Encoding_BigEndianBytes_{ - BigEndianBytes: &btapb.Type_Int64_Encoding_BigEndianBytes{}, - }, - }, - }, +func TestAggregateProto(t *testing.T) { + intType := &btapb.Type{ + Kind: &btapb.Type_Int64Type{ + Int64Type: &btapb.Type_Int64{ + Encoding: &btapb.Type_Int64_Encoding{ + Encoding: &btapb.Type_Int64_Encoding_BigEndianBytes_{ + BigEndianBytes: &btapb.Type_Int64_Encoding_BigEndianBytes{}, }, }, - Aggregator: &btapb.Type_Aggregate_Min_{ - Min: &btapb.Type_Aggregate_Min{}, - }, }, }, } - got := AggregateType{Input: Int64Type{}, Aggregator: MinAggregator{}}.proto() - if !proto.Equal(got, want) { - t.Errorf("got type %v, want: %v", got, want) - } -} - -func TestMaxAggregateProto(t *testing.T) { - want := &btapb.Type{ - Kind: &btapb.Type_AggregateType{ - AggregateType: &btapb.Type_Aggregate{ - InputType: &btapb.Type{ - Kind: &btapb.Type_Int64Type{ - Int64Type: &btapb.Type_Int64{ - Encoding: &btapb.Type_Int64_Encoding{ - Encoding: &btapb.Type_Int64_Encoding_BigEndianBytes_{ - BigEndianBytes: &btapb.Type_Int64_Encoding_BigEndianBytes{}, - }, - }, - }, - }, + testCases := []struct { + name string + agg Aggregator + protoAgg btapb.Type_Aggregate + }{ + { + name: "hll", + agg: HllppUniqueCountAggregator{}, + protoAgg: btapb.Type_Aggregate{ + InputType: intType, + Aggregator: &btapb.Type_Aggregate_HllppUniqueCount{ + HllppUniqueCount: &btapb.Type_Aggregate_HyperLogLogPlusPlusUniqueCount{}, + }, + }, + }, + { + name: "min", + agg: MinAggregator{}, + protoAgg: btapb.Type_Aggregate{ + InputType: intType, + Aggregator: &btapb.Type_Aggregate_Min_{ + Min: &btapb.Type_Aggregate_Min{}, }, + }, + }, + { + name: "max", + agg: MaxAggregator{}, + protoAgg: btapb.Type_Aggregate{ + InputType: intType, Aggregator: &btapb.Type_Aggregate_Max_{ Max: &btapb.Type_Aggregate_Max{}, }, }, }, - } + { + name: "sum", + agg: SumAggregator{}, + protoAgg: btapb.Type_Aggregate{ + InputType: intType, + Aggregator: &btapb.Type_Aggregate_Sum_{ + Sum: &btapb.Type_Aggregate_Sum{}, + }, + }, + }} + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + want := &btapb.Type{ + Kind: &btapb.Type_AggregateType{ + AggregateType: &tc.protoAgg, + }, + } + at := AggregateType{Input: Int64Type{Encoding: BigEndianBytesEncoding{}}, Aggregator: tc.agg} - got := AggregateType{Input: Int64Type{}, Aggregator: MaxAggregator{}}.proto() - if !proto.Equal(got, want) { - t.Errorf("got type %v, want: %v", got, want) + assertType(t, at, want) + }) } } -func TestHllAggregateProto(t *testing.T) { - want := &btapb.Type{ - Kind: &btapb.Type_AggregateType{ - AggregateType: &btapb.Type_Aggregate{ - InputType: &btapb.Type{ - Kind: &btapb.Type_Int64Type{ - Int64Type: &btapb.Type_Int64{ - Encoding: &btapb.Type_Int64_Encoding{ - Encoding: &btapb.Type_Int64_Encoding_BigEndianBytes_{ - BigEndianBytes: &btapb.Type_Int64_Encoding_BigEndianBytes{}, - }, - }, - }, - }, - }, - Aggregator: &btapb.Type_Aggregate_HllppUniqueCount{ - HllppUniqueCount: &btapb.Type_Aggregate_HyperLogLogPlusPlusUniqueCount{}, - }, - }, - }, - } +func assertType(t *testing.T, ty Type, want *btapb.Type) { + t.Helper() - got := AggregateType{Input: Int64Type{}, Aggregator: HllppUniqueCountAggregator{}}.proto() + got := ty.proto() if !proto.Equal(got, want) { t.Errorf("got type %v, want: %v", got, want) } + + gotJSON, err := MarshalJSON(ty) + if err != nil { + t.Fatalf("Error calling MarshalJSON: %v", err) + } + result, err := UnmarshalJSON(gotJSON) + if err != nil { + t.Fatalf("Error calling UnmarshalJSON: %v", err) + } + if diff := cmp.Diff(result, ty, cmpopts.IgnoreUnexported(unknown[btapb.Type]{})); diff != "" { + t.Errorf("Unexpected diff: \n%s", diff) + } + if !Equal(result, ty) { + t.Errorf("Unexpected result. Got %#v, want %#v", result, ty) + } } func TestNilChecks(t *testing.T) { @@ -208,10 +205,7 @@ func TestNilChecks(t *testing.T) { } // aggregateProtoToType - aggType1, ok := aggregateProtoToType(nil).(AggregateType) - if !ok { - t.Fatalf("got: %T, wanted AggregateType", aggType1) - } + aggType1 := aggregateProtoToType(nil) if val, ok := aggType1.Aggregator.(unknownAggregator); !ok { t.Errorf("got: %T, wanted unknownAggregator", val) } @@ -219,10 +213,7 @@ func TestNilChecks(t *testing.T) { t.Errorf("got: %v, wanted nil", aggType1.Input) } - aggType2, ok := aggregateProtoToType(&btapb.Type_Aggregate{}).(AggregateType) - if !ok { - t.Fatalf("got: %T, wanted AggregateType", aggType2) - } + aggType2 := aggregateProtoToType(&btapb.Type_Aggregate{}) if val, ok := aggType2.Aggregator.(unknownAggregator); !ok { t.Errorf("got: %T, wanted unknownAggregator", val) } diff --git a/datastore/client.go b/datastore/client.go index 4f376efc3aac..9f2744639a37 100644 --- a/datastore/client.go +++ b/datastore/client.go @@ -20,12 +20,12 @@ import ( "net/url" "time" + pb "cloud.google.com/go/datastore/apiv1/datastorepb" "cloud.google.com/go/datastore/internal" cloudinternal "cloud.google.com/go/internal" "cloud.google.com/go/internal/trace" "cloud.google.com/go/internal/version" gax "github.com/googleapis/gax-go/v2" - pb "google.golang.org/genproto/googleapis/datastore/v1" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" diff --git a/datastore/datastore.go b/datastore/datastore.go index 2ec5cb52917c..44a44f67b840 100644 --- a/datastore/datastore.go +++ b/datastore/datastore.go @@ -23,12 +23,12 @@ import ( "reflect" "time" + pb "cloud.google.com/go/datastore/apiv1/datastorepb" "cloud.google.com/go/internal/trace" "google.golang.org/api/option" "google.golang.org/api/option/internaloption" "google.golang.org/api/transport" gtransport "google.golang.org/api/transport/grpc" - pb "google.golang.org/genproto/googleapis/datastore/v1" "google.golang.org/grpc" "google.golang.org/protobuf/types/known/timestamppb" ) diff --git a/datastore/datastore_test.go b/datastore/datastore_test.go index e483936c3e58..18413459e388 100644 --- a/datastore/datastore_test.go +++ b/datastore/datastore_test.go @@ -22,11 +22,11 @@ import ( "testing" "time" + pb "cloud.google.com/go/datastore/apiv1/datastorepb" "cloud.google.com/go/internal/testutil" "github.com/google/go-cmp/cmp" "google.golang.org/api/option" "google.golang.org/api/transport/grpc" - pb "google.golang.org/genproto/googleapis/datastore/v1" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/timestamppb" ) diff --git a/datastore/integration_test.go b/datastore/integration_test.go index 1d7cf3f99180..d2ebd836b82e 100644 --- a/datastore/integration_test.go +++ b/datastore/integration_test.go @@ -29,13 +29,13 @@ import ( "testing" "time" + pb "cloud.google.com/go/datastore/apiv1/datastorepb" "cloud.google.com/go/internal/testutil" "cloud.google.com/go/internal/uid" "cloud.google.com/go/rpcreplay" "github.com/google/go-cmp/cmp/cmpopts" "google.golang.org/api/iterator" "google.golang.org/api/option" - pb "google.golang.org/genproto/googleapis/datastore/v1" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" diff --git a/datastore/key.go b/datastore/key.go index 7f97d2bff2ce..2f99836ddfdf 100644 --- a/datastore/key.go +++ b/datastore/key.go @@ -23,7 +23,7 @@ import ( "strconv" "strings" - pb "google.golang.org/genproto/googleapis/datastore/v1" + pb "cloud.google.com/go/datastore/apiv1/datastorepb" "google.golang.org/protobuf/proto" ) diff --git a/datastore/load.go b/datastore/load.go index 68c57c96d694..9eb3c18d7cac 100644 --- a/datastore/load.go +++ b/datastore/load.go @@ -21,8 +21,8 @@ import ( "time" "cloud.google.com/go/civil" + pb "cloud.google.com/go/datastore/apiv1/datastorepb" "cloud.google.com/go/internal/fields" - pb "google.golang.org/genproto/googleapis/datastore/v1" ) var ( diff --git a/datastore/load_test.go b/datastore/load_test.go index bb4c3f1667b2..381871edb6cb 100644 --- a/datastore/load_test.go +++ b/datastore/load_test.go @@ -22,9 +22,9 @@ import ( "time" "cloud.google.com/go/civil" + pb "cloud.google.com/go/datastore/apiv1/datastorepb" "cloud.google.com/go/internal/testutil" "github.com/google/go-cmp/cmp/cmpopts" - pb "google.golang.org/genproto/googleapis/datastore/v1" "google.golang.org/protobuf/types/known/timestamppb" ) diff --git a/datastore/mock_test.go b/datastore/mock_test.go index adc57485a7fe..3c0fad973aae 100644 --- a/datastore/mock_test.go +++ b/datastore/mock_test.go @@ -27,9 +27,9 @@ import ( "reflect" "testing" + pb "cloud.google.com/go/datastore/apiv1/datastorepb" "cloud.google.com/go/internal/testutil" "google.golang.org/api/option" - pb "google.golang.org/genproto/googleapis/datastore/v1" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" "google.golang.org/protobuf/encoding/prototext" diff --git a/datastore/mutation.go b/datastore/mutation.go index 09ab9326d7af..0adad9097570 100644 --- a/datastore/mutation.go +++ b/datastore/mutation.go @@ -17,7 +17,7 @@ package datastore import ( "fmt" - pb "google.golang.org/genproto/googleapis/datastore/v1" + pb "cloud.google.com/go/datastore/apiv1/datastorepb" ) // A Mutation represents a change to a Datastore entity. diff --git a/datastore/mutation_test.go b/datastore/mutation_test.go index 0d21141ad91e..81e1817cfc07 100644 --- a/datastore/mutation_test.go +++ b/datastore/mutation_test.go @@ -17,8 +17,8 @@ package datastore import ( "testing" + pb "cloud.google.com/go/datastore/apiv1/datastorepb" "cloud.google.com/go/internal/testutil" - pb "google.golang.org/genproto/googleapis/datastore/v1" ) func TestMutationProtos(t *testing.T) { diff --git a/datastore/prop_test.go b/datastore/prop_test.go index ceadd82aabfc..790db3d65de2 100644 --- a/datastore/prop_test.go +++ b/datastore/prop_test.go @@ -19,8 +19,8 @@ import ( "strings" "testing" + pb "cloud.google.com/go/datastore/apiv1/datastorepb" "cloud.google.com/go/internal/testutil" - pb "google.golang.org/genproto/googleapis/datastore/v1" ) func TestLoadSavePLS(t *testing.T) { diff --git a/datastore/query.go b/datastore/query.go index 8815bffe999b..f5dd66d81236 100644 --- a/datastore/query.go +++ b/datastore/query.go @@ -25,10 +25,10 @@ import ( "strings" "time" + pb "cloud.google.com/go/datastore/apiv1/datastorepb" "cloud.google.com/go/internal/protostruct" "cloud.google.com/go/internal/trace" "google.golang.org/api/iterator" - pb "google.golang.org/genproto/googleapis/datastore/v1" wrapperspb "google.golang.org/protobuf/types/known/wrapperspb" ) diff --git a/datastore/query_test.go b/datastore/query_test.go index 9c6d60ab1162..414583c7acdc 100644 --- a/datastore/query_test.go +++ b/datastore/query_test.go @@ -23,9 +23,9 @@ import ( "strings" "testing" + pb "cloud.google.com/go/datastore/apiv1/datastorepb" "cloud.google.com/go/internal/testutil" "github.com/google/go-cmp/cmp" - pb "google.golang.org/genproto/googleapis/datastore/v1" "google.golang.org/grpc" "google.golang.org/protobuf/proto" ) diff --git a/datastore/save.go b/datastore/save.go index de9b836d5ac7..9642c84da5ea 100644 --- a/datastore/save.go +++ b/datastore/save.go @@ -22,7 +22,7 @@ import ( "unicode/utf8" "cloud.google.com/go/civil" - pb "google.golang.org/genproto/googleapis/datastore/v1" + pb "cloud.google.com/go/datastore/apiv1/datastorepb" llpb "google.golang.org/genproto/googleapis/type/latlng" timepb "google.golang.org/protobuf/types/known/timestamppb" ) diff --git a/datastore/save_test.go b/datastore/save_test.go index 5071e6054d5e..5dbf332006cc 100644 --- a/datastore/save_test.go +++ b/datastore/save_test.go @@ -20,8 +20,8 @@ import ( "time" "cloud.google.com/go/civil" + pb "cloud.google.com/go/datastore/apiv1/datastorepb" "cloud.google.com/go/internal/testutil" - pb "google.golang.org/genproto/googleapis/datastore/v1" ) func TestInterfaceToProtoNil(t *testing.T) { diff --git a/datastore/transaction.go b/datastore/transaction.go index 8f533dec32f3..309ade4fe9dd 100644 --- a/datastore/transaction.go +++ b/datastore/transaction.go @@ -20,9 +20,9 @@ import ( "sync" "time" + pb "cloud.google.com/go/datastore/apiv1/datastorepb" "cloud.google.com/go/internal/trace" gax "github.com/googleapis/gax-go/v2" - pb "google.golang.org/genproto/googleapis/datastore/v1" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/timestamppb" diff --git a/datastore/transaction_test.go b/datastore/transaction_test.go index f5df8930d49d..63544d19ebb0 100644 --- a/datastore/transaction_test.go +++ b/datastore/transaction_test.go @@ -21,8 +21,8 @@ import ( "testing" "time" + pb "cloud.google.com/go/datastore/apiv1/datastorepb" gax "github.com/googleapis/gax-go/v2" - pb "google.golang.org/genproto/googleapis/datastore/v1" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "google.golang.org/protobuf/proto" diff --git a/datastore/util_test.go b/datastore/util_test.go index 64467b63f17f..76d143b2be54 100644 --- a/datastore/util_test.go +++ b/datastore/util_test.go @@ -24,9 +24,9 @@ import ( "testing" "time" + pb "cloud.google.com/go/datastore/apiv1/datastorepb" "cloud.google.com/go/internal/testutil" "github.com/google/go-cmp/cmp" - pb "google.golang.org/genproto/googleapis/datastore/v1" "google.golang.org/grpc" ) diff --git a/internal/aliasfix/mappings.go b/internal/aliasfix/mappings.go index dbbe5436e7d3..12cadfab175e 100644 --- a/internal/aliasfix/mappings.go +++ b/internal/aliasfix/mappings.go @@ -710,6 +710,10 @@ var GenprotoPkgMigration map[string]Pkg = map[string]Pkg{ ImportPath: "cloud.google.com/go/datastore/admin/apiv1/adminpb", Status: StatusMigrated, }, + "google.golang.org/genproto/googleapis/datastore/v1": { + ImportPath: "cloud.google.com/go/datastore/apiv1/datastorepb", + Status: StatusMigrated, + }, "google.golang.org/genproto/googleapis/devtools/artifactregistry/v1": { ImportPath: "cloud.google.com/go/artifactregistry/apiv1/artifactregistrypb", Status: StatusMigrated, diff --git a/storage/bucket.go b/storage/bucket.go index d582a60d0e83..3eded017831e 100644 --- a/storage/bucket.go +++ b/storage/bucket.go @@ -416,6 +416,10 @@ type BucketAttrs struct { // This field is read-only. Created time.Time + // Updated is the time at which the bucket was last modified. + // This field is read-only. + Updated time.Time + // VersioningEnabled reports whether this bucket has versioning enabled. VersioningEnabled bool @@ -824,6 +828,7 @@ func newBucket(b *raw.Bucket) (*BucketAttrs, error) { DefaultEventBasedHold: b.DefaultEventBasedHold, StorageClass: b.StorageClass, Created: convertTime(b.TimeCreated), + Updated: convertTime(b.Updated), VersioningEnabled: b.Versioning != nil && b.Versioning.Enabled, ACL: toBucketACLRules(b.Acl), DefaultObjectACL: toObjectACLRules(b.DefaultObjectAcl), @@ -861,6 +866,7 @@ func newBucketFromProto(b *storagepb.Bucket) *BucketAttrs { DefaultEventBasedHold: b.GetDefaultEventBasedHold(), StorageClass: b.GetStorageClass(), Created: b.GetCreateTime().AsTime(), + Updated: b.GetUpdateTime().AsTime(), VersioningEnabled: b.GetVersioning().GetEnabled(), ACL: toBucketACLRulesFromProto(b.GetAcl()), DefaultObjectACL: toObjectACLRulesFromProto(b.GetDefaultObjectAcl()), diff --git a/storage/bucket_test.go b/storage/bucket_test.go index fd03c1e41907..7126be8331a6 100644 --- a/storage/bucket_test.go +++ b/storage/bucket_test.go @@ -605,6 +605,7 @@ func TestNewBucket(t *testing.T) { Metageneration: 3, StorageClass: "sc", TimeCreated: "2017-10-23T04:05:06Z", + Updated: "2024-08-21T17:24:53Z", Versioning: &raw.BucketVersioning{Enabled: true}, Labels: labels, Billing: &raw.BucketBilling{RequesterPays: true}, @@ -676,6 +677,7 @@ func TestNewBucket(t *testing.T) { MetaGeneration: 3, StorageClass: "sc", Created: time.Date(2017, 10, 23, 4, 5, 6, 0, time.UTC), + Updated: time.Date(2024, 8, 21, 17, 24, 53, 0, time.UTC), VersioningEnabled: true, Labels: labels, Etag: "Zkyw9ACJZUvcYmlFaKGChzhmtnE/dt1zHSfweiWpwzdGsqXwuJZqiD0", @@ -767,6 +769,7 @@ func TestNewBucketFromProto(t *testing.T) { Rpo: rpoAsyncTurbo, Metageneration: int64(39), CreateTime: toProtoTimestamp(time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC)), + UpdateTime: toProtoTimestamp(time.Date(2024, 1, 2, 3, 4, 5, 6, time.UTC)), Labels: map[string]string{"label": "value"}, Cors: []*storagepb.Bucket_Cors{ { @@ -820,6 +823,7 @@ func TestNewBucketFromProto(t *testing.T) { RPO: RPOAsyncTurbo, MetaGeneration: 39, Created: time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC), + Updated: time.Date(2024, 1, 2, 3, 4, 5, 6, time.UTC), Labels: map[string]string{"label": "value"}, CORS: []CORS{ { diff --git a/storage/integration_test.go b/storage/integration_test.go index 941a534aa30e..aa3dd8b73088 100644 --- a/storage/integration_test.go +++ b/storage/integration_test.go @@ -589,6 +589,9 @@ func TestIntegration_BucketUpdate(t *testing.T) { if !testutil.Equal(attrs.Labels, wantLabels) { t.Fatalf("add labels: got %v, want %v", attrs.Labels, wantLabels) } + if !attrs.Created.Before(attrs.Updated) { + t.Errorf("got attrs.Updated %v before attrs.Created %v, want Attrs.Updated to be after", attrs.Updated, attrs.Created) + } // Turn off versioning again; add and remove some more labels. ua = BucketAttrsToUpdate{VersioningEnabled: false} @@ -607,6 +610,9 @@ func TestIntegration_BucketUpdate(t *testing.T) { if !testutil.Equal(attrs.Labels, wantLabels) { t.Fatalf("got %v, want %v", attrs.Labels, wantLabels) } + if !attrs.Created.Before(attrs.Updated) { + t.Errorf("got attrs.Updated %v before attrs.Created %v, want Attrs.Updated to be after", attrs.Updated, attrs.Created) + } // Configure a lifecycle wantLifecycle := Lifecycle{ @@ -626,6 +632,9 @@ func TestIntegration_BucketUpdate(t *testing.T) { if !testutil.Equal(attrs.Lifecycle, wantLifecycle) { t.Fatalf("got %v, want %v", attrs.Lifecycle, wantLifecycle) } + if !attrs.Created.Before(attrs.Updated) { + t.Errorf("got attrs.Updated %v before attrs.Created %v, want Attrs.Updated to be after", attrs.Updated, attrs.Created) + } // Check that StorageClass has "STANDARD" value for unset field by default // before passing new value. wantStorageClass := "STANDARD" @@ -638,6 +647,9 @@ func TestIntegration_BucketUpdate(t *testing.T) { if !testutil.Equal(attrs.StorageClass, wantStorageClass) { t.Fatalf("got %v, want %v", attrs.StorageClass, wantStorageClass) } + if !attrs.Created.Before(attrs.Updated) { + t.Errorf("got attrs.Updated %v before attrs.Created %v, want Attrs.Updated to be after", attrs.Updated, attrs.Created) + } // Empty update should succeed without changing the bucket. gotAttrs, err := b.Update(ctx, BucketAttrsToUpdate{}) @@ -647,6 +659,9 @@ func TestIntegration_BucketUpdate(t *testing.T) { if !testutil.Equal(attrs, gotAttrs) { t.Fatalf("empty update: got %v, want %v", gotAttrs, attrs) } + if !attrs.Created.Before(attrs.Updated) { + t.Errorf("got attrs.Updated %v before attrs.Created %v, want Attrs.Updated to be after", attrs.Updated, attrs.Created) + } }) } diff --git a/vertexai/CHANGES.md b/vertexai/CHANGES.md index a3ea6d6aeeea..e90afa1f724f 100644 --- a/vertexai/CHANGES.md +++ b/vertexai/CHANGES.md @@ -1,5 +1,21 @@ # Changelog +## [0.13.0](https://github.com/googleapis/google-cloud-go/compare/vertexai/v0.12.0...vertexai/v0.13.0) (2024-08-22) + + +### Features + +* **vertexai/genai:** Add WithClientInfo option ([#10535](https://github.com/googleapis/google-cloud-go/issues/10535)) ([265963b](https://github.com/googleapis/google-cloud-go/commit/265963bd5b91c257b3c3d3c1f52cdf2b5f4c9d1a)) +* **vertexai:** Update tokenizer documentation and pull new code ([#10718](https://github.com/googleapis/google-cloud-go/issues/10718)) ([0ee1430](https://github.com/googleapis/google-cloud-go/commit/0ee1430154f4d51d84b5d5927b1b477f6beb0fc1)) + + +### Bug Fixes + +* **vertexai:** Bump google.golang.org/api@v0.187.0 ([8fa9e39](https://github.com/googleapis/google-cloud-go/commit/8fa9e398e512fd8533fd49060371e61b5725a85b)) +* **vertexai:** Bump google.golang.org/grpc@v1.64.1 ([8ecc4e9](https://github.com/googleapis/google-cloud-go/commit/8ecc4e9622e5bbe9b90384d5848ab816027226c5)) +* **vertexai:** Update dependencies ([257c40b](https://github.com/googleapis/google-cloud-go/commit/257c40bd6d7e59730017cf32bda8823d7a232758)) +* **vertexai:** Update google.golang.org/api to v0.191.0 ([5b32644](https://github.com/googleapis/google-cloud-go/commit/5b32644eb82eb6bd6021f80b4fad471c60fb9d73)) + ## [0.12.0](https://github.com/googleapis/google-cloud-go/compare/vertexai/v0.11.0...vertexai/v0.12.0) (2024-06-12) diff --git a/vertexai/genai/tokenizer/tokenizer.go b/vertexai/genai/tokenizer/tokenizer.go index b51a03db0267..d17495677816 100644 --- a/vertexai/genai/tokenizer/tokenizer.go +++ b/vertexai/genai/tokenizer/tokenizer.go @@ -14,7 +14,7 @@ // Package tokenizer provides local token counting for Gemini models. This // tokenizer downloads its model from the web, but otherwise doesn't require -// an API call for every CountTokens invocation. +// an API call for every [CountTokens] invocation. package tokenizer import ( @@ -43,7 +43,7 @@ var supportedModels = map[string]bool{ // Tokenizer is a local tokenizer for text. type Tokenizer struct { - encoder *sentencepiece.Encoder + processor *sentencepiece.Processor } // CountTokensResponse is the response of [Tokenizer.CountTokens]. @@ -63,12 +63,12 @@ func New(modelName string) (*Tokenizer, error) { return nil, fmt.Errorf("loading model: %w", err) } - encoder, err := sentencepiece.NewEncoder(bytes.NewReader(data)) + processor, err := sentencepiece.NewProcessor(bytes.NewReader(data)) if err != nil { - return nil, fmt.Errorf("creating encoder: %w", err) + return nil, fmt.Errorf("creating processor: %w", err) } - return &Tokenizer{encoder: encoder}, nil + return &Tokenizer{processor: processor}, nil } // CountTokens counts the tokens in all the given parts and returns their @@ -79,7 +79,7 @@ func (tok *Tokenizer) CountTokens(parts ...genai.Part) (*CountTokensResponse, er for _, part := range parts { if t, ok := part.(genai.Text); ok { - toks := tok.encoder.Encode(string(t)) + toks := tok.processor.Encode(string(t)) sum += len(toks) } else { return nil, fmt.Errorf("Tokenizer.CountTokens only supports Text parts") diff --git a/vertexai/internal/import-go-sentencepiece.sh b/vertexai/internal/import-go-sentencepiece.sh index 50ba93f6541c..4a7496d32c3d 100755 --- a/vertexai/internal/import-go-sentencepiece.sh +++ b/vertexai/internal/import-go-sentencepiece.sh @@ -29,13 +29,15 @@ TEMP_DIR=$(mktemp -d) # Clone the repository with --depth 1 to get only the latest files git clone --depth 1 https://github.com/eliben/go-sentencepiece.git "$TEMP_DIR/go-sentencepiece" -# Copy the repository contents to here, excluding the .git directory rm -rf sentencepiece mkdir -p sentencepiece rsync -av \ --exclude='.git' \ + --exclude='.github' \ --exclude='go.mod' \ --exclude='go.sum' \ + --exclude='wasm' \ + --exclude='doc' \ --exclude='test' \ --exclude='*_test.go' \ "$TEMP_DIR/go-sentencepiece/" sentencepiece diff --git a/vertexai/internal/sentencepiece/README.md b/vertexai/internal/sentencepiece/README.md index 724221784780..8348eaefe18c 100644 --- a/vertexai/internal/sentencepiece/README.md +++ b/vertexai/internal/sentencepiece/README.md @@ -1,12 +1,19 @@ # go-sentencepiece +

+ Logo +

+ +---- + [![Go Reference](https://pkg.go.dev/badge/github.com/eliben/go-sentencepiece.svg)](https://pkg.go.dev/github.com/eliben/go-sentencepiece) -This is a pure Go implementation of encoding text with +This is a pure Go implementation of encoding and decoding text with the [SentencePiece tokenizer](https://github.com/google/sentencepiece). "Encoding" is the operation used to split text into tokens, using -a trained tokenizer model. +a trained tokenizer model. "Decoding" is the reverse process - converting +a list of tokens into the original text. SentencePiece is a general family of tokenizers that is configured by a protobuf configuration file. This repository currently focuses @@ -35,7 +42,7 @@ other configuration information. It is not part of this repository. Please fetch it from the [official Gemma implementation repository](https://github.com/google/gemma_pytorch/tree/main/tokenizer). -`NewEncoder*` constructors will expect to read this file. +`NewProcessor*` constructors will expect to read this file. ## Developing @@ -54,3 +61,11 @@ The configuration protobuf itself is obtained as described in the [Tokenizer configuration](#tokenizer-configuration) section. All tests require the `MODELPATH` env var to point to a local copy of the tokenizer configuration file. + +## Online demo + +To see an in-browser demo of this tokenizer in action, visit +https://eliben.github.io/go-sentencepiece/ + +The Go code is compiled to WebAssembly and loaded from a small +JS program to allow interactive encoding of text. diff --git a/vertexai/internal/sentencepiece/internal/cmd/dumper/main.go b/vertexai/internal/sentencepiece/internal/cmd/dumper/main.go index 74bafa991fdf..a5689e9cd110 100644 --- a/vertexai/internal/sentencepiece/internal/cmd/dumper/main.go +++ b/vertexai/internal/sentencepiece/internal/cmd/dumper/main.go @@ -34,6 +34,7 @@ import ( func main() { fDumpAll := flag.Bool("dumpall", false, "dump entire model proto") fFindUni := flag.Bool("finduni", false, "find unicode runes not in pieces") + fFindBytes := flag.Bool("findbytes", false, "show all byte pieces with their IDs") fEncodeFile := flag.String("encodefile", "", "file name to open and encode") flag.Parse() @@ -47,17 +48,24 @@ func main() { log.Fatal(err) } - var model model.ModelProto - err = proto.Unmarshal(b, &model) + var protomodel model.ModelProto + err = proto.Unmarshal(b, &protomodel) if err != nil { log.Fatal(err) } if *fDumpAll { - fmt.Println(prototext.Format(&model)) + fmt.Println(prototext.Format(&protomodel)) + } else if *fFindBytes { + for i, piece := range protomodel.GetPieces() { + if piece.GetType() == model.ModelProto_SentencePiece_BYTE { + fmt.Printf("%5d: %s\n", i, piece.GetPiece()) + } + } + } else if *fFindUni { pieces := make(map[string]int) - for i, piece := range model.GetPieces() { + for i, piece := range protomodel.GetPieces() { pieces[piece.GetPiece()] = i } @@ -69,7 +77,7 @@ func main() { } } } else if *fEncodeFile != "" { - enc, err := sentencepiece.NewEncoderFromPath(modelPath) + proc, err := sentencepiece.NewProcessorFromPath(modelPath) if err != nil { log.Fatal(err) } @@ -79,7 +87,7 @@ func main() { log.Fatal(err) } - tokens := enc.Encode(string(b)) + tokens := proc.Encode(string(b)) for _, t := range tokens { fmt.Println(t.ID) } diff --git a/vertexai/internal/sentencepiece/normalize.go b/vertexai/internal/sentencepiece/normalize.go index 6fb4f8674675..bbffd54b5f0d 100644 --- a/vertexai/internal/sentencepiece/normalize.go +++ b/vertexai/internal/sentencepiece/normalize.go @@ -24,11 +24,19 @@ import "strings" // normalizer that does none of this. These options can be added in the future // if needed. func normalize(text string) string { - return replaceSeparator(text) + return replaceSpacesBySeparator(text) } -// replaceSeparator replaces spaces by the whitespace separator used by +const whitespaceSeparator = "▁" + +// replaceSpacesBySeparator replaces spaces by the whitespace separator used by // the model. -func replaceSeparator(text string) string { - return strings.ReplaceAll(text, " ", "▁") +func replaceSpacesBySeparator(text string) string { + return strings.ReplaceAll(text, " ", whitespaceSeparator) +} + +// replaceSeparatorsBySpace replaces the whitespace separator used by +// the model back with spaces. +func replaceSeparatorsBySpace(text string) string { + return strings.ReplaceAll(text, whitespaceSeparator, " ") } diff --git a/vertexai/internal/sentencepiece/encoder.go b/vertexai/internal/sentencepiece/processor.go similarity index 68% rename from vertexai/internal/sentencepiece/encoder.go rename to vertexai/internal/sentencepiece/processor.go index 89c8b7785761..5287a249f9dc 100644 --- a/vertexai/internal/sentencepiece/encoder.go +++ b/vertexai/internal/sentencepiece/processor.go @@ -30,12 +30,15 @@ import ( const debugEncode = false -// Encoder represents a SentencePiece encoder (tokenizer). -// An Encoder converts input text into a sequence of tokens LLMs use. +// Processor represents a SentencePiece processor (tokenizer). +// A Processor converts input text into a sequence of tokens LLMs use, and back. // The mapping between token IDs and the text they represent is read from the // model proto (provided to the constructor); it's the same between all calls // to the Encode method. -type Encoder struct { +// +// The term "processor" comes from the original C++ SentencePiece library and +// its Python bindings. +type Processor struct { model *model.ModelProto pieces map[string]int @@ -48,23 +51,26 @@ type Encoder struct { // "user-defined" type in the model proto. userDefinedMatcher *prefixmatcher.PrefixMatcher - // byteTokens is a cache of byte values and the tokens they represent - byteTokens map[byte]Token + // byte2Token is a cache of byte values and the tokens they represent + byte2Token map[byte]Token + + // idToByte maps IDs to byte values they represent + idToByte map[int]byte } -// NewEncoderFromPath creates a new Encoder from a file path to the protobuf +// NewProcessorFromPath creates a new Processor from a file path to the protobuf // data. -func NewEncoderFromPath(protoFile string) (*Encoder, error) { +func NewProcessorFromPath(protoFile string) (*Processor, error) { f, err := os.Open(protoFile) if err != nil { return nil, fmt.Errorf("unable to read %q: %v", protoFile, err) } defer f.Close() - return NewEncoder(f) + return NewProcessor(f) } -// NewEncoder creates a new Encoder from a reader with the protobuf data. -func NewEncoder(protoReader io.Reader) (*Encoder, error) { +// NewProcessor creates a new Processor from a reader with the protobuf data. +func NewProcessor(protoReader io.Reader) (*Processor, error) { b, err := io.ReadAll(protoReader) if err != nil { return nil, fmt.Errorf("unable to read protobuf data: %v", err) @@ -81,10 +87,16 @@ func NewEncoder(protoReader io.Reader) (*Encoder, error) { return nil, fmt.Errorf("model type %s not supported", tspec.GetModelType()) } + nspec := mp.GetNormalizerSpec() + if *nspec.AddDummyPrefix || *nspec.RemoveExtraWhitespaces { + return nil, fmt.Errorf("normalizer spec options not supported: %s", nspec) + } + userDefined := make(map[string]bool) pieces := make(map[string]int) reserved := make(map[string]int) - byteTokens := make(map[byte]Token) + byte2Token := make(map[byte]Token) + idToByte := make(map[int]byte) unkID := -1 for i, piece := range mp.GetPieces() { @@ -111,7 +123,8 @@ func NewEncoder(protoReader io.Reader) (*Encoder, error) { } bv := convertHexValue(piece.GetPiece()) if bv >= 0 && bv < 256 { - byteTokens[byte(bv)] = Token{ID: i, Text: piece.GetPiece()} + byte2Token[byte(bv)] = Token{ID: i, Text: piece.GetPiece()} + idToByte[i] = byte(bv) } } } @@ -124,16 +137,17 @@ func NewEncoder(protoReader io.Reader) (*Encoder, error) { // values were found. if tspec.GetByteFallback() { for i := 0; i < 256; i++ { - if _, found := byteTokens[byte(i)]; !found { + if _, found := byte2Token[byte(i)]; !found { return nil, fmt.Errorf("byte value 0x%02X not found", i) } } } - return &Encoder{ + return &Processor{ model: &mp, userDefinedMatcher: prefixmatcher.NewFromSet(userDefined), - byteTokens: byteTokens, + byte2Token: byte2Token, + idToByte: idToByte, unknownID: unkID, pieces: pieces, reserved: reserved, @@ -141,7 +155,7 @@ func NewEncoder(protoReader io.Reader) (*Encoder, error) { } // Encode tokenizes the input text and returns a list of Tokens. -func (enc *Encoder) Encode(text string) []Token { +func (proc *Processor) Encode(text string) []Token { text = normalize(text) // We begin by having each symbol a single Unicode character (or a @@ -165,7 +179,7 @@ func (enc *Encoder) Encode(text string) []Token { for { // Match the next symbol in text - slen, found := enc.symbolMatch(text) + slen, found := proc.symbolMatch(text) // Append a list element for this symbol; note that this element will be // at index len(symList), so prev/next are set up accordingly. @@ -226,12 +240,12 @@ func (enc *Encoder) Encode(text string) []Token { } mergedSymbol := symList[left].symbol + symList[right].symbol - if id, found := enc.pieces[mergedSymbol]; found { + if id, found := proc.pieces[mergedSymbol]; found { mergeQueue.Insert(mergeCandidate{ left: left, right: right, length: len(mergedSymbol), - score: enc.model.GetPieces()[id].GetScore(), + score: proc.model.GetPieces()[id].GetScore(), }) } } @@ -278,13 +292,13 @@ func (enc *Encoder) Encode(text string) []Token { tokens := make([]Token, 0, len(symList)) for i := 0; i >= 0; i = symList[i].next { symbol := symList[i].symbol - id := enc.symbolToID(symbol) + id := proc.symbolToID(symbol) - if id == enc.unknownID && enc.model.GetTrainerSpec().GetByteFallback() { + if id == proc.unknownID && proc.model.GetTrainerSpec().GetByteFallback() { // Decompose this symbol into bytes, and report each byte as a separate // token. for i := 0; i < len(symbol); i++ { - tokens = append(tokens, enc.byteTokens[symbol[i]]) + tokens = append(tokens, proc.byte2Token[symbol[i]]) } } else { tokens = append(tokens, Token{ID: id, Text: symbol}) @@ -297,8 +311,8 @@ func (enc *Encoder) Encode(text string) []Token { // symbolMatch finds the length of the first symbol in text. A symbol is either // a user-defined symbol from the proto or a single rune. The second return // value is true iff a user-defined symbol was matched. -func (enc *Encoder) symbolMatch(text string) (int, bool) { - prefixLen := enc.userDefinedMatcher.FindPrefixLen(text) +func (proc *Processor) symbolMatch(text string) (int, bool) { + prefixLen := proc.userDefinedMatcher.FindPrefixLen(text) if prefixLen > 0 { return prefixLen, true } @@ -308,15 +322,15 @@ func (enc *Encoder) symbolMatch(text string) (int, bool) { } // symbolToID finds the right ID for the given textual symbol, or returns -// enc.unknownID if the symbol is unknown. -func (enc *Encoder) symbolToID(symbol string) int { - if id, found := enc.reserved[symbol]; found { +// proc.unknownID if the symbol is unknown. +func (proc *Processor) symbolToID(symbol string) int { + if id, found := proc.reserved[symbol]; found { return id } - if id, found := enc.pieces[symbol]; found { + if id, found := proc.pieces[symbol]; found { return id } - return enc.unknownID + return proc.unknownID } // convertHexValue converts strings of the form "<0xXY>" to the (unsigned) @@ -330,3 +344,77 @@ func convertHexValue(bv string) int { } return int(n) } + +// Decode translates a list of IDs produced by [Encode] back into the string +// it represents. +func (proc *Processor) Decode(ids []int) string { + var sb strings.Builder + + for i := 0; i < len(ids); { + // Find a run of IDs that represent single bytes starting at i. + nextNonByte := i + for nextNonByte < len(ids) && proc.isByteID(ids[nextNonByte]) { + nextNonByte++ + } + numBytes := nextNonByte - i + + // Handle a run of numBytes IDs, by decoding them into utf8 runes. + if numBytes > 0 { + buf := make([]byte, 0, numBytes) + for bi := i; bi < nextNonByte; bi++ { + buf = append(buf, proc.idToByte[ids[bi]]) + } + + for len(buf) > 0 { + // DecodeRune returns utf8.RuneError ('\uFFFD') for bad UTF8 encodings, + // and this is exactly what SentencePiece is supposed to emit for them. + // So we don't do any special handling for UTF8 decode errors here. + r, size := utf8.DecodeRune(buf) + sb.WriteRune(r) + buf = buf[size:] + } + } + + if nextNonByte >= len(ids) { + break + } + // Here nextNonByte is the index of an ID that's not a single byte. + id := ids[nextNonByte] + if proc.isControlID(id) { + // Don't emit anything for control IDs + } else if id == proc.unknownID { + // Special "unk_surface" string for unknown IDs + sb.WriteString(proc.model.GetTrainerSpec().GetUnkSurface()) + } else { + piece := proc.model.GetPieces()[id].GetPiece() + sb.WriteString(replaceSeparatorsBySpace(piece)) + } + i = nextNonByte + 1 + } + + return sb.String() +} + +// DecodeTokens is a convenience wrapper around [Decode], accepting a list of +// tokens as returned by [Encode]. It only uses the ID fields of tokens to +// decode the text. +func (proc *Processor) DecodeTokens(tokens []Token) string { + ids := make([]int, len(tokens)) + for i, t := range tokens { + ids[i] = t.ID + } + return proc.Decode(ids) +} + +// VocabularySize returns the vocabulary size from the proto model. +func (proc *Processor) VocabularySize() int { + return len(proc.model.GetPieces()) +} + +func (proc *Processor) isByteID(id int) bool { + return proc.model.GetPieces()[id].GetType() == model.ModelProto_SentencePiece_BYTE +} + +func (proc *Processor) isControlID(id int) bool { + return proc.model.GetPieces()[id].GetType() == model.ModelProto_SentencePiece_CONTROL +} diff --git a/vertexai/internal/version.go b/vertexai/internal/version.go index 1bcfd8a58abc..d30714fbcf9c 100644 --- a/vertexai/internal/version.go +++ b/vertexai/internal/version.go @@ -15,4 +15,4 @@ package internal // Version is the current tagged release of the library. -const Version = "0.12.0" +const Version = "0.13.0"