Skip to content

Commit

Permalink
use isort
Browse files Browse the repository at this point in the history
  • Loading branch information
pipliggins committed Nov 28, 2024
1 parent 323a87a commit 4956466
Show file tree
Hide file tree
Showing 8 changed files with 66 additions and 68 deletions.
6 changes: 5 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,9 @@ repos:
rev: v0.8.0
hooks:
- id: ruff
args: [ --fix ]
args: [ "--select", "I", "--fix" ]
- id: ruff-format
# - repo: https://github.com/asottile/reorder-python-imports
# rev: v3.13.0
# hooks:
# - id: reorder-python-imports
1 change: 0 additions & 1 deletion src/adtl/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import argparse
import importlib.metadata

import json

from adtl.parser import Parser
Expand Down
67 changes: 30 additions & 37 deletions src/adtl/parser.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,36 @@
from __future__ import annotations

import copy
import csv
import hashlib
import io
import itertools
import json
import logging
import itertools
import copy
import re
from collections import defaultdict, Counter
import warnings
from collections import Counter, defaultdict
from datetime import datetime
from pathlib import Path
from functools import lru_cache
from typing import Any, Dict, Iterable, List, Optional, Union, Callable, Literal
from more_itertools import unique_everseen
from pathlib import Path
from typing import Any, Callable, Iterable, Literal

import fastjsonschema
import pint
import tomli
import requests
import fastjsonschema

import tomli
from more_itertools import unique_everseen
from tqdm.autonotebook import tqdm
import warnings

import adtl.transformations as tf
from adtl.transformations import AdtlTransformationWarning

SUPPORTED_FORMATS = {"json": json.load, "toml": tomli.load}
DEFAULT_DATE_FORMAT = "%Y-%m-%d"

StrDict = Dict[str, Any]
Rule = Union[str, StrDict]
Context = Optional[Dict[str, Union[bool, int, str, List[str]]]]
StrDict = dict[str, Any]
Rule = str | StrDict
Context = dict[str, bool | int | str | list[str]] | None


def get_value(row: StrDict, rule: Rule, ctx: Context = None) -> Any:
Expand Down Expand Up @@ -177,7 +178,7 @@ def get_value_unhashed(row: StrDict, rule: Rule, ctx: Context = None) -> Any:
raise ValueError(f"Could not return value for {rule}")


def matching_fields(fields: List[str], pattern: str) -> List[str]:
def matching_fields(fields: list[str], pattern: str) -> list[str]:
"Returns fields matching pattern"
compiled_pattern = re.compile(pattern)
return [f for f in fields if compiled_pattern.match(f)]
Expand Down Expand Up @@ -340,7 +341,7 @@ def flatten(xs):
yield x


def expand_refs(spec_fragment: StrDict, defs: StrDict) -> Union[StrDict, List[StrDict]]:
def expand_refs(spec_fragment: StrDict, defs: StrDict) -> StrDict | list[StrDict]:
"Expand all references (ref) with definitions (defs)"

if spec_fragment == {}:
Expand All @@ -357,14 +358,14 @@ def expand_refs(spec_fragment: StrDict, defs: StrDict) -> Union[StrDict, List[St
return spec_fragment


def expand_for(spec: List[StrDict]) -> List[StrDict]:
def expand_for(spec: list[StrDict]) -> list[StrDict]:
"Expands for expressions in oneToMany table blocks"

out = []

def replace_val(
item: Union[str, float, Dict[str, Any]], replace: Dict[str, Any]
) -> Dict[str, Any]:
item: str | float | dict[str, Any], replace: dict[str, Any]
) -> dict[str, Any]:
block = {}
if isinstance(item, str):
return item.format(**replace)
Expand Down Expand Up @@ -432,12 +433,12 @@ def hash_sensitive(value: str) -> str:
return hashlib.sha256(str(value).encode("utf-8")).hexdigest()


def remove_null_keys(d: Dict[str, Any]) -> Dict[str, Any]:
def remove_null_keys(d: dict[str, Any]) -> dict[str, Any]:
"Removes keys which map to null - but not empty strings or 'unknown' etc types"
return {k: v for k, v in d.items() if v is not None}


def get_date_fields(schema: Dict[str, Any]) -> List[str]:
def get_date_fields(schema: dict[str, Any]) -> list[str]:
"Returns list of date fields from schema"
fields = [
field
Expand All @@ -453,8 +454,8 @@ def get_date_fields(schema: Dict[str, Any]) -> List[str]:


def make_fields_optional(
schema: Dict[str, Any], optional_fields: List[str]
) -> Dict[str, Any]:
schema: dict[str, Any], optional_fields: list[str]
) -> dict[str, Any]:
"Returns JSON schema with required fields modified to drop optional fields"
if optional_fields is None:
return schema
Expand All @@ -481,7 +482,7 @@ def relative_path(source_file, target_file):
return Path(source_file).parent / target_file


def read_definition(file: Path) -> Dict[str, Any]:
def read_definition(file: Path) -> dict[str, Any]:
"Reads definition from file into a dictionary"
if isinstance(file, str):
file = Path(file)
Expand Down Expand Up @@ -520,8 +521,8 @@ class Parser:

def __init__(
self,
spec: Union[str, Path, StrDict],
include_defs: List[str] = [],
spec: str | Path | StrDict,
include_defs: list[str] = [],
quiet: bool = False,
):
"""Loads specification from spec in format (default json)
Expand All @@ -536,7 +537,7 @@ def __init__(

self.data: StrDict = {}
self.defs: StrDict = {}
self.fieldnames: Dict[str, List[str]] = {}
self.fieldnames: dict[str, list[str]] = {}
self.specfile = None
self.include_defs = include_defs
self.validators: StrDict = {}
Expand Down Expand Up @@ -900,11 +901,7 @@ def read_table(self, table: str) -> Iterable[StrDict]:
for row in self.data[table]:
yield row

def write_csv(
self,
table: str,
output: Optional[str] = None,
) -> Optional[str]:
def write_csv(self, table: str, output: str | None = None) -> str | None:
"""Writes to output as CSV a particular table
Args:
Expand Down Expand Up @@ -935,11 +932,7 @@ def writerows(fp, table):
buf = io.StringIO()
return writerows(buf, table).getvalue()

def write_parquet(
self,
table: str,
output: Optional[str] = None,
) -> Optional[str]:
def write_parquet(self, table: str, output: str | None = None) -> str | None:
"""Writes to output as parquet a particular table
Args:
Expand Down Expand Up @@ -1005,7 +998,7 @@ def show_report(self):
print()

def save(
self, output: Optional[str] = None, format: Literal["csv", "parquet"] = "csv"
self, output: str | None = None, format: Literal["csv", "parquet"] = "csv"
):
"""Saves all tables to CSV
Expand Down
9 changes: 6 additions & 3 deletions src/adtl/python_interface.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from __future__ import annotations
from adtl import Parser
import pandas as pd
from typing import Literal

from pathlib import Path
from typing import Literal

import pandas as pd

from adtl import Parser


def parse(
Expand Down
36 changes: 17 additions & 19 deletions src/adtl/transformations.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,29 @@
"""Functions which can be applied to source fields, allowing extensibility"""

from typing import Any, Optional, List
from datetime import datetime, timedelta, date
from __future__ import annotations

from dateutil.relativedelta import relativedelta
from datetime import date, datetime, timedelta
from math import floor
from typing import Any, Literal

from dateutil.relativedelta import relativedelta

try:
import zoneinfo
except ImportError: # pragma: no cover
from backports import zoneinfo # noqa

import pint
import re

from typing import Literal, Union

import warnings

import pint


class AdtlTransformationWarning(UserWarning):
pass


def isNotNull(value: Optional[str]) -> bool:
def isNotNull(value: str | None) -> bool:
"Returns whether value is not null or an empty string"
return value not in [None, ""]

Expand All @@ -33,7 +33,7 @@ def textIfNotNull(field: str, return_val: Any) -> Any:
return return_val if field not in [None, ""] else None


def wordSubstituteSet(value: str, *params) -> List[str]:
def wordSubstituteSet(value: str, *params) -> list[str]:
"""
For a value that can have multiple words, use substitutions from params.
Expand Down Expand Up @@ -69,9 +69,7 @@ def wordSubstituteSet(value: str, *params) -> List[str]:
return sorted(set(out)) if out else None


def getFloat(
value: str, set_decimal: Optional[str] = None, separator: Optional[str] = None
):
def getFloat(value: str, set_decimal: str | None = None, separator: str | None = None):
"""Returns value transformed into a float.
Args:
Expand Down Expand Up @@ -351,13 +349,13 @@ def splitDate(


def startYear(
duration: Union[str, float],
currentdate: Union[list, str],
duration: str | float,
currentdate: list | str,
epoch: float,
dateformat: str = "%Y-%m-%d",
duration_type: Literal["years", "months", "days"] = "years",
provide_month_day: Union[bool, list] = False,
) -> Union[int, float]:
provide_month_day: bool | list = False,
) -> int | float:
"""
Use to calculate year e.g. of birth from date (e.g. current date) and
duration (e.g. age)
Expand Down Expand Up @@ -407,12 +405,12 @@ def startYear(


def startMonth(
duration: Union[str, float],
currentdate: Union[list, str],
duration: str | float,
currentdate: list | str,
epoch: float,
dateformat: str = "%Y-%m-%d",
duration_type: Literal["years", "months", "days"] = "years",
provide_month_day: Union[bool, list] = False,
provide_month_day: bool | list = False,
):
"""
Use to calculate month e.g. of birth from date (e.g. current date) and
Expand Down
8 changes: 4 additions & 4 deletions tests/test_parser.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
import collections
import contextlib
import io
import json
import contextlib
import collections
from pathlib import Path
from typing import Dict, Iterable, Any
from typing import Any, Dict, Iterable

import pytest
import responses
from pytest_unordered import unordered

import adtl.parser as parser
import adtl
import adtl.parser as parser

RULE_SINGLE_FIELD = {"field": "diabetes_mhyn"}
RULE_SINGLE_FIELD_WITH_MAPPING = {
Expand Down
3 changes: 2 additions & 1 deletion tests/test_python_interface.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import adtl
from pathlib import Path

import adtl


def test_parse(snapshot):
adtl.parse(
Expand Down
4 changes: 2 additions & 2 deletions tests/test_transformations.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest

from datetime import datetime

import pytest

import adtl.transformations as transform


Expand Down

0 comments on commit 4956466

Please sign in to comment.