Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use encoding/json as JSON decoder instead of mapstructure #6680

Merged
merged 15 commits into from
Oct 29, 2019
63 changes: 7 additions & 56 deletions agent/acl_endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"net/http"
"strconv"
"strings"
"time"

"github.com/hashicorp/consul/acl"
"github.com/hashicorp/consul/agent/structs"
Expand Down Expand Up @@ -273,53 +272,6 @@ func (s *HTTPServer) ACLPolicyCreate(resp http.ResponseWriter, req *http.Request
return s.aclPolicyWriteInternal(resp, req, "", true)
}

// fixTimeAndHashFields is used to help in decoding the ExpirationTTL, ExpirationTime, CreateTime, and Hash
// attributes from the ACL Token/Policy create/update requests. It is needed
// to help mapstructure decode things properly when decodeBody is used.
func fixTimeAndHashFields(raw interface{}) error {
rawMap, ok := raw.(map[string]interface{})
if !ok {
return nil
}

if val, ok := rawMap["ExpirationTTL"]; ok {
if sval, ok := val.(string); ok {
d, err := time.ParseDuration(sval)
if err != nil {
return err
}
rawMap["ExpirationTTL"] = d
}
}

if val, ok := rawMap["ExpirationTime"]; ok {
if sval, ok := val.(string); ok {
t, err := time.Parse(time.RFC3339, sval)
if err != nil {
return err
}
rawMap["ExpirationTime"] = t
}
}

if val, ok := rawMap["CreateTime"]; ok {
if sval, ok := val.(string); ok {
t, err := time.Parse(time.RFC3339, sval)
if err != nil {
return err
}
rawMap["CreateTime"] = t
}
}

if val, ok := rawMap["Hash"]; ok {
if sval, ok := val.(string); ok {
rawMap["Hash"] = []byte(sval)
}
}
return nil
}

func (s *HTTPServer) ACLPolicyWrite(resp http.ResponseWriter, req *http.Request, policyID string) (interface{}, error) {
return s.aclPolicyWriteInternal(resp, req, policyID, false)
}
Expand All @@ -331,7 +283,7 @@ func (s *HTTPServer) aclPolicyWriteInternal(resp http.ResponseWriter, req *http.
s.parseToken(req, &args.Token)
s.parseEntMeta(req, &args.Policy.EnterpriseMeta)

if err := decodeBody(req, &args.Policy, fixTimeAndHashFields); err != nil {
if err := decodeBody(req.Body, &args.Policy); err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("Policy decoding failed: %v", err)}
}

Expand Down Expand Up @@ -521,7 +473,7 @@ func (s *HTTPServer) aclTokenSetInternal(resp http.ResponseWriter, req *http.Req
s.parseToken(req, &args.Token)
s.parseEntMeta(req, &args.ACLToken.EnterpriseMeta)

if err := decodeBody(req, &args.ACLToken, fixTimeAndHashFields); err != nil {
if err := decodeBody(req.Body, &args.ACLToken); err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("Token decoding failed: %v", err)}
}

Expand Down Expand Up @@ -567,8 +519,7 @@ func (s *HTTPServer) ACLTokenClone(resp http.ResponseWriter, req *http.Request,
}

s.parseEntMeta(req, &args.ACLToken.EnterpriseMeta)

if err := decodeBody(req, &args.ACLToken, fixTimeAndHashFields); err != nil && err.Error() != "EOF" {
if err := decodeBody(req.Body, &args.ACLToken); err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("Token decoding failed: %v", err)}
}
s.parseToken(req, &args.Token)
Expand Down Expand Up @@ -705,7 +656,7 @@ func (s *HTTPServer) ACLRoleWrite(resp http.ResponseWriter, req *http.Request, r
s.parseToken(req, &args.Token)
s.parseEntMeta(req, &args.Role.EnterpriseMeta)

if err := decodeBody(req, &args.Role, fixTimeAndHashFields); err != nil {
if err := decodeBody(req.Body, &args.Role); err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("Role decoding failed: %v", err)}
}

Expand Down Expand Up @@ -844,7 +795,7 @@ func (s *HTTPServer) ACLBindingRuleWrite(resp http.ResponseWriter, req *http.Req
s.parseToken(req, &args.Token)
s.parseEntMeta(req, &args.BindingRule.EnterpriseMeta)

if err := decodeBody(req, &args.BindingRule, fixTimeAndHashFields); err != nil {
if err := decodeBody(req.Body, &args.BindingRule); err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("BindingRule decoding failed: %v", err)}
}

Expand Down Expand Up @@ -980,7 +931,7 @@ func (s *HTTPServer) ACLAuthMethodWrite(resp http.ResponseWriter, req *http.Requ
s.parseToken(req, &args.Token)
s.parseEntMeta(req, &args.AuthMethod.EnterpriseMeta)

if err := decodeBody(req, &args.AuthMethod, fixTimeAndHashFields); err != nil {
if err := decodeBody(req.Body, &args.AuthMethod); err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("AuthMethod decoding failed: %v", err)}
}

Expand Down Expand Up @@ -1029,7 +980,7 @@ func (s *HTTPServer) ACLLogin(resp http.ResponseWriter, req *http.Request) (inte
s.parseDC(req, &args.Datacenter)
s.parseEntMeta(req, &args.Auth.EnterpriseMeta)

if err := decodeBody(req, &args.Auth, nil); err != nil {
if err := decodeBody(req.Body, &args.Auth); err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("Failed to decode request body:: %v", err)}
}

Expand Down
2 changes: 1 addition & 1 deletion agent/acl_endpoint_legacy.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func (s *HTTPServer) aclSet(resp http.ResponseWriter, req *http.Request, update

// Handle optional request body
if req.ContentLength > 0 {
if err := decodeBody(req, &args.ACL, nil); err != nil {
if err := decodeBody(req.Body, &args.ACL); err != nil {
resp.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(resp, "Request decode failed: %v", err)
return nil, nil
Expand Down
85 changes: 6 additions & 79 deletions agent/agent_endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -457,11 +457,8 @@ func (s *HTTPServer) syncChanges() {

func (s *HTTPServer) AgentRegisterCheck(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
var args structs.CheckDefinition
// Fixup the type decode of TTL or Interval.
decodeCB := func(raw interface{}) error {
return FixupCheckType(raw)
}
if err := decodeBody(req, &args, decodeCB); err != nil {

if err := decodeBody(req.Body, &args); err != nil {
resp.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(resp, "Request decode failed: %v", err)
return nil, nil
Expand Down Expand Up @@ -606,7 +603,7 @@ type checkUpdate struct {
// APIs.
func (s *HTTPServer) AgentCheckUpdate(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
var update checkUpdate
if err := decodeBody(req, &update, nil); err != nil {
if err := decodeBody(req.Body, &update); err != nil {
resp.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(resp, "Request decode failed: %v", err)
return nil, nil
Expand Down Expand Up @@ -758,7 +755,7 @@ func (s *HTTPServer) AgentRegisterService(resp http.ResponseWriter, req *http.Re
var args structs.ServiceDefinition
// Fixup the type decode of TTL or Interval if a check if provided.

if err := decodeBody(req, &args, registerServiceDecodeCB); err != nil {
if err := decodeBody(req.Body, &args); err != nil {
resp.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(resp, "Request decode failed: %v", err)
return nil, nil
Expand Down Expand Up @@ -894,76 +891,6 @@ func (s *HTTPServer) AgentRegisterService(resp http.ResponseWriter, req *http.Re
return nil, nil
}

// registerServiceDecodeCB is used in AgentRegisterService for request body decoding
func registerServiceDecodeCB(raw interface{}) error {
rawMap, ok := raw.(map[string]interface{})
if !ok {
return nil
}

// see https://github.com/hashicorp/consul/pull/3557 why we need this
// and why we should get rid of it.
lib.TranslateKeys(rawMap, map[string]string{
"enable_tag_override": "EnableTagOverride",
// Proxy Upstreams
"destination_name": "DestinationName", // string
"destination_type": "DestinationType", // string
"destination_namespace": "DestinationNamespace", // string
"local_bind_port": "LocalBindPort", // int
"local_bind_address": "LocalBindAddress", // string
// Proxy Config
"destination_service_name": "DestinationServiceName", // string (Proxy.)
"destination_service_id": "DestinationServiceID", // string
"local_service_port": "LocalServicePort", // int
"local_service_address": "LocalServiceAddress", // string
// SidecarService
"sidecar_service": "SidecarService", // ServiceDefinition (Connect.)
// Expose Config
"local_path_port": "LocalPathPort", // int (Proxy.Expose.Paths.)
"listener_port": "ListenerPort", // int

// DON'T Recurse into these opaque config maps or we might mangle user's
// keys. Note empty canonical is a special sentinel to prevent recursion.
"Meta": "",

"tagged_addresses": "TaggedAddresses", // map[string]structs.ServiceAddress{Address string; Port int}

// upstreams is an array but this prevents recursion into config field of
// any item in the array.
"Proxy.Config": "",
"Proxy.Upstreams.Config": "",
"Connect.Proxy.Config": "",
"Connect.Proxy.Upstreams.Config": "",

// Same exceptions as above, but for a nested sidecar_service note we use
// the canonical form SidecarService since that is translated by the time
// the lookup here happens.
"Connect.SidecarService.Meta": "",
"Connect.SidecarService.Proxy.Config": "",
"Connect.SidecarService.Proxy.Upstreams.config": "",
})

for k, v := range rawMap {
switch strings.ToLower(k) {
case "check":
if err := FixupCheckType(v); err != nil {
return err
}
case "checks":
chkTypes, ok := v.([]interface{})
if !ok {
continue
}
for _, chkType := range chkTypes {
if err := FixupCheckType(chkType); err != nil {
return err
}
}
}
}
return nil
}

func (s *HTTPServer) AgentDeregisterService(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
serviceID := strings.TrimPrefix(req.URL.Path, "/v1/agent/service/deregister/")

Expand Down Expand Up @@ -1180,7 +1107,7 @@ func (s *HTTPServer) AgentToken(resp http.ResponseWriter, req *http.Request) (in
// The body is just the token, but it's in a JSON object so we can add
// fields to this later if needed.
var args api.AgentToken
if err := decodeBody(req, &args, nil); err != nil {
if err := decodeBody(req.Body, &args); err != nil {
resp.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(resp, "Request decode failed: %v", err)
return nil, nil
Expand Down Expand Up @@ -1339,7 +1266,7 @@ func (s *HTTPServer) AgentConnectAuthorize(resp http.ResponseWriter, req *http.R

// Decode the request from the request body
var authReq structs.ConnectAuthorizeRequest
if err := decodeBody(req, &authReq, nil); err != nil {
if err := decodeBody(req.Body, &authReq); err != nil {
return nil, BadRequestError{fmt.Sprintf("Request decode failed: %v", err)}
}

Expand Down
6 changes: 2 additions & 4 deletions agent/catalog_endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,12 @@ import (
"github.com/hashicorp/consul/agent/structs"
)

var durations = NewDurationFixer("interval", "timeout", "deregistercriticalserviceafter")

func (s *HTTPServer) CatalogRegister(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
metrics.IncrCounterWithLabels([]string{"client", "api", "catalog_register"}, 1,
[]metrics.Label{{Name: "node", Value: s.nodeName()}})

var args structs.RegisterRequest
if err := decodeBody(req, &args, durations.FixupDurations); err != nil {
if err := decodeBody(req.Body, &args); err != nil {
resp.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(resp, "Request decode failed: %v", err)
return nil, nil
Expand Down Expand Up @@ -46,7 +44,7 @@ func (s *HTTPServer) CatalogDeregister(resp http.ResponseWriter, req *http.Reque
[]metrics.Label{{Name: "node", Value: s.nodeName()}})

var args structs.DeregisterRequest
if err := decodeBody(req, &args, nil); err != nil {
if err := decodeBody(req.Body, &args); err != nil {
resp.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(resp, "Request decode failed: %v", err)
return nil, nil
Expand Down
2 changes: 1 addition & 1 deletion agent/config_endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ func (s *HTTPServer) ConfigApply(resp http.ResponseWriter, req *http.Request) (i
s.parseToken(req, &args.Token)

var raw map[string]interface{}
if err := decodeBody(req, &raw, nil); err != nil {
if err := decodeBodyDeprecated(req, &raw, nil); err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("Request decoding failed: %v", err)}
}

Expand Down
2 changes: 1 addition & 1 deletion agent/connect_ca_endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func (s *HTTPServer) ConnectCAConfigurationSet(resp http.ResponseWriter, req *ht
var args structs.CARequest
s.parseDC(req, &args.Datacenter)
s.parseToken(req, &args.Token)
if err := decodeBody(req, &args.Config, nil); err != nil {
if err := decodeBody(req.Body, &args.Config); err != nil {
resp.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(resp, "Request decode failed: %v", err)
return nil, nil
Expand Down
2 changes: 1 addition & 1 deletion agent/coordinate_endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ func (s *HTTPServer) CoordinateUpdate(resp http.ResponseWriter, req *http.Reques
}

args := structs.CoordinateUpdateRequest{}
if err := decodeBody(req, &args, nil); err != nil {
if err := decodeBody(req.Body, &args); err != nil {
resp.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(resp, "Request decode failed: %v", err)
return nil, nil
Expand Down
60 changes: 59 additions & 1 deletion agent/discovery_chain_endpoint.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package agent

import (
"encoding/json"
"fmt"
"net/http"
"strings"
Expand Down Expand Up @@ -28,7 +29,7 @@ func (s *HTTPServer) DiscoveryChainRead(resp http.ResponseWriter, req *http.Requ

if req.Method == "POST" {
var raw map[string]interface{}
if err := decodeBody(req, &raw, nil); err != nil {
if err := decodeBody(req.Body, &raw); err != nil {
return nil, BadRequestError{Reason: fmt.Sprintf("Request decoding failed: %v", err)}
}

Expand Down Expand Up @@ -91,6 +92,63 @@ type discoveryChainReadRequest struct {
OverrideConnectTimeout time.Duration
}

func (t *discoveryChainReadRequest) UnmarshalJSON(data []byte) (err error) {
type Alias discoveryChainReadRequest
aux := &struct {
OverrideConnectTimeout interface{}
OverrideProtocol interface{}
OverrideMeshGateway *struct{ Mode interface{} }

OverrideConnectTimeoutCamel interface{} `json:"override_connect_timeout"`
OverrideProtocolCamel interface{} `json:"override_protocol"`
OverrideMeshGatewayCamel *struct{ Mode interface{} } `json:"override_mesh_gateway"`

*Alias
}{
Alias: (*Alias)(t),
}
if err = json.Unmarshal(data, &aux); err != nil {
return err
}

if aux.OverrideConnectTimeout == nil {
aux.OverrideConnectTimeout = aux.OverrideConnectTimeoutCamel
}
if aux.OverrideProtocol == nil {
aux.OverrideProtocol = aux.OverrideProtocolCamel
}
if aux.OverrideMeshGateway == nil {
aux.OverrideMeshGateway = aux.OverrideMeshGatewayCamel
}

// weakly typed input
if aux.OverrideProtocol != nil {
switch v := aux.OverrideProtocol.(type) {
case string, float64, bool:
t.OverrideProtocol = fmt.Sprintf("%v", v)
default:
return fmt.Errorf("OverrideProtocol: invalid type %T", v)
}
}
if aux.OverrideMeshGateway != nil {
t.OverrideMeshGateway.Mode = structs.MeshGatewayMode(fmt.Sprintf("%v", aux.OverrideMeshGateway.Mode))
}

// duration
if aux.OverrideConnectTimeout != nil {
switch v := aux.OverrideConnectTimeout.(type) {
case string:
if t.OverrideConnectTimeout, err = time.ParseDuration(v); err != nil {
return err
}
case float64:
t.OverrideConnectTimeout = time.Duration(v)
}
}

return nil
}

// discoveryChainReadResponse is the API variation of structs.DiscoveryChainResponse
type discoveryChainReadResponse struct {
Chain *structs.CompiledDiscoveryChain
Expand Down
Loading