diff --git a/halfvec.go b/halfvec.go index 8cc7d2b..25730e4 100644 --- a/halfvec.go +++ b/halfvec.go @@ -5,6 +5,7 @@ import ( "database/sql/driver" "encoding/json" "fmt" + "slices" "strconv" "strings" ) @@ -26,17 +27,9 @@ func (v HalfVector) Slice() []float32 { // String returns a string representation of the half vector. func (v HalfVector) String() string { - buf := make([]byte, 0, 2+16*len(v.vec)) - buf = append(buf, '[') - - for i := 0; i < len(v.vec); i++ { - if i > 0 { - buf = append(buf, ',') - } - buf = strconv.AppendFloat(buf, float64(v.vec[i]), 'f', -1, 32) - } - - buf = append(buf, ']') + // should never throw an error + // but returning an empty string is fine if it does + buf, _ := v.EncodeText(nil) return string(buf) } @@ -54,6 +47,20 @@ func (v *HalfVector) Parse(s string) error { return nil } +// EncodeText encodes a text representation of the half vector. +func (v HalfVector) EncodeText(buf []byte) (newBuf []byte, err error) { + buf = slices.Grow(buf, 2+16*len(v.vec)) + buf = append(buf, '[') + for i := 0; i < len(v.vec); i++ { + if i > 0 { + buf = append(buf, ',') + } + buf = strconv.AppendFloat(buf, float64(v.vec[i]), 'f', -1, 32) + } + buf = append(buf, ']') + return buf, nil +} + // statically assert that HalfVector implements sql.Scanner. var _ sql.Scanner = (*HalfVector)(nil) diff --git a/pgx/halfvec.go b/pgx/halfvec.go new file mode 100644 index 0000000..f3854d8 --- /dev/null +++ b/pgx/halfvec.go @@ -0,0 +1,83 @@ +package pgx + +import ( + "database/sql/driver" + "fmt" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" + "github.com/pgvector/pgvector-go" +) + +type HalfVectorCodec struct{} + +func (HalfVectorCodec) FormatSupported(format int16) bool { + return format == pgx.TextFormatCode +} + +func (HalfVectorCodec) PreferredFormat() int16 { + return pgx.TextFormatCode +} + +func (HalfVectorCodec) PlanEncode(m *pgtype.Map, oid uint32, format int16, value any) pgtype.EncodePlan { + _, ok := value.(pgvector.HalfVector) + if !ok { + return nil + } + + if format == pgx.TextFormatCode { + return encodePlanHalfVectorCodecText{} + } + + return nil +} + +type encodePlanHalfVectorCodecText struct{} + +func (encodePlanHalfVectorCodecText) Encode(value any, buf []byte) (newBuf []byte, err error) { + v := value.(pgvector.HalfVector) + return v.EncodeText(buf) +} + +type scanPlanHalfVectorCodecText struct{} + +func (HalfVectorCodec) PlanScan(m *pgtype.Map, oid uint32, format int16, target any) pgtype.ScanPlan { + _, ok := target.(*pgvector.HalfVector) + if !ok { + return nil + } + + if format == pgx.TextFormatCode { + return scanPlanHalfVectorCodecText{} + } + + return nil +} + +func (scanPlanHalfVectorCodecText) Scan(src []byte, dst any) error { + v := (dst).(*pgvector.HalfVector) + return v.Scan(src) +} + +func (c HalfVectorCodec) DecodeDatabaseSQLValue(m *pgtype.Map, oid uint32, format int16, src []byte) (driver.Value, error) { + return c.DecodeValue(m, oid, format, src) +} + +func (c HalfVectorCodec) DecodeValue(m *pgtype.Map, oid uint32, format int16, src []byte) (any, error) { + if src == nil { + return nil, nil + } + + var vec pgvector.HalfVector + scanPlan := c.PlanScan(m, oid, format, &vec) + if scanPlan == nil { + return nil, fmt.Errorf("Unable to decode halfvec type") + } + + err := scanPlan.Scan(src, &vec) + if err != nil { + return nil, err + } + + return vec, nil +} diff --git a/pgx/register.go b/pgx/register.go index 3b39bae..2581ef5 100644 --- a/pgx/register.go +++ b/pgx/register.go @@ -10,8 +10,9 @@ import ( func RegisterTypes(ctx context.Context, conn *pgx.Conn) error { var vectorOid *uint32 + var halfvecOid *uint32 var sparsevecOid *uint32 - err := conn.QueryRow(ctx, "SELECT to_regtype('vector')::oid, to_regtype('sparsevec')::oid").Scan(&vectorOid, &sparsevecOid) + err := conn.QueryRow(ctx, "SELECT to_regtype('vector')::oid, to_regtype('halfvec')::oid, to_regtype('sparsevec')::oid").Scan(&vectorOid, &halfvecOid, &sparsevecOid) if err != nil { return err } @@ -23,6 +24,10 @@ func RegisterTypes(ctx context.Context, conn *pgx.Conn) error { tm := conn.TypeMap() tm.RegisterType(&pgtype.Type{Name: "vector", OID: *vectorOid, Codec: &VectorCodec{}}) + if halfvecOid != nil { + tm.RegisterType(&pgtype.Type{Name: "halfvec", OID: *halfvecOid, Codec: &HalfVectorCodec{}}) + } + if sparsevecOid != nil { tm.RegisterType(&pgtype.Type{Name: "sparsevec", OID: *sparsevecOid, Codec: &SparseVectorCodec{}}) }