diff --git a/ludwig/error.py b/ludwig/error.py new file mode 100644 index 00000000000..2eb67f45320 --- /dev/null +++ b/ludwig/error.py @@ -0,0 +1,46 @@ +# Copyright (c) 2022 Predibase, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from ludwig.api_annotations import PublicAPI + + +@PublicAPI +class LudwigError(Exception): + """Base class for all custom exceptions raised by the Ludwig framework.""" + + pass + + +@PublicAPI +class InputDataError(LudwigError, ValueError): + """Exception raised for errors in the input data. + + Appropriate for data which is not convertible to the input feature type, columns with all missing values, + categorical columns with only one category, etc... + + Attributes: + column - The name of the input column which caused the error + feature_type - The Ludwig feature type which caused the error (number, binary, category...). + message - An error message describing the situation. + """ + + def __init__(self, column_name: str, feature_type: str, message: str): + self.column_name = column_name + self.feature_type = feature_type + self.message = message + super().__init__(message) + + def __str__(self): + return f'Column "{self.column_name}" as {self.feature_type} feature: {self.message}' diff --git a/ludwig/features/binary_feature.py b/ludwig/features/binary_feature.py index 1caaa7c82f5..f4dc433e770 100644 --- a/ludwig/features/binary_feature.py +++ b/ludwig/features/binary_feature.py @@ -41,6 +41,7 @@ TIED, TYPE, ) +from ludwig.error import InputDataError from ludwig.features.base_feature import BaseFeatureMixin, InputFeature, OutputFeature, PredictModule from ludwig.schema.features.binary_feature import BinaryInputFeatureConfig, BinaryOutputFeatureConfig from ludwig.utils import calibration, output_feature_utils, strings_utils @@ -165,9 +166,8 @@ def get_feature_meta(column: DataFrame, preprocessing_parameters: Dict[str, Any] distinct_values = backend.df_engine.compute(column.drop_duplicates()) if len(distinct_values) > 2: - raise ValueError( - f"Binary feature column {column.name} expects 2 distinct values, " - f"found: {distinct_values.values.tolist()}" + raise InputDataError( + column.name, BINARY, f"expects 2 distinct values, found {distinct_values.values.tolist()}" ) if preprocessing_parameters["fallback_true_label"]: fallback_true_label = preprocessing_parameters["fallback_true_label"] diff --git a/ludwig/features/category_feature.py b/ludwig/features/category_feature.py index 66ac6e15383..76f378619fa 100644 --- a/ludwig/features/category_feature.py +++ b/ludwig/features/category_feature.py @@ -42,6 +42,7 @@ TOP_K, TYPE, ) +from ludwig.error import InputDataError from ludwig.features.base_feature import BaseFeatureMixin, InputFeature, OutputFeature, PredictModule from ludwig.schema.features.category_feature import CategoryInputFeatureConfig, CategoryOutputFeatureConfig from ludwig.utils import calibration, output_feature_utils @@ -136,8 +137,12 @@ def get_feature_meta(column, preprocessing_parameters, backend): num_most_frequent=preprocessing_parameters["most_common"], processor=backend.df_engine, ) - - return {"idx2str": idx2str, "str2idx": str2idx, "str2freq": str2freq, "vocab_size": len(str2idx)} + vocab_size = len(str2idx) + if vocab_size <= 1: + raise InputDataError( + column.name, CATEGORY, f"At least 2 distinct values are required, column only contains {str(idx2str)}" + ) + return {"idx2str": idx2str, "str2idx": str2idx, "str2freq": str2freq, "vocab_size": vocab_size} @staticmethod def feature_data(column, metadata):