Skip to content

Commit

Permalink
Fix flex checksum validation cfg (#2981)
Browse files Browse the repository at this point in the history
* update checksum validation setter

* add new line

* add new line

* add changelog

* update changelog content

* update changelog

* add integ test case for validation skip and crc64 checksum case
  • Loading branch information
wty-Bryant authored Jan 24, 2025
1 parent 9c76401 commit cb98dee
Show file tree
Hide file tree
Showing 7 changed files with 181 additions and 64 deletions.
9 changes: 9 additions & 0 deletions .changelog/df93fff8f662441fa12dac87614a5064.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
{
"id": "df93fff8-f662-441f-a12d-ac87614a5064",
"type": "bugfix",
"description": "Enable request checksum validation mode by default",
"modules": [
"service/internal/checksum",
"service/s3"
]
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import software.amazon.smithy.model.Model;
import software.amazon.smithy.model.knowledge.TopDownIndex;
import software.amazon.smithy.model.shapes.MemberShape;
import software.amazon.smithy.model.shapes.ShapeType;
import software.amazon.smithy.model.shapes.OperationShape;
import software.amazon.smithy.model.shapes.ServiceShape;
import software.amazon.smithy.aws.traits.HttpChecksumTrait;
Expand Down Expand Up @@ -54,6 +55,10 @@ private static String getRequestValidationModeAccessorFuncName(String operationN
return String.format("get%s%s", operationName, "RequestValidationModeMember");
}

private static String setRequestValidationModeAccessorFuncName(String operationName) {
return String.format("set%s%s", operationName, "RequestValidationModeMember");
}

private static String getAddInputMiddlewareFuncName(String operationName) {
return String.format("add%sInputChecksumMiddlewares", operationName);
}
Expand Down Expand Up @@ -158,7 +163,7 @@ public void writeAdditionalFiles(

goDelegator.useShapeWriter(operation, writer -> {
// generate getter helper function to access input member value
writeGetInputMemberAccessorHelper(writer, model, symbolProvider, operation);
writeInputMemberAccessorHelper(writer, model, symbolProvider, operation);

// generate middleware helper function
if (generateComputeInputChecksums) {
Expand Down Expand Up @@ -212,7 +217,7 @@ public static boolean hasInputChecksumTrait(Model model, ServiceShape service) {
return false;
}

private static boolean hasOutputChecksumTrait(Model model, ServiceShape service, OperationShape operation) {
private static boolean hasOutputChecksumTrait(Model model, ServiceShape service, OperationShape operation) {
if (!hasChecksumTrait(model, service, operation)) {
return false;
}
Expand Down Expand Up @@ -356,6 +361,7 @@ private void writeOutputMiddlewareHelper(
writer.write("""
return $T(stack, $T{
GetValidationMode: $L,
SetValidationMode: $L,
ResponseChecksumValidation: options.ResponseChecksumValidation,
ValidationAlgorithms: $L,
IgnoreMultipartValidation: $L,
Expand All @@ -367,6 +373,7 @@ private void writeOutputMiddlewareHelper(
SymbolUtils.createValueSymbolBuilder("OutputMiddlewareOptions",
AwsGoDependency.SERVICE_INTERNAL_CHECKSUM).build(),
getRequestValidationModeAccessorFuncName(operationName),
setRequestValidationModeAccessorFuncName(operationName),
convertToGoStringList(responseAlgorithms),
ignoreMultipartChecksumValidationMap.getOrDefault(
service.toShapeId(), new HashSet<>()).contains(operation.toShapeId())
Expand All @@ -389,7 +396,7 @@ private String convertToGoStringList(List<String> list) {
return sb.toString();
}

private void writeGetInputMemberAccessorHelper(
private void writeInputMemberAccessorHelper(
GoWriter writer,
Model model,
SymbolProvider symbolProvider,
Expand Down Expand Up @@ -438,6 +445,9 @@ private void writeGetInputMemberAccessorHelper(
String.format("%s gets the request checksum validation mode provided as input.", funcName));
getInputTemplate(writer, symbolProvider, input, funcName, memberName);
writer.insertTrailingNewline();
funcName = setRequestValidationModeAccessorFuncName(operationSymbol.getName());
setInputTemplate(writer, symbolProvider, input, funcName, memberName);
writer.insertTrailingNewline();
}
}

Expand All @@ -459,6 +469,26 @@ private void getInputTemplate(
writer.write("");
}

private void setInputTemplate(
GoWriter writer,
SymbolProvider symbolProvider,
StructureShape input,
String funcName,
String memberName
) {
writer.write(GoWriter.goTemplate("""
func $fn:L(input interface{}, mode string) {
in := input.(*$inputType:L)
in.$member:L = types.$member:L(mode)
}""",
Map.of(
"fn", funcName,
"inputType", symbolProvider.toSymbol(input).getName(),
"member", memberName
)));
writer.write("");
}

private void generateInputComputedChecksumMetadataHelpers(
GoWriter writer,
Model model,
Expand Down
4 changes: 4 additions & 0 deletions service/internal/checksum/middleware_add.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,9 @@ type OutputMiddlewareOptions struct {
// mode and true, or false if no mode is specified.
GetValidationMode func(interface{}) (string, bool)

// SetValidationMode is a function to set the checksum validation mode of input parameters
SetValidationMode func(interface{}, string)

// ResponseChecksumValidation is the user config to opt-in/out response checksum validation
ResponseChecksumValidation aws.ResponseChecksumValidation

Expand Down Expand Up @@ -141,6 +144,7 @@ type OutputMiddlewareOptions struct {
func AddOutputMiddleware(stack *middleware.Stack, options OutputMiddlewareOptions) error {
err := stack.Initialize.Add(&setupOutputContext{
GetValidationMode: options.GetValidationMode,
SetValidationMode: options.SetValidationMode,
ResponseChecksumValidation: options.ResponseChecksumValidation,
}, middleware.Before)
if err != nil {
Expand Down
4 changes: 4 additions & 0 deletions service/internal/checksum/middleware_setup_context.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ type setupOutputContext struct {
// mode and true, or false if no mode is specified.
GetValidationMode func(interface{}) (string, bool)

// SetValidationMode is a function to set the checksum validation mode of input parameters
SetValidationMode func(interface{}, string)

// ResponseChecksumValidation states user config to opt-in/out checksum validation
ResponseChecksumValidation aws.ResponseChecksumValidation
}
Expand All @@ -90,6 +93,7 @@ func (m *setupOutputContext) HandleInitialize(
mode, _ := m.GetValidationMode(in.Parameters)

if m.ResponseChecksumValidation == aws.ResponseChecksumValidationWhenSupported || mode == checksumValidationModeEnabled {
m.SetValidationMode(in.Parameters, checksumValidationModeEnabled)
ctx = setContextOutputValidationMode(ctx, checksumValidationModeEnabled)
}

Expand Down
60 changes: 44 additions & 16 deletions service/internal/checksum/middleware_setup_context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,50 +131,74 @@ func TestSetupOutput(t *testing.T) {
inputParams interface{}
ResponseChecksumValidation aws.ResponseChecksumValidation
getValidationMode func(interface{}) (string, bool)
expectValue string
setValidationMode func(interface{}, string)
expectCtxValue string
expectInputValue string
}{
"user config support checksum found empty": {
ResponseChecksumValidation: aws.ResponseChecksumValidationWhenSupported,
inputParams: Params{Value: ""},
inputParams: &Params{Value: ""},
getValidationMode: func(v interface{}) (string, bool) {
vv := v.(Params)
vv := v.(*Params)
return vv.Value, true
},
expectValue: "ENABLED",
setValidationMode: func(v interface{}, m string) {
vv := v.(*Params)
vv.Value = m
},
expectCtxValue: "ENABLED",
expectInputValue: "ENABLED",
},
"user config support checksum found invalid value": {
ResponseChecksumValidation: aws.ResponseChecksumValidationWhenSupported,
inputParams: Params{Value: "abc123"},
inputParams: &Params{Value: "abc123"},
getValidationMode: func(v interface{}) (string, bool) {
vv := v.(Params)
vv := v.(*Params)
return vv.Value, true

},
expectValue: "ENABLED",
setValidationMode: func(v interface{}, m string) {
vv := v.(*Params)
vv.Value = m
},
expectCtxValue: "ENABLED",
expectInputValue: "ENABLED",
},
"user config require checksum found invalid value": {
ResponseChecksumValidation: aws.ResponseChecksumValidationWhenRequired,
inputParams: Params{Value: "abc123"},
inputParams: &Params{Value: "abc123"},
getValidationMode: func(v interface{}) (string, bool) {
vv := v.(Params)
vv := v.(*Params)
return vv.Value, true
},
expectValue: "",
setValidationMode: func(v interface{}, m string) {
vv := v.(*Params)
vv.Value = m
},
expectCtxValue: "",
expectInputValue: "abc123",
},
"user config require checksum found valid value": {
ResponseChecksumValidation: aws.ResponseChecksumValidationWhenRequired,
inputParams: Params{Value: "ENABLED"},
inputParams: &Params{Value: "ENABLED"},
getValidationMode: func(v interface{}) (string, bool) {
vv := v.(Params)
vv := v.(*Params)
return vv.Value, true
},
expectValue: "ENABLED",
setValidationMode: func(v interface{}, m string) {
vv := v.(*Params)
vv.Value = m
},
expectCtxValue: "ENABLED",
expectInputValue: "ENABLED",
},
}

for name, c := range cases {
t.Run(name, func(t *testing.T) {
m := setupOutputContext{
GetValidationMode: c.getValidationMode,
SetValidationMode: c.setValidationMode,
ResponseChecksumValidation: c.ResponseChecksumValidation,
}

Expand All @@ -185,10 +209,14 @@ func TestSetupOutput(t *testing.T) {
out middleware.InitializeOutput, metadata middleware.Metadata, err error,
) {
v := getContextOutputValidationMode(ctx)
if e, a := c.expectValue, v; e != a {
t.Errorf("expect value %v, got %v", e, a)
if e, a := c.expectCtxValue, v; e != a {
t.Errorf("expect ctx checksum validation mode to be %v, got %v", e, a)
}

in := input.Parameters.(*Params)
if e, a := c.expectInputValue, in.Value; e != a {
t.Errorf("expect input checksum validation mode to be %v, got %v", e, a)
}

return out, metadata, nil
},
))
Expand Down
Loading

0 comments on commit cb98dee

Please sign in to comment.