From d726827bf762cdd26468d69f9e3f7b4d2ee4df90 Mon Sep 17 00:00:00 2001 From: Shivam Rathore Date: Thu, 20 Jun 2019 10:07:06 +0530 Subject: [PATCH] feature: inbuilt support for scanner and valuer in pq.Array for int32/float32/[]byte slices --- .gitignore | 2 + array.go | 139 ++++++++++++++++++++ array_test.go | 341 ++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 482 insertions(+) diff --git a/.gitignore b/.gitignore index 0f1d00e1..3243952a 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,5 @@ *.test *~ *.swp +.idea +.vscode \ No newline at end of file diff --git a/array.go b/array.go index e4933e22..405da236 100644 --- a/array.go +++ b/array.go @@ -35,19 +35,31 @@ func Array(a interface{}) interface { return (*BoolArray)(&a) case []float64: return (*Float64Array)(&a) + case []float32: + return (*Float32Array)(&a) case []int64: return (*Int64Array)(&a) + case []int32: + return (*Int32Array)(&a) case []string: return (*StringArray)(&a) + case [][]byte: + return (*ByteaArray)(&a) case *[]bool: return (*BoolArray)(a) case *[]float64: return (*Float64Array)(a) + case *[]float32: + return (*Float32Array)(a) case *[]int64: return (*Int64Array)(a) + case *[]int32: + return (*Int32Array)(a) case *[]string: return (*StringArray)(a) + case *[][]byte: + return (*ByteaArray)(a) } return GenericArray{a} @@ -267,6 +279,70 @@ func (a Float64Array) Value() (driver.Value, error) { return "{}", nil } +// Float32Array represents a one-dimensional array of the PostgreSQL double +// precision type. +type Float32Array []float32 + +// Scan implements the sql.Scanner interface. +func (a *Float32Array) Scan(src interface{}) error { + switch src := src.(type) { + case []byte: + return a.scanBytes(src) + case string: + return a.scanBytes([]byte(src)) + case nil: + *a = nil + return nil + } + + return fmt.Errorf("pq: cannot convert %T to Float32Array", src) +} + +func (a *Float32Array) scanBytes(src []byte) error { + elems, err := scanLinearArray(src, []byte{','}, "Float32Array") + if err != nil { + return err + } + if *a != nil && len(elems) == 0 { + *a = (*a)[:0] + } else { + b := make(Float32Array, len(elems)) + for i, v := range elems { + var x float64 + if x, err = strconv.ParseFloat(string(v), 32); err != nil { + return fmt.Errorf("pq: parsing array element index %d: %v", i, err) + } + b[i] = float32(x) + } + *a = b + } + return nil +} + +// Value implements the driver.Valuer interface. +func (a Float32Array) Value() (driver.Value, error) { + if a == nil { + return nil, nil + } + + if n := len(a); n > 0 { + // There will be at least two curly brackets, N bytes of values, + // and N-1 bytes of delimiters. + b := make([]byte, 1, 1+2*n) + b[0] = '{' + + b = strconv.AppendFloat(b, float64(a[0]), 'f', -1, 32) + for i := 1; i < n; i++ { + b = append(b, ',') + b = strconv.AppendFloat(b, float64(a[i]), 'f', -1, 32) + } + + return string(append(b, '}')), nil + } + + return "{}", nil +} + // GenericArray implements the driver.Valuer and sql.Scanner interfaces for // an array or slice of any dimension. type GenericArray struct{ A interface{} } @@ -483,6 +559,69 @@ func (a Int64Array) Value() (driver.Value, error) { return "{}", nil } +// Int32Array represents a one-dimensional array of the PostgreSQL integer types. +type Int32Array []int32 + +// Scan implements the sql.Scanner interface. +func (a *Int32Array) Scan(src interface{}) error { + switch src := src.(type) { + case []byte: + return a.scanBytes(src) + case string: + return a.scanBytes([]byte(src)) + case nil: + *a = nil + return nil + } + + return fmt.Errorf("pq: cannot convert %T to Int32Array", src) +} + +func (a *Int32Array) scanBytes(src []byte) error { + elems, err := scanLinearArray(src, []byte{','}, "Int32Array") + if err != nil { + return err + } + if *a != nil && len(elems) == 0 { + *a = (*a)[:0] + } else { + b := make(Int32Array, len(elems)) + for i, v := range elems { + var x int + if x, err = strconv.Atoi(string(v)); err != nil { + return fmt.Errorf("pq: parsing array element index %d: %v", i, err) + } + b[i] = int32(x) + } + *a = b + } + return nil +} + +// Value implements the driver.Valuer interface. +func (a Int32Array) Value() (driver.Value, error) { + if a == nil { + return nil, nil + } + + if n := len(a); n > 0 { + // There will be at least two curly brackets, N bytes of values, + // and N-1 bytes of delimiters. + b := make([]byte, 1, 1+2*n) + b[0] = '{' + + b = strconv.AppendInt(b, int64(a[0]), 10) + for i := 1; i < n; i++ { + b = append(b, ',') + b = strconv.AppendInt(b, int64(a[i]), 10) + } + + return string(append(b, '}')), nil + } + + return "{}", nil +} + // StringArray represents a one-dimensional array of the PostgreSQL character types. type StringArray []string diff --git a/array_test.go b/array_test.go index f724bcd8..5ca9f7a5 100644 --- a/array_test.go +++ b/array_test.go @@ -104,16 +104,33 @@ func TestArrayScanner(t *testing.T) { t.Errorf("Expected *Int64Array, got %T", s) } + s = Array(&[]float32{}) + if _, ok := s.(*Float32Array); !ok { + t.Errorf("Expected *Float32Array, got %T", s) + } + + s = Array(&[]int32{}) + if _, ok := s.(*Int32Array); !ok { + t.Errorf("Expected *Int32Array, got %T", s) + } + s = Array(&[]string{}) if _, ok := s.(*StringArray); !ok { t.Errorf("Expected *StringArray, got %T", s) } + s = Array(&[][]byte{}) + if _, ok := s.(*ByteaArray); !ok { + t.Errorf("Expected *ByteaArray, got %T", s) + } + for _, tt := range []interface{}{ &[]sql.Scanner{}, &[][]bool{}, &[][]float64{}, &[][]int64{}, + &[][]float32{}, + &[][]int32{}, &[][]string{}, } { s = Array(tt) @@ -139,17 +156,34 @@ func TestArrayValuer(t *testing.T) { t.Errorf("Expected *Int64Array, got %T", v) } + v = Array([]float32{}) + if _, ok := v.(*Float32Array); !ok { + t.Errorf("Expected *Float32Array, got %T", v) + } + + v = Array([]int32{}) + if _, ok := v.(*Int32Array); !ok { + t.Errorf("Expected *Int32Array, got %T", v) + } + v = Array([]string{}) if _, ok := v.(*StringArray); !ok { t.Errorf("Expected *StringArray, got %T", v) } + v = Array([][]byte{}) + if _, ok := v.(*ByteaArray); !ok { + t.Errorf("Expected *ByteaArray, got %T", v) + } + for _, tt := range []interface{}{ nil, []driver.Value{}, [][]bool{}, [][]float64{}, [][]int64{}, + [][]float32{}, + [][]int32{}, [][]string{}, } { v = Array(tt) @@ -773,6 +807,313 @@ func BenchmarkInt64ArrayValue(b *testing.B) { } } +func TestFloat32ArrayScanUnsupported(t *testing.T) { + var arr Float32Array + err := arr.Scan(true) + + if err == nil { + t.Fatal("Expected error when scanning from bool") + } + if !strings.Contains(err.Error(), "bool to Float32Array") { + t.Errorf("Expected type to be mentioned when scanning, got %q", err) + } +} + +func TestFloat32ArrayScanEmpty(t *testing.T) { + var arr Float32Array + err := arr.Scan(`{}`) + + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if arr == nil || len(arr) != 0 { + t.Errorf("Expected empty, got %#v", arr) + } +} + +func TestFloat32ArrayScanNil(t *testing.T) { + arr := Float32Array{5, 5, 5} + err := arr.Scan(nil) + + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if arr != nil { + t.Errorf("Expected nil, got %+v", arr) + } +} + +var Float32ArrayStringTests = []struct { + str string + arr Float32Array +}{ + {`{}`, Float32Array{}}, + {`{1.2}`, Float32Array{1.2}}, + {`{3.456,7.89}`, Float32Array{3.456, 7.89}}, + {`{3,1,2}`, Float32Array{3, 1, 2}}, +} + +func TestFloat32ArrayScanBytes(t *testing.T) { + for _, tt := range Float32ArrayStringTests { + bytes := []byte(tt.str) + arr := Float32Array{5, 5, 5} + err := arr.Scan(bytes) + + if err != nil { + t.Fatalf("Expected no error for %q, got %v", bytes, err) + } + if !reflect.DeepEqual(arr, tt.arr) { + t.Errorf("Expected %+v for %q, got %+v", tt.arr, bytes, arr) + } + } +} + +func BenchmarkFloat32ArrayScanBytes(b *testing.B) { + var a Float32Array + var x interface{} = []byte(`{1.2,3.4,5.6,7.8,9.01,2.34,5.67,8.90,1.234,5.678}`) + + for i := 0; i < b.N; i++ { + a = Float32Array{} + a.Scan(x) + } +} + +func TestFloat32ArrayScanString(t *testing.T) { + for _, tt := range Float32ArrayStringTests { + arr := Float32Array{5, 5, 5} + err := arr.Scan(tt.str) + + if err != nil { + t.Fatalf("Expected no error for %q, got %v", tt.str, err) + } + if !reflect.DeepEqual(arr, tt.arr) { + t.Errorf("Expected %+v for %q, got %+v", tt.arr, tt.str, arr) + } + } +} + +func TestFloat32ArrayScanError(t *testing.T) { + for _, tt := range []struct { + input, err string + }{ + {``, "unable to parse array"}, + {`{`, "unable to parse array"}, + {`{{5.6},{7.8}}`, "cannot convert ARRAY[2][1] to Float32Array"}, + {`{NULL}`, "parsing array element index 0:"}, + {`{a}`, "parsing array element index 0:"}, + {`{5.6,a}`, "parsing array element index 1:"}, + {`{5.6,7.8,a}`, "parsing array element index 2:"}, + } { + arr := Float32Array{5, 5, 5} + err := arr.Scan(tt.input) + + if err == nil { + t.Fatalf("Expected error for %q, got none", tt.input) + } + if !strings.Contains(err.Error(), tt.err) { + t.Errorf("Expected error to contain %q for %q, got %q", tt.err, tt.input, err) + } + if !reflect.DeepEqual(arr, Float32Array{5, 5, 5}) { + t.Errorf("Expected destination not to change for %q, got %+v", tt.input, arr) + } + } +} + +func TestFloat32ArrayValue(t *testing.T) { + result, err := Float32Array(nil).Value() + + if err != nil { + t.Fatalf("Expected no error for nil, got %v", err) + } + if result != nil { + t.Errorf("Expected nil, got %q", result) + } + + result, err = Float32Array([]float32{}).Value() + + if err != nil { + t.Fatalf("Expected no error for empty, got %v", err) + } + if expected := `{}`; !reflect.DeepEqual(result, expected) { + t.Errorf("Expected empty, got %q", result) + } + + result, err = Float32Array([]float32{1.2, 3.4, 5.6}).Value() + + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if expected := `{1.2,3.4,5.6}`; !reflect.DeepEqual(result, expected) { + t.Errorf("Expected %q, got %q", expected, result) + } +} + +func BenchmarkFloat32ArrayValue(b *testing.B) { + rand.Seed(1) + x := make([]float32, 10) + for i := 0; i < len(x); i++ { + x[i] = rand.Float32() + } + a := Float32Array(x) + + for i := 0; i < b.N; i++ { + a.Value() + } +} + +func TestInt32ArrayScanUnsupported(t *testing.T) { + var arr Int32Array + err := arr.Scan(true) + + if err == nil { + t.Fatal("Expected error when scanning from bool") + } + if !strings.Contains(err.Error(), "bool to Int32Array") { + t.Errorf("Expected type to be mentioned when scanning, got %q", err) + } +} + +func TestInt32ArrayScanEmpty(t *testing.T) { + var arr Int32Array + err := arr.Scan(`{}`) + + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if arr == nil || len(arr) != 0 { + t.Errorf("Expected empty, got %#v", arr) + } +} + +func TestInt32ArrayScanNil(t *testing.T) { + arr := Int32Array{5, 5, 5} + err := arr.Scan(nil) + + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if arr != nil { + t.Errorf("Expected nil, got %+v", arr) + } +} + +var Int32ArrayStringTests = []struct { + str string + arr Int32Array +}{ + {`{}`, Int32Array{}}, + {`{12}`, Int32Array{12}}, + {`{345,678}`, Int32Array{345, 678}}, +} + +func TestInt32ArrayScanBytes(t *testing.T) { + for _, tt := range Int32ArrayStringTests { + bytes := []byte(tt.str) + arr := Int32Array{5, 5, 5} + err := arr.Scan(bytes) + + if err != nil { + t.Fatalf("Expected no error for %q, got %v", bytes, err) + } + if !reflect.DeepEqual(arr, tt.arr) { + t.Errorf("Expected %+v for %q, got %+v", tt.arr, bytes, arr) + } + } +} + +func BenchmarkInt32ArrayScanBytes(b *testing.B) { + var a Int32Array + var x interface{} = []byte(`{1,2,3,4,5,6,7,8,9,0}`) + + for i := 0; i < b.N; i++ { + a = Int32Array{} + a.Scan(x) + } +} + +func TestInt32ArrayScanString(t *testing.T) { + for _, tt := range Int32ArrayStringTests { + arr := Int32Array{5, 5, 5} + err := arr.Scan(tt.str) + + if err != nil { + t.Fatalf("Expected no error for %q, got %v", tt.str, err) + } + if !reflect.DeepEqual(arr, tt.arr) { + t.Errorf("Expected %+v for %q, got %+v", tt.arr, tt.str, arr) + } + } +} + +func TestInt32ArrayScanError(t *testing.T) { + for _, tt := range []struct { + input, err string + }{ + {``, "unable to parse array"}, + {`{`, "unable to parse array"}, + {`{{5},{6}}`, "cannot convert ARRAY[2][1] to Int32Array"}, + {`{NULL}`, "parsing array element index 0:"}, + {`{a}`, "parsing array element index 0:"}, + {`{5,a}`, "parsing array element index 1:"}, + {`{5,6,a}`, "parsing array element index 2:"}, + } { + arr := Int32Array{5, 5, 5} + err := arr.Scan(tt.input) + + if err == nil { + t.Fatalf("Expected error for %q, got none", tt.input) + } + if !strings.Contains(err.Error(), tt.err) { + t.Errorf("Expected error to contain %q for %q, got %q", tt.err, tt.input, err) + } + if !reflect.DeepEqual(arr, Int32Array{5, 5, 5}) { + t.Errorf("Expected destination not to change for %q, got %+v", tt.input, arr) + } + } +} + +func TestInt32ArrayValue(t *testing.T) { + result, err := Int32Array(nil).Value() + + if err != nil { + t.Fatalf("Expected no error for nil, got %v", err) + } + if result != nil { + t.Errorf("Expected nil, got %q", result) + } + + result, err = Int32Array([]int32{}).Value() + + if err != nil { + t.Fatalf("Expected no error for empty, got %v", err) + } + if expected := `{}`; !reflect.DeepEqual(result, expected) { + t.Errorf("Expected empty, got %q", result) + } + + result, err = Int32Array([]int32{1, 2, 3}).Value() + + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if expected := `{1,2,3}`; !reflect.DeepEqual(result, expected) { + t.Errorf("Expected %q, got %q", expected, result) + } +} + +func BenchmarkInt32ArrayValue(b *testing.B) { + rand.Seed(1) + x := make([]int32, 10) + for i := 0; i < len(x); i++ { + x[i] = rand.Int31() + } + a := Int32Array(x) + + for i := 0; i < b.N; i++ { + a.Value() + } +} + func TestStringArrayScanUnsupported(t *testing.T) { var arr StringArray err := arr.Scan(true)