Skip to content

Commit

Permalink
Allow a custom port range for EC2 VMs
Browse files Browse the repository at this point in the history
Set the additional text with a comma-separated list of
ports i.e. 22,443,80,8080 and these will be added to the
security group.

Signed-off-by: Alex Ellis (OpenFaaS Ltd) <alexellis2@gmail.com>
  • Loading branch information
alexellis committed May 3, 2023
1 parent 8351174 commit 0b465cc
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 7 deletions.
47 changes: 40 additions & 7 deletions provision/ec2.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@ package provision

import (
"fmt"
"github.com/aws/aws-sdk-go/aws/credentials"
"strconv"
"strings"

"github.com/aws/aws-sdk-go/aws/credentials"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/ec2"
Expand Down Expand Up @@ -40,10 +41,12 @@ func (p *EC2Provisioner) Provision(host BasicHost) (*ProvisionedHost, error) {
}
pro := host.Additional["pro"]

ports := host.Additional["ports"]

var vpcID = host.Additional["vpc-id"]
var subnetID = host.Additional["subnet-id"]

groupID, name, err := p.createEC2SecurityGroup(vpcID, port, pro)
groupID, name, err := p.createEC2SecurityGroup(vpcID, port, pro, ports)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -85,6 +88,7 @@ func (p *EC2Provisioner) Provision(host BasicHost) (*ProvisionedHost, error) {
return nil, fmt.Errorf("could not create host: %s", runResult.String())
}

// AE: not sure why this error isn't handled?
_, err = p.ec2Provisioner.CreateTags(&ec2.CreateTagsInput{
Resources: []*string{runResult.Instances[0].InstanceId},
Tags: []*ec2.Tag{
Expand Down Expand Up @@ -247,9 +251,21 @@ func (p *EC2Provisioner) lookupID(request HostDeleteRequest) (string, error) {
}

// createEC2SecurityGroup creates a security group for the exit-node
func (p *EC2Provisioner) createEC2SecurityGroup(vpcID string, controlPort int, pro string) (*string, *string, error) {
ports := []int{80, 443, controlPort}
proPorts := []int{1024, 65535}
func (p *EC2Provisioner) createEC2SecurityGroup(vpcID string, controlPort int, pro, extraPorts string) (*string, *string, error) {
ports := []int{controlPort}

proPortRange := []int{1024, 65535}

if len(extraPorts) > 0 {
extraPorts, err := parsePorts(extraPorts)
if err != nil {
return nil, nil, err
}
ports = append(ports, extraPorts...)

proPortRange = []int{}
}

groupName := "inlets-" + uuid.New().String()
var input = &ec2.CreateSecurityGroupInput{
Description: aws.String("inlets security group"),
Expand All @@ -271,8 +287,9 @@ func (p *EC2Provisioner) createEC2SecurityGroup(vpcID string, controlPort int, p
return group.GroupId, &groupName, err
}
}
if pro == "true" {
err = p.createEC2SecurityGroupRule(*group.GroupId, proPorts[0], proPorts[1])

if pro == "true" && len(proPortRange) == 2 {
err = p.createEC2SecurityGroupRule(*group.GroupId, proPortRange[0], proPortRange[1])
if err != nil {
return group.GroupId, &groupName, err
}
Expand All @@ -281,6 +298,22 @@ func (p *EC2Provisioner) createEC2SecurityGroup(vpcID string, controlPort int, p
return group.GroupId, &groupName, nil
}

func parsePorts(extraPorts string) ([]int, error) {
var ports []int
parts := strings.Split(extraPorts, ",")
for _, part := range parts {
if trimmed := strings.TrimSpace(part); len(trimmed) > 0 {
port, err := strconv.Atoi(trimmed)
if err != nil {
return nil, err
}
ports = append(ports, port)
}
}

return ports, nil
}

func (p *EC2Provisioner) createEC2SecurityGroupRule(groupID string, fromPort, toPort int) error {
_, err := p.ec2Provisioner.AuthorizeSecurityGroupIngress(&ec2.AuthorizeSecurityGroupIngressInput{
CidrIp: aws.String("0.0.0.0/0"),
Expand Down
63 changes: 63 additions & 0 deletions provision/ec2_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package provision

import "testing"

func Test_parsePorts_empty(t *testing.T) {

ports, err := parsePorts("")
if err != nil {
t.Fatal(err)
}

if len(ports) != 0 {
t.Fatalf("Expected empty slice, got %d", len(ports))
}
}

func Test_parsePorts_single(t *testing.T) {

wantPort := 80
str := "80"
ports, err := parsePorts(str)
if err != nil {
t.Fatal(err)
}

if len(ports) != 1 {
t.Fatalf("Want single port, got %d", len(ports))
}

if ports[0] != wantPort {
t.Fatalf("Want port %d, got %d", wantPort, ports[0])
}
}

func Test_parsePorts_multiple(t *testing.T) {

wantPorts := []int{27017, 22}

str := "27017,22"

ports, err := parsePorts(str)
if err != nil {
t.Fatal(err)
}

if len(ports) != len(wantPorts) {
t.Fatalf("Want %d ports, got %d", len(wantPorts), len(ports))
}

found := 0

for _, port := range ports {
for _, wantPort := range wantPorts {
if port == wantPort {
found++
}
}
}

if found != len(wantPorts) {
t.Fatalf("Want %v ports, got %v", wantPorts, ports)
}
}

0 comments on commit 0b465cc

Please sign in to comment.