Skip to content

Commit

Permalink
Sanitize the string to avoid a connection string injection (#532)
Browse files Browse the repository at this point in the history
We should avoid possible connection injection. To do so, we have to
sanitize and validate inputs for the database
  • Loading branch information
nikpodsh authored Jun 28, 2023
1 parent a48e79a commit 4197f8b
Show file tree
Hide file tree
Showing 8 changed files with 120 additions and 31 deletions.
2 changes: 1 addition & 1 deletion backend/cdkproxymain.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def check_connect(response: Response):
return {
'DH_DOCKER_VERSION': os.environ.get('DH_DOCKER_VERSION'),
'_ts': datetime.now().isoformat(),
'message': f"Connected to database for environment {ENVNAME}({engine.dbconfig.params['host']}:{engine.dbconfig.params['port']})",
'message': f"Connected to database for environment {ENVNAME}({engine.dbconfig.host})",
}
except Exception as e:
logger.exception('DBCONNECTIONERROR')
Expand Down
2 changes: 1 addition & 1 deletion backend/dataall/aws/handlers/quicksight.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ def create_data_source_vpc(AwsAccountId, region, UserName, vpcConnectionId):
DataSourceParameters={
'AuroraPostgreSqlParameters': {
'Host': aurora_params_dict["host"],
'Port': aurora_params_dict["port"],
'Port': "5432",
'Database': aurora_params_dict["dbname"]
}
},
Expand Down
2 changes: 1 addition & 1 deletion backend/dataall/cdkproxy/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def check_connect(response: Response):
engine = connect()
return {
'_ts': datetime.now().isoformat(),
'message': f"Connected to database for environment {ENVNAME}({engine.dbconfig.params['host']}:{engine.dbconfig.params['port']})",
'message': f"Connected to database for environment {ENVNAME}({engine.dbconfig.host})",
}
except Exception as e:
logger.exception('DBCONNECTIONERROR')
Expand Down
19 changes: 10 additions & 9 deletions backend/dataall/db/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,19 +29,19 @@ def __init__(self, dbconfig: DbConfig):
dbconfig.url,
echo=False,
pool_size=1,
connect_args={'options': f"-csearch_path={dbconfig.params['schema']}"},
connect_args={'options': f"-csearch_path={dbconfig.schema}"},
)
try:
if not self.engine.dialect.has_schema(
self.engine, dbconfig.params['schema']
self.engine, dbconfig.schema
):
log.info(
f"Schema not found - init the schema {dbconfig.params['schema']}"
f"Schema not found - init the schema {dbconfig.schema}"
)
self.engine.execute(
sqlalchemy.schema.CreateSchema(dbconfig.params['schema'])
sqlalchemy.schema.CreateSchema(dbconfig.schema)
)
log.info('-- Using schema: %s --', dbconfig.params['schema'])
log.info('-- Using schema: %s --', dbconfig.schema)
except Exception as e:
log.error(f'Could not create schema: {e}')

Expand Down Expand Up @@ -124,10 +124,12 @@ def get_engine(envname=ENVNAME):
creds = json.loads(db_credentials_string['SecretString'])
user = creds['username']
pwd = creds['password']
host = param_store.get_parameter(env=envname, path='aurora/hostname')
database = param_store.get_parameter(env=envname, path='aurora/db')

db_params = {
'host': param_store.get_parameter(env=envname, path='aurora/hostname'),
'port': param_store.get_parameter(env=envname, path='aurora/port'),
'db': param_store.get_parameter(env=envname, path='aurora/db'),
'host': host,
'db': database,
'user': user,
'pwd': pwd,
'schema': schema,
Expand All @@ -136,7 +138,6 @@ def get_engine(envname=ENVNAME):
hostname = 'db' if envname == 'dkrcompose' else 'localhost'
db_params = {
'host': hostname,
'port': '5432',
'db': 'dataall',
'user': 'postgres',
'pwd': 'docker',
Expand Down
56 changes: 46 additions & 10 deletions backend/dataall/db/dbconfig.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,59 @@
import os
import re

_SANITIZE_WORD_REGEX = r"[^\w]" # A-Za-z0-9_
_SANITIZE_HOST_REGEX = r"[^\w.-]"
_SANITIZE_PWD_REGEX = r"[\"\s%+~`#$&*()|\[\]{}:;<>?!'/]+"
_AURORA_HOST_SUFFIX = "rds.amazonaws.com"
_POSTGRES_MAX_LEN = 63
_MAX_HOST_LENGTH = 253

_envname = os.getenv('envname', 'local')


class DbConfig:
def __init__(self, **kwargs):
self.params = kwargs
self.url = f"postgresql+pygresql://{self.params['user']}:{self.params['pwd']}@{self.params['host']}/{self.params['db']}"
def __init__(self, user: str, pwd: str, host: str, db: str, schema: str):
for param in (user, db, schema):
if len(param) > _POSTGRES_MAX_LEN:
raise ValueError(
f"PostgreSQL doesn't allow values more than 63 characters"
f" parameters {user}, {db}, {schema}"
)

if len(host) > _MAX_HOST_LENGTH:
raise ValueError(f"Hostname is too long: {host}")

if _envname not in ['local', 'pytest', 'dkrcompose'] and not host.lower().endswith(_AURORA_HOST_SUFFIX):
raise ValueError(f"Unknown host {host} for the rds")

self.user = self._sanitize_and_compare(_SANITIZE_WORD_REGEX, user, "username")
self.host = self._sanitize_and_compare(_SANITIZE_HOST_REGEX, host, "host")
self.db = self._sanitize_and_compare(_SANITIZE_WORD_REGEX, db, "database name")
self.schema = self._sanitize_and_compare(_SANITIZE_WORD_REGEX, schema, "schema")
pwd = self._sanitize_and_compare(_SANITIZE_PWD_REGEX, pwd, "password")
self.url = f"postgresql+pygresql://{self.user}:{pwd}@{self.host}/{self.db}"

def __str__(self):
lines = []
lines.append(' DbConfig >')
lines = [' DbConfig >']
hr = ' '.join(['+', ''.ljust(10, '-'), '+', ''.ljust(65, '-'), '+'])
lines.append(hr)
header = ' '.join(['+', 'Db Param'.ljust(10), ' ', 'Value'.ljust(65), '+'])
lines.append(header)
hr = ' '.join(['+', ''.ljust(10, '-'), '+', ''.ljust(65, '-'), '+'])
lines.append(hr)
for k in self.params:
v = self.params[k]
if k == 'pwd':
v = '*' * len(self.params[k])
lines.append(' '.join(['|', k.ljust(10), '|', v.ljust(65), '|']))
lines.append(' '.join(['|', "host".ljust(10), '|', self.host.ljust(65), '|']))
lines.append(' '.join(['|', "db".ljust(10), '|', self.db.ljust(65), '|']))
lines.append(' '.join(['|', "user".ljust(10), '|', self.user.ljust(65), '|']))
lines.append(' '.join(['|', "pwd".ljust(10), '|', "*****".ljust(65), '|']))

hr = ' '.join(['+', ''.ljust(10, '-'), '+', ''.ljust(65, '-'), '+'])
lines.append(hr)
return '\n'.join(lines)

@staticmethod
def _sanitize_and_compare(regex, string: str, param_name) -> str:
sanitized = re.sub(regex, "", string)
if sanitized != string:
raise ValueError(f"Can't create a database connection. The {param_name} parameter has invalid symbols."
f" The sanitized string length: {len(sanitized)} < original : {len(string)}")
return sanitized
7 changes: 1 addition & 6 deletions deploy/stacks/aurora.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def __init__(
):
super().__init__(scope, id, **kwargs)

# if exclude_characters property is set make sure that the pwd regex in DbConfig is changed accordingly
db_credentials = rds.DatabaseSecret(
self, f'{resource_prefix}-{envname}-aurora-db', username='dtaadmin'
)
Expand Down Expand Up @@ -152,12 +153,6 @@ def __init__(
parameter_name=f'/dataall/{envname}/aurora/hostname',
string_value=str(database.cluster_endpoint.hostname),
)
ssm.StringParameter(
self,
'DatabasePortParameter',
parameter_name=f'/dataall/{envname}/aurora/port',
string_value=str(database.cluster_endpoint.port),
)

ssm.StringParameter(
self,
Expand Down
5 changes: 2 additions & 3 deletions tests/db/test_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@ def test(db: dataall.db.Engine):
if os.getenv('local') or os.getenv('pytest'):
config: dataall.db.DbConfig = db.dbconfig
print(config)
assert config.params.get('host') == 'localhost'
assert config.params.get('port') == '5432'
assert config.params.get('schema') == 'pytest'
assert config.host == 'localhost'
assert config.schema == 'pytest'
with db.scoped_session() as session:
models = []
models = models + dataall.db.Base.__subclasses__()
Expand Down
58 changes: 58 additions & 0 deletions tests/db/test_dbconfig.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import pytest

from dataall.db import DbConfig


def test_incorrect_database():
with pytest.raises(ValueError):
DbConfig(
user='dataall',
pwd='123456789',
host="dataall.eu-west-1.rds.amazonaws.com",
db='dataall\'; DROP TABLE users;',
schema='dev'
)


def test_incorrect_user():
with pytest.raises(ValueError):
DbConfig(
user='dataall2;^&*end',
pwd='qwsufn3i20d-_s3qaSW3d2',
host="dataall.eu-west-1.rds.amazonaws.com",
db='dataall',
schema='dev'
)


def test_incorrect_pwd():
with pytest.raises(ValueError):
DbConfig(
user='dataall',
pwd='qazxsVFRTGBdfrew-332_c2@dataall.eu-west-1.rds.amazonaws.com/dataall\'; drop table dataset; # ',
host="dataall.eu-west-1.rds.amazonaws.com",
db='dataall',
schema='dev'
)


def test_incorrect_host():
with pytest.raises(ValueError):
DbConfig(
user='dataall',
pwd='q68rjdmwiosoxahGDYJWIdi-9eu93_9dJJ_',
host="dataall.eu-west-1$%#&@*#)$#.rds.amazonaws.com",
db='dataall',
schema='dev'
)


def test_correct_config():
# no exception is raised
DbConfig(
user='dataall',
pwd='q68rjdm_aX',
host="dataall.eu-west-1.rds.amazonaws.com",
db='dataall',
schema='dev'
)

0 comments on commit 4197f8b

Please sign in to comment.