Skip to content

Commit

Permalink
feat: automatically refresh role assumptions (#106)
Browse files Browse the repository at this point in the history
* feat: use role assumption helper to assume roles

* refactor: fixup lint/types

* feat: use $AWS_ROLE_CONFIG_PATH to be more consistent

* refactor: remove unused named tuple
  • Loading branch information
blacha authored Aug 24, 2022
1 parent 25a85c1 commit 40e4b52
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 89 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,5 +37,5 @@ Python version is set to `3.8.10` as it is the current version used by `osgeo/gd
3. Run the following command

```bash
docker run -v ${HOME}/.aws/credentials:/root/.aws/credentials:ro -e AWS_PROFILE='your-aws-profile' 'image-id' python create_polygons.py --uri 's3://path-to-the-tiff/image.tif' --destination 'destination-bucket'
docker run -v ${HOME}/.aws/credentials:/root/.aws/credentials:ro -e AWS_PROFILE 'image-id' python create_polygons.py --uri 's3://path-to-the-tiff/image.tif' --destination 'destination-bucket'
```
35 changes: 35 additions & 0 deletions scripts/aws/aws_credential_source.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from dataclasses import dataclass
from typing import Optional


# pylint: disable=too-many-instance-attributes
@dataclass
class CredentialSource:
bucket: str
"""Base bucket location may be a subset of bucket"""
type: str
"""Type of role assumption generally "s3"""
prefix: str
"""
Prefix for what the role is valid, generally starts with s3://
"""
accountId: str
"""
AWS Account id of the bucket owner
"""
roleArn: str
"""
Role arn to use
"""
externalId: Optional[str] = None
"""
Role external ID if it exists
"""
roleSessionDuration: Optional[int] = 1 * 60 * 60
"""
Max duration of the assumed session in seconds, default 1 hours
"""
flags: Optional[str] = None
"""
flags that the role can use either "r" for read-only or "rw" for read-write
"""
113 changes: 63 additions & 50 deletions scripts/aws/aws_helper.py
Original file line number Diff line number Diff line change
@@ -1,81 +1,94 @@
import json
from os import environ
from typing import TYPE_CHECKING, NamedTuple
from typing import Any, Dict, List, NamedTuple, Optional
from urllib.parse import urlparse

import boto3
import botocore
from botocore.credentials import AssumeRoleCredentialFetcher, DeferredRefreshableCredentials
from linz_logger import get_log

if TYPE_CHECKING:
from mypy_boto3_s3.service_resource import Bucket
else:
Bucket = object
from scripts.aws.aws_credential_source import CredentialSource

Credentials = NamedTuple("Credentials", [("access_key", str), ("secret_key", str), ("token", str)])
S3Path = NamedTuple("S3Path", [("bucket", str), ("key", str)])

aws_profile = environ.get("AWS_PROFILE")
session = boto3.Session(profile_name=aws_profile)
bucket_roles = {}
bucket_credentials = {}
sessions: Dict[str, boto3.Session] = {}

bucket_roles: List[CredentialSource] = []

client_sts = session.client("sts")

# Load bucket to roleArn mapping for LINZ internal buckets from SSM
def init_roles() -> None:
s3 = boto3.resource("s3")
bucket_config_path = environ.get("AWS_BUCKET_CONFIG_PATH", "s3://linz-bucket-config/config.json")

content_object = s3.Object("linz-bucket-config", "config.json")
# Load bucket to roleArn mapping for LINZ internal buckets from SSM
def _init_roles() -> None:
s3 = session.resource("s3")
config_path = parse_path(bucket_config_path)
content_object = s3.Object(config_path.bucket, config_path.key)
file_content = content_object.get()["Body"].read().decode("utf-8")
json_content = json.loads(file_content)

get_log().debug("bucket_config", config=json_content)
get_log().trace("bucket_config_load", config=bucket_config_path)

for cfg in json_content["buckets"]:
bucket_roles[cfg["bucket"]] = cfg
for cfg in json_content["prefixes"]:
bucket_roles.append(CredentialSource(**cfg))

get_log().debug("bucket_config_loaded", config=bucket_config_path, prefix_count=len(bucket_roles))


def _get_client_creator(local_session: boto3.Session) -> Any:
def client_creator(service_name: str, **kwargs: Any) -> Any:
return local_session.client(service_name, **kwargs)

return client_creator

def get_credentials(bucket_name: str) -> Credentials:
get_log().debug("get_credentials_bucket_name", bucket_name=bucket_name)
if not bucket_roles:
init_roles()
if bucket_name in bucket_roles:
# FIXME: check if the token is expired - add a parameter
if bucket_name not in bucket_credentials:
role_arn = bucket_roles[bucket_name]["roleArn"]
get_log().debug("sts_assume_role", bucket_name=bucket_name, role_arn=role_arn)
assumed_role_object = client_sts.assume_role(RoleArn=role_arn, RoleSessionName="gdal")
bucket_credentials[bucket_name] = Credentials(
assumed_role_object["Credentials"]["AccessKeyId"],
assumed_role_object["Credentials"]["SecretAccessKey"],
assumed_role_object["Credentials"]["SessionToken"],
)

return bucket_credentials[bucket_name]

session_credentials = session.get_credentials()
default_credentials = Credentials(
session_credentials.access_key, session_credentials.secret_key, session_credentials.token
)

return default_credentials
def get_session(prefix: str) -> boto3.Session:
cfg = _get_credential_config(prefix)
if cfg is None:
raise Exception(f"Unable to find role for prefix: {prefix}")

current_session = sessions.get(cfg.roleArn, None)
if current_session is not None:
return current_session

def get_bucket(bucket_name: str) -> Bucket:
credentials = get_credentials(bucket_name=bucket_name)
extra_args: Dict[str, Any] = {"DurationSeconds": cfg.roleSessionDuration}

s3_resource = boto3.resource(
"s3",
aws_access_key_id=credentials.access_key,
aws_secret_access_key=credentials.secret_key,
aws_session_token=credentials.token,
if cfg.externalId:
extra_args["ExternalId"] = cfg.externalId

fetcher = AssumeRoleCredentialFetcher(
client_creator=_get_client_creator(session),
source_credentials=session.get_credentials(),
role_arn=cfg.roleArn,
extra_args=extra_args,
)
botocore_session = botocore.session.Session()

# pylint:disable=protected-access
botocore_session._credentials = DeferredRefreshableCredentials(
method="assume-role", refresh_using=fetcher.fetch_credentials
)
s3_bucket: Bucket = s3_resource.Bucket(bucket_name)
return s3_bucket

current_session = boto3.Session(botocore_session=botocore_session)
sessions[cfg.roleArn] = current_session

get_log().info("role_assume", prefix=prefix, bucket=cfg.bucket, role_arn=cfg.roleArn)
return current_session


def _get_credential_config(prefix: str) -> Optional[CredentialSource]:
get_log().debug("get_credentials_bucket_name", prefix=prefix)
if not bucket_roles:
_init_roles()

for cfg in bucket_roles:
if prefix.startswith(cfg.prefix):
return cfg

def get_bucket_name_from_path(path: str) -> str:
path_parts = path.replace("s3://", "").split("/")
return path_parts.pop(0)
return None


def parse_path(path: str) -> S3Path:
Expand Down
38 changes: 2 additions & 36 deletions scripts/files/fs_s3.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,10 @@
from typing import TYPE_CHECKING, Dict

import boto3
import botocore
from linz_logger import get_log

from scripts.aws.aws_helper import get_credentials, parse_path
from scripts.aws.aws_helper import get_session, parse_path
from scripts.logging.time_helper import time_in_ms

if TYPE_CHECKING:
from mypy_boto3_s3.service_resource import S3ServiceResource
else:
S3ServiceResource = object

# s3_session = [{"bucket", boto3.Session}]
s3_sessions: Dict[str, S3ServiceResource] = {}


def _get_s3_resource(bucket_name: str) -> S3ServiceResource:
"""Return a boto3 S3 Resource with the AWS credentials for a bucket.
Args:
bucket_name (str): The name of the bucket.
Returns:
S3ServiceResource: The boto3 S3 Resource.
"""
session = s3_sessions.get(bucket_name, None)
if session is None:
# TODO implement refreshable session TDE-235
credentials = get_credentials(bucket_name)
session = boto3.Session(
aws_access_key_id=credentials.access_key,
aws_secret_access_key=credentials.secret_key,
aws_session_token=credentials.token,
)
s3_sessions[bucket_name] = session

s3: S3ServiceResource = session.resource("s3")
return s3


def write(destination: str, source: bytes) -> None:
"""Write a source (bytes) in a AWS s3 destination (path in a bucket).
Expand Down Expand Up @@ -81,7 +47,7 @@ def read(path: str, needs_credentials: bool = False) -> bytes:

try:
if needs_credentials:
s3 = _get_s3_resource(s3_path.bucket)
s3 = get_session(path).client("s3")

s3_object = s3.Object(s3_path.bucket, key)
file: bytes = s3_object.get()["Body"].read()
Expand Down
5 changes: 3 additions & 2 deletions scripts/gdal/gdal_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from linz_logger import get_log

from scripts.aws.aws_helper import get_bucket_name_from_path, get_credentials, is_s3
from scripts.aws.aws_helper import get_session, is_s3
from scripts.logging.time_helper import time_in_ms


Expand Down Expand Up @@ -59,7 +59,8 @@ def run_gdal(
if input_file:
if is_s3(input_file):
# Set the credentials for GDAL to be able to read the input file
credentials = get_credentials(get_bucket_name_from_path(input_file))
session = get_session(input_file)
credentials = session.get_credentials()
gdal_env["AWS_ACCESS_KEY_ID"] = credentials.access_key
gdal_env["AWS_SECRET_ACCESS_KEY"] = credentials.secret_key
gdal_env["AWS_SESSION_TOKEN"] = credentials.token
Expand Down

0 comments on commit 40e4b52

Please sign in to comment.