diff --git a/additional_codes/tests/test_business_rules.py b/additional_codes/tests/test_business_rules.py index 976242da5..0dbc55fa6 100644 --- a/additional_codes/tests/test_business_rules.py +++ b/additional_codes/tests/test_business_rules.py @@ -84,6 +84,7 @@ def test_ACN2_type_must_exist(reference_nonexistent_record): def test_ACN2_allowed_application_codes(app_code, expect_error): """The referenced additional code type must have as application code "non- Meursing" or "Export Refund for Processed Agricultural Goods”.""" + additional_code = factories.AdditionalCodeFactory.create( type__application_code=app_code, ) diff --git a/exporter/management/commands/export_quotas.py b/exporter/management/commands/export_quotas.py new file mode 100644 index 000000000..1c7ea385b --- /dev/null +++ b/exporter/management/commands/export_quotas.py @@ -0,0 +1,50 @@ +import logging +from typing import Any +from typing import Optional + +from django.core.management import BaseCommand +from django.core.management.base import CommandParser + +from exporter.quotas.tasks import export_and_upload_quotas_csv + +logger = logging.getLogger(__name__) + + +class Command(BaseCommand): + help = ( + "Create a CSV of quotas for use within data workspace to produce the " + "HMRC tariff open data CSV. The filename take the form " + "quotas_export_.csv. Care should be taken to ensure that " + "there is sufficient local file system storage to accommodate the " + "CSV file (although it should not be very large, less than 5MB " + "(1.8MB at time of creation) - if you choose to target remote S3 " + "storage, then a temporary local copy of the file will be created " + "and cleaned up." + ) + + def add_arguments(self, parser: CommandParser) -> None: + parser.add_argument( + "--asynchronous", + action="store_const", + help="Queue the CSV export task to run in an asynchronous process.", + const=True, + default=False, + ) + parser.add_argument( + "--save-local", + help=( + "Save the quotas CSV to the local file system under the " + "(existing) directory given by DIRECTORY_PATH." + ), + dest="DIRECTORY_PATH", + ) + return super().add_arguments(parser) + + def handle(self, *args: Any, **options: Any) -> Optional[str]: + logger.info(f"Triggering quotas export to CSV") + + local_path = options["DIRECTORY_PATH"] + if options["asynchronous"]: + export_and_upload_quotas_csv.delay(local_path) + else: + export_and_upload_quotas_csv(local_path) diff --git a/exporter/quotas/__init__.py b/exporter/quotas/__init__.py new file mode 100644 index 000000000..9fac8697d --- /dev/null +++ b/exporter/quotas/__init__.py @@ -0,0 +1,43 @@ +""" +quotas Export +============= + +The quotas export system will query the TAP database for published quota data and store in a CSV +file. + +The general process is: + +1. query the TAP database for the correct dataset to export. +2. Iterate the query result and create the data for the output. +3. Write the data to the CSV file +4. Upload the result to the designated storage (S3 or Local) + + +This process has been chosen to optimise for: + +- Speed, query and data production speed will be a lot faster when processed at source. +- Testability, We have the facility to test the output and process within TAP effectively +- Adaptability, With test coverage highlighting any issues caused by database changes etc., the adaptability +if this implementation is high +- Data Quality, Using TAP to produce the data will improve the quality of the output as it's using the same filters +and joins as TAP its self does - removing the need to run queries in SQL which has been problematic, and is +difficult to maintain. +""" + +import os +import shutil +from itertools import chain +from pathlib import Path +from tempfile import NamedTemporaryFile + +import apsw +from django.apps import apps +from django.conf import settings + +from exporter.quotas import runner +from exporter.quotas import tasks + + +def make_export(quotas_csv_named_temp_file: NamedTemporaryFile): + quota_csv_exporter = runner.QuotaExport(quotas_csv_named_temp_file) + quota_csv_exporter.run() diff --git a/exporter/quotas/runner.py b/exporter/quotas/runner.py new file mode 100644 index 000000000..b0515b0d9 --- /dev/null +++ b/exporter/quotas/runner.py @@ -0,0 +1,351 @@ +import csv +import logging +from datetime import date +from datetime import timedelta +from tempfile import NamedTemporaryFile +from typing import List + +from django.db.models import Q + +from commodities.models import GoodsNomenclatureDescription +from measures.models import Measure +from quotas.models import QuotaDefinition +from quotas.models import QuotaOrderNumberOrigin + +logger = logging.getLogger(__name__) + + +def normalise_loglevel(loglevel): + """ + Attempt conversion of `loglevel` from a string integer value (e.g. "20") to + its loglevel name (e.g. "INFO"). + + This function can be used after, for instance, copying log levels from + environment variables, when the incorrect representation (int as string + rather than the log level name) may occur. + """ + try: + return logging._levelToName.get(int(loglevel)) + except: + return loglevel + + +class QuotaExport: + """Runs the export command against TAP data to extract quota CSV data.""" + + def __init__(self, target_file: NamedTemporaryFile): + # self.rows = [] + # self.quotas = None + self.target_file = target_file + + @staticmethod + def csv_headers(): + """ + Produces a list of headers for the CSV. + + Returns: + list: list of header names + """ + quota_headers = [ + "quota_definition__sid", + "quota__order_number", + "quota__geographical_areas", + "quota__geographical_area_exclusions", + "quota__headings", + "quota__commodities", + "quota__measurement_unit", + "quota__monetary_unit", + "quota_definition__description", + "quota_definition__validity_start_date", + "quota_definition__validity_end_date", + # 'quota_definition__suspension_periods', from HMRC data + # 'quota_definition__blocking_periods', from HMRC data + # 'quota_definition__status', from HMRC data + # 'quota_definition__last_allocation_date', from HMRC data + "quota_definition__initial_volume", + # 'quota_definition__balance', from HMRC data + # 'quota_definition__fill_rate', from HMRC data + "api_query_date", # used to query the HMRC API + ] + + return quota_headers + + def run(self): + """ + Produces data for the quota export CSV, from the TAP database. + + Returns: + None: Operations performed and stored within the NamedTemporaryFile + """ + + quotas = QuotaDefinition.objects.latest_approved().filter( + sid__gte=20000, + valid_between__startswith__lte=date.today() + timedelta(weeks=52 * 3), + ) + + with open(self.target_file.name, "wt") as file: + writer = csv.writer(file) + writer.writerow(self.csv_headers()) + for quota in quotas: + item_ids = self.get_goods_nomenclature_item_ids(quota) + geographical_areas, geographical_area_exclusions = ( + self.get_geographical_areas_and_exclusions(quota) + ) + goods_nomenclature_headings = self.get_goods_nomenclature_headings( + item_ids, + ) + if geographical_areas != "" and goods_nomenclature_headings != "": + quota_data = [ + quota.sid, + quota.order_number.order_number, + geographical_areas, + geographical_area_exclusions, + goods_nomenclature_headings, + "|".join(item_ids), + self.get_measurement_unit(quota), + self.get_monetary_unit(quota), + quota.description, + quota.valid_between.lower, + quota.valid_between.upper, + quota.initial_volume, + self.get_api_query_date(quota), + ] + + writer.writerow(quota_data) + + @staticmethod + def get_geographical_areas_and_exclusions(quota): + """ + Returns a tuple of geographical areas and exclusions associated with a + Quota. + + Args: + quota: the quota to be queried + + Returns: + tuple(str, str) : geographical areas and exclusions + """ + geographical_areas = [] + geographical_area_exclusions = [] + + # get all geographical areas that are / were / will be enabled on the end date of the quota + for origin in ( + QuotaOrderNumberOrigin.objects.latest_approved() + .filter( + order_number__order_number=quota.order_number.order_number, + valid_between__startswith__lte=quota.valid_between.upper, + ) + .filter( + Q(valid_between__endswith__gte=quota.valid_between.upper) + | Q(valid_between__endswith=None), + ) + ): + geographical_areas.append( + origin.geographical_area.descriptions.latest_approved() + .last() + .description, + ) + for ( + exclusion + ) in origin.quotaordernumberoriginexclusion_set.latest_approved(): + geographical_area_exclusions.append( + f"{exclusion.excluded_geographical_area.descriptions.latest_approved().last().description} excluded from {origin.geographical_area.descriptions.latest_approved().last().description}", + ) + + geographical_areas_str = "|".join(geographical_areas) + geographical_area_exclusions_str = "|".join(geographical_area_exclusions) + + return geographical_areas_str, geographical_area_exclusions_str + + @staticmethod + def get_monetary_unit(quota): + """ + Returns the monetary unit associated with a Quota as a string. + + Args: + quota: the quota to be queried + + Returns: + str or None: Monetary unit as string or None + """ + monetary_unit = None + if quota.monetary_unit: + monetary_unit = ( + f"{quota.monetary_unit.description} ({quota.monetary_unit.code})" + ) + return monetary_unit + + @staticmethod + def get_measurement_unit(quota): + """ + Returns the measurement unit associated with a Quota as a string. + + Args: + quota: the quota to be queried + + Returns: + str or None: Measurement unit as string or None + """ + if quota.measurement_unit: + measurement_unit_description = f"{quota.measurement_unit.description}" + if quota.measurement_unit.abbreviation != "": + measurement_unit_description = ( + measurement_unit_description + + f" ({quota.measurement_unit.abbreviation})" + ) + return measurement_unit_description + return None + + @staticmethod + def get_api_query_date(quota): + """ + Returns the most appropriate date for querying the HMRC API. + + Dates are checked against current date and collected, the oldest of + the dates is used as the API query date. Typically, this wil be today's date + or the end date of the quota if < today's date + + note: quotas that start in the future will not be populated on the HMRC API so + None s returned to indicate this query can be safely skipped + + Args: + quota: The quota to be queried + + Returns: + date or none: a date if available or none if quota is in the future + """ + api_query_dates = [] + + # collect possible query dates, but only for current and historical, not future + if quota.valid_between.lower <= date.today(): + if quota.valid_between.upper: + # when not infinity + api_query_dates.append(quota.valid_between.upper) + else: + # when infinity + api_query_dates.append(date.today()) + + tap_measures = quota.order_number.measure_set.latest_approved().filter( + # has valid between with end date and today's date is within that range + Q( + valid_between__startswith__lte=date.today(), + valid_between__endswith__gte=date.today(), + ) + | + # has an open-ended date range but started before today + Q( + valid_between__startswith__lte=date.today(), + valid_between__endswith=None, + ), + ) + + for tap_measure in tap_measures: + if tap_measure.valid_between.upper is None: + api_query_dates.append(date.today()) + else: + api_query_dates.append(tap_measure.valid_between.upper) + + api_query_dates.sort() + else: + # no query dates for future quotas + api_query_dates = [None] + + return api_query_dates[0] + + @staticmethod + def get_associated_measures(quota): + """ + Returns associated measures for the quota. + + Args: + quota: The quota to be queried + + Returns: + TrackedModelQuerySet(Measures): A set of measures associated with the + provided quota + """ + measures = ( + Measure.objects.latest_approved() + .filter( + order_number=quota.order_number, + valid_between__startswith__lte=quota.valid_between.upper, + ) + .filter( + Q( + valid_between__endswith__gte=quota.valid_between.upper, + ) + | Q( + valid_between__endswith=None, + ), + ) + ) + + return measures + + def get_goods_nomenclature_item_ids(self, quota): + """ + Collects associated item_ids for a quota. + + Args: + quota: The quota to be queried + + Returns: + list(str): list of strings each containing the associated item_id for a + measure + """ + item_ids = [] + for measure in self.get_associated_measures(quota): + item_ids.append(measure.goods_nomenclature.item_id) + + return item_ids + + def get_goods_nomenclature_headings(self, item_ids: List[str]): + """ + Returns a string representing the headings and descriptions for measures + associated with a quota. Headings are at the 4 digit level, e.g. + 1234000000. + + Args: + item_ids: list(str) : a list of strings representing item_ids + + Returns: + str: unique headings and associated descriptions for each heading seperated + by the "|" character (bar) + """ + heading_item_ids = [] + headings = [] + + for item_id in item_ids: + heading_item_id = item_id[:4] + if heading_item_id not in heading_item_ids: + heading_and_desc = ( + heading_item_id + + "-" + + self.get_goods_nomenclature_description( + heading_item_id + "000000", + ) + ) + headings.append(heading_and_desc) + heading_item_ids.append(heading_item_id) + + return "|".join(headings) + + @staticmethod + def get_goods_nomenclature_description(item_id): + """ + Returns the description associated with an item_id. + + Args: + item_id: the item_id to be queried + + Returns: + str: the current description for the item_id + """ + description = ( + GoodsNomenclatureDescription.objects.latest_approved() + .filter(described_goods_nomenclature__item_id=item_id) + .order_by("-validity_start") + .first() + ) + + return description.description diff --git a/exporter/quotas/tasks.py b/exporter/quotas/tasks.py new file mode 100644 index 000000000..1f7a1de3f --- /dev/null +++ b/exporter/quotas/tasks.py @@ -0,0 +1,56 @@ +import logging +import os +from datetime import date + +from common.celery import app +from exporter import storages + +logger = logging.getLogger(__name__) + + +def get_output_filename(): + """ + Generate output filename with transaction order field. + + If no revisions are present the filename is prefixed with seed_. + """ + date_str = f"{date.today().strftime('%Y%m%d')}" + return f"quotas_export_{date_str}.csv" + + +@app.task +def export_and_upload_quotas_csv(local_path: str = None) -> bool: + """ + Generates an export of latest published quota data from the TAP database to + a CSV file. + + If `local_path` is provided, then the quotas CSV file will be saved in + that directory location (note that in this case `local_path` must be an + existing directory path on the local file system). + + If `local_path` is not provided, then the quotas CSV file will be saved + to the configured S3 bucket. + """ + csv_file_name = get_output_filename() + + if local_path: + logger.info("Quota export process targeting local file system.") + storage = storages.QuotaLocalStorage(location=local_path) + else: + logger.info("Quota export process targeting S3 file system.") + storage = storages.QuotaS3Storage() + + export_filename = storage.generate_filename(csv_file_name) + + logger.info(f"Checking for existing file {export_filename}.") + if storage.exists(export_filename): + logger.info( + f"file {export_filename} already exists. Exiting process, " + f"pid={os.getpid()}.", + ) + return False + + logger.info(f"Generating quotas CSV export {export_filename}.") + storage.export_csv(export_filename) + logger.info(f"Quotas CSV export {export_filename} complete.") + return True diff --git a/exporter/sqlite/tasks.py b/exporter/sqlite/tasks.py index b803d662a..451e477a4 100644 --- a/exporter/sqlite/tasks.py +++ b/exporter/sqlite/tasks.py @@ -47,10 +47,10 @@ def export_and_upload_sqlite(local_path: str = None) -> bool: db_name = get_output_filename() if local_path: - logger.info("SQLite export process targetting local file system.") + logger.info("SQLite export process targeting local file system.") storage = storages.SQLiteLocalStorage(location=local_path) else: - logger.info("SQLite export process targetting S3 file system.") + logger.info("SQLite export process targeting S3 file system.") storage = storages.SQLiteS3Storage() export_filename = storage.generate_filename(db_name) diff --git a/exporter/storages.py b/exporter/storages.py index 6b2c65d0a..f68ca7147 100644 --- a/exporter/storages.py +++ b/exporter/storages.py @@ -1,4 +1,6 @@ import logging +import os +import shutil import sqlite3 from functools import cached_property from os import path @@ -11,6 +13,7 @@ from storages.backends.s3boto3 import S3Boto3Storage from common.util import log_timing +from exporter import quotas from exporter import sqlite logger = logging.getLogger(__name__) @@ -50,6 +53,29 @@ def is_valid_sqlite(file_path: str) -> bool: return True +def is_valid_quotas_csv(file_path: str) -> bool: + """ + `file_path` should be a path to a file on the local file system. Validation. + + includes: + - test that a file exists at `file_path`, + - test that the file at `file_path` has non-zero size, + + If errors are found during validation, then exceptions that this function + may raise include: + - FileNotFoundError if no file was found at `file_path`. + - exporter.storage.EmptyFileException if the file at `file_path` has + zero size. + + Returns True if validation checks all pass. + """ + + if path.getsize(file_path) == 0: + raise EmptyFileException(f"{file_path} has zero size.") + + return True + + class HMRCStorage(S3Boto3Storage): def get_default_settings(self): # Importing settings here makes it possible for tests to override_settings @@ -68,19 +94,19 @@ def get_object_parameters(self, name): return super().get_object_parameters(name) -class SQLiteExportMixin: +class DataExportMixin: """Mixin class used to define a common export API among SQLite Storage subclasses.""" def export_database(self, filename: str): """Export Tamato's primary database to an SQLite file format, saving to - Storage's backing store (S3, local file system, etc).""" + Storage's backing store (S3, local file system, etc.).""" raise NotImplementedError class SQLiteS3StorageBase(S3Boto3Storage): """Storage base class used for remotely storing SQLite database files to an - AWS S3-like backing store (AWS S3, Minio, etc).""" + AWS S3-like backing store (AWS S3, Minio, etc.).""" def get_default_settings(self): from django.conf import settings @@ -104,7 +130,37 @@ def generate_filename(self, filename: str) -> str: return super().generate_filename(filename) -class SQLiteS3VFSStorage(SQLiteExportMixin, SQLiteS3StorageBase): +class QuotasExportS3StorageBase(S3Boto3Storage): + """Storage base class used for remotely storing Quotas Export CSV file to an + AWS S3-like backing store (AWS S3, Minio, etc.).""" + + def get_default_settings(self): + from django.conf import settings + + quotas_s3_settings = dict( + super().get_default_settings(), + bucket_name=settings.QUOTAS_EXPORT_STORAGE_BUCKET_NAME, + access_key=settings.QUOTAS_EXPORT_S3_ACCESS_KEY_ID, + secret_key=settings.QUOTAS_EXPORT_S3_SECRET_ACCESS_KEY, + region_name=settings.QUOTAS_EXPORT_S3_REGION_NAME, + endpoint_url=settings.S3_ENDPOINT_URL, + default_acl="private", + ) + print(quotas_s3_settings) + + return quotas_s3_settings + + def generate_filename(self, filename: str) -> str: + from django.conf import settings + + filename = path.join( + settings.QUOTAS_EXPORT_DESTINATION_FOLDER, + filename, + ) + return super().generate_filename(filename) + + +class SQLiteS3VFSStorage(DataExportMixin, SQLiteS3StorageBase): """ Storage class used for remotely storing SQLite database files to an AWS S3-like backing store. @@ -132,7 +188,7 @@ def export_database(self, filename: str): self.bucket.Object(filename).upload_fileobj(vfs_fileobj) -class SQLiteS3Storage(SQLiteExportMixin, SQLiteS3StorageBase): +class SQLiteS3Storage(DataExportMixin, SQLiteS3StorageBase): """ Storage class used for remotely storing SQLite database files to an AWS S3-like backing store. @@ -153,7 +209,7 @@ def export_database(self, filename: str): self.save(filename, temp_sqlite_db.file) -class SQLiteLocalStorage(SQLiteExportMixin, Storage): +class SQLiteLocalStorage(DataExportMixin, Storage): """Storage class used for storing SQLite database files to the local file system.""" @@ -175,3 +231,54 @@ def export_database(self, filename: str): logger.info(f"Saving {filename} to local file system storage.") sqlite.make_export(connection) connection.close() + + +class QuotaLocalStorage(Storage): + """Storage class used for storing quota CSV data to the local file + system.""" + + def __init__(self, location) -> None: + self._location = Path(location).expanduser().resolve() + logger.info(f"Normalised path `{location}` to `{self._location}`.") + if not self._location.is_dir(): + raise Exception(f"Directory does not exist: {location}.") + + def path(self, name: str) -> str: + return str(self._location.joinpath(name)) + + def exists(self, name: str) -> bool: + return Path(self.path(name)).exists() + + @log_timing(logger_function=logger.info) + def export_csv(self, filename: str): + with NamedTemporaryFile() as quotas_csv_named_temp_file: + logger.info(f"Saving {filename} to local file system storage.") + quotas.make_export(quotas_csv_named_temp_file) + if is_valid_quotas_csv(quotas_csv_named_temp_file.name): + # Only save to S3 if the CSV file is valid. + destination_file_path = os.path.join(self._location, filename) + self.save_local(destination_file_path, quotas_csv_named_temp_file) + + def save_local(self, destination_file_path, file_to_save): + shutil.copy(file_to_save.name, destination_file_path) + + +class QuotaS3Storage(DataExportMixin, QuotasExportS3StorageBase): + """ + Storage class used for remotely storing SQLite database files to an AWS + S3-like backing store. + + This class applies a strategy that first creates a temporary instance of the + SQLite file on the local file system before transferring its contents to S3. + """ + + @log_timing(logger_function=logger.info) + def export_csv(self, filename: str): + with NamedTemporaryFile() as quotas_csv_named_temp_file: + quotas.make_export(quotas_csv_named_temp_file) + logger.info(f"Saving {filename} to S3 storage.") + if is_valid_quotas_csv(quotas_csv_named_temp_file.name): + # Only save to S3 if the CSV file is valid. + self.save(filename, quotas_csv_named_temp_file.file) + + os.unlink(quotas_csv_named_temp_file.name) diff --git a/exporter/tests/quotas_export/test_quotas.py b/exporter/tests/quotas_export/test_quotas.py new file mode 100644 index 000000000..c603be318 --- /dev/null +++ b/exporter/tests/quotas_export/test_quotas.py @@ -0,0 +1,6 @@ +from exporter.quotas import make_export + + +def test_make_export(): + + make_export diff --git a/exporter/tests/quotas_export/test_quotas_runner.py b/exporter/tests/quotas_export/test_quotas_runner.py new file mode 100644 index 000000000..23707679f --- /dev/null +++ b/exporter/tests/quotas_export/test_quotas_runner.py @@ -0,0 +1,317 @@ +from datetime import date +from datetime import timedelta +from tempfile import NamedTemporaryFile + +import pytest + +from common.tests import factories +from common.util import TaricDateRange +from exporter.quotas.runner import QuotaExport + +pytestmark = pytest.mark.django_db + + +@pytest.mark.exporter +class TestQuotaExport: + target_class = QuotaExport + + def get_target(self): + ntf = NamedTemporaryFile() + return self.target_class(ntf) + + def test_init(self): + ntf = NamedTemporaryFile() + target = self.target_class(ntf) + assert target.target_file == ntf + + def test_csv_headers(self): + target = self.get_target() + assert len(target.csv_headers()) == 13 + + def test_get_geographical_areas_and_exclusions(self): + # seed setup + quota = factories.QuotaDefinitionFactory( + valid_between=TaricDateRange( + date.today() + timedelta(days=-365), + date.today() + timedelta(days=+365), + ), + ) + exclusion = factories.QuotaOrderNumberOriginExclusionFactory( + origin=quota.order_number.quotaordernumberorigin_set.first(), + ) + + target = self.get_target() + geo_areas, geo_area_exclusions = target.get_geographical_areas_and_exclusions( + quota, + ) + for origin in quota.order_number.quotaordernumberorigin_set.all(): + assert ( + origin.geographical_area.descriptions.all().last().description + in geo_areas + ) + assert ( + exclusion.excluded_geographical_area.descriptions.all().last().description + in geo_area_exclusions + ) + + def test_get_monetary_unit_populated(self): + quota = factories.QuotaDefinitionFactory( + valid_between=TaricDateRange( + date.today() + timedelta(days=-365), + date.today() + timedelta(days=+365), + ), + monetary_unit=factories.MonetaryUnitFactory(), + measurement_unit=None, + ) + + target = self.get_target() + assert ( + target.get_monetary_unit(quota) + == f"{quota.monetary_unit.description} ({quota.monetary_unit.code})" + ) + + def test_get_monetary_unit_blank(self): + quota = factories.QuotaDefinitionFactory( + valid_between=TaricDateRange( + date.today() + timedelta(days=-365), + date.today() + timedelta(days=+365), + ), + monetary_unit=None, + measurement_unit=factories.MeasurementUnitFactory(), + ) + + target = self.get_target() + assert target.get_monetary_unit(quota) is None + + def test_get_measurement_unit_blank(self): + quota = factories.QuotaDefinitionFactory( + valid_between=TaricDateRange( + date.today() + timedelta(days=-365), + date.today() + timedelta(days=+365), + ), + monetary_unit=factories.MonetaryUnitFactory(), + measurement_unit=None, + ) + + target = self.get_target() + assert target.get_measurement_unit(quota) is None + + def test_get_measurement_unit_populated(self): + quota = factories.QuotaDefinitionFactory( + valid_between=TaricDateRange( + date.today() + timedelta(days=-365), + date.today() + timedelta(days=+365), + ), + monetary_unit=None, + measurement_unit=factories.MeasurementUnitFactory( + code="AAA", + description="BBB", + ), + ) + + target = self.get_target() + assert target.get_measurement_unit(quota) == "BBB" + + def test_get_measurement_unit_populated_with_abbreviation(self): + quota = factories.QuotaDefinitionFactory( + valid_between=TaricDateRange( + date.today() + timedelta(days=-365), + date.today() + timedelta(days=+365), + ), + monetary_unit=None, + measurement_unit=factories.MeasurementUnitFactory( + code="AAA", + description="BBB", + abbreviation="zzz", + ), + ) + + target = self.get_target() + assert target.get_measurement_unit(quota) == "BBB (zzz)" + + def test_get_api_query_date(self): + quota = factories.QuotaDefinitionFactory( + valid_between=TaricDateRange( + date.today() + timedelta(days=-365), + date.today() + timedelta(days=365), + ), + ) + + target = self.get_target() + assert target.get_api_query_date(quota) == date.today() + timedelta(days=+365) + + def test_get_api_query_measures_included(self): + quota = factories.QuotaDefinitionFactory( + valid_between=TaricDateRange( + date.today() + timedelta(days=-365), + ), + ) + + factories.MeasureFactory( + order_number=quota.order_number, + valid_between=TaricDateRange( + date.today() + timedelta(days=-365), + ), + ) + + factories.MeasureFactory( + order_number=quota.order_number, + valid_between=TaricDateRange( + date.today() + timedelta(days=-365), + date.today() + timedelta(days=365), + ), + ) + + target = self.get_target() + assert target.get_api_query_date(quota) == date.today() + + def test_get_associated_measures(self): + quota = factories.QuotaDefinitionFactory( + valid_between=TaricDateRange( + date.today() + timedelta(days=-365), + date.today() + timedelta(days=10), + ), + ) + + measure_included_1 = factories.MeasureFactory( + order_number=quota.order_number, + valid_between=TaricDateRange( + date.today() + timedelta(days=-365), + ), + ) + + measure_excluded_1 = factories.MeasureFactory( + order_number=quota.order_number, + valid_between=TaricDateRange( + date.today() + timedelta(days=20), + date.today() + timedelta(days=365), + ), + ) + + target = self.get_target() + measures = target.get_associated_measures(quota) + assert measure_included_1 in measures + assert measure_excluded_1 not in measures + + def test_get_goods_nomenclature_item_ids(self): + quota = factories.QuotaDefinitionFactory( + valid_between=TaricDateRange( + date.today() + timedelta(days=-365), + date.today() + timedelta(days=10), + ), + ) + + measure_included_1 = factories.MeasureFactory( + order_number=quota.order_number, + valid_between=TaricDateRange( + date.today() + timedelta(days=-365), + ), + ) + + measure_excluded_1 = factories.MeasureFactory( + order_number=quota.order_number, + valid_between=TaricDateRange( + date.today() + timedelta(days=20), + date.today() + timedelta(days=365), + ), + ) + + target = self.get_target() + item_ids = target.get_goods_nomenclature_item_ids(quota) + assert measure_included_1.goods_nomenclature.item_id in item_ids + assert measure_excluded_1.goods_nomenclature.item_id not in item_ids + + def test_get_goods_nomenclature_headings(self): + quota = factories.QuotaDefinitionFactory( + valid_between=TaricDateRange( + date.today() + timedelta(days=-365), + date.today() + timedelta(days=10), + ), + ) + + factories.MeasureFactory( + order_number=quota.order_number, + valid_between=TaricDateRange( + date.today() + timedelta(days=-365), + ), + goods_nomenclature__description__description="gggg", + goods_nomenclature__item_id="0102030405", + ) + + factories.MeasureFactory( + order_number=quota.order_number, + valid_between=TaricDateRange( + date.today() + timedelta(days=-365), + ), + goods_nomenclature__description__description="hhhh", + goods_nomenclature__item_id="0102778899", + ) + + factories.GoodsNomenclatureFactory( + item_id="0102000000", + description__description="zzz", + ) + + target = self.get_target() + item_ids = target.get_goods_nomenclature_item_ids(quota) + headings = target.get_goods_nomenclature_headings(item_ids) + assert headings == "0102-zzz" + + def test_run(self): + quota = factories.QuotaDefinitionFactory( + valid_between=TaricDateRange( + date.today() + timedelta(days=-365), + date.today() + timedelta(days=10), + ), + sid=20001, + order_number__order_number="056789", + ) + + factories.MeasureFactory( + order_number=quota.order_number, + valid_between=TaricDateRange( + date.today() + timedelta(days=-365), + ), + goods_nomenclature__description__description="gggg", + goods_nomenclature__item_id="0102030405", + ) + + factories.MeasureFactory( + order_number=quota.order_number, + valid_between=TaricDateRange( + date.today() + timedelta(days=-365), + ), + goods_nomenclature__description__description="hhhh", + goods_nomenclature__item_id="0102778899", + ) + + factories.GoodsNomenclatureFactory( + item_id="0102000000", + description__description="zzz", + ) + + with NamedTemporaryFile() as ntf: + target = self.target_class(ntf) + target.run() + content = open(ntf.name, "r").read() + + headers_str = ( + "quota_definition__sid,quota__order_number," + + "quota__geographical_areas," + + "quota__geographical_area_exclusions," + + "quota__headings," + + "quota__commodities," + + "quota__measurement_unit," + + "quota__monetary_unit," + + "quota_definition__description," + + "quota_definition__validity_start_date," + + "quota_definition__validity_end_date," + + "quota_definition__initial_volume," + + "api_query_date" + ) + + assert headers_str in content + # check rows count + assert len(content.split("\n")) > 2 + assert "0102-zzz" in content + assert "20001,056789" in content diff --git a/exporter/tests/quotas_export/test_quotas_tasks.py b/exporter/tests/quotas_export/test_quotas_tasks.py new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/exporter/tests/quotas_export/test_quotas_tasks.py @@ -0,0 +1 @@ + diff --git a/exporter/tests/test_exporter_commands.py b/exporter/tests/test_exporter_commands.py index 762819c88..b6257c0cb 100644 --- a/exporter/tests/test_exporter_commands.py +++ b/exporter/tests/test_exporter_commands.py @@ -24,6 +24,7 @@ (None, "/tmp"), ), ) +@pytest.mark.exporter def test_dump_sqlite_command(asynchronous_flag, save_local_flag_value): flags = [] if asynchronous_flag: @@ -50,6 +51,7 @@ def test_dump_sqlite_command(asynchronous_flag, save_local_flag_value): @pytest.mark.skip() +@pytest.mark.exporter def test_upload_command_uploads_queued_workbasket_to_s3( approved_transaction, hmrc_storage, @@ -97,6 +99,7 @@ def test_upload_command_uploads_queued_workbasket_to_s3( @pytest.mark.skip(reason="broken test - TODO") +@pytest.mark.exporter def test_dump_command_outputs_queued_workbasket(approved_transaction, capsys): """Exercise HMRCStorage and verify content is saved to bucket.""" with capsys.disabled(): @@ -124,6 +127,7 @@ def test_dump_command_outputs_queued_workbasket(approved_transaction, capsys): assert codes == expected_codes +@pytest.mark.exporter def test_dump_command_exits_on_unchecked_workbasket(): workbasket = WorkBasketFactory.create( status=WorkflowStatus.EDITING, diff --git a/exporter/tests/test_exporter_tasks.py b/exporter/tests/test_exporter_tasks.py index 7e6cdf19f..95c1f32ec 100644 --- a/exporter/tests/test_exporter_tasks.py +++ b/exporter/tests/test_exporter_tasks.py @@ -21,6 +21,7 @@ class SentinelError(Exception): pass +@pytest.mark.exporter def test_upload_workbaskets_uploads_queued_workbasket_to_s3( approved_transaction, hmrc_storage, @@ -83,6 +84,7 @@ def test_upload_workbaskets_uploads_queued_workbasket_to_s3( SentinelError(), ], ) +@pytest.mark.exporter def test_upload_workbaskets_retries(mock_save, settings): """Verify if HMRCStorage.save raises a boto.ConnectionError the task upload_workflow task retries based on @@ -112,6 +114,7 @@ def test_upload_workbaskets_retries(mock_save, settings): SentinelError(), ], ) +@pytest.mark.exporter def test_notify_hmrc_retries(mock_post, settings, hmrc_storage, responses): """Verify if HMRCStorage.save raises a boto.ConnectionError the task upload_workflow task retries based on diff --git a/exporter/tests/test_files/empty.csv b/exporter/tests/test_files/empty.csv new file mode 100644 index 000000000..e69de29bb diff --git a/exporter/tests/test_files/some.csv b/exporter/tests/test_files/some.csv new file mode 100644 index 000000000..8eed507b2 --- /dev/null +++ b/exporter/tests/test_files/some.csv @@ -0,0 +1 @@ +quota_definition__sid,quota__order_number,quota__geographical_areas,quota__geographical_area_exclusions,quota__headings,quota__commodities,quota__measurement_unit,quota__monetary_unit,quota_definition__description,quota_definition__validity_start_date,quota_definition__validity_end_date,quota_definition__initial_volume,api_query_date diff --git a/exporter/tests/test_files/valid.csv b/exporter/tests/test_files/valid.csv new file mode 100644 index 000000000..3bea3e45e --- /dev/null +++ b/exporter/tests/test_files/valid.csv @@ -0,0 +1,2 @@ +some,headers +and,values \ No newline at end of file diff --git a/exporter/tests/test_sqlite.py b/exporter/tests/test_sqlite.py index bb4f7cdf9..97f702e12 100644 --- a/exporter/tests/test_sqlite.py +++ b/exporter/tests/test_sqlite.py @@ -1,6 +1,4 @@ -import sqlite3 import tempfile -from contextlib import nullcontext from io import BytesIO from os import path from pathlib import Path @@ -16,7 +14,6 @@ from exporter.sqlite import tasks from exporter.sqlite.runner import Runner from exporter.sqlite.runner import SQLiteMigrator -from exporter.storages import EmptyFileException from exporter.storages import is_valid_sqlite from workbaskets.validators import WorkflowStatus @@ -47,46 +44,11 @@ def sqlite_database(sqlite_template: Runner) -> Iterator[Runner]: yield Runner(in_memory_database) -def get_test_file_path(filename): - return path.join( - path.dirname(path.abspath(__file__)), - "test_files", - filename, - ) - - -@pytest.mark.parametrize( - ("test_file_path, expect_context"), - ( - ( - get_test_file_path("valid_sqlite.db"), - nullcontext(), - ), - ( - "/invalid/file/path", - pytest.raises(FileNotFoundError), - ), - ( - get_test_file_path("empty_sqlite.db"), - pytest.raises(EmptyFileException), - ), - ( - get_test_file_path("invalid_sqlite.db"), - pytest.raises(sqlite3.DatabaseError), - ), - ), -) -def test_is_valid_sqlite(test_file_path, expect_context): - """Test that `is_valid_sqlite()` raises correct exceptions for invalid - SQLite files and succeeds for valid SQLite files.""" - with expect_context: - is_valid_sqlite(test_file_path) - - @pytest.mark.parametrize( ("migrations_in_tmp_dir"), (False, True), ) +@pytest.mark.exporter def test_sqlite_migrator(migrations_in_tmp_dir): """Test SQLiteMigrator.""" with tempfile.NamedTemporaryFile() as sqlite_file: diff --git a/exporter/tests/test_storages.py b/exporter/tests/test_storages.py new file mode 100644 index 000000000..1ab43e672 --- /dev/null +++ b/exporter/tests/test_storages.py @@ -0,0 +1,378 @@ +import csv +import os +import sqlite3 +from contextlib import nullcontext +from os import path +from pathlib import Path +from unittest.mock import MagicMock +from unittest.mock import patch + +import pytest +import sqlite_s3vfs +from django.conf import settings + +from exporter.storages import EmptyFileException +from exporter.storages import HMRCStorage +from exporter.storages import QuotaLocalStorage +from exporter.storages import QuotaS3Storage +from exporter.storages import QuotasExportS3StorageBase +from exporter.storages import SQLiteLocalStorage +from exporter.storages import SQLiteS3Storage +from exporter.storages import SQLiteS3StorageBase +from exporter.storages import SQLiteS3VFSStorage +from exporter.storages import is_valid_quotas_csv +from exporter.storages import is_valid_sqlite + +pytestmark = pytest.mark.django_db + + +def get_test_file_path(filename) -> str: + return str( + path.join( + path.dirname(path.abspath(__file__)), + "test_files", + filename, + ), + ) + + +@pytest.fixture() +def fake_connection(): + pass + + +@pytest.fixture() +def mock_make_export(fake_connection): + def mock_make_export(*args, **kwargs): + pass + + +@pytest.fixture() +def fake_apsw_connection(): + class FakeASPWConnection: + def close(self): + pass + + return FakeASPWConnection() + + +@pytest.mark.parametrize( + ("test_file_path, expect_context"), + ( + ( + get_test_file_path("valid_sqlite.db"), + nullcontext(), + ), + ( + "/invalid/file/path", + pytest.raises(FileNotFoundError), + ), + ( + get_test_file_path("empty_sqlite.db"), + pytest.raises(EmptyFileException), + ), + ( + get_test_file_path("invalid_sqlite.db"), + pytest.raises(sqlite3.DatabaseError), + ), + ), +) +@pytest.mark.exporter +def test_is_valid_sqlite(test_file_path, expect_context): + """Test that `is_valid_sqlite()` raises correct exceptions for invalid + SQLite files and succeeds for valid SQLite files.""" + with expect_context: + is_valid_sqlite(test_file_path) + + +@pytest.mark.parametrize( + ("test_file_path, expect_context"), + ( + ( + get_test_file_path("valid_sqlite.db"), + nullcontext(), + ), + ( + "/invalid/file/path", + pytest.raises(FileNotFoundError), + ), + ( + get_test_file_path("empty.csv"), + pytest.raises(EmptyFileException), + ), + ), +) +@pytest.mark.exporter +def test_is_valid_quotas_csv(test_file_path, expect_context): + """Test that `is_valid_sqlite()` raises correct exceptions for invalid + SQLite files and succeeds for valid SQLite files.""" + with expect_context: + is_valid_quotas_csv(test_file_path) + + +@pytest.mark.exporter +class TestSQLiteS3StorageBase: + target_class = SQLiteS3StorageBase + + def get_target(self): + return self.target_class() + + def test_get_default_settings(self): + target = self.get_target() + default_settings = target.get_default_settings() + + assert default_settings["bucket_name"] == settings.SQLITE_STORAGE_BUCKET_NAME + assert default_settings["access_key"] == settings.SQLITE_S3_ACCESS_KEY_ID + assert default_settings["secret_key"] == settings.SQLITE_S3_SECRET_ACCESS_KEY + assert default_settings["endpoint_url"] == settings.SQLITE_S3_ENDPOINT_URL + assert default_settings["default_acl"] == "private" + + def test_generate_filename(self): + target = self.get_target() + file_name = target.generate_filename("zzz.zzz") + assert file_name == "sqlite/zzz.zzz" + + +@pytest.mark.exporter +class TestQuotasExportS3StorageBase: + target_class = QuotasExportS3StorageBase + + def get_target(self): + return self.target_class() + + def test_get_default_settings(self): + target = self.get_target() + default_settings = target.get_default_settings() + + assert ( + default_settings["bucket_name"] + == settings.QUOTAS_EXPORT_STORAGE_BUCKET_NAME + ) + assert default_settings["access_key"] == settings.QUOTAS_EXPORT_S3_ACCESS_KEY_ID + assert ( + default_settings["secret_key"] + == settings.QUOTAS_EXPORT_S3_SECRET_ACCESS_KEY + ) + assert default_settings["endpoint_url"] == settings.S3_ENDPOINT_URL + assert default_settings["default_acl"] == "private" + + def test_generate_filename(self): + target = self.get_target() + file_name = target.generate_filename("zzz.zzz") + assert file_name == "quotas_export/zzz.zzz" + + +@pytest.mark.exporter +class TestHMRCStorage: + target_class = HMRCStorage + + def get_target(self): + return self.target_class() + + def test_get_default_settings(self): + target = self.get_target() + default_settings = target.get_default_settings() + + assert default_settings["bucket_name"] == settings.HMRC_STORAGE_BUCKET_NAME + assert default_settings["default_acl"] == "private" + + def test_get_object_parameters(self): + target = self.get_target() + params = target.get_object_parameters("file.ext") + assert params == {"ContentDisposition": "attachment; filename=file.ext"} + + +@pytest.mark.exporter +class TestQuotaLocalStorage: + target_class = QuotaLocalStorage + target_class_location = "exporter/tests/test_files" + + def get_target(self, location=None): + if location is None: + return self.target_class(self.target_class_location) + return self.target_class(location) + + def test_init_bad_location(self): + with pytest.raises(Exception) as e: + self.get_target("zzz") + + assert str(e.value) == "Directory does not exist: zzz." + + def test_init_good_location(self): + target = self.get_target() + resolved_location = Path(self.target_class_location).expanduser().resolve() + assert target._location == resolved_location + + def test_path(self): + target = self.get_target() + target_path = target.path("a.csv") + expected_path = ( + str(Path(self.target_class_location).expanduser().resolve()) + "/a.csv" + ) + assert target_path == expected_path + + def test_exists(self): + target = self.get_target() + target_exists = target.exists("valid.csv") + assert target_exists + + target_exists = target.exists("zzz.csv") + assert not target_exists + + def test_export_csv(self): + def mocked_make_export(named_temp_file): + with open(named_temp_file.name, "wt") as file: + writer = csv.writer(file) + writer.writerow(["header1", "header2", "header3"]) + writer.writerow(["data1", "data2", "data3"]) + + patch("exporter.quotas.make_export", mocked_make_export) + target = self.get_target() + target.export_csv("some.csv") + assert os.path.exists(os.path.join(target._location, "some.csv")) + + +@pytest.mark.exporter +class TestSQLiteS3VFSStorage: + target_class = SQLiteS3VFSStorage + + def get_target(self): + return self.target_class() + + @patch( + "exporter.storages.SQLiteS3VFSStorage.listdir", + return_value=( + [], + ["xxx"], + ), + ) + def test_exists(self, mocked_list_dir): + target = self.get_target() + assert target.exists("xxx") + mocked_list_dir.assert_called_once() + + def test_vfs(self): + target = self.get_target() + assert type(target.vfs) == sqlite_s3vfs.S3VFS + + @patch("exporter.sqlite.make_export", return_value=mock_make_export) + def test_export_database(self, patched_make_export): + class FakeConnection: + def close(self): + pass + + with patch("apsw.Connection") as mocked_connection: + fake_bucket = MagicMock() + mocked_connection.return_value = FakeConnection() + target = self.get_target() + target._bucket = fake_bucket + target.export_database("valid.file") + patched_make_export.assert_called_once() + + +@pytest.mark.exporter +class TestSQLiteS3Storage: + target_class = SQLiteS3Storage + + def get_target(self): + return self.target_class() + + @patch("exporter.sqlite.make_export", return_value=mock_make_export) + @patch("exporter.storages.is_valid_sqlite", return_value=True) + def test_export_database(self, patched_make_export, patched_is_valid_sqlite): + class FakeConnection: + def close(self): + pass + + with patch("apsw.Connection") as mocked_connection: + fake_bucket = MagicMock() + mocked_connection.return_value = FakeConnection() + target = self.get_target() + target._bucket = fake_bucket + target.export_database("valid.file") + patched_make_export.assert_called_once() + + @patch("exporter.sqlite.make_export", return_value=mock_make_export) + def test_export_database_zero_file_size(self, patched_make_export): + class FakeConnection: + def close(self): + pass + + with patch("apsw.Connection") as mocked_connection: + fake_bucket = MagicMock() + mocked_connection.return_value = FakeConnection() + target = self.get_target() + target._bucket = fake_bucket + with pytest.raises(EmptyFileException) as e: + target.export_database("valid.file") + + patched_make_export.assert_called_once() + + +@pytest.mark.exporter +class TestSQLiteLocalStorage: + target_class = SQLiteLocalStorage + + def get_target(self): + return self.target_class("exporter/tests/test_files") + + def test_path(self): + target = self.get_target() + assert str(target.path("some_file.type")) == str( + target._location.joinpath("some_file.type"), + ) + + def test_exists(self): + target = self.get_target() + assert not target.exists("dfsgdfg") + + @patch("exporter.sqlite.make_export", return_value=mock_make_export) + def test_export_database(self, patched_make_export): + class FakeConnection: + def close(self): + pass + + with patch("apsw.Connection") as mocked_connection: + fake_bucket = MagicMock() + mocked_connection.return_value = FakeConnection() + target = self.get_target() + target._bucket = fake_bucket + target.export_database("valid.file") + patched_make_export.assert_called_once() + + +@pytest.mark.exporter +class TestQuotaS3Storage: + target_class = QuotaS3Storage + + def get_target(self): + return self.target_class() + + @patch("exporter.storages.is_valid_quotas_csv", return_value=True) + def test_export_csv(self, patched_is_valid_quotas_csv): + def mocked_make_export(named_temp_file): + with open(named_temp_file.name, "wt") as file: + writer = csv.writer(file) + writer.writerow(["header1", "header2", "header3"]) + writer.writerow(["data1", "data2", "data3"]) + + patch("exporter.quotas.make_export", mocked_make_export) + target = self.get_target() + with patch.object(target, "save") as mock_save: + mock_save.return_value = None + target.export_csv("valid.file") + patched_is_valid_quotas_csv.assert_called_once() + mock_save.assert_called_once() + + def test_export_csv_invalid(self): + def mocked_make_export(named_temp_file): + pass + + with patch( + "exporter.quotas.make_export", + mocked_make_export, + ) as patched_make_export: + target = self.get_target() + with pytest.raises(EmptyFileException) as e: + target.export_csv("valid.file") + assert "has zero size." in str(e.value) diff --git a/exporter/tests/test_util.py b/exporter/tests/test_util.py index 7b4cc533e..c4504d92f 100644 --- a/exporter/tests/test_util.py +++ b/exporter/tests/test_util.py @@ -1,7 +1,10 @@ +import pytest + from exporter.util import exceptions_as_messages from exporter.util import item_timer +@pytest.mark.exporter def test_exceptions_as_messages(): exception_list = { "first_exception": [Exception("test")], @@ -16,6 +19,7 @@ def test_exceptions_as_messages(): } +@pytest.mark.exporter def test_item_timer(): """Verify that item_timer yields a tuple containing the time to retrieve each item and the item itself.""" diff --git a/exporter/tests/test_views.py b/exporter/tests/test_views.py index 24e84bff9..90c70c0ba 100644 --- a/exporter/tests/test_views.py +++ b/exporter/tests/test_views.py @@ -6,7 +6,7 @@ from exporter import views -@pytest.mark.django_db +@pytest.mark.exporter def test_activity_stream(admin_client: Client): certificate_type = factories.CertificateTypeFactory.create() factories.CertificateFactory.create_batch( diff --git a/measures/editors.py b/measures/editors.py index 0a55451f2..daa98e457 100644 --- a/measures/editors.py +++ b/measures/editors.py @@ -4,15 +4,15 @@ from django.db.transaction import atomic -from workbaskets import models as workbasket_models -from measures import models as measure_models +from common.models.utils import override_current_transaction from common.util import TaricDateRange from common.validators import UpdateType -from common.models.utils import override_current_transaction +from measures import models as measure_models from measures.util import update_measure_components from measures.util import update_measure_condition_components from measures.util import update_measure_excluded_geographical_areas from measures.util import update_measure_footnote_associations +from workbaskets import models as workbasket_models class MeasuresEditor: @@ -23,7 +23,7 @@ class MeasuresEditor: """The workbasket with which created measures will be associated.""" selected_measures: List - """ The measures in which the edits will apply to.""" + """The measures in which the edits will apply to.""" data: Dict """Validated, cleaned and accumulated data created by the Form instances of diff --git a/measures/forms/wizard.py b/measures/forms/wizard.py index d83142df3..f10c6f98a 100644 --- a/measures/forms/wizard.py +++ b/measures/forms/wizard.py @@ -836,7 +836,7 @@ def serializable_init_kwargs(cls, kwargs: Dict) -> Dict: def deserialize_init_kwargs(cls, form_kwargs: Dict) -> Dict: serialized_selected_measures_pks = form_kwargs.get("selected_measures") deserialized_selected_measures = models.Measure.objects.filter( - pk__in=serialized_selected_measures_pks + pk__in=serialized_selected_measures_pks, ) kwargs = { @@ -904,7 +904,7 @@ def serializable_init_kwargs(cls, kwargs: Dict) -> Dict: def deserialize_init_kwargs(cls, form_kwargs: Dict) -> Dict: serialized_selected_measures_pks = form_kwargs.get("selected_measures") deserialized_selected_measures = models.Measure.objects.filter( - pk__in=serialized_selected_measures_pks + pk__in=serialized_selected_measures_pks, ) kwargs = { @@ -957,7 +957,7 @@ def serializable_init_kwargs(cls, kwargs: Dict) -> Dict: def deserialize_init_kwargs(cls, form_kwargs: Dict) -> Dict: serialized_selected_measures_pks = form_kwargs.get("selected_measures") deserialized_selected_measures = models.Measure.objects.filter( - pk__in=serialized_selected_measures_pks + pk__in=serialized_selected_measures_pks, ) kwargs = { @@ -1027,7 +1027,7 @@ def serializable_init_kwargs(cls, kwargs: Dict) -> Dict: def deserialize_init_kwargs(cls, form_kwargs: Dict) -> Dict: serialized_selected_measures_pks = form_kwargs.get("selected_measures") deserialized_selected_measures = models.Measure.objects.filter( - pk__in=serialized_selected_measures_pks + pk__in=serialized_selected_measures_pks, ) kwargs = { diff --git a/measures/models/bulk_processing.py b/measures/models/bulk_processing.py index 908e90d98..bcdfdf1a0 100644 --- a/measures/models/bulk_processing.py +++ b/measures/models/bulk_processing.py @@ -17,8 +17,8 @@ from common.celery import app from common.models.mixins import TimestampedMixin from common.models.utils import override_current_transaction -from measures.models.tracked_models import Measure from measures.editors import MeasuresEditor +from measures.models.tracked_models import Measure logger = logging.getLogger(__name__) @@ -443,8 +443,8 @@ def create( class MeasuresBulkEditor(BulkProcessor): """ - Model class used to bulk edit Measures instances from serialized form - data. + Model class used to bulk edit Measures instances from serialized form data. + The stored form data is serialized and deserialized by Forms that subclass SerializableFormMixin. """ @@ -509,11 +509,13 @@ def edit_measures(self) -> Iterable[Measure]: ): cleaned_data = self.get_forms_cleaned_data() deserialized_selected_measures = Measure.objects.filter( - pk__in=self.selected_measures + pk__in=self.selected_measures, ) measures_editor = MeasuresEditor( - self.workbasket, deserialized_selected_measures, cleaned_data + self.workbasket, + deserialized_selected_measures, + cleaned_data, ) return measures_editor.edit_measures() diff --git a/measures/tests/test_bulk_processing.py b/measures/tests/test_bulk_processing.py index 0b29330df..ea9ab5ff9 100644 --- a/measures/tests/test_bulk_processing.py +++ b/measures/tests/test_bulk_processing.py @@ -11,7 +11,6 @@ from common.validators import ApplicabilityCode from measures import forms from measures.models import MeasuresBulkCreator -from measures.models import MeasuresBulkEditor from measures.models import ProcessingState from measures.tests.factories import MeasuresBulkCreatorFactory from measures.tests.factories import MeasuresBulkEditorFactory @@ -384,7 +383,7 @@ def test_bulk_editor_get_forms_cleaned_data_errors( "start_date_0": "", "start_date_1": "", "start_date_2": "", - } + }, }, "Enter the day, month and year", ), diff --git a/measures/tests/test_forms.py b/measures/tests/test_forms.py index 294fe9ea5..535d1c812 100644 --- a/measures/tests/test_forms.py +++ b/measures/tests/test_forms.py @@ -1970,8 +1970,8 @@ def test_simple_measure_edit_forms_serialize_deserialize( request, duty_sentence_parser, ): - """Test that the EditMeasure simple forms that use the - SerializableFormMixin behave correctly and as expected.""" + """Test that the EditMeasure simple forms that use the SerializableFormMixin + behave correctly and as expected.""" # Create some measures to apply this data to, for the kwargs quota_order_number = factories.QuotaOrderNumberFactory() diff --git a/measures/tests/test_views.py b/measures/tests/test_views.py index f6faf38b6..8968342d4 100644 --- a/measures/tests/test_views.py +++ b/measures/tests/test_views.py @@ -58,9 +58,9 @@ @pytest.fixture() def mocked_diff_components(): - """Mocks `diff_components()` inside `update_measure_components()` that is called in - `MeasureEditWizard` to prevent parsing errors where test measures lack a - duty sentence.""" + """Mocks `diff_components()` inside `update_measure_components()` that is + called in `MeasureEditWizard` to prevent parsing errors where test measures + lack a duty sentence.""" with patch( "measures.editors.update_measure_components", ) as update_measure_components: diff --git a/measures/util.py b/measures/util.py index 80849bde4..4dfa0ebc4 100644 --- a/measures/util.py +++ b/measures/util.py @@ -1,20 +1,17 @@ import decimal +import logging from datetime import date from math import floor +from typing import List +from typing import Type -from common.models import TrackedModel from common.models.transactions import Transaction from common.validators import UpdateType - from geo_areas.models import GeographicalArea from geo_areas.utils import get_all_members_of_geo_groups from measures import models as measure_models -from typing import List -from typing import Type from workbaskets import models as workbasket_models -import logging - logger = logging.getLogger(__name__) @@ -125,8 +122,7 @@ def update_measure_condition_components( measure: "measure_models.Measure", workbasket: "workbasket_models.WorkBasket", ): - """Updates the measure condition components associated to the - measure.""" + """Updates the measure condition components associated to the measure.""" conditions = measure.conditions.current() for condition in conditions: condition.new_version( diff --git a/measures/views/mixins.py b/measures/views/mixins.py index feb0d21ac..6fb61435a 100644 --- a/measures/views/mixins.py +++ b/measures/views/mixins.py @@ -1,5 +1,5 @@ -from typing import Type from typing import Dict +from typing import Type from common.models import TrackedModel from measures import models @@ -54,13 +54,18 @@ def get_queryset(self): class MeasureSerializableWizardMixin: - """A Mixin for the wizard forms that utilise asynchronous bulk processing. This mixin provides the functionality to go through each form - and serialize the data ready for storing in the database.""" + """ + A Mixin for the wizard forms that utilise asynchronous bulk processing. + + This mixin provides the functionality to go through each form and serialize + the data ready for storing in the database. + """ def get_data_form_list(self) -> dict: """ Returns a form list based on form_list, conditionally including only those items as per condition_list and also appearing in data_form_list. + The list is generated dynamically because conditions in condition_list may be dynamic. Essentially, version of `WizardView.get_form_list()` filtering in only @@ -76,6 +81,7 @@ def get_data_form_list(self) -> dict: def all_serializable_form_data(self) -> Dict: """ Returns serializable data for all wizard steps. + This is a re-implementation of MeasureCreateWizard.get_all_cleaned_data(), but using self.data after is_valid() has been successfully run. @@ -91,6 +97,7 @@ def all_serializable_form_data(self) -> Dict: def serializable_form_data_for_step(self, step) -> Dict: """ Returns serializable data for a wizard step. + This is a re-implementation of WizardView.get_cleaned_data_for_step(), returning the serializable version of data in place of the form's regular cleaned_data. diff --git a/measures/views/wizard.py b/measures/views/wizard.py index b69b6d1b9..9b61071ef 100644 --- a/measures/views/wizard.py +++ b/measures/views/wizard.py @@ -1,6 +1,4 @@ import logging -from typing import Dict -from typing import List from crispy_forms_gds.helper import FormHelper from django.conf import settings @@ -179,7 +177,9 @@ def edit_measures(self, selected_measures, cleaned_data): wizard forms.""" measures_editor = MeasuresEditor( - self.workbasket, selected_measures, cleaned_data + self.workbasket, + selected_measures, + cleaned_data, ) return measures_editor.edit_measures() @@ -210,7 +210,9 @@ def sync_done(self, form_list, **kwargs): @method_decorator(require_current_workbasket, name="dispatch") class MeasureCreateWizard( - PermissionRequiredMixin, NamedUrlSessionWizardView, MeasureSerializableWizardMixin + PermissionRequiredMixin, + NamedUrlSessionWizardView, + MeasureSerializableWizardMixin, ): """ Multipart form wizard for creating a single measure. diff --git a/pii-secret-exclude.txt b/pii-secret-exclude.txt index 013ace85a..d38441101 100644 --- a/pii-secret-exclude.txt +++ b/pii-secret-exclude.txt @@ -30,3 +30,6 @@ common/static/common/js/components/GeoAreaField/tests/__snapshots__/index.test.j common/static/common/js/components/GeoAreaForm/tests/__snapshots__/index.test.js.snap common/migrations/0001_initial.py common/jinja2/common/homepage.jinja +exporter/tests/test_files/empty.csv +exporter/tests/test_files/some.csv +exporter/tests/test_files/valid.csv diff --git a/pyproject.toml b/pyproject.toml index fbd3393b3..60e26da3d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -98,5 +98,6 @@ norecursedirs = [ ] markers = [ "importer_v2", - "reference_documents" + "reference_documents", + "exporter" ] diff --git a/quotas/tests/test_business_rules.py b/quotas/tests/test_business_rules.py index 281379e81..156728348 100644 --- a/quotas/tests/test_business_rules.py +++ b/quotas/tests/test_business_rules.py @@ -2,7 +2,6 @@ from datetime import date from decimal import Decimal -from django.forms import ValidationError import pytest from django.db import DataError @@ -659,12 +658,16 @@ def test_QA2(date_ranges): ], ) def test_QA2_dict( - sub_definition_valid_between, main_definition_valid_between, expected_response + sub_definition_valid_between, + main_definition_valid_between, + expected_response, ): - """As above, but checking between a definition and a dict with date ranges""" + """As above, but checking between a definition and a dict with date + ranges.""" assert ( business_rules.check_QA2_dict( - sub_definition_valid_between, main_definition_valid_between + sub_definition_valid_between, + main_definition_valid_between, ) == expected_response ) @@ -733,7 +736,7 @@ def test_QA3_dict( sub_init_volume, expected_response, ): - """As above, but checking between a definition and a dict""" + """As above, but checking between a definition and a dict.""" assert ( business_rules.check_QA3_dict( @@ -786,7 +789,7 @@ def test_QA4(coefficient, expect_error): ], ) def test_QA4_dict(coefficient, expected_response): - """As above, but checking between a definition and a dict""" + """As above, but checking between a definition and a dict.""" assert business_rules.check_QA4_dict(coefficient) == expected_response @@ -806,8 +809,9 @@ def test_QA5(existing_volume, new_volume, coeff, type, error_expected): Whenever a sub-quota is defined with the ‘equivalent’ type, it must have the same volume as the other sub-quotas associated with the parent quota. - Moreover it must be defined with a coefficient not equal to 1. - When a sub-quota relationship type is defined as 'equivalent' it must have the same volume as the ones associated with the parent quota + Moreover it must be defined with a coefficient not equal to 1. When a sub- + quota relationship type is defined as 'equivalent' it must have the same + volume as the ones associated with the parent quota A sub-quota defined with the 'normal' type must have a coefficient of 1. """ diff --git a/reference_documents/management/commands/ref_doc_csv_importer.py b/reference_documents/management/commands/ref_doc_csv_importer.py index ee0f51dc4..dccd46e08 100644 --- a/reference_documents/management/commands/ref_doc_csv_importer.py +++ b/reference_documents/management/commands/ref_doc_csv_importer.py @@ -206,7 +206,6 @@ def create_ref_docs_and_versions(self): self.quotas_df["Document Version"] == version ] - add_to_index = 1 for index, row in quotas_df.iterrows(): # split order numbers order_number = row["Quota Number"] diff --git a/reference_documents/views/reference_document_version_views.py b/reference_documents/views/reference_document_version_views.py index 6226ba703..7bbf972cd 100644 --- a/reference_documents/views/reference_document_version_views.py +++ b/reference_documents/views/reference_document_version_views.py @@ -121,7 +121,7 @@ def duties_row_data(self): }, { "html": f"Edit " - f"Delete", + f"Delete", }, ], ) diff --git a/settings/common.py b/settings/common.py index dacc84c50..d89faeccf 100644 --- a/settings/common.py +++ b/settings/common.py @@ -69,11 +69,9 @@ }, } - # Auto field type specification required since Django 3.2. DEFAULT_AUTO_FIELD = "django.db.models.AutoField" - # -- Application DJANGO_CORE_APPS = [ @@ -273,7 +271,6 @@ # -- Security SECRET_KEY = os.environ.get("SECRET_KEY", "@@i$w*ct^hfihgh21@^8n+&ba@_l3x") - # Whitelist values for the HTTP Host header, to prevent certain attacks # App runs inside GOV.UK PaaS, so we can allow all hosts ALLOWED_HOSTS = re.split(r"\s|,", os.environ.get("ALLOWED_HOSTS", "")) @@ -287,7 +284,6 @@ paas_hosts = json.loads(os.environ["VCAP_APPLICATION"])["uris"] ALLOWED_HOSTS.extend(paas_hosts) - # Sets the X-Content-Type-Options: nosniff header SECURE_CONTENT_TYPE_NOSNIFF = True @@ -318,7 +314,6 @@ # URL path where static files are served STATIC_URL = "/assets/" - # -- Database if MAINTENANCE_MODE: @@ -393,7 +388,6 @@ os.environ.get("EXPORTER_UPLOAD_DEFAULT_RETRY_DELAY", "8"), ) - EXPORTER_MAXIMUM_ENVELOPE_SIZE = 39 * 1024 * 1024 EXPORTER_DISABLE_NOTIFICATION = is_truthy( os.environ.get("EXPORTER_DISABLE_NOTIFICATION", "false"), @@ -430,7 +424,6 @@ HMRC_STORAGE_BUCKET_NAME = os.environ.get("HMRC_STORAGE_BUCKET_NAME", "hmrc") HMRC_STORAGE_DIRECTORY = os.environ.get("HMRC_STORAGE_DIRECTORY", "tohmrc/staging/") - # S3 settings for packaging automation. if is_copilot(): @@ -466,6 +459,11 @@ IMPORTER_S3_REGION_NAME = credentials["aws_region"] IMPORTER_S3_ACCESS_KEY_ID = credentials["aws_access_key_id"] IMPORTER_S3_SECRET_ACCESS_KEY = credentials["aws_secret_access_key"] + if "quotas-export" in bucket["name"]: + QUOTAS_EXPORT_STORAGE_BUCKET_NAME = credentials["bucket_name"] + QUOTAS_EXPORT_S3_REGION_NAME = credentials["aws_region"] + QUOTAS_EXPORT_S3_ACCESS_KEY_ID = credentials["aws_access_key_id"] + QUOTAS_EXPORT_S3_SECRET_ACCESS_KEY = credentials["aws_secret_access_key"] else: IMPORTER_S3_REGION_NAME = os.environ.get("AWS_REGION", "eu-west-2") IMPORTER_S3_ACCESS_KEY_ID = os.environ.get("S3_ACCESS_KEY_ID") @@ -485,6 +483,22 @@ "IMPORTER_STORAGE_BUCKET_NAME", "importer", ) + QUOTAS_EXPORT_STORAGE_BUCKET_NAME = os.environ.get( + "QUOTAS_EXPORT_S3_ACCESS_KEY_ID", + "quotas-export-local", + ) + QUOTAS_EXPORT_S3_REGION_NAME = os.environ.get( + "QUOTAS_EXPORT_S3_REGION_NAME", + "eu-west-2", + ) + QUOTAS_EXPORT_S3_ACCESS_KEY_ID = os.environ.get( + "QUOTAS_EXPORT_S3_ACCESS_KEY_ID", + "quotas-export-id", + ) + QUOTAS_EXPORT_S3_SECRET_ACCESS_KEY = os.environ.get( + "QUOTAS_EXPORT_S3_SECRET_ACCESS_KEY", + "quotas-export-key", + ) S3_ENDPOINT_URL = os.environ.get( "S3_ENDPOINT_URL", @@ -523,6 +537,11 @@ os.environ.get("CROWN_DEPENDENCIES_API_DEFAULT_RETRY_DELAY", "8"), ) +# quota export S3 additional settings +QUOTAS_EXPORT_DESTINATION_FOLDER = os.environ.get( + "QUOTAS_EXPORT_DESTINATION_FOLDER", + "quotas_export/", +) # SQLite AWS settings if is_copilot(): @@ -580,7 +599,6 @@ CROWN_DEPENDENCIES_GET_API_KEY = os.environ.get("CROWN_DEPENDENCIES_GET_API_KEY", "") CROWN_DEPENDENCIES_POST_API_KEY = os.environ.get("CROWN_DEPENDENCIES_POST_API_KEY", "") - if is_copilot(): CELERY_BROKER_URL = ( os.getenv("CELERY_BROKER_URL", default=None) + "?ssl_cert_reqs=required" @@ -841,7 +859,6 @@ ) PATH_XSD_COMMODITIES_TARIC = Path(PATH_COMMODITIES_ASSETS, "commodities_taric3.xsd") - # Default username for envelope data imports DATA_IMPORT_USERNAME = os.environ.get("TAMATO_IMPORT_USERNAME", "test") @@ -910,13 +927,11 @@ else: BASE_SERVICE_URL = os.environ.get("BASE_SERVICE_URL") - # ClamAV CLAM_AV_USERNAME = os.environ.get("CLAM_AV_USERNAME", "") CLAM_AV_PASSWORD = os.environ.get("CLAM_AV_PASSWORD", "") CLAM_AV_DOMAIN = os.environ.get("CLAM_AV_DOMAIN", "") - FILE_UPLOAD_HANDLERS = ( "django_chunk_upload_handlers.clam_av.ClamAVFileUploadHandler", "django.core.files.uploadhandler.MemoryFileUploadHandler", # defaults @@ -924,7 +939,6 @@ ) # Order is important DATA_MIGRATION_BATCH_SIZE = int(os.environ.get("DATA_MIGRATION_BATCH_SIZE", "10000")) - # Asynchronous / background (bulk) object creation and editing config. MEASURES_ASYNC_CREATION = is_truthy(os.environ.get("MEASURES_ASYNC_CREATION", "true")) MEASURES_ASYNC_EDIT = is_truthy(os.environ.get("MEASURES_ASYNC_EDIT", "true")) diff --git a/wsgi.py b/wsgi.py index e537f5824..90e580e84 100644 --- a/wsgi.py +++ b/wsgi.py @@ -10,7 +10,6 @@ import os import dotenv - from django.core.wsgi import get_wsgi_application dotenv.load_dotenv()