Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 28 additions & 23 deletions pkg/generate/code/synced.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ import (
func ResourceIsSynced(
cfg *ackgenconfig.Config,
r *model.CRD,
// String
// resource variable name
resVarName string,
// Number of levels of indentation to use
indentLevel int,
Expand All @@ -62,24 +62,24 @@ func ResourceIsSynced(
return out
}

for _, condition := range resConfig.Synced.When {
if condition.Path == nil || *condition.Path == "" {
for _, condCfg := range resConfig.Synced.When {
if condCfg.Path == nil || *condCfg.Path == "" {
panic("Received an empty sync condition path. 'SyncCondition.Path' must be provided.")
}
if len(condition.In) == 0 {
if len(condCfg.In) == 0 {
panic("'SyncCondition.In' must be provided.")
}
fp := fieldpath.FromString(*condition.Path)
field, err := getTopLevelField(r, *condition.Path)
fp := fieldpath.FromString(*condCfg.Path)
field, err := getTopLevelField(r, *condCfg.Path)
if err != nil {
msg := fmt.Sprintf("cannot find top level field of path '%s': %v", *condition.Path, err)
msg := fmt.Sprintf("cannot find top level field of path '%s': %v", *condCfg.Path, err)
panic(msg)
}
candidatesVarName := fmt.Sprintf("%sCandidates", field.Names.CamelLower)
if fp.Size() == 2 {
out += scalarFieldEqual(resVarName, candidatesVarName, field.ShapeRef.GoTypeElem(), condition)
out += scalarFieldEqual(resVarName, candidatesVarName, field.ShapeRef.GoTypeElem(), condCfg)
} else {
out += fieldPathSafeEqual(resVarName, candidatesVarName, field, condition)
out += fieldPathSafeEqual(resVarName, candidatesVarName, field, condCfg)
}
}

Expand Down Expand Up @@ -113,21 +113,26 @@ func getTopLevelField(r *model.CRD, fieldPath string) (*model.Field, error) {
}

// scalarFieldEqual returns Go code that compares a scalar field to a given set of values.
func scalarFieldEqual(resVarName string, candidatesVarName string, goType string, condition ackgenconfig.SyncedCondition) string {
func scalarFieldEqual(
resVarName string,
candidatesVarName string,
goType string,
condCfg ackgenconfig.SyncedCondition,
) string {
out := ""
fieldPath := fmt.Sprintf("%s.%s", resVarName, *condition.Path)
fieldPath := fmt.Sprintf("%s.%s", resVarName, *condCfg.Path)

valuesSlice := ""
switch goType {
case "string":
// []string{"AVAILABLE", "ACTIVE"}
valuesSlice = fmt.Sprintf("[]string{\"%s\"}", strings.Join(condition.In, "\", \""))
valuesSlice = fmt.Sprintf("[]string{\"%s\"}", strings.Join(condCfg.In, "\", \""))
case "int64", "PositiveLongObject", "Long":
// []int64{1, 2}
valuesSlice = fmt.Sprintf("[]int{%s}", strings.Join(condition.In, ", "))
valuesSlice = fmt.Sprintf("[]int{%s}", strings.Join(condCfg.In, ", "))
case "bool":
// []bool{false}
valuesSlice = fmt.Sprintf("[]bool{%s}", condition.In)
valuesSlice = fmt.Sprintf("[]bool{%s}", condCfg.In)
default:
panic("not supported type " + goType)
}
Expand Down Expand Up @@ -157,11 +162,11 @@ func fieldPathSafeEqual(
resVarName string,
candidatesVarName string,
field *model.Field,
condition ackgenconfig.SyncedCondition,
condCfg ackgenconfig.SyncedCondition,
) string {
out := ""
rootPath := fmt.Sprintf("%s.%s", resVarName, strings.Split(*condition.Path, ".")[0])
knownShapesPath := strings.Join(strings.Split(*condition.Path, ".")[1:], ".")
rootPath := fmt.Sprintf("%s.%s", resVarName, strings.Split(*condCfg.Path, ".")[0])
knownShapesPath := strings.Join(strings.Split(*condCfg.Path, ".")[1:], ".")

fp := fieldpath.FromString(knownShapesPath)
shapes := fp.IterShapeRefs(field.ShapeRef)
Expand All @@ -171,7 +176,7 @@ func fieldPathSafeEqual(
if index == len(shapes)-1 {
// Some aws-sdk-go scalar shapes don't contain the real name of a shape
// In this case we use the full path given in condition.Path
subFieldPath = fmt.Sprintf("%s.%s", resVarName, *condition.Path)
subFieldPath = fmt.Sprintf("%s.%s", resVarName, *condCfg.Path)
} else {
subFieldPath += "." + shape.Shape.ShapeName
}
Expand All @@ -182,13 +187,13 @@ func fieldPathSafeEqual(
// }
out += "\t}\n"
}
out += scalarFieldEqual(resVarName, candidatesVarName, shapes[len(shapes)-1].GoTypeElem(), condition)
out += scalarFieldEqual(resVarName, candidatesVarName, shapes[len(shapes)-1].GoTypeElem(), condCfg)
return out
}

func fieldPathContainsMapOrArray(fieldPath string, shapeRef *awssdkmodel.ShapeRef) bool {
c := fieldpath.FromString(fieldPath)
sr := c.ShapeRef(shapeRef)
fp := fieldpath.FromString(fieldPath)
sr := fp.ShapeRef(shapeRef)

if sr == nil {
return false
Expand All @@ -197,8 +202,8 @@ func fieldPathContainsMapOrArray(fieldPath string, shapeRef *awssdkmodel.ShapeRe
return true
}
if sr.ShapeName == "structure" {
fieldName := c.PopFront()
return fieldPathContainsMapOrArray(c.Copy().At(1), sr.Shape.MemberRefs[fieldName])
fieldName := fp.PopFront()
return fieldPathContainsMapOrArray(fp.Copy().At(1), sr.Shape.MemberRefs[fieldName])
}
return false
}