Skip to content

Commit

Permalink
feat(rds): add public command to generate report of publicly exposed …
Browse files Browse the repository at this point in the history
…RDS instances (#20)

* feat(rds): add public command to generate report of publicly exposed RDS instances

* chore(rds): added UsageText to the public command

* chore: fix lint errors
  • Loading branch information
clok authored Sep 30, 2020
1 parent 9ecccba commit 73fb4c3
Show file tree
Hide file tree
Showing 10 changed files with 307 additions and 41 deletions.
8 changes: 4 additions & 4 deletions iam/get.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,9 @@ func ListUsers(showOnly string) error {
}
for _, user := range data {
if user.HasConsoleAccess() {
summaryStats["consoleAccess"] += 1
summaryStats["consoleAccess"]++
}
summaryStats[user.CheckStatus()] += 1
summaryStats[user.CheckStatus()]++

if showOnly == "" || showOnly == user.CheckStatus() {
t.AppendRow([]interface{}{
Expand All @@ -101,9 +101,9 @@ func ListUsers(showOnly string) error {
for _, key := range user.accessKeys {
switch aws.StringValue(key.status) {
case "Active":
summaryStats["activeKeys"] += 1
summaryStats["activeKeys"]++
case "Inactive":
summaryStats["inactiveKeys"] += 1
summaryStats["inactiveKeys"]++
}
st.AppendRow([]interface{}{
aws.StringValue(key.id),
Expand Down
28 changes: 28 additions & 0 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,34 @@ throttling from AWS with an exponential backoff with retry.
return nil
},
},
{
Name: "public",
Usage: "Produce report of instances that have public interfaces attached",
UsageText: `
Produces a report that displays a list RDS servers that are configured as Publicly Accessible.
The report contains:
DB INSTANCE:
- Name of the instance
ENGINE:
- RDS DB engine
SECURITY GROUPS:
- Security Group ID
- Security Group Name
- Inbound Port
- CIDR rules applied to the Port
`,
Action: func(c *cli.Context) error {
err := rds.ListPublicInterfaces()
if err != nil {
return cli.NewExitError(err, 2)
}
return nil
},
},
},
},
{
Expand Down
89 changes: 89 additions & 0 deletions rds/public.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
package rds

import (
"fmt"
"github.com/GoodwayGroup/gw-aws-audit/sg"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/rds"
as "github.com/clok/awssession"
"github.com/clok/kemba"
"github.com/jedib0t/go-pretty/v6/table"
"net"
"os"
)

// ListPublicInterfaces will list RDS instances with a public interface attached.
func ListPublicInterfaces() error {
k := kemba.New("gw-aws-audit:rds:ListPublicInterfaces")
sess, err := as.New()
if err != nil {
return err
}
client := rds.New(sess)
cnt := 0

var result *rds.DescribeDBInstancesOutput
result, err = client.DescribeDBInstances(&rds.DescribeDBInstancesInput{})

if err != nil {
fmt.Println("Failed to list instances")
return err
}

t := table.NewWriter()
t.SetOutputMirror(os.Stdout)
t.SetStyle(table.StyleLight)
t.AppendHeader(table.Row{"DB Instance", "Engine", "Security Groups"})

k.Printf("checking %d RDS instances", len(result.DBInstances))
for _, db := range result.DBInstances {
if aws.BoolValue(db.PubliclyAccessible) {
cnt++

var sgIDs []*string
for _, sec := range db.VpcSecurityGroups {
sgIDs = append(sgIDs, sec.VpcSecurityGroupId)
}
sgs, err := sg.GetSecurityGroups(sgIDs)
if err != nil {
return err
}
var securityGroups []*sg.SecurityGroup
for _, sec := range sgs {
securityGroups = append(securityGroups, sec)
}
k.Log(securityGroups)

var ips []string
var stub string
for _, sec := range securityGroups {
for token, rule := range sec.Rules() {
port, _, _ := sec.ParseRuleToken(token)
for _, ip := range rule {
_, ipv4Net, _ := net.ParseCIDR(aws.StringValue(ip.CidrIp))
ips = append(ips, ipv4Net.String())
}
stub = fmt.Sprintf("%s\t%s\t%s\n\n\t", sec.ID(), sec.Name(), port)
for i, ip := range ips {
if i != 0 && i%4 == 0 {
stub = fmt.Sprintf("%s\n\t", stub)
}
stub = fmt.Sprintf("%s %20s", stub, ip)
}
stub = fmt.Sprintf("%s\n", stub)
}
}

name := aws.StringValue(db.DBInstanceIdentifier)
engine := aws.StringValue(db.Engine)

t.AppendRow([]interface{}{name, engine, stub})
}
}

// There are a LOT of metrics to consider
// See: https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/USER_Monitoring.OS.html
t.AppendFooter(table.Row{"DB Instances", cnt})
t.Render()
return nil
}
2 changes: 1 addition & 1 deletion sg/attached.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ func ListAttachedSecurityGroups() error {
return err
}

var attached []*securityGroup
var attached []*SecurityGroup
for _, sg := range sgs {
if sg.attached != nil {
attached = append(attached, sg)
Expand Down
2 changes: 1 addition & 1 deletion sg/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ func GenerateExternalAWSIPReport() error {
return err
}

var securityGroups []*securityGroup
var securityGroups []*SecurityGroup
for _, sg := range sgs {
securityGroups = append(securityGroups, sg)
}
Expand Down
2 changes: 1 addition & 1 deletion sg/detached.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ func ListDetachedSecurityGroups() error {
return err
}

var detached []*securityGroup
var detached []*SecurityGroup
for _, sg := range sgs {
if sg.attached == nil {
detached = append(detached, sg)
Expand Down
44 changes: 37 additions & 7 deletions sg/get.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,25 +13,33 @@ var (
ksg = kemba.New("gw-aws-audit:sg")
)

func getAllSecurityGroups() (map[string]*securityGroup, error) {
kl := ksg.Extend("get-all-sg")
// GetSecurityGroups will retrieve a list of Security Group IDs with mapped ports
func GetSecurityGroups(sgIDs []*string) (map[string]*SecurityGroup, error) {
kl := ksg.Extend("get-sg")
sess, err := as.New()
if err != nil {
return nil, err
}
client := ec2.New(sess)

kl.Printf("retrieving SG IDs: %# v", sgIDs)
var results *ec2.DescribeSecurityGroupsOutput
results, err = client.DescribeSecurityGroups(&ec2.DescribeSecurityGroupsInput{
MaxResults: aws.Int64(1000),
GroupIds: sgIDs,
})
if err != nil {
fmt.Println("Failed to list Security Groups")
return nil, err
}

kl.Printf("found %d security groups", len(results.SecurityGroups))
secGroups := make(map[string]*securityGroup, len(results.SecurityGroups))
secGroups := processSecurityGroupsResponse(results)

return secGroups, nil
}

func processSecurityGroupsResponse(results *ec2.DescribeSecurityGroupsOutput) map[string]*SecurityGroup {
secGroups := make(map[string]*SecurityGroup, len(results.SecurityGroups))
for _, sec := range results.SecurityGroups {
rules := map[string][]*ec2.IpRange{}
for _, rule := range sec.IpPermissions {
Expand Down Expand Up @@ -63,12 +71,34 @@ func getAllSecurityGroups() (map[string]*securityGroup, error) {
}
}

secGroups[aws.StringValue(sec.GroupId)] = &securityGroup{
secGroups[aws.StringValue(sec.GroupId)] = &SecurityGroup{
id: aws.StringValue(sec.GroupId),
name: aws.StringValue(sec.GroupName),
rules: rules,
}
}
return secGroups
}

func getAllSecurityGroups() (map[string]*SecurityGroup, error) {
kl := ksg.Extend("get-all-sg")
sess, err := as.New()
if err != nil {
return nil, err
}
client := ec2.New(sess)

var results *ec2.DescribeSecurityGroupsOutput
results, err = client.DescribeSecurityGroups(&ec2.DescribeSecurityGroupsInput{
MaxResults: aws.Int64(1000),
})
if err != nil {
fmt.Println("Failed to list Security Groups")
return nil, err
}

kl.Printf("found %d security groups", len(results.SecurityGroups))
secGroups := processSecurityGroupsResponse(results)

return secGroups, nil
}
Expand Down Expand Up @@ -101,7 +131,7 @@ func buildPortToken(fromPort string, toPort string, proto *string, securityGroup
return strings.Join(parts, "::")
}

func detectAttachedSecurityGroups(sgs map[string]*securityGroup) error {
func detectAttachedSecurityGroups(sgs map[string]*SecurityGroup) error {
kl := ksg.Extend("detect-attached")
sess, err := as.New()
if err != nil {
Expand Down Expand Up @@ -171,7 +201,7 @@ func detectAttachedSecurityGroups(sgs map[string]*securityGroup) error {
return nil
}

func getAnnotatedSecurityGroups() (map[string]*securityGroup, error) {
func getAnnotatedSecurityGroups() (map[string]*SecurityGroup, error) {
// get all sgs in a region
sgs, err := getAllSecurityGroups()
if err != nil {
Expand Down
12 changes: 6 additions & 6 deletions sg/helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func generateReport(c *cli.Context, checkFxn func(a []string, b string) bool, po
return err
}

var securityGroups []*securityGroup
var securityGroups []*SecurityGroup
for _, sg := range sgs {
if sg.attached != nil || c.Bool("all") {
securityGroups = append(securityGroups, sg)
Expand Down Expand Up @@ -95,7 +95,7 @@ func parseToken(token string) (port string, protocol string, sgIDs string) {
return parts[0], parts[1], parts[2]
}

func processSecurityGroups(securityGroups []*securityGroup, groupedCIDRs *groupedIPBlockRules, checkFxn func(a []string, b string) bool, ports []string, ignoredProtocols map[string]bool) (*groupedSecurityGroups, error) {
func processSecurityGroups(securityGroups []*SecurityGroup, groupedCIDRs *groupedIPBlockRules, checkFxn func(a []string, b string) bool, ports []string, ignoredProtocols map[string]bool) (*groupedSecurityGroups, error) {
kl := ksg.Extend("processSecurityGroups")
mappedSGs := newGroupedSecurityGroups()

Expand Down Expand Up @@ -208,7 +208,7 @@ func generateIPBlockRules(c *cli.Context) (*groupedIPBlockRules, error) {
return groupedCIDRs, nil
}

func printTable(data map[*securityGroup][]*portToIP) {
func printTable(data map[*SecurityGroup][]*portToIP) {
t := table.NewWriter()
t.SetOutputMirror(os.Stdout)
t.SetStyle(table.StyleLight)
Expand All @@ -222,7 +222,7 @@ func printTable(data map[*securityGroup][]*portToIP) {
if i == 0 {
id = sec.id
name = sec.name
usage = sec.getAttachmentsAsString()
usage = sec.GetAttachmentsAsString()
}
t.AppendRow([]interface{}{
id,
Expand All @@ -238,7 +238,7 @@ func printTable(data map[*securityGroup][]*portToIP) {
t.Render()
}

func printAmazonTable(data map[*securityGroup][]*portToIP) {
func printAmazonTable(data map[*SecurityGroup][]*portToIP) {
t := table.NewWriter()
t.SetOutputMirror(os.Stdout)
t.SetStyle(table.StyleLight)
Expand All @@ -252,7 +252,7 @@ func printAmazonTable(data map[*securityGroup][]*portToIP) {
if i == 0 {
id = sec.id
name = sec.name
usage = sec.getAttachmentsAsString()
usage = sec.GetAttachmentsAsString()
}
t.AppendRow([]interface{}{
id,
Expand Down
Loading

0 comments on commit 73fb4c3

Please sign in to comment.