Skip to content

Commit

Permalink
Replace tuples with named tuples to improve code readability
Browse files Browse the repository at this point in the history
Co-authored-by: Shane Frasier <jeremy.frasier@gwe.cisa.dhs.gov>
  • Loading branch information
dav3r and jsf9k committed May 18, 2024
1 parent 8dff260 commit a9ec40c
Showing 1 changed file with 50 additions and 37 deletions.
87 changes: 50 additions & 37 deletions src/lambda_handler.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
"""AWS Lambda handler to publish egress IPs from a provided list of AWS accounts."""

# Standard Python Libraries
from collections import namedtuple
from datetime import datetime, timezone
from ipaddress import collapse_addresses, ip_network
import logging
import os
import re
from typing import Any, Dict, Iterator, List, Optional, Set, Tuple, TypedDict, Union
from typing import Any, Dict, Iterator, List, Optional, Set, TypedDict, Union

# Third-Party Libraries
import boto3
Expand All @@ -15,9 +16,14 @@
logger = logging.getLogger()
logger.setLevel(default_log_level)

# Use a named tuple to hold AWS credentials
aws_credentials = namedtuple(
"aws_credentials", ["access_key_id", "secret_access_key", "session_token"]
)

def assume_role(role_arn: str, session_name: str) -> Tuple[str, str, str]:
"""Assume the given role and return a tuple containing the assumed role's credentials."""

def assume_role(role_arn: str, session_name: str) -> aws_credentials:
"""Assume the given role and return a named tuple containing the assumed role's credentials."""
# Create an STS session with current credentials
sts: boto3.client = boto3.client("sts")

Expand All @@ -26,7 +32,7 @@ def assume_role(role_arn: str, session_name: str) -> Tuple[str, str, str]:
RoleArn=role_arn, RoleSessionName=session_name
)

return (
return aws_credentials(
response["Credentials"]["AccessKeyId"],
response["Credentials"]["SecretAccessKey"],
response["Credentials"]["SessionToken"],
Expand All @@ -37,28 +43,28 @@ def create_assumed_aws_client(
aws_service: str, role_arn: str, session_name: str
) -> boto3.client:
"""Assume the given role and return an AWS client for the given service using that role."""
role_credentials: Tuple[str, str, str] = assume_role(role_arn, session_name)
role_credentials: aws_credentials = assume_role(role_arn, session_name)

return boto3.client(
aws_service,
aws_access_key_id=role_credentials[0],
aws_secret_access_key=role_credentials[1],
aws_session_token=role_credentials[2],
aws_access_key_id=role_credentials.access_key_id,
aws_secret_access_key=role_credentials.secret_access_key,
aws_session_token=role_credentials.session_token,
)


def create_assumed_aws_resource(
aws_service: str, region: str, role_arn: str, session_name: str
) -> boto3.resource:
"""Assume the given role and return an AWS resource object for the given service using that role."""
role_credentials: Tuple[str, str, str] = assume_role(role_arn, session_name)
role_credentials: aws_credentials = assume_role(role_arn, session_name)

return boto3.resource(
aws_service,
region_name=region,
aws_access_key_id=role_credentials[0],
aws_secret_access_key=role_credentials[1],
aws_session_token=role_credentials[2],
aws_access_key_id=role_credentials.access_key_id,
aws_secret_access_key=role_credentials.secret_access_key,
aws_session_token=role_credentials.session_token,
)


Expand All @@ -72,9 +78,13 @@ def convert_tags(aws_resource: boto3.resource) -> Dict[str, str]:
return tags


# Use a named tuple to hold EC2 information
ec2_info = namedtuple("ec2_info", ["application_tag_value", "public_ip"])


def get_ec2_ips(
ec2: boto3.resource, application_tag_name: str, publish_egress_tag_name: str
) -> Iterator[Tuple[str, str]]:
) -> Iterator[ec2_info]:
"""Create a set of public EC2 IPs.
Yields (application tag value, public_ip) tuples.
Expand All @@ -100,7 +110,7 @@ def get_ec2_ips(
# Send back a tuple associating the public IP to an application.
# If application is unset, return "", so that the IP can be included
# in a list of all IPs if desired (e.g. using app_regex=".*").
yield (tags.get(application_tag_name, ""), instance.public_ip_address)
yield ec2_info(tags.get(application_tag_name, ""), instance.public_ip_address)

for vpc_address in vpc_addresses:
# Convert elastic IP tags from an AWS dictionary into a Python dictionary
Expand All @@ -112,7 +122,7 @@ def get_ec2_ips(
# Send back a tuple associating the public IP to an application.
# If application is unset, return "", so that the IP can be included
# in a list of all IPs if desired (e.g. using app_regex=".*").
yield (eip_tags.get(application_tag_name, ""), vpc_address.public_ip)
yield ec2_info(eip_tags.get(application_tag_name, ""), vpc_address.public_ip)


def get_ec2_regions(
Expand Down Expand Up @@ -188,9 +198,11 @@ class FileConfig(TypedDict):
static_ips: List[str]


def validate_event_data(
event: Dict[str, Any]
) -> Tuple[Dict[str, Any], bool, List[str]]:
# Use a named tuple to store the results of the event validation
event_validation = namedtuple("event_validation", ["event", "valid", "errors"])


def validate_event_data(event: Dict[str, Any]) -> event_validation:
"""Validate the event data and return a tuple containing the validated event, a boolean result (True if valid, False if invalid), and a list of error message strings."""
result = True
errors = []
Expand Down Expand Up @@ -268,36 +280,35 @@ def validate_event_data(
if errors:
result = False

return event, result, errors
return event_validation(event, result, errors)


def task_publish(event: Dict[str, Any]) -> Dict[str, Union[Optional[str], bool]]:
"""Publish the egress IP addresses in the given AWS accounts to an S3 bucket."""
result: Dict[str, Union[Optional[str], bool]] = {"message": None, "success": True}

# Validate all event data before going any further
event_valid: bool
event_errors: List[str]
event, event_valid, event_errors = validate_event_data(event)
if not event_valid:
for e in event_errors:
event_validation_info: event_validation = validate_event_data(event)
if not event_validation_info.valid:
for e in event_validation_info.errors:
logging.error(e)
failed_task(result, " ".join(event_errors))
failed_task(result, " ".join(event_validation_info.errors))
return result
validated_event = event_validation_info.event

# The account IDs to examine for IP addresses
account_ids: List[str] = event["account_ids"]
account_ids: List[str] = validated_event["account_ids"]

# Name of the AWS resource tag whose value represents the application
# associated with an IP address
application_tag_name: str = event.get("application_tag", "Application")
application_tag_name: str = validated_event.get("application_tag", "Application")

# The bucket to publish the files to
bucket_name: str = event["bucket_name"]
bucket_name: str = validated_event["bucket_name"]

# Name of the IAM role to assume that can read the necessary EC2 data
# in each AWS account. Note that this role must exist in each account.
ec2_read_role_name: str = event.get("role_name", "EC2ReadOnly")
ec2_read_role_name: str = validated_event.get("role_name", "EC2ReadOnly")

# A list of dictionaries that define the files to be created and
# published. When an IP is to be published, its associated
Expand All @@ -310,7 +321,7 @@ def task_publish(event: Dict[str, Any]) -> Dict[str, Union[Optional[str], bool]]
# - "filename" (string): the name of the file
# - "static_ips" (list(string)): a list of CIDR blocks that will always
# be included in the published file
file_configs: List[FileConfig] = event["file_configs"]
file_configs: List[FileConfig] = validated_event["file_configs"]

# Header template for each file, comprised of a list of strings.
# When the file is published, newline characters are automatically added
Expand All @@ -320,7 +331,7 @@ def task_publish(event: Dict[str, Any]) -> Dict[str, Union[Optional[str], bool]]
# {filename} - name of the published file
# {timestamp} - timestamp when the file was published
# {description} - description of the published file
file_header: List[str] = event.get(
file_header: List[str] = validated_event.get(
"file_header",
[
"###",
Expand All @@ -332,10 +343,12 @@ def task_publish(event: Dict[str, Any]) -> Dict[str, Union[Optional[str], bool]]
)

# AWS resource tag name indicating whether an IP address should be published
publish_egress_tag_name: str = event.get("publish_egress_tag", "Publish Egress")
publish_egress_tag_name: str = validated_event.get(
"publish_egress_tag", "Publish Egress"
)

# An AWS-style filter definition to limit the queried regions
region_filters: List[Dict[str, Union[str, List[str]]]] = event.get(
region_filters: List[Dict[str, Union[str, List[str]]]] = validated_event.get(
"region_filters", []
)

Expand Down Expand Up @@ -368,19 +381,19 @@ def task_publish(event: Dict[str, Any]) -> Dict[str, Union[Optional[str], bool]]
)

# Get the public IPs of instances that are tagged to be published
for application_tag_value, public_ip in get_ec2_ips(
for ec2_info in get_ec2_ips(
ec2, application_tag_name, publish_egress_tag_name
):
# Loop through all regexes and add IP to set if matched
for config in file_configs:
if config["app_regex"].match(application_tag_value):
config["ip_set"].add(ip_network(public_ip))
if config["app_regex"].match(ec2_info.application_tag_value):
config["ip_set"].add(ip_network(ec2_info.public_ip))

# Use a single timestamp for all files
now = "{:%a %b %d %H:%M:%S UTC %Y}".format(datetime.utcnow())

# The domain to display in the header of each published file
domain: str = event.get("domain", "example.gov")
domain: str = validated_event.get("domain", "example.gov")

# Update each object (file) in the bucket
for config in file_configs:
Expand Down

0 comments on commit a9ec40c

Please sign in to comment.