Skip to content

Commit

Permalink
private/mode/api: Add codegen typed Errors for RESTJSON and JSONRPC APIs
Browse files Browse the repository at this point in the history
Adds code generated error types for APIs using RESTJSON and JSONRPC
protocol, and modeled errors. This adds generated error types that can
be typed asserted to in order to read error values in addition to Code
and Message.
  • Loading branch information
jasdel committed Jan 6, 2020
1 parent 6a08738 commit b0f208a
Show file tree
Hide file tree
Showing 18 changed files with 594 additions and 202 deletions.
70 changes: 50 additions & 20 deletions private/model/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ type API struct {
EndpointDiscoveryOp *Operation

HasEndpointARN bool `json:"-"`

WithGeneratedTypedErrors bool
}

// A Metadata is the metadata about an API's definition.
Expand Down Expand Up @@ -282,20 +284,27 @@ func (a *API) importsGoCode() string {

// A tplAPI is the top level template for the API
var tplAPI = template.Must(template.New("api").Parse(`
{{ range $_, $o := .OperationList }}
{{ $o.GoCode }}
{{- range $_, $o := .OperationList }}
{{ end }}
{{ $o.GoCode }}
{{- end }}
{{ range $_, $s := .ShapeList }}
{{ if and $s.IsInternal (eq $s.Type "structure") }}{{ $s.GoCode }}{{ end }}
{{- range $_, $s := $.Shapes }}
{{- if and $s.IsInternal (eq $s.Type "structure") (not $s.Exception) }}
{{ end }}
{{ $s.GoCode }}
{{- else if and $s.Exception (or $.WithGeneratedTypedErrors $s.EventFor) }}
{{ $s.GoCode }}
{{- end }}
{{- end }}
{{ range $_, $s := .ShapeList }}
{{ if $s.IsEnum }}{{ $s.GoCode }}{{ end }}
{{- range $_, $s := $.Shapes }}
{{- if $s.IsEnum }}
{{ end }}
{{ $s.GoCode }}
{{- end }}
{{- end }}
`))

// AddImport adds the import path to the generated file's import.
Expand Down Expand Up @@ -598,7 +607,14 @@ func newClient(cfg aws.Config, handlers request.Handlers, partitionID, endpoint,
svc.Handlers.Build.PushBackNamed({{ .ProtocolPackage }}.BuildHandler)
svc.Handlers.Unmarshal.PushBackNamed({{ .ProtocolPackage }}.UnmarshalHandler)
svc.Handlers.UnmarshalMeta.PushBackNamed({{ .ProtocolPackage }}.UnmarshalMetaHandler)
svc.Handlers.UnmarshalError.PushBackNamed({{ .ProtocolPackage }}.UnmarshalErrorHandler)
{{- if and $.WithGeneratedTypedErrors (gt (len $.ShapeListErrors) 0) }}
{{- $_ := $.AddSDKImport "private/protocol" }}
svc.Handlers.UnmarshalError.PushBackNamed(
protocol.NewUnmarshalErrorHandler({{ .ProtocolPackage }}.NewUnmarshalTypedError(exceptionFromCode)).NamedHandler(),
)
{{- else }}
svc.Handlers.UnmarshalError.PushBackNamed({{ .ProtocolPackage }}.UnmarshalErrorHandler)
{{- end }}
{{ if .HasEventStream }}
svc.Handlers.BuildStream.PushBackNamed({{ .ProtocolPackage }}.BuildHandler)
svc.Handlers.UnmarshalStream.PushBackNamed({{ .ProtocolPackage }}.UnmarshalHandler)
Expand Down Expand Up @@ -862,27 +878,41 @@ func resolveShapeValidations(s *Shape, ancestry ...*Shape) {
// A tplAPIErrors is the top level template for the API
var tplAPIErrors = template.Must(template.New("api").Parse(`
const (
{{ range $_, $s := $.ShapeListErrors }}
// {{ $s.ErrorCodeName }} for service response error code
// {{ printf "%q" $s.ErrorName }}.
{{ if $s.Docstring -}}
//
{{ $s.Docstring }}
{{ end -}}
{{ $s.ErrorCodeName }} = {{ printf "%q" $s.ErrorName }}
{{ end }}
{{- range $_, $s := $.ShapeListErrors }}
// {{ $s.ErrorCodeName }} for service response error code
// {{ printf "%q" $s.ErrorName }}.
{{ if $s.Docstring -}}
//
{{ $s.Docstring }}
{{ end -}}
{{ $s.ErrorCodeName }} = {{ printf "%q" $s.ErrorName }}
{{- end }}
)
{{- if $.WithGeneratedTypedErrors }}
{{- $_ := $.AddSDKImport "private/protocol" }}
var exceptionFromCode = map[string]func(protocol.ResponseMetadata)error {
{{- range $_, $s := $.ShapeListErrors }}
"{{ $s.ErrorName }}": newError{{ $s.ShapeName }},
{{- end }}
}
{{- end }}
`))

// APIErrorsGoCode returns the Go code for the errors.go file.
func (a *API) APIErrorsGoCode() string {
a.resetImports()

var buf bytes.Buffer
err := tplAPIErrors.Execute(&buf, a)

if err != nil {
panic(err)
}

return strings.TrimSpace(buf.String())
return a.importsGoCode() + strings.TrimSpace(buf.String())
}

// removeOperation removes an operation, its input/output shapes, as well as
Expand Down
61 changes: 22 additions & 39 deletions private/model/api/eventstream_tmpl.go
Original file line number Diff line number Diff line change
Expand Up @@ -234,14 +234,17 @@ func (es *{{ $esapi.Name }}) waitStreamPartClose() {
{{- end }}
{{- if $outputStream }}
{{- if eq .API.Metadata.Protocol "json" }}
func (es *{{ $esapi.Name}}) {{ $esapi.StreamOutputUnmarshalerForEventName }}(eventType string) (eventstreamapi.Unmarshaler, error) {
type {{ $esapi.StreamOutputUnmarshalerForEventName }} struct {
unmarshalerForEvent eventstreamapi.EventNameForUnmarshaler
output {{ $.OutputRef.GoType }}
}
func (e {{ $esapi.StreamOutputUnmarshalerForEventName }}) UnmarshalerForEventName(eventType string) (eventstreamapi.Unmarshaler, error) {
if eventType == "initial-response" {
return es.output, nil
return e.output, nil
}
return {{ $outputStream.StreamUnmarshalerForEventName }}(eventType)
return e.unmarshalerForEvent.UnmarshalerForEventName(eventType)
}
{{- end }}
Expand All @@ -261,16 +264,26 @@ func (es *{{ $esapi.Name }}) waitStreamPartClose() {
opts = append(opts, eventstream.DecodeWithLogger(r.Config.Logger))
}
var unmarshalerForEvent eventstreamapi.EventNameForUnmarshaler
unmarshalerForEvent = &{{ $outputStream.StreamUnmarshalerForEventName }}{
metadata: protocol.ResponseMetadata{
StatusCode: r.HTTPResponse.StatusCode,
RequestID: r.RequestID,
},
}
{{- if eq .API.Metadata.Protocol "json" }}
unmarshalerForEvent = &{{ $esapi.StreamOutputUnmarshalerForEventName }}{
unmarshalerForEvent: unmarshalerForEvent,
output: es.output,
}
{{- end }}
decoder := eventstream.NewDecoder(r.HTTPResponse.Body, opts...)
eventReader := eventstreamapi.NewEventReader(decoder,
protocol.HandlerPayloadUnmarshal{
Unmarshalers: r.Handlers.UnmarshalStream,
},
{{- if eq .API.Metadata.Protocol "json" }}
es.{{ $esapi.StreamOutputUnmarshalerForEventName }},
{{- else }}
{{ $outputStream.StreamUnmarshalerForEventName }},
{{- end }}
unmarshalerForEvent.UnmarshalerForEventName,
)
es.outputReader = r.HTTPResponse.Body
Expand Down Expand Up @@ -597,33 +610,3 @@ func (s *{{ $.ShapeName}}) MarshalEvent(pm protocol.PayloadMarshaler) (msg event
return msg, err
}
`))

var eventStreamExceptionEventShapeTmpl = template.Must(
template.New("eventStreamExceptionEventShapeTmpl").Parse(`
// Code returns the exception type name.
func (s {{ $.ShapeName }}) Code() string {
{{- if $.ErrorInfo.Code }}
return "{{ $.ErrorInfo.Code }}"
{{- else }}
return "{{ $.ShapeName }}"
{{ end -}}
}
// Message returns the exception's message.
func (s {{ $.ShapeName }}) Message() string {
{{- if index $.MemberRefs "Message_" }}
return *s.Message_
{{- else }}
return ""
{{ end -}}
}
// OrigErr always returns nil, satisfies awserr.Error interface.
func (s {{ $.ShapeName }}) OrigErr() error {
return nil
}
func (s {{ $.ShapeName }}) Error() string {
return fmt.Sprintf("%s: %s", s.Code(), s.Message())
}
`))
8 changes: 6 additions & 2 deletions private/model/api/eventstream_tmpl_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,15 +104,19 @@ func (r *{{ $es.StreamReaderImplName }}) readEventStream() {
}
}
func {{ $es.StreamUnmarshalerForEventName }}(eventType string) (eventstreamapi.Unmarshaler, error) {
type {{ $es.StreamUnmarshalerForEventName }} struct {
metadata protocol.ResponseMetadata
}
func (u {{ $es.StreamUnmarshalerForEventName }}) UnmarshalerForEventName(eventType string) (eventstreamapi.Unmarshaler, error) {
switch eventType {
{{- range $_, $event := $es.Events }}
case {{ printf "%q" $event.Name }}:
return &{{ $event.Shape.ShapeName }}{}, nil
{{- end }}
{{- range $_, $event := $es.Exceptions }}
case {{ printf "%q" $event.Name }}:
return &{{ $event.Shape.ShapeName }}{}, nil
return newError{{ $event.Shape.ShapeName }}(u.metadata).(eventstreamapi.Unmarshaler), nil
{{- end }}
default:
return nil, awserr.New(
Expand Down
6 changes: 6 additions & 0 deletions private/model/api/eventstream_tmpl_readertests.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,12 @@ func valueForType(s *Shape, visited []string) string {
for _, refName := range s.MemberNames() {
fmt.Fprintf(w, "%s: %s,\n", refName, valueForType(s.MemberRefs[refName].Shape, visited))
}
if s.Exception {
fmt.Fprintf(w, `respMetadata: protocol.ResponseMetadata{
StatusCode: 200,
},
`)
}
fmt.Fprintf(w, "}")
return w.String()
case "list":
Expand Down
1 change: 1 addition & 0 deletions private/model/api/load.go
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ func (a *API) Setup() error {

a.findEndpointDiscoveryOp()
a.injectUnboundedOutputStreaming()
a.enableGeneratedTypedErrors()
if err := a.customizationPasses(); err != nil {
return err
}
Expand Down
22 changes: 13 additions & 9 deletions private/model/api/operation.go
Original file line number Diff line number Diff line change
Expand Up @@ -310,28 +310,32 @@ func (c *{{ .API.StructName }}) {{ .ExportedName }}Request(` +
}
// {{ .ExportedName }} API operation for {{ .API.Metadata.ServiceFullName }}.
{{ if .Documentation -}}
{{- if .Documentation }}
//
{{ .Documentation }}
{{ end -}}
{{- end }}
//
// Returns awserr.Error for service API and SDK errors. Use runtime type assertions
// with awserr.Error's Code and Message methods to get detailed information about
// the error.
//
// See the AWS API reference guide for {{ .API.Metadata.ServiceFullName }}'s
// API operation {{ .ExportedName }} for usage and error information.
{{ if .ErrorRefs -}}
{{- if .ErrorRefs }}
//
// Returned Error Codes:
{{ range $_, $err := .ErrorRefs -}}
// Returned Error {{ if $.API.WithGeneratedTypedErrors }}Types{{ else }}Codes{{ end }}:
{{- range $_, $err := .ErrorRefs -}}
{{- if $.API.WithGeneratedTypedErrors }}
// * {{ $err.ShapeName }}
{{- else }}
// * {{ $err.Shape.ErrorCodeName }} "{{ $err.Shape.ErrorName}}"
{{ if $err.Docstring -}}
{{- end }}
{{- if $err.Docstring }}
{{ $err.IndentedDocstring }}
{{ end -}}
{{- end }}
//
{{ end -}}
{{ end -}}
{{- end }}
{{- end }}
{{ $crosslinkURL := $.API.GetCrosslinkURL $.ExportedName -}}
{{ if ne $crosslinkURL "" -}}
// See also, {{ $crosslinkURL }}
Expand Down
18 changes: 17 additions & 1 deletion private/model/api/passes.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,17 @@ import (
"strings"
)

func (a *API) enableGeneratedTypedErrors() {
switch a.Metadata.Protocol {
case "json":
case "rest-json":
default:
return
}

a.WithGeneratedTypedErrors = true
}

// updateTopLevelShapeReferences moves resultWrapper, locationName, and
// xmlNamespace traits from toplevel shape references to the toplevel
// shapes for easier code generation
Expand Down Expand Up @@ -290,7 +301,12 @@ func exceptionCollides(name string) bool {
switch name {
case "Code",
"Message",
"OrigErr":
"OrigErr",
"Error",
"String",
"GoString",
"RequestID",
"StatusCode":
return true
}
return false
Expand Down
Loading

0 comments on commit b0f208a

Please sign in to comment.