Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update validateDataset.py to show errors on uploads more user-friendly #573

Merged
merged 2 commits into from
Feb 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 129 additions & 3 deletions lab/pyutils/validateDataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,16 +54,63 @@
MIN_ROW_PER_CLASS = 2


def check_dataframe(df, target_column):
'''
check_dataframe function checks whether each column contains only numeric data or not.
missing values are not allowed in df.
strings are not allowed in df.
inf or -inf are not allowed in df.
'''

error_message = "Found error in data:"

# find columns contain missing value(NaN) in df
nan_cols = df.columns[df.isnull().any()].tolist()
if len(nan_cols) > 0:
error_message += "* 'MISSING VALUE' in " + \
str(nan_cols) + ""

df_non_target = df.drop(columns=target_column, axis=1)
inf_cols_list = []

# find which features contain infinity or -infinity in df_non_target

# find columns whose data type is object
# object dtype for storing strings in pandas
# if a column contains both string and numeric data, its dtype is object.
str_cols = df.columns[df.dtypes == object].tolist()

non_str_cols = df_non_target.columns.difference(str_cols)

for col in non_str_cols:

if np.isinf(df[col]).any():
inf_cols_list.append(col)

if len(inf_cols_list) > 0:
error_message += "* '+INFINITY or -INFINITY' in " + \
str(inf_cols_list) + " "

# str_trigger = False
if len(str_cols) > 0:
error_message += "* 'STRING' in " + \
str(str_cols)+" "

return error_message


def validate_data_from_server(file_id, prediction_type, target_field, categories=None, ordinals=None, **kwargs):
# Read the data set into memory
raw_data = get_file_from_server(file_id)
df = pd.read_csv(StringIO(raw_data), sep=None, engine='python', **kwargs)
return validate_data(df, prediction_type, target_field, categories, ordinals)
# return validate_data(df, prediction_type, target_field, categories, ordinals)
return validate_data_updated(df, prediction_type, target_field, categories, ordinals)


def validate_data_from_filepath(file_id, prediction_type, target_field, categories=None, ordinals=None, **kwargs):
# Read the data set into memory
df = pd.read_csv(file_id, sep=None, engine='python', **kwargs)
# print("dfprint", df)
return validate_data(df, prediction_type, target_field, categories, ordinals)


Expand Down Expand Up @@ -190,6 +237,87 @@ def validate_data(df, prediction_type="classification", target_column=None, cate
return True, None


def validate_data_updated(df, prediction_type="classification", target_column=None, categories=None, ordinals=None):
'''
Check that a df is valid
This function checks for the following:
- prediction_type is valid
- number of rows and columns is valid
- target column is valid
- missing values in df.
- strings in df.
- inf or -inf in df.



@return tuple
boolean - validation result
string - message
'''

# check prediction type is valid
if prediction_type not in ["classification", "regression"]:
logger.warn(f"Invalid prediction type: '{prediction_type}'")
return False, f"Invalid prediction type: '{prediction_type}'"

# check the number of rows and columns is valid
if df.shape[0] < MIN_ROWS:
logger.warn("Dataset has dimensions {}, classification datasets must have at least {} rows.".format(
df.shape, MIN_ROWS))
return False, "Dataset has dimensions {}, classification datasets must have at least {} rows.".format(df.shape, MIN_ROWS)

# check the number of columns is valid
if df.shape[1] < MIN_COLS:
logger.warn("Dataset has dimensions {}, classification datasets must have at least {} columns.".format(
df.shape, MIN_COLS))
return False, "Dataset has dimensions {}, classification datasets must have at least {} columns.".format(df.shape, MIN_COLS)

# target column validation
if (target_column != None):
if not (target_column in df.columns):
logger.warn("Target column '" + target_column + "' not in data")
return False, "Target column '" + target_column + "' not in data"
if categories and target_column in categories:
logger.warn("Target column '" + target_column +
"' cannot be a categorical feature")
return False, "Target column '" + target_column + "' cannot be a categorical feature"
if ordinals and target_column in ordinals:
logger.warn("Target column '" + target_column +
"' cannot be an ordinal feature")
return False, "Target column '" + target_column + "' cannot be an ordinal feature"

# check only check if target is specified
if target_column:

# classification
if (prediction_type == "classification"):
# target column of classification problem does not need to be numeric
df_non_target = df.drop(columns=target_column, axis=1)

# Check rows per class
counts = df.groupby(target_column).count()
fails_validation = counts[counts[counts.columns[1]]
< MIN_ROW_PER_CLASS]
if (not fails_validation.empty):
msg = "Classification datasets must have at least 2 rows per class, class(es) '{}' have only 1 row.".format(
list(fails_validation.index.values))
logger.warn(msg)
return False, msg

# In the below code,the check_dataframe() checks whether features and target column contain only processed data.
# check whether each column contains only processed data or not
# missing values are not allowed in df
# strings are not allowed in df
# inf or -inf are not allowed in df
if (len(df.columns)) > 0:
error_message = check_dataframe(df, target_column)
if error_message != "Found error in data:":
logger.warn(str(error_message))
return False, str(error_message)

return True, None


def get_file_from_server(file_id):
'''
Retrieve a file from the main Aliro server
Expand Down Expand Up @@ -262,8 +390,6 @@ def main():
if args.JSON_ORDINALS:
ordinals = simplejson.loads(args.JSON_ORDINALS)
prediction_type = args.PREDICTION_TYPE
# print("categories: ")
# print(categories)

if (args.IDENTIFIER_TYPE == 'filepath'):
success, errorMessage = validate_data_from_filepath(
Expand Down
1 change: 1 addition & 0 deletions lab/webapp/src/components/FileUpload/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -2448,6 +2448,7 @@ handleCatFeaturesUserTextCancel() {
<Modal.Content>{this.state.errorModalContent}</Modal.Content>



</Modal>
)
}
Expand Down