Skip to content

Commit

Permalink
Dataall code/infrastructure AWS migrations (#1374)
Browse files Browse the repository at this point in the history
### Feature or Bugfix
- Feature


### Detail
During the work on #1280 we encountered a problem: we need to revise all
share-policies in already existed AWS deployments (remove s3:*
permissions and add appropriate ones). It sounds like "migration", but
has nothing to do with database, so it doesn't belong to alembic
migrations module.

As a general solution to such problems, we introduced 'dataall
migrations' -- a module, that is called in lambda function, triggered by
update of backend stack, after DB-migrations. This module has the same
logic of migrations, which are executed one by one. Current revision uid
is stored in SSM Parameter store.

Module contains:|
 - BaseDataAllMigration -- python class  -- a base for future migrations
 - Herder -- python class, that handles migrations
- folder "versions", where python files are stored with subclasses of
BaseDataAllMigration.

Each subclass has parameters
   - uid of the revision `key = '51132fed-c36d-470c-9946-5164581856cb'
   - name ` name = 'Remove Wildcard from Sharing Policy'
- short description `description = 'Remove Wildcard from Sharing Policy'
- uid of previous migration, to specify the order `previous_migration =
'0' # initial migration`
- possible methods `up` and `down` for upgrade and downgrade
respectively
   
In order to add new migration, developer needs to create a python class
in folder "versions" (folder name is hardcoded by now) and specify the
parameters described above. Right now it can be done only manually.


### Security
Please answer the questions below briefly where applicable, or write
`N/A`. Based on
[OWASP 10](https://owasp.org/Top10/en/).

- Does this PR introduce or modify any input fields or queries - this
includes
fetching data from storage outside the application (e.g. a database, an
S3 bucket)?
  - Is the input sanitized?
- What precautions are you taking before deserializing the data you
consume?
  - Is injection prevented by parametrizing queries?
  - Have you ensured no `eval` or similar functions are used?
- Does this PR introduce any functionality or component that requires
authorization?
- How have you ensured it respects the existing AuthN/AuthZ mechanisms?
  - Are you logging failed auth attempts?
- Are you using or adding any cryptographic features?
  - Do you use a standard proven implementations?
  - Are the used keys controlled by the customer? Where are they stored?
- Are you introducing any new policies/roles/users?
  - Have you used the least-privilege principle? How?


By submitting this pull request, I confirm that my contribution is made
under the terms of the Apache 2.0 license.

---------

Co-authored-by: Sofia Sazonova <sazonova@amazon.co.uk>
  • Loading branch information
SofiaSazonova and Sofia Sazonova authored Jul 9, 2024
1 parent 82f2dc6 commit bd8e1e8
Show file tree
Hide file tree
Showing 15 changed files with 643 additions and 24 deletions.
34 changes: 13 additions & 21 deletions backend/dataall/base/aws/parameter_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,37 +29,29 @@ def client(AwsAccountId=None, region=None, role=None):
def get_parameter_value(AwsAccountId=None, region=None, parameter_path=None):
if not parameter_path:
raise Exception('Parameter name is None')
try:
parameter_value = ParameterStoreManager.client(AwsAccountId, region).get_parameter(Name=parameter_path)[
'Parameter'
]['Value']
except ClientError as e:
raise Exception(e)
parameter_value = ParameterStoreManager.client(AwsAccountId, region).get_parameter(Name=parameter_path)[
'Parameter'
]['Value']
return parameter_value

@staticmethod
def get_parameters_by_path(AwsAccountId=None, region=None, parameter_path=None):
if not parameter_path:
raise Exception('Parameter name is None')
try:
parameter_values = ParameterStoreManager.client(AwsAccountId, region).get_parameters_by_path(
Path=parameter_path
)['Parameters']
except ClientError as e:
raise Exception(e)
parameter_values = ParameterStoreManager.client(AwsAccountId, region).get_parameters_by_path(
Path=parameter_path
)['Parameters']
return parameter_values

@staticmethod
def update_parameter(AwsAccountId, region, parameter_name, parameter_value):
def update_parameter(AwsAccountId, region, parameter_name, parameter_value, parameter_type='String'):
if not parameter_name:
raise Exception('Parameter name is None')
if not parameter_value:
raise Exception('Parameter value is None')
try:
response = ParameterStoreManager.client(AwsAccountId, region).put_parameter(
Name=parameter_name, Value=parameter_value, Overwrite=True
)['Version']
except ClientError as e:
raise Exception(e)
else:
return str(response)

response = ParameterStoreManager.client(AwsAccountId, region).put_parameter(
Name=parameter_name, Value=parameter_value, Overwrite=True, Type=parameter_type
)['Version']

return str(response)
1 change: 1 addition & 0 deletions backend/dataall/core/environment/cdk/pivot_role_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def create_pivot_role(self, principal_id: str, external_id: str) -> iam.Role:
f'arn:aws:iam::{principal_id}:role/*graphql-role',
f'arn:aws:iam::{principal_id}:role/*awsworker-role',
f'arn:aws:iam::{principal_id}:role/*ecs-tasks-role',
f'arn:aws:iam::{principal_id}:role/*dataall-migration-role',
]
},
},
Expand Down
38 changes: 38 additions & 0 deletions backend/deployment_triggers/dataall_migrate_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import logging
import os
from migrations.dataall_migrations.migrationmanager import MigrationManager
from dataall.base.db import get_engine

logger = logging.getLogger()
logger.setLevel(os.environ.get('LOG_LEVEL', 'INFO'))

ENVNAME = os.environ.get('envname', 'local')
ENGINE = get_engine(envname=ENVNAME)
PARAM_KEY = f'/dataall/{ENVNAME}/dataall-migration/revision'


def get_current_revision():
with ENGINE.scoped_session() as session:
row = session.query('revision from dataall_migrations').first()
return row[0] if row else row


def put_latest_revision(old_revision, new_revision):
with ENGINE.scoped_session() as session:
if old_revision:
sql_params = "UPDATE dataall_migrations SET revision='{}' WHERE revision='{}';".format(
new_revision, old_revision
)
else:
sql_params = "INSERT INTO dataall_migrations VALUES('{}');".format(new_revision)
session.execute(sql_params)


def handler(event, context) -> None:
revision = get_current_revision()
current_key = revision or '0'
manager = MigrationManager(current_key)
new_version = manager.upgrade()
if not new_version:
raise Exception('Data.all migration failed.')
put_latest_revision(revision, new_version)
49 changes: 49 additions & 0 deletions backend/migrations/dataall_migrations/base_migration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import logging
import os
from abc import ABC, abstractmethod
from typing import Type, Union

logger = logging.getLogger()
logger.setLevel(os.environ.get('LOG_LEVEL', 'INFO'))


class MigrationBase(ABC):
@classmethod
@abstractmethod
def revision_id(cls) -> str:
"""
Uniq revision identifier.
"""
...

@classmethod
@abstractmethod
def description(cls) -> str:
"""
Short description of migration logic and purpose.
"""
...

@classmethod
@abstractmethod
def next_migration(cls) -> Union[Type['MigrationBase'], None]:
"""
Returns next migration class
"""
...

@classmethod
@abstractmethod
def up(cls) -> None:
"""
Performs upgrade
"""
...

@classmethod
@abstractmethod
def down(cls) -> None:
"""
Performs downgrade
"""
...
129 changes: 129 additions & 0 deletions backend/migrations/dataall_migrations/migrationmanager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import os
from collections import deque
from typing import Deque
from migrations.dataall_migrations.versions.initial import InitMigration
from migrations.dataall_migrations.base_migration import MigrationBase
from typing import Type, Union

import logging

logger = logging.getLogger()
logger.setLevel(os.environ.get('LOG_LEVEL', 'INFO'))


class MigrationManager:
def __init__(self, current_revision='0', initial_migration=InitMigration):
self.initial_migration = initial_migration
self.previous_migrations: Deque[Union[Type[MigrationBase], None]] = deque()
self.current_migration = initial_migration

while True:
self.previous_migrations.append(self.current_migration)
if self.current_migration.revision_id() == current_revision:
break
self.current_migration = self.current_migration.next_migration()
if not self.current_migration:
raise Exception(f'Revision {current_revision} is not found.')

def _save_upgraded(self, executed_ups: Deque[Union[Type[MigrationBase], None]]):
while executed_ups:
self.previous_migrations.append(executed_ups.popleft())

def _save_downgrades(self, executed_downs: Deque[Union[Type[MigrationBase], None]]):
while executed_downs:
try:
self.previous_migrations.remove(executed_downs.pop())
except Exception as e:
...

def _check_downgrade_id(self, target_revision_id):
for pm in self.previous_migrations:
if pm.revision_id() == target_revision_id:
return True
if target_revision_id == self.current_migration.revision_id():
return False

raise Exception(f'Failed to find {target_revision_id} in migration history.')

def _check_upgrade_id(self, target_revision_id):
if target_revision_id is None:
if (
self.current_migration.next_migration() is None
or self.current_migration.revision_id() == target_revision_id
):
return False
return True

revision = self.current_migration.next_migration()
while revision is not None:
if revision.revision_id() == target_revision_id:
return True
revision = revision.next_migration()

raise Exception(f'Failed to find {target_revision_id}.')

def upgrade(self, target_revision_id=None):
if not self._check_upgrade_id(target_revision_id):
logger.info('Data-all version is up to date')
return self.current_migration.revision_id()

logger.info(f"Upgrade from {self.current_migration.revision_id()} to {target_revision_id or 'latest'}")
executed_upgrades: Deque[Union[Type[MigrationBase], None]] = deque()
saved_start_migration = self.current_migration
self.current_migration = self.current_migration.next_migration()
while self.current_migration is not None:
try:
logger.info(f'Applying migration {self.current_migration.__name__}')
self.current_migration.up()
executed_upgrades.append(self.current_migration)
logger.info(f'Migration {self.current_migration.__name__} completed')
if (
self.current_migration.revision_id() == target_revision_id
or self.current_migration.next_migration() is None
):
break
self.current_migration = self.current_migration.next_migration()
except Exception as e:
# Something went wrong revert
logger.exception(f'An error occurred while applying the migration.{e}.')
while executed_upgrades:
migration = executed_upgrades.pop()
migration.down()
self.current_migration = saved_start_migration
return False
logger.info('Upgrade completed')
self._save_upgraded(executed_upgrades)
return self.current_migration.revision_id()

def downgrade(self, target_revision_id='0'):
if not self._check_downgrade_id(target_revision_id):
logger.info(f'Current revision is {self.current_migration.revision_id()}')
return

logger.info(f"Downgrade from {self.current_migration.revision_id()} to {target_revision_id or 'initial'}")
executed_downgrades: Deque[Union[Type[MigrationBase], None]] = deque()
while self.current_migration:
if self.previous_migrations:
self.current_migration = self.previous_migrations[-1]
if self.current_migration.revision_id() == target_revision_id:
break
self.previous_migrations.pop()
else:
break

try:
logger.info(f'Reverting migration {self.current_migration.__name__}')
self.current_migration.down()
executed_downgrades.append(self.current_migration)
logger.info(f'Migration {self.current_migration.__name__} completed')
except Exception as e:
logger.exception(f'An error occurred while reverting the migration.{e}.')
while executed_downgrades:
up_migration = executed_downgrades.pop()
up_migration.up()
self.current_migration = up_migration
return False

logger.info('Downgrade completed')
self._save_downgrades(executed_downgrades)
return self.current_migration.revision_id()
Empty file.
25 changes: 25 additions & 0 deletions backend/migrations/dataall_migrations/versions/initial.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from migrations.dataall_migrations.base_migration import MigrationBase
from migrations.dataall_migrations.versions.remove_wildcard_share_policy import RemoveWildCard
from typing import Type, Union


class InitMigration(MigrationBase):
@classmethod
def revision_id(cls) -> str:
return '0'

@classmethod
def description(cls) -> str:
return 'Initial migration'

@classmethod
def next_migration(cls) -> Union[Type['MigrationBase'], None]:
return RemoveWildCard

@classmethod
def up(cls) -> None:
print('Initial migration. Up.')

@classmethod
def down(cls) -> None:
print('Initial migration. Down')
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import os

from dataall.modules.s3_datasets_shares.services.s3_share_managed_policy_service import (
S3SharePolicyService,
IAM_S3_ACCESS_POINTS_STATEMENT_SID,
IAM_S3_BUCKETS_STATEMENT_SID,
)
from migrations.dataall_migrations.base_migration import MigrationBase
from dataall.base.aws.iam import IAM
from dataall.base.db import get_engine
from dataall.core.environment.db.environment_repositories import EnvironmentRepository
import json
from typing import Type, Union
import logging

logger = logging.getLogger()
logger.setLevel(os.environ.get('LOG_LEVEL', 'INFO'))


class RemoveWildCard(MigrationBase):
@classmethod
def revision_id(cls) -> str:
return '51132fed-c36d-470c-9946-5164581856cb'

@classmethod
def description(cls) -> str:
return 'Remove Wildcard from Sharing Policy'

@classmethod
def next_migration(cls) -> Union[Type['MigrationBase'], None]:
return None

@classmethod
def up(cls):
logger.info('removing wildcard from sharing policy')
envname = os.environ.get('envname', 'local')
engine = get_engine(envname=envname)
with engine.scoped_session() as session:
all_envs = EnvironmentRepository.query_all_active_environments(session)
for env in all_envs:
cons_roles = EnvironmentRepository.query_all_environment_consumption_roles(
session, env.environmentUri, filter=None
)
for role in cons_roles:
share_policy_service = S3SharePolicyService(
environmentUri=env.environmentUri,
account=env.AwsAccountId,
region=env.region,
role_name=role.IAMRoleName,
resource_prefix=env.resourcePrefix,
)
share_resource_policy_name = share_policy_service.generate_policy_name()
version_id, policy_document = IAM.get_managed_policy_default_version(
env.AwsAccountId, env.region, policy_name=share_resource_policy_name
)
if policy_document is not None:
statements = policy_document.get('Statement', [])
for statement in statements:
if statement['Sid'] in [
f'{IAM_S3_BUCKETS_STATEMENT_SID}S3',
f'{IAM_S3_ACCESS_POINTS_STATEMENT_SID}S3',
]:
actions = set(statement['Action'])
if 's3:*' in actions:
actions.remove('s3:*')
actions.add('s3:List*')
actions.add('s3:Describe*')
actions.add('s3:GetObject')
statement['Action'] = list(actions)
policy_document['Statement'] = statements
IAM.update_managed_policy_default_version(
env.AwsAccountId,
env.region,
share_resource_policy_name,
version_id,
json.dumps(policy_document),
)

@classmethod
def down(cls):
logger.info('Downgrade is not supported.')
Loading

0 comments on commit bd8e1e8

Please sign in to comment.