Skip to content

Commit

Permalink
Merge pull request #24 from caddac/category-support
Browse files Browse the repository at this point in the history
Fix support for category dtypes
  • Loading branch information
multimeric authored May 8, 2019
2 parents c157e14 + 16d1a45 commit fe3171d
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 5 deletions.
7 changes: 5 additions & 2 deletions pandas_schema/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from . import column
from .validation_warning import ValidationWarning
from .errors import PanSchArgumentError
from pandas.api.types import is_categorical_dtype, is_numeric_dtype


class _BaseValidation:
Expand Down Expand Up @@ -84,10 +85,12 @@ def get_errors(self, series: pd.Series, column: 'column.Column'):
simple_validation = ~self.validate(series)
if column.allow_empty:
# Failing results are those that are not empty, and fail the validation
if np.issubdtype(series.dtype, np.number):
validated = ~series.isna() & simple_validation
# explicitly check to make sure the series isn't a category because issubdtype will FAIL if it is
if is_categorical_dtype(series) or is_numeric_dtype(series):
validated = ~series.isnull() & simple_validation
else:
validated = (series.str.len() > 0) & simple_validation

else:
validated = simple_validation

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def run(self):
],
keywords='pandas csv verification schema',
packages=find_packages(include=['pandas_schema']),
install_requires=['numpy', 'pandas'],
install_requires=['numpy', 'pandas>=0.19'],
cmdclass={
'build_readme': BuildReadme,
'build_site': BuildHtmlDocs
Expand Down
34 changes: 32 additions & 2 deletions test/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@ def seriesEquality(self, s1: pd.Series, s2: pd.Series, msg: str = None):
if not s1.equals(s2):
raise self.failureException(msg)

def validate_and_compare(self, series: list, expected_result: bool, msg: str = None):
def validate_and_compare(self, series: list, expected_result: bool, msg: str = None, series_dtype: object = None):
"""
Checks that every element in the provided series is equal to `expected_result` after validation
:param series_dtype: Explicity specifies the dtype for the generated Series
:param series: The series to check
:param expected_result: Whether the elements in this series should pass the validation
:param msg: The message to display if this test fails
Expand All @@ -31,7 +32,7 @@ def validate_and_compare(self, series: list, expected_result: bool, msg: str = N
self.addTypeEqualityFunc(pd.Series, self.seriesEquality)

# Convert the input list to a series and validate it
results = self.validator.validate(pd.Series(series))
results = self.validator.validate(pd.Series(series, dtype=series_dtype))

# Now find any items where their validation does not correspond to the expected_result
for item, result in zip(series, results):
Expand Down Expand Up @@ -639,3 +640,32 @@ def test_in_range_allow_empty_false_with_error(self):
validator = InRangeValidation(min=4)
errors = validator.get_errors(pd.Series(self.vals), Column('', allow_empty=False))
self.assertEqual(len(errors), len(self.vals))


class PandasDtypeTests(ValidationTestBase):
"""
Tests Series with various pandas dtypes that don't exist in numpy (specifically categories)
"""

def setUp(self):
self.validator = InListValidation(['a', 'b', 'c'], case_sensitive=False)

def test_valid_elements(self):
errors = self.validator.get_errors(pd.Series(['a', 'b', 'c', None, 'A', 'B', 'C'], dtype='category'),
Column('', allow_empty=True))
self.assertEqual(len(errors), 0)

def test_invalid_empty_elements(self):
errors = self.validator.get_errors(pd.Series(['aa', 'bb', 'd', None], dtype='category'),
Column('', allow_empty=False))
self.assertEqual(len(errors), 4)

def test_invalid_and_empty_elements(self):
errors = self.validator.get_errors(pd.Series(['a', None], dtype='category'),
Column('', allow_empty=False))
self.assertEqual(len(errors), 1)

def test_invalid_elements(self):
errors = self.validator.get_errors(pd.Series(['aa', 'bb', 'd'], dtype='category'),
Column('', allow_empty=True))
self.assertEqual(len(errors), 3)

0 comments on commit fe3171d

Please sign in to comment.