From 90efc7263c2ded96d210c23939bc4fb1ebbcb199 Mon Sep 17 00:00:00 2001 From: zargot Date: Fri, 11 Oct 2024 11:08:05 -0400 Subject: [PATCH] add file object support --- src/odm_sharing/private/cons.py | 98 +++++++++++++++++++++++--------- src/odm_sharing/private/rules.py | 14 +++-- src/odm_sharing/private/utils.py | 7 +++ src/odm_sharing/sharing.py | 32 ++++++----- src/odm_sharing/tools/share.py | 2 +- 5 files changed, 105 insertions(+), 48 deletions(-) diff --git a/src/odm_sharing/private/cons.py b/src/odm_sharing/private/cons.py index 70ab7ad1..9a3b9534 100644 --- a/src/odm_sharing/private/cons.py +++ b/src/odm_sharing/private/cons.py @@ -2,8 +2,9 @@ import os from collections import defaultdict from dataclasses import dataclass +from io import IOBase from pathlib import Path -from typing import Dict, Generator, List, Set +from typing import Dict, Generator, List, Set, Tuple, Union, cast import openpyxl as xl import pandas as pd @@ -15,6 +16,14 @@ from odm_sharing.private.utils import qt +@dataclass(frozen=True) +class CsvFile: + table: str + file: IOBase + + +CsvPath = str +CsvDataSourceList = Union[List[CsvPath], List[CsvFile]] Sheet = xl.worksheet._read_only.ReadOnlyWorksheet @@ -131,19 +140,39 @@ def _normalize_bool_values(df: pd.DataFrame, bool_cols: Set[ColumnName] df[col] = df[col].replace({F: '0', T: '1'}) -def _connect_csv(paths: List[str]) -> Connection: +def _import_csv(data_source: Union[CsvPath, CsvFile] + ) -> Tuple[TableName, pd.DataFrame]: + # XXX: NA-values are not normalized to avoid mutating user data (#31) + ds = data_source + if isinstance(ds, CsvPath): + path = ds + assert path.endswith('.csv') + table = Path(path).stem + logging.info(f'importing {qt(table)} from {path}') + df = pd.read_csv(path, na_filter=False) + return (table, df) + else: + assert isinstance(ds, CsvFile) + logging.info(f'importing {qt(ds.table)} from a file') + df = pd.read_csv(ds.file, na_filter=False) # type: ignore + return (ds.table, df) + + +def _connect_csv(data_sources: CsvDataSourceList) -> Connection: '''copies file data to in-memory db - :raises OSError:''' + :raises DataSourceError: + :raises OSError: + ''' + assert len(data_sources) > 0 + if isinstance(data_sources[0], CsvPath): + _check_csv_paths(cast(List[CsvPath], data_sources)) - # XXX: NA-values are not normalized to avoid mutating user data (#31) dfs = {} tables = set() bool_cols = {} - for path in paths: - table = Path(path).stem - logging.info(f'importing {qt(table)} from {path}') - df = pd.read_csv(path, na_filter=False) + for ds in data_sources: + (table, df) = _import_csv(ds) bool_cols[table] = _find_bool_cols(df, BOOL_VALS) _normalize_bool_values(df, bool_cols[table]) dfs[table] = df @@ -244,22 +273,31 @@ def _detect_sqlalchemy(path: str) -> bool: return False -def _check_datasources(paths: List[str]) -> None: - if not paths: - raise DataSourceError('no data source') - (_, ext) = os.path.splitext(paths[0]) - all_same_ext = seq(paths).map(lambda p: p.endswith(ext)).all() - if not all_same_ext: +def _check_csv_paths(paths: List[str]) -> None: + if not seq(paths).map(lambda p: p.endswith('.csv')).all(): raise DataSourceError( - 'mixing multiple data source types is not allowed') + 'mixing CSV files with other file types is not allowed') + +def _detect_csv_input(data_sources: Union[str, CsvDataSourceList]) -> bool: + is_str = isinstance(data_sources, str) + is_list = isinstance(data_sources, list) + is_path_list = is_list and isinstance(data_sources[0], str) + is_file_list = is_list and not is_path_list + return ((is_str and cast(str, data_sources).endswith('.csv')) or + (is_path_list and cast(str, data_sources[0]).endswith('.csv')) or + is_file_list) -def connect(paths: List[str], tables: Set[str] = set() - ) -> Connection: + +def connect( + data_sources: Union[str, CsvDataSourceList], + tables: Set[str] = set() +) -> Connection: ''' connects to one or more data sources and returns the connection - :param paths: must all be of the same (file)type. + :param data_sources: filepath or database URL, or list of multiple CSV + paths/files :param tables: when connecting to an excel file, this acts as a sheet whitelist @@ -271,16 +309,22 @@ def connect(paths: List[str], tables: Set[str] = set() # as 0/1, which we'll have to convert back (using previously detected bool # columns) to 'FALSE'/'TRUE' before returning the data to the user. This # happens in `odm_sharing.sharing.get_data`. - _check_datasources(paths) + if not data_sources: + raise DataSourceError('no data source') try: - path = paths[0] - is_csv = path.endswith('.csv') - if not is_csv and len(paths) > 1: - logging.warning('ignoring additional inputs (for CSV only)') - - if is_csv: - return _connect_csv(paths) - elif path.endswith('.xlsx'): + if _detect_csv_input(data_sources): + csv_data_sources = cast(CsvDataSourceList, data_sources) + return _connect_csv(csv_data_sources) + + is_list = isinstance(data_sources, list) + if is_list: + if len(data_sources) > 1: + raise DataSourceError('specifying multiple inputs is only ' + + 'allowed for CSV files') + + path = (cast(str, data_sources[0]) if is_list + else cast(str, data_sources)) + if path.endswith('.xlsx'): return _connect_excel(path, tables) elif _detect_sqlite(path): return _connect_db(f'sqlite:///{path}') diff --git a/src/odm_sharing/private/rules.py b/src/odm_sharing/private/rules.py index 4a402a1e..fc815258 100644 --- a/src/odm_sharing/private/rules.py +++ b/src/odm_sharing/private/rules.py @@ -1,17 +1,19 @@ import sys +from io import IOBase from dataclasses import dataclass, field from enum import EnumMeta -from pathlib import Path from typing import Any, Dict, List, Union import pandas as pd from functional import seq from odm_sharing.private.stdext import StrValueEnum -from odm_sharing.private.utils import fmt_set, qt +from odm_sharing.private.utils import fmt_set, get_filename, qt RuleId = int +SchemaFile = IOBase +SchemaPath = str class RuleMode(StrValueEnum): @@ -212,15 +214,17 @@ def check_set(ctx: SchemaCtx, actual: str, expected: Union[set, list] raise ParseError(errors) -def load(schema_path: str) -> Dict[RuleId, Rule]: +def load(schema: Union[SchemaPath, SchemaFile]) -> Dict[RuleId, Rule]: '''loads a sharing schema + :param schema: file path/object + :returns: rules parsed from schema, by rule id :raises OSError, ParseError: ''' - filename = Path(schema_path).name + filename = get_filename(schema) ctx = SchemaCtx(filename) - data = pd.read_csv(schema_path) + data = pd.read_csv(schema) # type: ignore # replace all different NA values with an empty string data = data.fillna('') diff --git a/src/odm_sharing/private/utils.py b/src/odm_sharing/private/utils.py index 8a978c4d..4b8c4ff3 100644 --- a/src/odm_sharing/private/utils.py +++ b/src/odm_sharing/private/utils.py @@ -1,3 +1,5 @@ +from io import IOBase +from pathlib import Path from typing import Iterable, Union @@ -28,3 +30,8 @@ def gen_output_filename(input_name: str, schema_name: str, org: str, [schema_name, org] + ([table] if table else [])) return '-'.join(parts) + f'.{ext}' + + +def get_filename(file: Union[str, IOBase]) -> str: + '''returns the path filename, or a dummy name for file objects''' + return Path(file).name if isinstance(file, str) else 'file-obj' diff --git a/src/odm_sharing/sharing.py b/src/odm_sharing/sharing.py index 6915d121..51680b48 100644 --- a/src/odm_sharing/sharing.py +++ b/src/odm_sharing/sharing.py @@ -1,4 +1,4 @@ -from pathlib import Path +from io import IOBase from typing import Dict, List, Tuple, Union import numpy as np @@ -10,16 +10,17 @@ import odm_sharing.private.rules as rules import odm_sharing.private.trees as trees from odm_sharing.private.common import ColumnName, OrgName, TableName, F, T -from odm_sharing.private.cons import Connection +from odm_sharing.private.cons import Connection, CsvDataSourceList from odm_sharing.private.queries import OrgTableQueries, Query, TableQuery from odm_sharing.private.rules import RuleId -from odm_sharing.private.utils import qt +from odm_sharing.private.utils import get_filename, qt -def parse(schema_path: str, orgs: List[str] = []) -> OrgTableQueries: +def parse(schema_file: Union[str, IOBase], + orgs: List[str] = []) -> OrgTableQueries: '''loads and parses a schema file into query objects - :param schema_path: schema filepath + :param schema_file: schema file path/object :param orgs: organization whitelist, disabled if empty :return: a query per table per org. `OrgName` and `TableName` are @@ -29,14 +30,14 @@ def parse(schema_path: str, orgs: List[str] = []) -> OrgTableQueries: :raises OSError: if the schema file can't be loaded :raises ParseError: if the schema parsing fails ''' - ruleset = rules.load(schema_path) - filename = Path(schema_path).name + ruleset = rules.load(schema_file) + filename = get_filename(schema_file) tree = trees.parse(ruleset, orgs, filename) return queries.generate(tree) def connect( - data_sources: Union[str, List[str]], + data_sources: Union[str, CsvDataSourceList], tables: List[str] = [], ) -> Connection: ''' @@ -46,14 +47,15 @@ def connect( Warning: Even tho using a database as input is supported, it hasn't been tested properly. - :param data_sources: filepath(s) or database URL + :param data_sources: filepath or database URL, or list of CSV files :param tables: table name whitelist, disabled if empty :return: the data source connection object :raises DataSourceError: if the connection couldn't be established ''' - if isinstance(data_sources, str): + # normalize single str/path input as list + if not isinstance(data_sources, list): data_sources = [data_sources] return cons.connect(data_sources, set(tables)) @@ -154,8 +156,8 @@ def get_columns(c: Connection, tq: TableQuery def extract( - schema_path: str, - data_sources: Union[str, List[str]], + schema_file: Union[str, IOBase], + data_sources: Union[str, CsvDataSourceList], orgs: List[str] = [], ) -> Dict[OrgName, Dict[TableName, pandas.DataFrame]]: '''high-level function for retrieving filtered data @@ -163,8 +165,8 @@ def extract( Warning: Boolean values from CSV/Excel files will be normalized as TRUE/FALSE. - :param schema_path: rule schema filepath - :param data_sources: filepath(s) or database URL + :param schema_file: rule schema file path/object + :param data_sources: filepath or database URL, or list of CSV files :param orgs: organization whitelist, disabled if empty :return: a dataset per table per org. `OrgName` and `TableName` are @@ -174,7 +176,7 @@ def extract( data source ''' con = connect(data_sources) - queries = parse(schema_path, orgs) + queries = parse(schema_file, orgs) result: Dict[OrgName, Dict[TableName, pandas.DataFrame]] = {} for org, tablequeries in queries.items(): result[org] = {} diff --git a/src/odm_sharing/tools/share.py b/src/odm_sharing/tools/share.py index 74430286..f8aa74b7 100644 --- a/src/odm_sharing/tools/share.py +++ b/src/odm_sharing/tools/share.py @@ -214,7 +214,7 @@ def share( try: for table, data in org_data.items(): if outfmt == OutFmt.CSV: - p = gen_filepath(outdir, schema_name, org, table, + p = gen_filepath(outdir, '', schema_name, org, table, 'csv') logging.info('writing ' + p.relpath) data.to_csv(p.abspath, index=False)