From b99e53e1e49ff4dbf2ce655d32febc22e621291d Mon Sep 17 00:00:00 2001 From: Philippe Daouadi Date: Wed, 15 Mar 2017 16:46:09 +0100 Subject: [PATCH] Add json (un)marshalling methods on NullUUID --- uuid.go | 24 ++++++++++++++++++++++++ uuid_test.go | 44 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 68 insertions(+) diff --git a/uuid.go b/uuid.go index 295f3fc..da191fd 100644 --- a/uuid.go +++ b/uuid.go @@ -32,6 +32,7 @@ import ( "database/sql/driver" "encoding/binary" "encoding/hex" + "encoding/json" "fmt" "hash" "net" @@ -344,6 +345,29 @@ func (u *NullUUID) Scan(src interface{}) error { return u.UUID.Scan(src) } +// MarshalJSON marshalls the NullUUID as nil or the nested UUID +func (u NullUUID) MarshalJSON() ([]byte, error) { + if u.Valid == false { + return json.Marshal(nil) + } + return json.Marshal(u.UUID) +} + +// UnmarshalJSON unmarshalls a NullUUID +func (u *NullUUID) UnmarshalJSON(b []byte) error { + if bytes.Equal(b, []byte("null")) { + u.UUID, u.Valid = Nil, false + return nil + } + + if err := json.Unmarshal(b, &u.UUID); err != nil { + return err + } + u.Valid = true + + return nil +} + // FromBytes returns UUID converted from raw byte slice input. // It will return error if the slice isn't 16 bytes long. func FromBytes(input []byte) (u UUID, err error) { diff --git a/uuid_test.go b/uuid_test.go index 5650480..de471dc 100644 --- a/uuid_test.go +++ b/uuid_test.go @@ -23,6 +23,7 @@ package uuid import ( "bytes" + "encoding/json" "testing" ) @@ -631,3 +632,46 @@ func TestNewV5(t *testing.T) { t.Errorf("UUIDv3 generated same UUIDs for sane names in different namespaces: %s and %s", u1, u4) } } + +func TestMarshalNullUUID(t *testing.T) { + u := NullUUID{UUID: NewV4(), Valid: true} + j, err := json.Marshal(u) + if err != nil { + t.Error("Couldn't marshal a valid NullUUID: ", err) + } + + if string(j) != "\""+u.UUID.String()+"\"" { + t.Error("Marshaled NullUUID is incorrect: ", string(j)) + } + + nu := NullUUID{Valid: false} + j, err = json.Marshal(nu) + if err != nil { + t.Error("Couldn't marshal an invalid NullUUID: ", err) + } + + if string(j) != "null" { + t.Error("Marshaled NullUUID is incorrect: ", string(j)) + } +} + +func TestUnmarshalNullUUID(t *testing.T) { + var u NullUUID + err := json.Unmarshal([]byte("null"), &u) + if err != nil { + t.Error("Couldn't Unmarshal an invalid NullUUID: ", err) + } + + if u.Valid != false { + t.Error("Unmarshaled NullUUID is valid but shouldn't") + } + + err = json.Unmarshal([]byte("\"886313e1-3b8a-5372-9b90-0c9aee199e5d\""), &u) + if err != nil { + t.Error("Couldn't Unmarshal an invalid NullUUID: ", err) + } + + if u.Valid != true || u.UUID.String() != "886313e1-3b8a-5372-9b90-0c9aee199e5d" { + t.Error("Unmarshaled NullUUID is incorrect: ", u.Valid, u.UUID) + } +}