Skip to content

Commit

Permalink
Merge pull request #62 from fblackburn1/allow-configure-parallel-option
Browse files Browse the repository at this point in the history
Allow to configure parallel option
  • Loading branch information
hkage authored Jun 14, 2024
2 parents 87ba777 + f864ae3 commit 9d47900
Show file tree
Hide file tree
Showing 8 changed files with 125 additions and 22 deletions.
1 change: 1 addition & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ Usage
--dump-file DUMP_FILE
Create a database dump file with the given name
--init-sql INIT_SQL SQL to run before starting anonymization
--parallel Data anonymization is done in parallel
Despite the database connection values, you will have to define a YAML schema file, that includes
all anonymization rules for that database. Take a look at the `schema documentation`_ or the
Expand Down
20 changes: 18 additions & 2 deletions pganonymize/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import logging
import time

from pganonymize.config import config
from pganonymize.config import config, validate_args_with_config
from pganonymize.constants import DATABASE_ARGS, DEFAULT_SCHEMA_FILE
from pganonymize.providers import provider_registry
from pganonymize.utils import anonymize_tables, create_database_dump, get_connection, truncate_tables
Expand Down Expand Up @@ -47,6 +47,15 @@ def get_arg_parser():
default=False)
parser.add_argument('--dump-file', help='Create a database dump file with the given name')
parser.add_argument('--init-sql', help='SQL to run before starting anonymization', default=False)
parser.add_argument(
'--parallel',
action='store_true',
help=(
'Parallelize anonymization of value.'
'WARNING: `fake.unique.*` providers are not compatible with this option'
),
default=False,
)

return parser

Expand All @@ -65,6 +74,8 @@ def main(args):

config.schema_file = args.schema

validate_args_with_config(args, config)

pg_args = get_pg_args(args)
connection = get_connection(pg_args)
if args.init_sql:
Expand All @@ -75,7 +86,12 @@ def main(args):

start_time = time.time()
truncate_tables(connection)
anonymize_tables(connection, verbose=args.verbose, dry_run=args.dry_run)
anonymize_tables(
connection,
verbose=args.verbose,
dry_run=args.dry_run,
parallel=args.parallel,
)

if not args.dry_run:
connection.commit()
Expand Down
12 changes: 12 additions & 0 deletions pganonymize/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import re

import yaml
from pganonymize.exceptions import InvalidConfiguration


class Config(object):
Expand Down Expand Up @@ -55,3 +56,14 @@ def constructor_env_variables(loader, node):


config = Config()


def validate_args_with_config(args, config):
definitions = config.schema.get('tables', [])
for definition in definitions:
table_definition = list(definition.values())[0]
columns = table_definition.get('fields', [])
for column in columns:
column_config = list(column.values())[0]
if args.parallel and column_config['provider']['name'].startswith('fake.unique'):
raise InvalidConfiguration('`--parallel` option and `fake.unique.*` providers are incompatible')
4 changes: 4 additions & 0 deletions pganonymize/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,7 @@ class ProviderAlreadyRegistered(PgAnonymizeException):

class BadDataFormat(PgAnonymizeException):
"""Raised if the anonymized data cannot be copied."""


class InvalidConfiguration(PgAnonymizeException):
"""Raised if configuration is invalid."""
36 changes: 30 additions & 6 deletions pganonymize/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,14 @@
psycopg2.extras.register_uuid()


def anonymize_tables(connection, verbose=False, dry_run=False):
def anonymize_tables(connection, verbose=False, dry_run=False, parallel=False):
"""
Anonymize a list of tables according to the schema definition.
:param connection: A database connection instance.
:param bool verbose: Display logging information and a progress bar.
:param bool dry_run: Script is running in dry-run mode, no commit expected.
:param bool parallel: Data anonymization is done in parallel.
"""
definitions = config.schema.get('tables', [])
for definition in definitions:
Expand All @@ -44,8 +45,19 @@ def anonymize_tables(connection, verbose=False, dry_run=False):
primary_key = table_definition.get('primary_key', DEFAULT_PRIMARY_KEY)
total_count = get_table_count(connection, table_name, dry_run)
chunk_size = table_definition.get('chunk_size', DEFAULT_CHUNK_SIZE)
build_and_then_import_data(connection, table_name, primary_key, columns, excludes,
search, total_count, chunk_size, verbose=verbose, dry_run=dry_run)
build_and_then_import_data(
connection,
table_name,
primary_key,
columns,
excludes,
search,
total_count,
chunk_size,
verbose=verbose,
dry_run=dry_run,
parallel=parallel,
)
end_time = time.time()
logging.info('{} anonymization took {:.2f}s'.format(table_name, end_time - start_time))

Expand All @@ -63,8 +75,19 @@ def process_row(row, columns, excludes):
return row


def build_and_then_import_data(connection, table, primary_key, columns,
excludes, search, total_count, chunk_size, verbose=False, dry_run=False):
def build_and_then_import_data(
connection,
table,
primary_key,
columns,
excludes,
search,
total_count,
chunk_size,
verbose=False,
dry_run=False,
parallel=False,
):
"""
Select all data from a table and return it together with a list of table columns.
Expand All @@ -78,6 +101,7 @@ def build_and_then_import_data(connection, table, primary_key, columns,
:param int chunk_size: Number of data rows to fetch with the cursor
:param bool verbose: Display logging information and a progress bar.
:param bool dry_run: Script is running in dry-run mode, no commit expected.
:param bool parallel: Data anonymization is done in parallel.
"""
column_names = get_column_names(columns)
sql_columns = SQL(', ').join([Identifier(column_name) for column_name in [primary_key] + column_names])
Expand All @@ -95,7 +119,7 @@ def build_and_then_import_data(connection, table, primary_key, columns,
for i in trange(batches, desc="Processing {} batches for {}".format(batches, table), disable=not verbose):
records = cursor.fetchmany(size=chunk_size)
if records:
data = parmap.map(process_row, records, columns, excludes, pm_pbar=verbose)
data = parmap.map(process_row, records, columns, excludes, pm_pbar=verbose, pm_parallel=parallel)
import_data(connection, temp_table, [primary_key] + column_names, filter(None, data))
apply_anonymized_data(connection, temp_table, table, primary_key, columns)

Expand Down
12 changes: 6 additions & 6 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class TestCli(object):
@pytest.mark.parametrize('cli_args, expected, expected_executes, commit_calls, call_dump', [
['--host localhost --port 5432 --user root --password my-cool-password --dbname db --schema ./tests/schemes/valid_schema.yml -v --init-sql "set work_mem=\'1GB\'"', # noqa
Namespace(verbose=1, list_providers=False, schema='./tests/schemes/valid_schema.yml', dbname='db', user='root',
password='my-cool-password', host='localhost', port='5432', dry_run=False, dump_file=None, init_sql="set work_mem='1GB'"), # noqa
password='my-cool-password', host='localhost', port='5432', dry_run=False, dump_file=None, init_sql="set work_mem='1GB'", parallel=False), # noqa
[call("set work_mem='1GB'"),
call('TRUNCATE TABLE "django_session"'),
call('SELECT COUNT(*) FROM "auth_user"'),
Expand All @@ -32,7 +32,7 @@ class TestCli(object):
],
['--dry-run --host localhost --port 5432 --user root --password my-cool-password --dbname db --schema ./tests/schemes/valid_schema.yml -v --init-sql "set work_mem=\'1GB\'"', # noqa
Namespace(verbose=1, list_providers=False, schema='./tests/schemes/valid_schema.yml', dbname='db', user='root',
password='my-cool-password', host='localhost', port='5432', dry_run=True, dump_file=None, init_sql="set work_mem='1GB'"), # noqa
password='my-cool-password', host='localhost', port='5432', dry_run=True, dump_file=None, init_sql="set work_mem='1GB'", parallel=False), # noqa
[call("set work_mem='1GB'"),
call('TRUNCATE TABLE "django_session"'),
call('SELECT "id", "first_name", "last_name", "email" FROM "auth_user" LIMIT 100'),
Expand All @@ -44,7 +44,7 @@ class TestCli(object):
],
['--dump-file ./dump.sql --host localhost --port 5432 --user root --password my-cool-password --dbname db --schema ./tests/schemes/valid_schema.yml -v --init-sql "set work_mem=\'1GB\'"', # noqa
Namespace(verbose=1, list_providers=False, schema='./tests/schemes/valid_schema.yml', dbname='db', user='root',
password='my-cool-password', host='localhost', port='5432', dry_run=False, dump_file='./dump.sql', init_sql="set work_mem='1GB'"), # noqa
password='my-cool-password', host='localhost', port='5432', dry_run=False, dump_file='./dump.sql', init_sql="set work_mem='1GB'", parallel=False), # noqa
[
call("set work_mem='1GB'"),
call('TRUNCATE TABLE "django_session"'),
Expand All @@ -56,12 +56,12 @@ class TestCli(object):
call('UPDATE "auth_user" t SET "first_name" = s."first_name", "last_name" = s."last_name", "email" = s."email" FROM "tmp_auth_user" s WHERE t."id" = s."id"') # noqa
],
1,
[call('PGPASSWORD=my-cool-password pg_dump -Fc -Z 9 -d db -U root -h localhost -p 5432 -f ./dump.sql', shell=True)]
[call('PGPASSWORD=my-cool-password pg_dump -Fc -Z 9 -d db -U root -h localhost -p 5432 -f ./dump.sql', shell=True)] # noqa
],
['--list-providers',
['--list-providers --parallel',
Namespace(verbose=None, list_providers=True, schema='schema.yml', dbname=None, user=None,
password='', host='localhost', port='5432', dry_run=False, dump_file=None, init_sql=False),
password='', host='localhost', port='5432', dry_run=False, dump_file=None, init_sql=False, parallel=True), # noqa
[], 0, []
]
])
Expand Down
40 changes: 38 additions & 2 deletions tests/test_config.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import os

import pytest
from mock import patch
from mock import patch, Mock

from pganonymize.config import load_schema
from pganonymize.config import load_schema, validate_args_with_config
from pganonymize.exceptions import InvalidConfiguration


@pytest.mark.parametrize('file, envs, expected', [
Expand Down Expand Up @@ -59,3 +60,38 @@
def test_load_schema(file, envs, expected):
with patch.dict(os.environ, envs):
assert load_schema(file) == expected


def test_validate_args_with_config_when_valid():
args = Mock(parallel=False)
schema = {
'tables': [
{
'table_name': {
'fields': [
{'column_name': {'provider': {'name': 'fake.unique.pystr'}}}
]
}
}
]
}
config = Mock(schema=schema)
validate_args_with_config(args, config)


def test_validate_args_with_config_when_invalid():
args = Mock(parallel=True)
schema = {
'tables': [
{
'table_name': {
'fields': [
{'column_name': {'provider': {'name': 'fake.unique.pystr'}}}
]
}
}
]
}
config = Mock(schema=schema)
with pytest.raises(InvalidConfiguration):
validate_args_with_config(args, config)
22 changes: 16 additions & 6 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,12 +236,22 @@ class TestCreateDatabaseDump(object):

@patch('pganonymize.utils.subprocess.call')
def test(self, mock_call):
create_database_dump('/tmp/dump.gz', {'dbname': 'database', 'user': 'foo', 'host': 'localhost', 'port': 5432})
mock_call.assert_called_once_with('pg_dump -Fc -Z 9 -d database -U foo -h localhost -p 5432 -f /tmp/dump.gz',
shell=True)
create_database_dump(
'/tmp/dump.gz',
{'dbname': 'database', 'user': 'foo', 'host': 'localhost', 'port': 5432},
)
mock_call.assert_called_once_with(
'pg_dump -Fc -Z 9 -d database -U foo -h localhost -p 5432 -f /tmp/dump.gz',
shell=True,
)

@patch('pganonymize.utils.subprocess.call')
def test_with_password(self, mock_call):
create_database_dump('/tmp/dump.gz', {'dbname': 'database', 'user': 'foo', 'host': 'localhost', 'port': 5432, 'password': 'pass'})
mock_call.assert_called_once_with('PGPASSWORD=pass pg_dump -Fc -Z 9 -d database -U foo -h localhost -p 5432 -f /tmp/dump.gz',
shell=True)
create_database_dump(
'/tmp/dump.gz',
{'dbname': 'database', 'user': 'foo', 'host': 'localhost', 'port': 5432, 'password': 'pass'},
)
mock_call.assert_called_once_with(
'PGPASSWORD=pass pg_dump -Fc -Z 9 -d database -U foo -h localhost -p 5432 -f /tmp/dump.gz',
shell=True,
)

0 comments on commit 9d47900

Please sign in to comment.