Skip to content

Commit

Permalink
[PECO-1112] Added decimal handling (#167)
Browse files Browse the repository at this point in the history
We need to dynamically set the actual values of decimals, and this
should be the smallest value that could hypothetically encompass the
decimal string.
  • Loading branch information
nithinkdb authored Sep 27, 2023
2 parents c08cf71 + c20e62b commit 250160b
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 3 deletions.
4 changes: 2 additions & 2 deletions parameter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ func TestParameter_Inference(t *testing.T) {
assert.Equal(t, string("TIMESTAMP"), *parameters[1].Type)
assert.Equal(t, &cli_service.TSparkParameterValue{StringValue: strPtr("5")}, parameters[2].Value)
assert.Equal(t, string("true"), *parameters[3].Value.StringValue)
assert.Equal(t, string("DECIMAL"), *parameters[4].Type)
assert.Equal(t, string("DECIMAL(2,1)"), *parameters[4].Type)
assert.Equal(t, string("6.2"), *parameters[4].Value.StringValue)
})
}
Expand All @@ -31,6 +31,6 @@ func TestParameters_Names(t *testing.T) {
assert.Equal(t, cli_service.TSparkParameterValue{StringValue: strPtr("26")}, *parameters[0].Value)
assert.Equal(t, string("2"), *parameters[1].Name)
assert.Equal(t, cli_service.TSparkParameterValue{StringValue: strPtr("6.2")}, *parameters[1].Value)
assert.Equal(t, string("DECIMAL"), *parameters[1].Type)
assert.Equal(t, string("DECIMAL(2,1)"), *parameters[1].Type)
})
}
27 changes: 26 additions & 1 deletion parameters.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"database/sql/driver"
"fmt"
"strconv"
"strings"
"time"

"github.com/databricks/databricks-sql-go/internal/cli_service"
Expand Down Expand Up @@ -142,9 +143,33 @@ func convertNamedValuesToSparkParams(values []driver.NamedValue) []*cli_service.
for i := range sqlParams {
sqlParam := sqlParams[i]
sparkParamValue := sqlParam.Value.(string)
sparkParamType := sqlParam.Type.String()
var sparkParamType string
if sqlParam.Type == Decimal {
sparkParamType = inferDecimalType(sparkParamValue)
} else {
sparkParamType = sqlParam.Type.String()
}
sparkParam := cli_service.TSparkParameter{Name: &sqlParam.Name, Type: &sparkParamType, Value: &cli_service.TSparkParameterValue{StringValue: &sparkParamValue}}
sparkParams = append(sparkParams, &sparkParam)
}
return sparkParams
}

func inferDecimalType(d string) (t string) {
var overall int
var after int
if strings.HasPrefix(d, "0.") {
// Less than one
overall = len(d) - 2
after = len(d) - 2
} else if !strings.Contains(d, ".") {
// Less than one
overall = len(d)
after = 0
} else {
components := strings.Split(d, ".")
overall, after = len(components[0])+len(components[1]), len(components[1])
}

return fmt.Sprintf("DECIMAL(%d,%d)", overall, after)
}

0 comments on commit 250160b

Please sign in to comment.