Skip to content

Commit

Permalink
feature: inbuilt support for scanner and valuer in pq.Array for int32…
Browse files Browse the repository at this point in the history
…/float32/[]byte slices
  • Loading branch information
Shivam010 committed Nov 30, 2020
1 parent 11a44e2 commit d726827
Show file tree
Hide file tree
Showing 3 changed files with 482 additions and 0 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,5 @@
*.test
*~
*.swp
.idea
.vscode
139 changes: 139 additions & 0 deletions array.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,19 +35,31 @@ func Array(a interface{}) interface {
return (*BoolArray)(&a)
case []float64:
return (*Float64Array)(&a)
case []float32:
return (*Float32Array)(&a)
case []int64:
return (*Int64Array)(&a)
case []int32:
return (*Int32Array)(&a)
case []string:
return (*StringArray)(&a)
case [][]byte:
return (*ByteaArray)(&a)

case *[]bool:
return (*BoolArray)(a)
case *[]float64:
return (*Float64Array)(a)
case *[]float32:
return (*Float32Array)(a)
case *[]int64:
return (*Int64Array)(a)
case *[]int32:
return (*Int32Array)(a)
case *[]string:
return (*StringArray)(a)
case *[][]byte:
return (*ByteaArray)(a)
}

return GenericArray{a}
Expand Down Expand Up @@ -267,6 +279,70 @@ func (a Float64Array) Value() (driver.Value, error) {
return "{}", nil
}

// Float32Array represents a one-dimensional array of the PostgreSQL double
// precision type.
type Float32Array []float32

// Scan implements the sql.Scanner interface.
func (a *Float32Array) Scan(src interface{}) error {
switch src := src.(type) {
case []byte:
return a.scanBytes(src)
case string:
return a.scanBytes([]byte(src))
case nil:
*a = nil
return nil
}

return fmt.Errorf("pq: cannot convert %T to Float32Array", src)
}

func (a *Float32Array) scanBytes(src []byte) error {
elems, err := scanLinearArray(src, []byte{','}, "Float32Array")
if err != nil {
return err
}
if *a != nil && len(elems) == 0 {
*a = (*a)[:0]
} else {
b := make(Float32Array, len(elems))
for i, v := range elems {
var x float64
if x, err = strconv.ParseFloat(string(v), 32); err != nil {
return fmt.Errorf("pq: parsing array element index %d: %v", i, err)
}
b[i] = float32(x)
}
*a = b
}
return nil
}

// Value implements the driver.Valuer interface.
func (a Float32Array) Value() (driver.Value, error) {
if a == nil {
return nil, nil
}

if n := len(a); n > 0 {
// There will be at least two curly brackets, N bytes of values,
// and N-1 bytes of delimiters.
b := make([]byte, 1, 1+2*n)
b[0] = '{'

b = strconv.AppendFloat(b, float64(a[0]), 'f', -1, 32)
for i := 1; i < n; i++ {
b = append(b, ',')
b = strconv.AppendFloat(b, float64(a[i]), 'f', -1, 32)
}

return string(append(b, '}')), nil
}

return "{}", nil
}

// GenericArray implements the driver.Valuer and sql.Scanner interfaces for
// an array or slice of any dimension.
type GenericArray struct{ A interface{} }
Expand Down Expand Up @@ -483,6 +559,69 @@ func (a Int64Array) Value() (driver.Value, error) {
return "{}", nil
}

// Int32Array represents a one-dimensional array of the PostgreSQL integer types.
type Int32Array []int32

// Scan implements the sql.Scanner interface.
func (a *Int32Array) Scan(src interface{}) error {
switch src := src.(type) {
case []byte:
return a.scanBytes(src)
case string:
return a.scanBytes([]byte(src))
case nil:
*a = nil
return nil
}

return fmt.Errorf("pq: cannot convert %T to Int32Array", src)
}

func (a *Int32Array) scanBytes(src []byte) error {
elems, err := scanLinearArray(src, []byte{','}, "Int32Array")
if err != nil {
return err
}
if *a != nil && len(elems) == 0 {
*a = (*a)[:0]
} else {
b := make(Int32Array, len(elems))
for i, v := range elems {
var x int
if x, err = strconv.Atoi(string(v)); err != nil {
return fmt.Errorf("pq: parsing array element index %d: %v", i, err)
}
b[i] = int32(x)
}
*a = b
}
return nil
}

// Value implements the driver.Valuer interface.
func (a Int32Array) Value() (driver.Value, error) {
if a == nil {
return nil, nil
}

if n := len(a); n > 0 {
// There will be at least two curly brackets, N bytes of values,
// and N-1 bytes of delimiters.
b := make([]byte, 1, 1+2*n)
b[0] = '{'

b = strconv.AppendInt(b, int64(a[0]), 10)
for i := 1; i < n; i++ {
b = append(b, ',')
b = strconv.AppendInt(b, int64(a[i]), 10)
}

return string(append(b, '}')), nil
}

return "{}", nil
}

// StringArray represents a one-dimensional array of the PostgreSQL character types.
type StringArray []string

Expand Down
Loading

0 comments on commit d726827

Please sign in to comment.