-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.go
79 lines (68 loc) · 1.87 KB
/
main.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
package main
import (
"context"
"github.com/aws/aws-lambda-go/lambda"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/rdsdataservice"
"log"
"os"
"strings"
)
type CustomResourceEvent struct {
RequestType string `json:"RequestType"`
}
func main() {
lambda.Start(handler)
}
func handler(_ context.Context, event CustomResourceEvent) {
log.Printf("Event: %s", event.RequestType)
if event.RequestType == "Create" {
onCreate()
}
//Don't really care about other types of events. "Update", "Delete"
}
func onCreate() {
dbArn, ok := os.LookupEnv("DB_ARN")
if !ok {
log.Fatalf("DB_ARN env variable is not set")
}
secretArn, ok := os.LookupEnv("SECRET_ARN")
if !ok {
log.Fatalf("SECRET_ARN env variable is not set")
}
dbName, ok := os.LookupEnv("DATABASE_NAME")
if !ok {
log.Fatalf("DATABASE_NAME env variable is not set")
}
var statements []string
for _, envKeyValue := range os.Environ() {
if strings.HasPrefix(envKeyValue, "STATEMENT_") {
statement := strings.Split(envKeyValue, "=")[1]
statements = append(statements, statement)
}
}
sess, err := session.NewSession()
if err != nil {
log.Fatalf("Failed to open a session: %v\n", err)
}
svc := rdsdataservice.New(sess)
for _, statement := range statements {
input := &rdsdataservice.ExecuteStatementInput{
Database: aws.String(dbName),
ResourceArn: aws.String(dbArn),
SecretArn: aws.String(secretArn),
Sql: aws.String(statement),
}
log.Printf("Executing statement: %s\n", input)
output, err := svc.ExecuteStatement(input)
if err != nil {
if awsErr, ok := err.(awserr.Error); ok {
log.Fatalf("Failed to execute statement %s:\nCode: %s\nMessage: %s\nOrig Message:%v\n", statement,
awsErr.Code(), awsErr.Message(), awsErr.OrigErr())
}
}
log.Print(output)
}
}