Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: Support nullable types in Always Encrypted #179

Merged
merged 4 commits into from
Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -427,9 +427,8 @@ If the correct key provider is included in your application, decryption of encry

Encryption of parameters passed to `Exec` and `Query` variants requires an extra round trip per query to fetch the encryption metadata. If the error returned by a query attempt indicates a type mismatch between the parameter and the destination table, most likely your input type is not a strict match for the SQL Server data type of the destination. You may be using a Go `string` when you need to use one of the driver-specific aliases like `VarChar` or `NVarCharMax`.

*** NOTE *** - Currently `char` and `varchar` types do not include a collation parameter component so can't be used for inserting encrypted values. Also, using a nullable sql package type like `sql.NullableInt32` to pass a `NULL` value for an encrypted column will not work unless the encrypted column type is `nvarchar`.
*** NOTE *** - Currently `char` and `varchar` types do not include a collation parameter component so can't be used for inserting encrypted values.
https://github.com/microsoft/go-mssqldb/issues/129
https://github.com/microsoft/go-mssqldb/issues/130


### Local certificate AE key provider
Expand Down
17 changes: 13 additions & 4 deletions alwaysencrypted_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"crypto/rand"
"database/sql"
"database/sql/driver"
"fmt"
"math/big"
"strings"
Expand Down Expand Up @@ -65,8 +66,10 @@ func TestAlwaysEncryptedE2E(t *testing.T) {
{"datetimeoffset(7)", "DATETIMEOFFSET", ColumnEncryptionRandomized, dt},
{"datetime2(7)", "DATETIME2", ColumnEncryptionDeterministic, civil.DateTimeOf(dt)},
{"nvarchar(max)", "NVARCHAR", ColumnEncryptionRandomized, NVarCharMax("nvarcharmaxval")},
// TODO: The driver throws away type information about Valuer implementations and sends nil as nvarchar(1). Fix that.
// {"int", "INT", ColumnEncryptionDeterministic, sql.NullInt32{Valid: false}},
{"int", "INT", ColumnEncryptionDeterministic, sql.NullInt32{Valid: false}},
{"bigint", "BIGINT", ColumnEncryptionDeterministic, sql.NullInt64{Int64: 128, Valid: true}},
{"uniqueidentifier", "UNIQUEIDENTIFIER", ColumnEncryptionRandomized, UniqueIdentifier{0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF}},
{"uniqueidentifier", "UNIQUEIDENTIFIER", ColumnEncryptionRandomized, NullUniqueIdentifier{Valid: false}},
}
for _, test := range providerTests {
// turn off key caching
Expand Down Expand Up @@ -230,13 +233,19 @@ func comparisonValueFromObject(object interface{}) string {
case time.Time:
return civil.DateTimeOf(v).String()
//return v.Format(time.RFC3339)
case fmt.Stringer:
return v.String()
case bool:
if v == true {
return "1"
}
return "0"
case driver.Valuer:
val, _ := v.Value()
if val == nil {
return "<nil>"
}
return comparisonValueFromObject(val)
case fmt.Stringer:
return v.String()
default:
return fmt.Sprintf("%v", v)
}
Expand Down
34 changes: 34 additions & 0 deletions mssql.go
Original file line number Diff line number Diff line change
Expand Up @@ -982,7 +982,37 @@ func (s *Stmt) makeParam(val driver.Value) (res param, err error) {
res.ti.Size = 0
return
}
switch valuer := val.(type) {
case UniqueIdentifier:
case NullUniqueIdentifier:
default:
break
case driver.Valuer:
// If the value has a non-nil value, call MakeParam on its Value
val, e := driver.DefaultParameterConverter.ConvertValue(valuer)
if e != nil {
err = e
return
}
if val != nil {
return s.makeParam(val)
}
}
switch val := val.(type) {
case UniqueIdentifier:
res.ti.TypeId = typeGuid
res.ti.Size = 16
guid, _ := val.Value()
res.buffer = guid.([]byte)
case NullUniqueIdentifier:
res.ti.TypeId = typeGuid
res.ti.Size = 16
if val.Valid {
guid, _ := val.Value()
res.buffer = guid.([]byte)
} else {
res.buffer = []byte{}
}
case int:
res.ti.TypeId = typeIntN
// Rather than guess if the caller intends to pass a 32bit int from a 64bit app based on the
Expand Down Expand Up @@ -1021,6 +1051,10 @@ func (s *Stmt) makeParam(val driver.Value) (res param, err error) {
res.ti.TypeId = typeIntN
res.ti.Size = 8
res.buffer = []byte{}
case sql.NullInt32:
res.ti.TypeId = typeIntN
res.ti.Size = 4
res.buffer = []byte{}
case byte:
res.ti.TypeId = typeIntN
res.buffer = []byte{val}
Expand Down
2 changes: 2 additions & 0 deletions mssql_go19.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ func convertInputParameter(val interface{}) (interface{}, error) {
// return nil
case float32:
return val, nil
case driver.Valuer:
return val, nil
default:
return driver.DefaultParameterConverter.ConvertValue(v)
}
Expand Down
27 changes: 27 additions & 0 deletions queries_go19_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"fmt"
"reflect"
"regexp"
"strings"
"testing"
"time"

Expand All @@ -31,6 +32,32 @@ func TestOutputParam(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

t.Run("varchar(max) to sql.NullString", func(t *testing.T) {
sqltextcreate := `CREATE PROCEDURE [GetTask]
@strparam varchar(max) = NULL OUTPUT
AS
SELECT @strparam = REPLICATE('a', 8000)
RETURN 0`
sqltextdrop := `drop procedure GetTask`
sqltextrun := `GetTask`
_, _ = db.ExecContext(ctx, sqltextdrop)
_, err = db.ExecContext(ctx, sqltextcreate)
if err != nil {
t.Fatal(err)
}
defer db.ExecContext(ctx, sqltextdrop)
nullstr := sql.NullString{}
_, err := db.ExecContext(ctx, sqltextrun,
sql.Named("strparam", sql.Out{Dest: &nullstr}),
)
if err != nil {
t.Error(err)
}
defer db.ExecContext(ctx, sqltextdrop)
if nullstr.String != strings.Repeat("a", 8000) {
t.Error("Got incorrect NullString of length:", len(nullstr.String))
}
})
t.Run("sp with rows", func(t *testing.T) {
sqltextcreate := `
CREATE PROCEDURE spwithrows
Expand Down
13 changes: 13 additions & 0 deletions queries_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,19 @@ func TestSelect(t *testing.T) {
}
})
})
t.Run("scan into sql.NullString", func(t *testing.T) {
row := conn.QueryRow("SELECT REPLICATE('a', 8000)")
var out sql.NullString
err := row.Scan(&out)
if err != nil {
t.Error("Scan to NullString failed", err.Error())
return
}

if out.String != strings.Repeat("a", 8000) {
t.Error("got back a string with count:", len(out.String))
}
})
}

func TestSelectDateTimeOffset(t *testing.T) {
Expand Down
Loading