diff --git a/protocol/payment.go b/protocol/payment.go index 7f9c622..6b54bfa 100644 --- a/protocol/payment.go +++ b/protocol/payment.go @@ -120,6 +120,8 @@ func (p *PaymentRequestInfoData) UnmarshalJSON(data []byte) error { var rails Iden3PaymentRailsRequestV1 var railsCol []Iden3PaymentRailsRequestV1 + p.crypto, p.rails = nil, nil + if err := json.Unmarshal(data, &crypto); err == nil { if crypto.Type == Iden3PaymentRequestCryptoV1Type { p.crypto = append(p.crypto, crypto) @@ -361,23 +363,34 @@ func (p PaymentContext) MarshalJSON() ([]byte, error) { // UnmarshalJSON unmarshal the PaymentContext from JSON. func (p *PaymentContext) UnmarshalJSON(data []byte) error { - var str string - var strCol []string - var itemCol []interface{} - - if err := json.Unmarshal(data, &str); err == nil { - p.str = &str - return nil + var o any + if err := json.Unmarshal(data, &o); err != nil { + return err } - if err := json.Unmarshal(data, &strCol); err == nil { - p.strCol = strCol - return nil - } - if err := json.Unmarshal(data, &itemCol); err == nil { - p.itemCol = itemCol - return nil + + switch v := o.(type) { + case string: + p.str = &v + p.strCol = nil + p.itemCol = nil + case []any: + p.str = nil + p.itemCol = nil + p.strCol = make([]string, len(v)) + for i := range v { + s, ok := v[i].(string) + if !ok { + p.strCol = nil + p.itemCol = v + break + } + p.strCol[i] = s + } + default: + return errors.Errorf("failed to unmarshal PaymentContext: %s", string(data)) } - return errors.Errorf("failed to unmarshal PaymentContext: %s", string(data)) + + return nil } // Data returns the data in the union. diff --git a/protocol/payment_test.go b/protocol/payment_test.go index 45e860b..53f47e5 100644 --- a/protocol/payment_test.go +++ b/protocol/payment_test.go @@ -453,6 +453,11 @@ func TestPaymentContext(t *testing.T) { payload: []byte(`[{"field":"context1"}, "context in a string"]`), expectedPayload: []byte(`[{"field":"context1"}, "context in a string"]`), }, + { + desc: "A list of heterogeneous objects, first is a string", + payload: []byte(`["context in a string", {"field":"context1"}]`), + expectedPayload: []byte(`["context in a string", {"field":"context1"}]`), + }, } { t.Run(tc.desc, func(t *testing.T) { var msg PaymentContext