From 329cb45913691e9e0369a94fa2cf059c237307a7 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 11 Jan 2025 10:47:57 -0600 Subject: [PATCH 1/2] XMLCodec: fix DecodeValue to return a []byte Previously, DecodeValue would always return nil with the default Unmarshal function. fixes https://github.com/jackc/pgx/issues/2227 --- pgtype/xml.go | 6 +++--- pgtype/xml_test.go | 29 +++++++++++++++++++++++++++++ 2 files changed, 32 insertions(+), 3 deletions(-) diff --git a/pgtype/xml.go b/pgtype/xml.go index 79e3698a4..1d159cf1b 100644 --- a/pgtype/xml.go +++ b/pgtype/xml.go @@ -192,7 +192,7 @@ func (c *XMLCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (an return nil, nil } - var dst any - err := c.Unmarshal(src, &dst) - return dst, err + dstBuf := make([]byte, len(src)) + copy(dstBuf, src) + return dstBuf, nil } diff --git a/pgtype/xml_test.go b/pgtype/xml_test.go index 0f755e96f..211859147 100644 --- a/pgtype/xml_test.go +++ b/pgtype/xml_test.go @@ -97,3 +97,32 @@ func TestXMLCodecPointerToPointerToString(t *testing.T) { require.Nil(t, s) }) } + +func TestXMLCodecDecodeValue(t *testing.T) { + skipCockroachDB(t, "CockroachDB does not support XML.") + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + for _, tt := range []struct { + sql string + expected any + }{ + { + sql: `select 'bar'::xml`, + expected: []byte("bar"), + }, + } { + t.Run(tt.sql, func(t *testing.T) { + rows, err := conn.Query(ctx, tt.sql) + require.NoError(t, err) + + for rows.Next() { + values, err := rows.Values() + require.NoError(t, err) + require.Len(t, values, 1) + require.Equal(t, tt.expected, values[0]) + } + + require.NoError(t, rows.Err()) + }) + } + }) +} From 03f08abda3e096464b7ae894e8c12c322e3f73a3 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 11 Jan 2025 11:26:46 -0600 Subject: [PATCH 2/2] Fix in Unmarshal function rather than DecodeValue This preserves backwards compatibility in the unlikely event someone is using an alternative XML unmarshaler that does support unmarshalling into *any. --- pgtype/pgtype_default.go | 20 +++++++++++++++++++- pgtype/xml.go | 6 +++--- 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/pgtype/pgtype_default.go b/pgtype/pgtype_default.go index 9496cb974..8cb512fa5 100644 --- a/pgtype/pgtype_default.go +++ b/pgtype/pgtype_default.go @@ -91,7 +91,25 @@ func initDefaultMap() { defaultMap.RegisterType(&Type{Name: "varchar", OID: VarcharOID, Codec: TextCodec{}}) defaultMap.RegisterType(&Type{Name: "xid", OID: XIDOID, Codec: Uint32Codec{}}) defaultMap.RegisterType(&Type{Name: "xid8", OID: XID8OID, Codec: Uint64Codec{}}) - defaultMap.RegisterType(&Type{Name: "xml", OID: XMLOID, Codec: &XMLCodec{Marshal: xml.Marshal, Unmarshal: xml.Unmarshal}}) + defaultMap.RegisterType(&Type{Name: "xml", OID: XMLOID, Codec: &XMLCodec{ + Marshal: xml.Marshal, + // xml.Unmarshal does not support unmarshalling into *any. However, XMLCodec.DecodeValue calls Unmarshal with a + // *any. Wrap xml.Marshal with a function that copies the data into a new byte slice in this case. Not implementing + // directly in XMLCodec.DecodeValue to allow for the unlikely possibility that someone uses an alternative XML + // unmarshaler that does support unmarshalling into *any. + // + // https://github.com/jackc/pgx/issues/2227 + // https://github.com/jackc/pgx/pull/2228 + Unmarshal: func(data []byte, v any) error { + if v, ok := v.(*any); ok { + dstBuf := make([]byte, len(data)) + copy(dstBuf, data) + *v = dstBuf + return nil + } + return xml.Unmarshal(data, v) + }, + }}) // Range types defaultMap.RegisterType(&Type{Name: "daterange", OID: DaterangeOID, Codec: &RangeCodec{ElementType: defaultMap.oidToType[DateOID]}}) diff --git a/pgtype/xml.go b/pgtype/xml.go index 1d159cf1b..79e3698a4 100644 --- a/pgtype/xml.go +++ b/pgtype/xml.go @@ -192,7 +192,7 @@ func (c *XMLCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (an return nil, nil } - dstBuf := make([]byte, len(src)) - copy(dstBuf, src) - return dstBuf, nil + var dst any + err := c.Unmarshal(src, &dst) + return dst, err }