Skip to content

Commit

Permalink
Clean up column subset code; usage of set methods, method type
Browse files Browse the repository at this point in the history
signature, removed unusued code
  • Loading branch information
multimeric committed Nov 20, 2017
1 parent f795e57 commit 75b4899
Showing 1 changed file with 18 additions and 13 deletions.
31 changes: 18 additions & 13 deletions pandas_schema/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ class Schema:
"""
A schema that defines the columns required in the target DataFrame
"""

def __init__(self, columns: typing.Iterable[Column], ordered: bool = False):
"""
:param columns: A list of column objects
Expand All @@ -28,35 +29,39 @@ def __init__(self, columns: typing.Iterable[Column], ordered: bool = False):
self.columns = list(columns)
self.ordered = ordered

def validate(self, df: pd.DataFrame, columns: typing.List[Column]=None) -> typing.List[ValidationWarning]:
def validate(self, df: pd.DataFrame, columns: typing.List[str] = None) -> typing.List[ValidationWarning]:
"""
Runs a full validation of the target DataFrame using the internal columns list
:param df: A pandas DataFrame to validate
:param columns: A list of columns indicating a subset of the schema that we want to validate
:return: A list of ValidationWarning objects that list the ways in which the DataFrame was invalid
"""
errors = []
df_cols = len(df.columns)

# If no columns are passed, validate against every column in the schema
# This is the default behaviour
# If no columns are passed, validate against every column in the schema. This is the default behaviour
if columns is None:
schema_cols = len(self.columns)
columns_to_pair = self.columns
if df_cols != schema_cols:
errors.append(ValidationWarning('Invalid number of columns. The schema specifies {}, but the data frame has {}'.format(schema_cols,
df_cols)))
errors.append(
ValidationWarning(
'Invalid number of columns. The schema specifies {}, but the data frame has {}'.format(
schema_cols,
df_cols)
)
)
return errors

# Else check that columns passed in as an argument are part of the
# current schema, else raise an error
# If we did pass in columns, check that they are part of the current schema
else:
if set(self.get_column_names()).intersection(columns) == set(columns):
schema_cols = len(columns)
if set(columns).issubset(self.get_column_names()):
columns_to_pair = [column for column in self.columns if column.name in columns]
else:
raise PanSchArgumentError('Columns {} passed in are not part of the schema'.format(
set(columns).difference(self.columns)))
raise PanSchArgumentError(
'Columns {} passed in are not part of the schema'.format(set(columns).difference(self.columns))
)

# We associate the column objects in the schema with data frame series either by name or by position, depending
# on the value of self.ordered
Expand All @@ -69,11 +74,11 @@ def validate(self, df: pd.DataFrame, columns: typing.List[Column]=None) -> typin

# Throw an error if the schema column isn't in the data frame
if column.name not in df:
errors.append(ValidationWarning('The column {} exists in the schema but not in the data frame'.format(column.name)))
errors.append(ValidationWarning(
'The column {} exists in the schema but not in the data frame'.format(column.name)))
return errors

column_pairs.append((df[column.name], column))


# Iterate over each pair of schema columns and data frame series and run validations
for series, column in column_pairs:
Expand Down

0 comments on commit 75b4899

Please sign in to comment.