Skip to content

Commit

Permalink
Refactored dbsqlparam
Browse files Browse the repository at this point in the history
Signed-off-by: nithinkdb <nithin.krishnamurthi@databricks.com>
  • Loading branch information
nithinkdb committed Sep 29, 2023
1 parent e338383 commit 1ba2c5e
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 17 deletions.
6 changes: 3 additions & 3 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -596,9 +596,9 @@ func (c *conn) pollOperation(ctx context.Context, opHandle *cli_service.TOperati

func (c *conn) CheckNamedValue(nv *driver.NamedValue) error {
var err error
if dbsqlParam, ok := nv.Value.(DBSqlParam); ok {
nv.Name = dbsqlParam.Name
dbsqlParam.Value, err = driver.DefaultParameterConverter.ConvertValue(dbsqlParam.Value)
if parameter, ok := nv.Value.(Parameter); ok {
nv.Name = parameter.Name
parameter.Value, err = driver.DefaultParameterConverter.ConvertValue(parameter.Value)
return err
}

Expand Down
10 changes: 5 additions & 5 deletions examples/parameters/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,11 @@ func main() {
:p_double AS col_double,
:p_float AS col_float,
:p_date AS col_date`,
dbsql.DBSqlParam{Name: "p_bool", Value: true},
dbsql.DBSqlParam{Name: "p_int", Value: int(1234)},
dbsql.DBSqlParam{Name: "p_double", Type: dbsql.SqlDouble, Value: "3.14"},
dbsql.DBSqlParam{Name: "p_float", Type: dbsql.SqlFloat, Value: "3.14"},
dbsql.DBSqlParam{Name: "p_date", Type: dbsql.SqlDate, Value: "2017-07-23 00:00:00"}).Scan(&p_bool, &p_int, &p_double, &p_float, &p_date)
dbsql.Parameter{Name: "p_bool", Value: true},
dbsql.Parameter{Name: "p_int", Value: int(1234)},
dbsql.Parameter{Name: "p_double", Type: dbsql.SqlDouble, Value: "3.14"},
dbsql.Parameter{Name: "p_float", Type: dbsql.SqlFloat, Value: "3.14"},
dbsql.Parameter{Name: "p_date", Type: dbsql.SqlDate, Value: "2017-07-23 00:00:00"}).Scan(&p_bool, &p_int, &p_double, &p_float, &p_date)

if err1 != nil {
if err1 == sql.ErrNoRows {
Expand Down
4 changes: 2 additions & 2 deletions parameter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (

func TestParameter_Inference(t *testing.T) {
t.Run("Should infer types correctly", func(t *testing.T) {
values := [5]driver.NamedValue{{Name: "", Value: float32(5.1)}, {Name: "", Value: time.Now()}, {Name: "", Value: int64(5)}, {Name: "", Value: true}, {Name: "", Value: DBSqlParam{Value: "6.2", Type: SqlDecimal}}}
values := [5]driver.NamedValue{{Name: "", Value: float32(5.1)}, {Name: "", Value: time.Now()}, {Name: "", Value: int64(5)}, {Name: "", Value: true}, {Name: "", Value: Parameter{Value: "6.2", Type: SqlDecimal}}}
parameters := convertNamedValuesToSparkParams(values[:])
assert.Equal(t, strconv.FormatFloat(float64(5.1), 'f', -1, 64), *parameters[0].Value.StringValue)
assert.NotNil(t, parameters[1].Value.StringValue)
Expand All @@ -25,7 +25,7 @@ func TestParameter_Inference(t *testing.T) {
}
func TestParameters_Names(t *testing.T) {
t.Run("Should infer types correctly", func(t *testing.T) {
values := [2]driver.NamedValue{{Name: "1", Value: int(26)}, {Name: "", Value: DBSqlParam{Name: "2", Type: SqlDecimal, Value: "6.2"}}}
values := [2]driver.NamedValue{{Name: "1", Value: int(26)}, {Name: "", Value: Parameter{Name: "2", Type: SqlDecimal, Value: "6.2"}}}
parameters := convertNamedValuesToSparkParams(values[:])
assert.Equal(t, string("1"), *parameters[0].Name)
assert.Equal(t, cli_service.TSparkParameterValue{StringValue: strPtr("26")}, *parameters[0].Value)
Expand Down
14 changes: 7 additions & 7 deletions parameters.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
"github.com/databricks/databricks-sql-go/internal/cli_service"
)

type DBSqlParam struct {
type Parameter struct {
Name string
Type SqlType
Value any
Expand Down Expand Up @@ -67,12 +67,12 @@ func (s SqlType) String() string {
return "unknown"
}

func valuesToDBSQLParams(namedValues []driver.NamedValue) []DBSqlParam {
var params []DBSqlParam
func valuesToParameters(namedValues []driver.NamedValue) []Parameter {
var params []Parameter
for i := range namedValues {
newParam := *new(DBSqlParam)
newParam := *new(Parameter)
namedValue := namedValues[i]
param, ok := namedValue.Value.(DBSqlParam)
param, ok := namedValue.Value.(Parameter)
if ok {
newParam.Name = param.Name
newParam.Value = param.Value
Expand All @@ -86,7 +86,7 @@ func valuesToDBSQLParams(namedValues []driver.NamedValue) []DBSqlParam {
return params
}

func inferTypes(params []DBSqlParam) {
func inferTypes(params []Parameter) {
for i := range params {
param := &params[i]
if param.Type == SqlUnkown {
Expand Down Expand Up @@ -144,7 +144,7 @@ func inferTypes(params []DBSqlParam) {
func convertNamedValuesToSparkParams(values []driver.NamedValue) []*cli_service.TSparkParameter {
var sparkParams []*cli_service.TSparkParameter

sqlParams := valuesToDBSQLParams(values)
sqlParams := valuesToParameters(values)
inferTypes(sqlParams)
for i := range sqlParams {
sqlParam := sqlParams[i]
Expand Down

0 comments on commit 1ba2c5e

Please sign in to comment.