Skip to content

Commit

Permalink
add file object support
Browse files Browse the repository at this point in the history
  • Loading branch information
zargot committed Oct 15, 2024
1 parent 0ec296b commit 90efc72
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 48 deletions.
98 changes: 71 additions & 27 deletions src/odm_sharing/private/cons.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
import os
from collections import defaultdict
from dataclasses import dataclass
from io import IOBase
from pathlib import Path
from typing import Dict, Generator, List, Set
from typing import Dict, Generator, List, Set, Tuple, Union, cast

import openpyxl as xl
import pandas as pd
Expand All @@ -15,6 +16,14 @@
from odm_sharing.private.utils import qt


@dataclass(frozen=True)
class CsvFile:
table: str
file: IOBase


CsvPath = str
CsvDataSourceList = Union[List[CsvPath], List[CsvFile]]
Sheet = xl.worksheet._read_only.ReadOnlyWorksheet


Expand Down Expand Up @@ -131,19 +140,39 @@ def _normalize_bool_values(df: pd.DataFrame, bool_cols: Set[ColumnName]
df[col] = df[col].replace({F: '0', T: '1'})


def _connect_csv(paths: List[str]) -> Connection:
def _import_csv(data_source: Union[CsvPath, CsvFile]
) -> Tuple[TableName, pd.DataFrame]:
# XXX: NA-values are not normalized to avoid mutating user data (#31)
ds = data_source
if isinstance(ds, CsvPath):
path = ds
assert path.endswith('.csv')
table = Path(path).stem
logging.info(f'importing {qt(table)} from {path}')
df = pd.read_csv(path, na_filter=False)
return (table, df)
else:
assert isinstance(ds, CsvFile)
logging.info(f'importing {qt(ds.table)} from a file')
df = pd.read_csv(ds.file, na_filter=False) # type: ignore
return (ds.table, df)


def _connect_csv(data_sources: CsvDataSourceList) -> Connection:
'''copies file data to in-memory db
:raises OSError:'''
:raises DataSourceError:
:raises OSError:
'''
assert len(data_sources) > 0
if isinstance(data_sources[0], CsvPath):
_check_csv_paths(cast(List[CsvPath], data_sources))

# XXX: NA-values are not normalized to avoid mutating user data (#31)
dfs = {}
tables = set()
bool_cols = {}
for path in paths:
table = Path(path).stem
logging.info(f'importing {qt(table)} from {path}')
df = pd.read_csv(path, na_filter=False)
for ds in data_sources:
(table, df) = _import_csv(ds)
bool_cols[table] = _find_bool_cols(df, BOOL_VALS)
_normalize_bool_values(df, bool_cols[table])
dfs[table] = df
Expand Down Expand Up @@ -244,22 +273,31 @@ def _detect_sqlalchemy(path: str) -> bool:
return False


def _check_datasources(paths: List[str]) -> None:
if not paths:
raise DataSourceError('no data source')
(_, ext) = os.path.splitext(paths[0])
all_same_ext = seq(paths).map(lambda p: p.endswith(ext)).all()
if not all_same_ext:
def _check_csv_paths(paths: List[str]) -> None:
if not seq(paths).map(lambda p: p.endswith('.csv')).all():
raise DataSourceError(
'mixing multiple data source types is not allowed')
'mixing CSV files with other file types is not allowed')


def _detect_csv_input(data_sources: Union[str, CsvDataSourceList]) -> bool:
is_str = isinstance(data_sources, str)
is_list = isinstance(data_sources, list)
is_path_list = is_list and isinstance(data_sources[0], str)
is_file_list = is_list and not is_path_list
return ((is_str and cast(str, data_sources).endswith('.csv')) or
(is_path_list and cast(str, data_sources[0]).endswith('.csv')) or
is_file_list)

def connect(paths: List[str], tables: Set[str] = set()
) -> Connection:

def connect(
data_sources: Union[str, CsvDataSourceList],
tables: Set[str] = set()
) -> Connection:
'''
connects to one or more data sources and returns the connection
:param paths: must all be of the same (file)type.
:param data_sources: filepath or database URL, or list of multiple CSV
paths/files
:param tables: when connecting to an excel file, this acts as a sheet
whitelist
Expand All @@ -271,16 +309,22 @@ def connect(paths: List[str], tables: Set[str] = set()
# as 0/1, which we'll have to convert back (using previously detected bool
# columns) to 'FALSE'/'TRUE' before returning the data to the user. This
# happens in `odm_sharing.sharing.get_data`.
_check_datasources(paths)
if not data_sources:
raise DataSourceError('no data source')
try:
path = paths[0]
is_csv = path.endswith('.csv')
if not is_csv and len(paths) > 1:
logging.warning('ignoring additional inputs (for CSV only)')

if is_csv:
return _connect_csv(paths)
elif path.endswith('.xlsx'):
if _detect_csv_input(data_sources):
csv_data_sources = cast(CsvDataSourceList, data_sources)
return _connect_csv(csv_data_sources)

is_list = isinstance(data_sources, list)
if is_list:
if len(data_sources) > 1:
raise DataSourceError('specifying multiple inputs is only ' +
'allowed for CSV files')

path = (cast(str, data_sources[0]) if is_list
else cast(str, data_sources))
if path.endswith('.xlsx'):
return _connect_excel(path, tables)
elif _detect_sqlite(path):
return _connect_db(f'sqlite:///{path}')
Expand Down
14 changes: 9 additions & 5 deletions src/odm_sharing/private/rules.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
import sys
from io import IOBase

from dataclasses import dataclass, field
from enum import EnumMeta
from pathlib import Path
from typing import Any, Dict, List, Union

import pandas as pd
from functional import seq

from odm_sharing.private.stdext import StrValueEnum
from odm_sharing.private.utils import fmt_set, qt
from odm_sharing.private.utils import fmt_set, get_filename, qt

RuleId = int
SchemaFile = IOBase
SchemaPath = str


class RuleMode(StrValueEnum):
Expand Down Expand Up @@ -212,15 +214,17 @@ def check_set(ctx: SchemaCtx, actual: str, expected: Union[set, list]
raise ParseError(errors)


def load(schema_path: str) -> Dict[RuleId, Rule]:
def load(schema: Union[SchemaPath, SchemaFile]) -> Dict[RuleId, Rule]:
'''loads a sharing schema
:param schema: file path/object
:returns: rules parsed from schema, by rule id
:raises OSError, ParseError:
'''
filename = Path(schema_path).name
filename = get_filename(schema)
ctx = SchemaCtx(filename)
data = pd.read_csv(schema_path)
data = pd.read_csv(schema) # type: ignore

# replace all different NA values with an empty string
data = data.fillna('')
Expand Down
7 changes: 7 additions & 0 deletions src/odm_sharing/private/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from io import IOBase
from pathlib import Path
from typing import Iterable, Union


Expand Down Expand Up @@ -28,3 +30,8 @@ def gen_output_filename(input_name: str, schema_name: str, org: str,
[schema_name, org] +
([table] if table else []))
return '-'.join(parts) + f'.{ext}'


def get_filename(file: Union[str, IOBase]) -> str:
'''returns the path filename, or a dummy name for file objects'''
return Path(file).name if isinstance(file, str) else 'file-obj'
32 changes: 17 additions & 15 deletions src/odm_sharing/sharing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from pathlib import Path
from io import IOBase
from typing import Dict, List, Tuple, Union

import numpy as np
Expand All @@ -10,16 +10,17 @@
import odm_sharing.private.rules as rules
import odm_sharing.private.trees as trees
from odm_sharing.private.common import ColumnName, OrgName, TableName, F, T
from odm_sharing.private.cons import Connection
from odm_sharing.private.cons import Connection, CsvDataSourceList
from odm_sharing.private.queries import OrgTableQueries, Query, TableQuery
from odm_sharing.private.rules import RuleId
from odm_sharing.private.utils import qt
from odm_sharing.private.utils import get_filename, qt


def parse(schema_path: str, orgs: List[str] = []) -> OrgTableQueries:
def parse(schema_file: Union[str, IOBase],
orgs: List[str] = []) -> OrgTableQueries:
'''loads and parses a schema file into query objects
:param schema_path: schema filepath
:param schema_file: schema file path/object
:param orgs: organization whitelist, disabled if empty
:return: a query per table per org. `OrgName` and `TableName` are
Expand All @@ -29,14 +30,14 @@ def parse(schema_path: str, orgs: List[str] = []) -> OrgTableQueries:
:raises OSError: if the schema file can't be loaded
:raises ParseError: if the schema parsing fails
'''
ruleset = rules.load(schema_path)
filename = Path(schema_path).name
ruleset = rules.load(schema_file)
filename = get_filename(schema_file)
tree = trees.parse(ruleset, orgs, filename)
return queries.generate(tree)


def connect(
data_sources: Union[str, List[str]],
data_sources: Union[str, CsvDataSourceList],
tables: List[str] = [],
) -> Connection:
'''
Expand All @@ -46,14 +47,15 @@ def connect(
Warning: Even tho using a database as input is supported, it hasn't been
tested properly.
:param data_sources: filepath(s) or database URL
:param data_sources: filepath or database URL, or list of CSV files
:param tables: table name whitelist, disabled if empty
:return: the data source connection object
:raises DataSourceError: if the connection couldn't be established
'''
if isinstance(data_sources, str):
# normalize single str/path input as list
if not isinstance(data_sources, list):
data_sources = [data_sources]
return cons.connect(data_sources, set(tables))

Expand Down Expand Up @@ -154,17 +156,17 @@ def get_columns(c: Connection, tq: TableQuery


def extract(
schema_path: str,
data_sources: Union[str, List[str]],
schema_file: Union[str, IOBase],
data_sources: Union[str, CsvDataSourceList],
orgs: List[str] = [],
) -> Dict[OrgName, Dict[TableName, pandas.DataFrame]]:
'''high-level function for retrieving filtered data
Warning: Boolean values from CSV/Excel files will be normalized as
TRUE/FALSE.
:param schema_path: rule schema filepath
:param data_sources: filepath(s) or database URL
:param schema_file: rule schema file path/object
:param data_sources: filepath or database URL, or list of CSV files
:param orgs: organization whitelist, disabled if empty
:return: a dataset per table per org. `OrgName` and `TableName` are
Expand All @@ -174,7 +176,7 @@ def extract(
data source
'''
con = connect(data_sources)
queries = parse(schema_path, orgs)
queries = parse(schema_file, orgs)
result: Dict[OrgName, Dict[TableName, pandas.DataFrame]] = {}
for org, tablequeries in queries.items():
result[org] = {}
Expand Down
2 changes: 1 addition & 1 deletion src/odm_sharing/tools/share.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def share(
try:
for table, data in org_data.items():
if outfmt == OutFmt.CSV:
p = gen_filepath(outdir, schema_name, org, table,
p = gen_filepath(outdir, '', schema_name, org, table,
'csv')
logging.info('writing ' + p.relpath)
data.to_csv(p.abspath, index=False)
Expand Down

0 comments on commit 90efc72

Please sign in to comment.