From aeb5e5d885adfb4a4caf3dc2008c03d46a8da8de Mon Sep 17 00:00:00 2001 From: Esdras Beleza Date: Tue, 16 Apr 2024 10:32:37 +0100 Subject: [PATCH] Fix Spark parameter creation when passing a `nil`-value named parameter to a query (#199) - When a `sql.NamedValue` has the field `Value` set to `nil`, the resulting `cli_service.TSparkParameter` will also have the value `nil` instead of `*cli_service.TSparkParameterValue{StringValue: *"%!s("}`. - Add the type `SqlVoid`, following the conventions used in the [NodeJS connector](https://github.com/databricks/databricks-sql-nodejs/blob/main/lib/DBSQLParameter.ts#L43-L51) and the [Python driver](https://github.com/databricks/databricks-sql-python/blob/f6fd7a7956a4dbc78ad36b5e079fe8d74176a0f1/src/databricks/sql/parameters/native.py#L319-L323). Fix #193. --------- Signed-off-by: Esdras Beleza Signed-off-by: Levko Kravets Signed-off-by: candiduslynx Co-authored-by: Levko Kravets Co-authored-by: Mahdi Dibaiee Co-authored-by: Alex Shcherbakov --- CHANGELOG.md | 3 ++- parameter_test.go | 7 +++++-- parameters.go | 19 ++++++++++++++++--- 3 files changed, 23 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b899222..48e7db5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,7 @@ # Release History -- Fix formatting of *float64 parameters +- Bug fix for issue 193: convertNamedValuesToSparkParams was incorrectly creating a Spark parameter value as "%!s()" when a named param was nil (databricks/databricks-sql-go#199 by @esdrasbeleza) +- Fix formatting of *float64 parameters (databricks/databricks-sql-go#215 by @esdrasbeleza) ## v1.5.4 (2024-04-10) diff --git a/parameter_test.go b/parameter_test.go index 3d4b10b..2c994cf 100644 --- a/parameter_test.go +++ b/parameter_test.go @@ -12,12 +12,13 @@ import ( func TestParameter_Inference(t *testing.T) { t.Run("Should infer types correctly", func(t *testing.T) { - values := [6]driver.NamedValue{ + values := [7]driver.NamedValue{ {Name: "", Value: float32(5.1)}, {Name: "", Value: time.Now()}, {Name: "", Value: int64(5)}, {Name: "", Value: true}, {Name: "", Value: Parameter{Value: "6.2", Type: SqlDecimal}}, + {Name: "", Value: nil}, {Name: "", Value: Parameter{Value: float64Ptr(6.2), Type: SqlUnkown}}, } parameters := convertNamedValuesToSparkParams(values[:]) @@ -28,7 +29,9 @@ func TestParameter_Inference(t *testing.T) { assert.Equal(t, string("true"), *parameters[3].Value.StringValue) assert.Equal(t, string("DECIMAL(2,1)"), *parameters[4].Type) assert.Equal(t, string("6.2"), *parameters[4].Value.StringValue) - assert.Equal(t, &cli_service.TSparkParameterValue{StringValue: strPtr("6.2")}, parameters[5].Value) + assert.Equal(t, string("VOID"), *parameters[5].Type) + assert.Nil(t, parameters[5].Value) + assert.Equal(t, &cli_service.TSparkParameterValue{StringValue: strPtr("6.2")}, parameters[6].Value) }) } func TestParameters_Names(t *testing.T) { diff --git a/parameters.go b/parameters.go index 7b780f4..0652d80 100644 --- a/parameters.go +++ b/parameters.go @@ -34,6 +34,7 @@ const ( SqlBoolean SqlIntervalMonth SqlIntervalDay + SqlVoid ) func (s SqlType) String() string { @@ -64,6 +65,8 @@ func (s SqlType) String() string { return "INTERVAL MONTH" case SqlIntervalDay: return "INTERVAL DAY" + case SqlVoid: + return "VOID" } return "unknown" } @@ -149,6 +152,9 @@ func inferType(param *Parameter) { case time.Time: param.Value = value.Format(time.RFC3339Nano) param.Type = SqlTimestamp + case nil: + param.Value = nil + param.Type = SqlVoid default: s := fmt.Sprintf("%s", param.Value) param.Value = s @@ -163,14 +169,21 @@ func convertNamedValuesToSparkParams(values []driver.NamedValue) []*cli_service. inferTypes(sqlParams) for i := range sqlParams { sqlParam := sqlParams[i] - sparkParamValue := sqlParam.Value.(string) + sparkValue := new(cli_service.TSparkParameterValue) + if sqlParam.Type == SqlVoid { + sparkValue = nil + } else { + stringValue := sqlParam.Value.(string) + sparkValue = &cli_service.TSparkParameterValue{StringValue: &stringValue} + } + var sparkParamType string if sqlParam.Type == SqlDecimal { - sparkParamType = inferDecimalType(sparkParamValue) + sparkParamType = inferDecimalType(sparkValue.GetStringValue()) } else { sparkParamType = sqlParam.Type.String() } - sparkParam := cli_service.TSparkParameter{Name: &sqlParam.Name, Type: &sparkParamType, Value: &cli_service.TSparkParameterValue{StringValue: &sparkParamValue}} + sparkParam := cli_service.TSparkParameter{Name: &sqlParam.Name, Type: &sparkParamType, Value: sparkValue} sparkParams = append(sparkParams, &sparkParam) } return sparkParams