Skip to content

Commit

Permalink
Use encoding/json as JSON decoder instead of mapstructure (#6680)
Browse files Browse the repository at this point in the history
Fixes #6147
  • Loading branch information
s-mang authored Oct 29, 2019
1 parent 82f1eac commit 78ad820
Show file tree
Hide file tree
Showing 26 changed files with 715 additions and 911 deletions.
63 changes: 7 additions & 56 deletions agent/acl_endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"net/http"
"strconv"
"strings"
"time"

"github.com/hashicorp/consul/acl"
"github.com/hashicorp/consul/agent/structs"
Expand Down Expand Up @@ -273,53 +272,6 @@ func (s *HTTPServer) ACLPolicyCreate(resp http.ResponseWriter, req *http.Request
return s.aclPolicyWriteInternal(resp, req, "", true)
}

// fixTimeAndHashFields is used to help in decoding the ExpirationTTL, ExpirationTime, CreateTime, and Hash
// attributes from the ACL Token/Policy create/update requests. It is needed
// to help mapstructure decode things properly when decodeBody is used.
func fixTimeAndHashFields(raw interface{}) error {
rawMap, ok := raw.(map[string]interface{})
if !ok {
return nil
}

if val, ok := rawMap["ExpirationTTL"]; ok {
if sval, ok := val.(string); ok {
d, err := time.ParseDuration(sval)
if err != nil {
return err
}
rawMap["ExpirationTTL"] = d
}
}

if val, ok := rawMap["ExpirationTime"]; ok {
if sval, ok := val.(string); ok {
t, err := time.Parse(time.RFC3339, sval)
if err != nil {
return err
}
rawMap["ExpirationTime"] = t
}
}

if val, ok := rawMap["CreateTime"]; ok {
if sval, ok := val.(string); ok {
t, err := time.Parse(time.RFC3339, sval)
if err != nil {
return err
}
rawMap["CreateTime"] = t
}
}

if val, ok := rawMap["Hash"]; ok {
if sval, ok := val.(string); ok {
rawMap["Hash"] = []byte(sval)
}
}
return nil
}

func (s *HTTPServer) ACLPolicyWrite(resp http.ResponseWriter, req *http.Request, policyID string) (interface{}, error) {
return s.aclPolicyWriteInternal(resp, req, policyID, false)
}
Expand All @@ -331,7 +283,7 @@ func (s *HTTPServer) aclPolicyWriteInternal(resp http.ResponseWriter, req *http.
s.parseToken(req, &args.Token)
s.parseEntMeta(req, &args.Policy.EnterpriseMeta)

if err := decodeBody(req, &args.Policy, fixTimeAndHashFields); err != nil {
if err := decodeBody(req.Body, &args.Policy); err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("Policy decoding failed: %v", err)}
}

Expand Down Expand Up @@ -521,7 +473,7 @@ func (s *HTTPServer) aclTokenSetInternal(resp http.ResponseWriter, req *http.Req
s.parseToken(req, &args.Token)
s.parseEntMeta(req, &args.ACLToken.EnterpriseMeta)

if err := decodeBody(req, &args.ACLToken, fixTimeAndHashFields); err != nil {
if err := decodeBody(req.Body, &args.ACLToken); err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("Token decoding failed: %v", err)}
}

Expand Down Expand Up @@ -567,8 +519,7 @@ func (s *HTTPServer) ACLTokenClone(resp http.ResponseWriter, req *http.Request,
}

s.parseEntMeta(req, &args.ACLToken.EnterpriseMeta)

if err := decodeBody(req, &args.ACLToken, fixTimeAndHashFields); err != nil && err.Error() != "EOF" {
if err := decodeBody(req.Body, &args.ACLToken); err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("Token decoding failed: %v", err)}
}
s.parseToken(req, &args.Token)
Expand Down Expand Up @@ -705,7 +656,7 @@ func (s *HTTPServer) ACLRoleWrite(resp http.ResponseWriter, req *http.Request, r
s.parseToken(req, &args.Token)
s.parseEntMeta(req, &args.Role.EnterpriseMeta)

if err := decodeBody(req, &args.Role, fixTimeAndHashFields); err != nil {
if err := decodeBody(req.Body, &args.Role); err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("Role decoding failed: %v", err)}
}

Expand Down Expand Up @@ -844,7 +795,7 @@ func (s *HTTPServer) ACLBindingRuleWrite(resp http.ResponseWriter, req *http.Req
s.parseToken(req, &args.Token)
s.parseEntMeta(req, &args.BindingRule.EnterpriseMeta)

if err := decodeBody(req, &args.BindingRule, fixTimeAndHashFields); err != nil {
if err := decodeBody(req.Body, &args.BindingRule); err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("BindingRule decoding failed: %v", err)}
}

Expand Down Expand Up @@ -980,7 +931,7 @@ func (s *HTTPServer) ACLAuthMethodWrite(resp http.ResponseWriter, req *http.Requ
s.parseToken(req, &args.Token)
s.parseEntMeta(req, &args.AuthMethod.EnterpriseMeta)

if err := decodeBody(req, &args.AuthMethod, fixTimeAndHashFields); err != nil {
if err := decodeBody(req.Body, &args.AuthMethod); err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("AuthMethod decoding failed: %v", err)}
}

Expand Down Expand Up @@ -1029,7 +980,7 @@ func (s *HTTPServer) ACLLogin(resp http.ResponseWriter, req *http.Request) (inte
s.parseDC(req, &args.Datacenter)
s.parseEntMeta(req, &args.Auth.EnterpriseMeta)

if err := decodeBody(req, &args.Auth, nil); err != nil {
if err := decodeBody(req.Body, &args.Auth); err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("Failed to decode request body:: %v", err)}
}

Expand Down
2 changes: 1 addition & 1 deletion agent/acl_endpoint_legacy.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func (s *HTTPServer) aclSet(resp http.ResponseWriter, req *http.Request, update

// Handle optional request body
if req.ContentLength > 0 {
if err := decodeBody(req, &args.ACL, nil); err != nil {
if err := decodeBody(req.Body, &args.ACL); err != nil {
resp.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(resp, "Request decode failed: %v", err)
return nil, nil
Expand Down
85 changes: 6 additions & 79 deletions agent/agent_endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -457,11 +457,8 @@ func (s *HTTPServer) syncChanges() {

func (s *HTTPServer) AgentRegisterCheck(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
var args structs.CheckDefinition
// Fixup the type decode of TTL or Interval.
decodeCB := func(raw interface{}) error {
return FixupCheckType(raw)
}
if err := decodeBody(req, &args, decodeCB); err != nil {

if err := decodeBody(req.Body, &args); err != nil {
resp.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(resp, "Request decode failed: %v", err)
return nil, nil
Expand Down Expand Up @@ -606,7 +603,7 @@ type checkUpdate struct {
// APIs.
func (s *HTTPServer) AgentCheckUpdate(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
var update checkUpdate
if err := decodeBody(req, &update, nil); err != nil {
if err := decodeBody(req.Body, &update); err != nil {
resp.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(resp, "Request decode failed: %v", err)
return nil, nil
Expand Down Expand Up @@ -758,7 +755,7 @@ func (s *HTTPServer) AgentRegisterService(resp http.ResponseWriter, req *http.Re
var args structs.ServiceDefinition
// Fixup the type decode of TTL or Interval if a check if provided.

if err := decodeBody(req, &args, registerServiceDecodeCB); err != nil {
if err := decodeBody(req.Body, &args); err != nil {
resp.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(resp, "Request decode failed: %v", err)
return nil, nil
Expand Down Expand Up @@ -894,76 +891,6 @@ func (s *HTTPServer) AgentRegisterService(resp http.ResponseWriter, req *http.Re
return nil, nil
}

// registerServiceDecodeCB is used in AgentRegisterService for request body decoding
func registerServiceDecodeCB(raw interface{}) error {
rawMap, ok := raw.(map[string]interface{})
if !ok {
return nil
}

// see https://github.com/hashicorp/consul/pull/3557 why we need this
// and why we should get rid of it.
lib.TranslateKeys(rawMap, map[string]string{
"enable_tag_override": "EnableTagOverride",
// Proxy Upstreams
"destination_name": "DestinationName", // string
"destination_type": "DestinationType", // string
"destination_namespace": "DestinationNamespace", // string
"local_bind_port": "LocalBindPort", // int
"local_bind_address": "LocalBindAddress", // string
// Proxy Config
"destination_service_name": "DestinationServiceName", // string (Proxy.)
"destination_service_id": "DestinationServiceID", // string
"local_service_port": "LocalServicePort", // int
"local_service_address": "LocalServiceAddress", // string
// SidecarService
"sidecar_service": "SidecarService", // ServiceDefinition (Connect.)
// Expose Config
"local_path_port": "LocalPathPort", // int (Proxy.Expose.Paths.)
"listener_port": "ListenerPort", // int

// DON'T Recurse into these opaque config maps or we might mangle user's
// keys. Note empty canonical is a special sentinel to prevent recursion.
"Meta": "",

"tagged_addresses": "TaggedAddresses", // map[string]structs.ServiceAddress{Address string; Port int}

// upstreams is an array but this prevents recursion into config field of
// any item in the array.
"Proxy.Config": "",
"Proxy.Upstreams.Config": "",
"Connect.Proxy.Config": "",
"Connect.Proxy.Upstreams.Config": "",

// Same exceptions as above, but for a nested sidecar_service note we use
// the canonical form SidecarService since that is translated by the time
// the lookup here happens.
"Connect.SidecarService.Meta": "",
"Connect.SidecarService.Proxy.Config": "",
"Connect.SidecarService.Proxy.Upstreams.config": "",
})

for k, v := range rawMap {
switch strings.ToLower(k) {
case "check":
if err := FixupCheckType(v); err != nil {
return err
}
case "checks":
chkTypes, ok := v.([]interface{})
if !ok {
continue
}
for _, chkType := range chkTypes {
if err := FixupCheckType(chkType); err != nil {
return err
}
}
}
}
return nil
}

func (s *HTTPServer) AgentDeregisterService(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
serviceID := strings.TrimPrefix(req.URL.Path, "/v1/agent/service/deregister/")

Expand Down Expand Up @@ -1180,7 +1107,7 @@ func (s *HTTPServer) AgentToken(resp http.ResponseWriter, req *http.Request) (in
// The body is just the token, but it's in a JSON object so we can add
// fields to this later if needed.
var args api.AgentToken
if err := decodeBody(req, &args, nil); err != nil {
if err := decodeBody(req.Body, &args); err != nil {
resp.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(resp, "Request decode failed: %v", err)
return nil, nil
Expand Down Expand Up @@ -1339,7 +1266,7 @@ func (s *HTTPServer) AgentConnectAuthorize(resp http.ResponseWriter, req *http.R

// Decode the request from the request body
var authReq structs.ConnectAuthorizeRequest
if err := decodeBody(req, &authReq, nil); err != nil {
if err := decodeBody(req.Body, &authReq); err != nil {
return nil, BadRequestError{fmt.Sprintf("Request decode failed: %v", err)}
}

Expand Down
6 changes: 2 additions & 4 deletions agent/catalog_endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,12 @@ import (
"github.com/hashicorp/consul/agent/structs"
)

var durations = NewDurationFixer("interval", "timeout", "deregistercriticalserviceafter")

func (s *HTTPServer) CatalogRegister(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
metrics.IncrCounterWithLabels([]string{"client", "api", "catalog_register"}, 1,
[]metrics.Label{{Name: "node", Value: s.nodeName()}})

var args structs.RegisterRequest
if err := decodeBody(req, &args, durations.FixupDurations); err != nil {
if err := decodeBody(req.Body, &args); err != nil {
resp.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(resp, "Request decode failed: %v", err)
return nil, nil
Expand Down Expand Up @@ -46,7 +44,7 @@ func (s *HTTPServer) CatalogDeregister(resp http.ResponseWriter, req *http.Reque
[]metrics.Label{{Name: "node", Value: s.nodeName()}})

var args structs.DeregisterRequest
if err := decodeBody(req, &args, nil); err != nil {
if err := decodeBody(req.Body, &args); err != nil {
resp.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(resp, "Request decode failed: %v", err)
return nil, nil
Expand Down
2 changes: 1 addition & 1 deletion agent/config_endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ func (s *HTTPServer) ConfigApply(resp http.ResponseWriter, req *http.Request) (i
s.parseToken(req, &args.Token)

var raw map[string]interface{}
if err := decodeBody(req, &raw, nil); err != nil {
if err := decodeBodyDeprecated(req, &raw, nil); err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("Request decoding failed: %v", err)}
}

Expand Down
2 changes: 1 addition & 1 deletion agent/connect_ca_endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func (s *HTTPServer) ConnectCAConfigurationSet(resp http.ResponseWriter, req *ht
var args structs.CARequest
s.parseDC(req, &args.Datacenter)
s.parseToken(req, &args.Token)
if err := decodeBody(req, &args.Config, nil); err != nil {
if err := decodeBody(req.Body, &args.Config); err != nil {
resp.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(resp, "Request decode failed: %v", err)
return nil, nil
Expand Down
2 changes: 1 addition & 1 deletion agent/coordinate_endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ func (s *HTTPServer) CoordinateUpdate(resp http.ResponseWriter, req *http.Reques
}

args := structs.CoordinateUpdateRequest{}
if err := decodeBody(req, &args, nil); err != nil {
if err := decodeBody(req.Body, &args); err != nil {
resp.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(resp, "Request decode failed: %v", err)
return nil, nil
Expand Down
60 changes: 59 additions & 1 deletion agent/discovery_chain_endpoint.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package agent

import (
"encoding/json"
"fmt"
"net/http"
"strings"
Expand Down Expand Up @@ -28,7 +29,7 @@ func (s *HTTPServer) DiscoveryChainRead(resp http.ResponseWriter, req *http.Requ

if req.Method == "POST" {
var raw map[string]interface{}
if err := decodeBody(req, &raw, nil); err != nil {
if err := decodeBody(req.Body, &raw); err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("Request decoding failed: %v", err)}
}

Expand Down Expand Up @@ -91,6 +92,63 @@ type discoveryChainReadRequest struct {
OverrideConnectTimeout time.Duration
}

func (t *discoveryChainReadRequest) UnmarshalJSON(data []byte) (err error) {
type Alias discoveryChainReadRequest
aux := &struct {
OverrideConnectTimeout interface{}
OverrideProtocol interface{}
OverrideMeshGateway *struct{ Mode interface{} }

OverrideConnectTimeoutSnake interface{} `json:"override_connect_timeout"`
OverrideProtocolSnake interface{} `json:"override_protocol"`
OverrideMeshGatewaySnake *struct{ Mode interface{} } `json:"override_mesh_gateway"`

*Alias
}{
Alias: (*Alias)(t),
}
if err = json.Unmarshal(data, &aux); err != nil {
return err
}

if aux.OverrideConnectTimeout == nil {
aux.OverrideConnectTimeout = aux.OverrideConnectTimeoutSnake
}
if aux.OverrideProtocol == nil {
aux.OverrideProtocol = aux.OverrideProtocolSnake
}
if aux.OverrideMeshGateway == nil {
aux.OverrideMeshGateway = aux.OverrideMeshGatewaySnake
}

// weakly typed input
if aux.OverrideProtocol != nil {
switch v := aux.OverrideProtocol.(type) {
case string, float64, bool:
t.OverrideProtocol = fmt.Sprintf("%v", v)
default:
return fmt.Errorf("OverrideProtocol: invalid type %T", v)
}
}
if aux.OverrideMeshGateway != nil {
t.OverrideMeshGateway.Mode = structs.MeshGatewayMode(fmt.Sprintf("%v", aux.OverrideMeshGateway.Mode))
}

// duration
if aux.OverrideConnectTimeout != nil {
switch v := aux.OverrideConnectTimeout.(type) {
case string:
if t.OverrideConnectTimeout, err = time.ParseDuration(v); err != nil {
return err
}
case float64:
t.OverrideConnectTimeout = time.Duration(v)
}
}

return nil
}

// discoveryChainReadResponse is the API variation of structs.DiscoveryChainResponse
type discoveryChainReadResponse struct {
Chain *structs.CompiledDiscoveryChain
Expand Down
Loading

0 comments on commit 78ad820

Please sign in to comment.