Skip to content

Commit

Permalink
improve package api
Browse files Browse the repository at this point in the history
  • Loading branch information
ohadmata committed May 29, 2024
1 parent ef45f1b commit 899e164
Show file tree
Hide file tree
Showing 10 changed files with 85 additions and 90 deletions.
4 changes: 3 additions & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,9 @@ disable=
R1719,
W0707,
W0108,
R0913
R0913,
R0902,
W0125

# Enable the message, report, category or checker with the given id(s). You can
# either give multiple identifier separated by comma (,) or put this option
Expand Down
22 changes: 7 additions & 15 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -129,37 +129,29 @@ shmessy = Shmessy(
locale_formatter: Optional[str] = "en_US",
use_random_sample: Optional[bool] = True,
types_to_ignore: Optional[List[str]] = None,
max_columns_num: Optional[int] = 500
max_columns_num: Optional[int] = 500,
fallback_to_string: Optional[bool] = False, # Fallback to string in case of casting exception
fallback_to_null: Optional[bool] = False, # Fallback to null in case of casting exception
use_csv_sniffer: Optional[bool] = True, # Use python sniffer to identify the dialect (seperator / quote-char / etc...)
fix_column_names: Optional[bool] = False, # Replace non-alphabetic/numeric chars with underscore
)
```

### read_csv
```python
shmessy.read_csv(
filepath_or_buffer: str | TextIO | BinaryIO,
use_sniffer: Optional[bool] = True, # Use python sniffer to identify the dialect (seperator / quote-char / etc...)
fixed_schema: Optional[ShmessySchema] = None, # Fix the given CSV according to this schema
fix_column_names: Optional[bool] = False, # Replace non-alphabetic/numeric chars with underscore
fallback_to_string: Optional[bool] = False, # Fallback to string in case of casting exception
fallback_to_null: Optional[bool] = False, # Fallback to null in case of casting exception
) -> DataFrame
shmessy.read_csv(filepath_or_buffer: Union[str, TextIO, BinaryIO]) -> DataFrame
```

### infer_schema
```python
shmessy.infer_schema(
df: Dataframe # Input dataframe
) -> ShmessySchema
shmessy.infer_schema(df: Dataframe) -> ShmessySchema
```

### fix_schema
```python
shmessy.fix_schema(
df: Dataframe,
fix_column_names: Optional[bool] = False, # Replace non-alphabetic/numeric chars with underscore
fixed_schema: Optional[ShmessySchema] = None, # Fix the given DF according to this schema
fallback_to_string: Optional[bool] = False, # Fallback to string in case of casting exception
fallback_to_null: Optional[bool] = False, # Fallback to null in case of casting exception
) -> DataFrame
```

Expand Down
88 changes: 28 additions & 60 deletions src/shmessy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import csv
import locale
import logging
import time
Expand All @@ -7,14 +6,15 @@
import pandas as pd
from pandas import DataFrame

from .exceptions import TooManyColumnException, exception_router
from .exceptions import exception_router
from .schema import ShmessySchema
from .types_handler import TypesHandler
from .utils import (
_check_number_of_columns,
_fix_column_names,
_fix_column_names_in_df,
_fix_column_names_in_shmessy_schema,
_get_sample_from_csv,
_get_dialect,
_get_sampled_df,
)

Expand All @@ -30,13 +30,21 @@ def __init__(
use_random_sample: Optional[bool] = True,
types_to_ignore: Optional[List[str]] = None,
max_columns_num: Optional[int] = 500,
fallback_to_string: Optional[bool] = False,
fallback_to_null: Optional[bool] = False,
use_csv_sniffer: Optional[bool] = True,
fix_column_names: Optional[bool] = False,
) -> None:
self.__types_handler = TypesHandler(types_to_ignore=types_to_ignore)
self.__sample_size = sample_size
self.__reader_encoding = reader_encoding
self.__locale_formatter = locale_formatter
self.__use_random_sample = use_random_sample
self.__max_columns_num = max_columns_num
self.__fallback_to_string = fallback_to_string
self.__fallback_to_null = fallback_to_null
self.__use_csv_sniffer = use_csv_sniffer
self.__fix_column_names = fix_column_names

self.__inferred_schema: Optional[ShmessySchema] = None

Expand All @@ -48,6 +56,7 @@ def get_inferred_schema(self) -> ShmessySchema:
return self.__inferred_schema

def infer_schema(self, df: DataFrame) -> ShmessySchema:
_check_number_of_columns(df=df, max_columns_num=self.__max_columns_num)
start_time = time.time()
df = _get_sampled_df(
df=df,
Expand All @@ -69,18 +78,10 @@ def fix_schema(
self,
df: DataFrame,
*,
fix_column_names: Optional[bool] = False,
fixed_schema: Optional[ShmessySchema] = None,
fallback_to_string: Optional[bool] = False,
fallback_to_null: Optional[bool] = False,
) -> DataFrame:
try:
existing_columns_num = len(df.columns)
if existing_columns_num > self.__max_columns_num:
raise TooManyColumnException(
max_columns_num=self.__max_columns_num,
existing_columns_num=existing_columns_num,
)
_check_number_of_columns(df=df, max_columns_num=self.__max_columns_num)

if fixed_schema is None:
fixed_schema = self.infer_schema(df)
Expand All @@ -89,11 +90,11 @@ def fix_schema(
df[column.field_name] = self.__types_handler.fix_field(
column=df[column.field_name],
inferred_field=column,
fallback_to_string=fallback_to_string,
fallback_to_null=fallback_to_null,
fallback_to_string=self.__fallback_to_string,
fallback_to_null=self.__fallback_to_null,
)

if fix_column_names:
if self.__fix_column_names:
mapping = _fix_column_names(df)
df = _fix_column_names_in_df(input_df=df, mapping=mapping)
fixed_schema = _fix_column_names_in_shmessy_schema(
Expand All @@ -105,59 +106,26 @@ def fix_schema(
except Exception as e:
exception_router(e)

def read_csv(
self,
filepath_or_buffer: Union[str, TextIO, BinaryIO],
*,
use_sniffer: Optional[bool] = True,
fixed_schema: Optional[ShmessySchema] = None,
fix_column_names: Optional[bool] = False,
fallback_to_string: Optional[bool] = False,
fallback_to_null: Optional[bool] = False,
) -> DataFrame:
def read_csv(self, filepath_or_buffer: Union[str, TextIO, BinaryIO]) -> DataFrame:
try:
dialect = None

if use_sniffer:
try:
dialect = csv.Sniffer().sniff(
sample=_get_sample_from_csv(
filepath_or_buffer=filepath_or_buffer,
sample_size=self.__sample_size,
encoding=self.__reader_encoding,
),
delimiters="".join([",", "\t", ";", ":"]),
)
except Exception as e: # noqa
logger.debug(
f"Could not use python sniffer to infer csv schema, Using pandas default settings: {e}"
)
dialect = (
_get_dialect(
filepath_or_buffer=filepath_or_buffer,
sample_size=self.__sample_size,
reader_encoding=self.__reader_encoding,
)
if self.__use_csv_sniffer
else None
)

df = pd.read_csv(
index_col=False,
filepath_or_buffer=filepath_or_buffer,
dialect=dialect() if dialect else None,
dialect=dialect() if dialect else None, # noqa
encoding=self.__reader_encoding,
)

existing_columns_num = len(df.columns)
if existing_columns_num > self.__max_columns_num:
raise TooManyColumnException(
max_columns_num=self.__max_columns_num,
existing_columns_num=existing_columns_num,
)

if fixed_schema is None:
fixed_schema = self.infer_schema(df)

self.__inferred_schema = fixed_schema
return self.fix_schema(
df=df,
fixed_schema=fixed_schema,
fix_column_names=fix_column_names,
fallback_to_string=fallback_to_string,
fallback_to_null=fallback_to_null,
)
return self.fix_schema(df=df)

except Exception as e:
exception_router(e)
36 changes: 35 additions & 1 deletion src/shmessy/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,44 @@
import csv
import logging
import re
from typing import BinaryIO, Dict, Optional, TextIO, Union
from typing import Any, BinaryIO, Dict, Optional, TextIO, Union

from pandas import DataFrame

from .exceptions import TooManyColumnException
from .schema import ShmessySchema

logger = logging.getLogger(__name__)


def _check_number_of_columns(df: DataFrame, max_columns_num: int) -> None:
existing_columns_num = len(df.columns)
if existing_columns_num > max_columns_num:
raise TooManyColumnException(
max_columns_num=max_columns_num,
existing_columns_num=existing_columns_num,
)


def _get_dialect(
filepath_or_buffer: Union[str, TextIO, BinaryIO],
sample_size: int,
reader_encoding: Optional[str],
) -> Optional[Any]:
try:
return csv.Sniffer().sniff(
sample=_get_sample_from_csv(
filepath_or_buffer=filepath_or_buffer,
sample_size=sample_size,
encoding=reader_encoding,
),
delimiters="".join([",", "\t", ";", ":"]),
)
except Exception as e: # noqa
logger.debug(
f"Could not use python sniffer to infer csv schema, Using pandas default settings: {e}"
)


def _get_sampled_df(df: DataFrame, sample_size: int, random_sample: bool) -> DataFrame:
number_of_rows: int = len(df)
Expand Down
6 changes: 3 additions & 3 deletions tests/intg/test_fix_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@
expected_result=["name&space", "degree", "score%to@100"]
)
def test_fix_column_names(df_data, fix_column_names, expected_result):
shmessy = Shmessy()
shmessy = Shmessy(fix_column_names=fix_column_names)
df = pd.DataFrame(df_data)
df = shmessy.fix_schema(df=df, fix_column_names=fix_column_names)
df = shmessy.fix_schema(df=df)
fixed_schema = shmessy.get_inferred_schema()
assert [column for column in df] == expected_result
assert [column.field_name for column in fixed_schema.columns] == expected_result
Expand All @@ -67,6 +67,6 @@ def test_fix_column_names(df_data, fix_column_names, expected_result):
)
def test_issue_input_columns_as_object(df_data, fix_column_names, expected_result):
df = pd.DataFrame(df_data).astype(object)
df = Shmessy().fix_schema(df=df, fix_column_names=fix_column_names)
df = Shmessy(fix_column_names=fix_column_names).fix_schema(df)
assert [column for column in df] == expected_result

2 changes: 1 addition & 1 deletion tests/intg/test_read_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def test_read_csv_file_with_single_column(files_folder):
def test_read_csv_with_text_data(files_folder):
path = files_folder.as_posix() + "/data_8.csv"
with open(path, mode="rt") as file_input:
Shmessy(use_random_sample=False).read_csv(file_input, fallback_to_string=True)
Shmessy(use_random_sample=False, fallback_to_string=True).read_csv(file_input)
assert True


Expand Down
2 changes: 1 addition & 1 deletion tests/intg/test_read_excel.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def test_read_excel_with_numeric_headers(files_folder):

def test_read_excel_with_numeric_headers_fix_column_names(files_folder):
df = pandas.read_excel(files_folder.as_posix() + "/data_9.xlsx", engine="calamine")
df = Shmessy().fix_schema(df, fix_column_names=True)
df = Shmessy(fix_column_names=True).fix_schema(df)

assert df["0"].dtype == np.dtype("int64")
assert df["First_Name"].dtype == np.dtype("O")
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/test_boolean_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,9 @@ def test_boolean_fallback_to_null_turn_off(df_data, expected_shmessy_type, expec
expected_numpy_type=np.dtype("O")
)
def test_boolean_fallback_to_null_turn_on(df_data, expected_shmessy_type, expected_numpy_type, expected_result):
shmessy = Shmessy(use_random_sample=False, sample_size=2)
shmessy = Shmessy(use_random_sample=False, sample_size=2, fallback_to_null=True)
df = pd.DataFrame(df_data)
fixed_df = shmessy.fix_schema(df, fallback_to_null=True)
fixed_df = shmessy.fix_schema(df)
result = shmessy.get_inferred_schema()

assert result.columns[0].inferred_type == expected_shmessy_type
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/test_date_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,9 +189,9 @@ def test_date_type(df_data, expected_shmessy_type, expected_numpy_type, expected
expected_numpy_type=np.dtype("datetime64")
)
def test_date_fallback_to_null_turn_on(df_data, expected_shmessy_type, expected_numpy_type, expected_result):
shmessy = Shmessy(use_random_sample=False, sample_size=2)
shmessy = Shmessy(use_random_sample=False, sample_size=2, fallback_to_null=True)
df = pd.DataFrame(df_data)
fixed_df = shmessy.fix_schema(df, fallback_to_null=True)
fixed_df = shmessy.fix_schema(df)
result = shmessy.get_inferred_schema()

assert result.columns[0].inferred_type == expected_shmessy_type
Expand Down
7 changes: 3 additions & 4 deletions tests/unit/test_property_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,10 @@ def shmessy_bool_st(draw) -> pd.Series:
fallback_to_string=st.booleans(),
)
def test_fix_schema_cols_hp(df, fix_column_names, fallback_to_string):
df_fixed = Shmessy().fix_schema(
df=df,
fix_column_names=fix_column_names,
df_fixed = Shmessy(
fallback_to_string=fallback_to_string,
)
fix_column_names=fix_column_names
).fix_schema(df)
assert set(list(df_fixed)) == set(list(df)) if not fix_column_names else True
allowed_chars = set(string.ascii_lowercase).union(set(string.ascii_uppercase)).union(set(string.digits))
allowed_chars.add("_")
Expand Down

0 comments on commit 899e164

Please sign in to comment.