diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index 939842df3389..b085e6fe8d36 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -758,6 +758,23 @@ def _check_for_bad_pandas_dtypes(pandas_dtypes_series: pd_Series) -> None: f'Fields with bad pandas dtypes: {", ".join(bad_pandas_dtypes)}') +def _pandas_to_numpy( + data: pd_DataFrame, + target_dtype: "np.typing.DTypeLike" +) -> np.ndarray: + _check_for_bad_pandas_dtypes(data.dtypes) + try: + # most common case (no nullable dtypes) + return data.to_numpy(dtype=target_dtype, copy=False) + except TypeError: + # 1.0 <= pd version < 1.1 and nullable dtypes, least common case + # raises error because array is casted to type(pd.NA) and there's no na_value argument + return data.astype(target_dtype, copy=False).values + except ValueError: + # data has nullable dtypes, but we can specify na_value argument and copy will be made + return data.to_numpy(dtype=target_dtype, na_value=np.nan) + + def _data_from_pandas( data: pd_DataFrame, feature_name: _LGBM_FeatureNameConfiguration, @@ -790,22 +807,17 @@ def _data_from_pandas( else: # use cat cols specified by user categorical_feature = list(categorical_feature) # type: ignore[assignment] - # get numpy representation of the data - _check_for_bad_pandas_dtypes(data.dtypes) df_dtypes = [dtype.type for dtype in data.dtypes] - df_dtypes.append(np.float32) # so that the target dtype considers floats + # so that the target dtype considers floats + df_dtypes.append(np.float32) target_dtype = np.result_type(*df_dtypes) - try: - # most common case (no nullable dtypes) - data = data.to_numpy(dtype=target_dtype, copy=False) - except TypeError: - # 1.0 <= pd version < 1.1 and nullable dtypes, least common case - # raises error because array is casted to type(pd.NA) and there's no na_value argument - data = data.astype(target_dtype, copy=False).values - except ValueError: - # data has nullable dtypes, but we can specify na_value argument and copy will be made - data = data.to_numpy(dtype=target_dtype, na_value=np.nan) - return data, feature_name, categorical_feature, pandas_categorical + + return ( + _pandas_to_numpy(data, target_dtype=target_dtype), + feature_name, + categorical_feature, + pandas_categorical + ) def _dump_pandas_categorical( @@ -2805,18 +2817,7 @@ def set_label(self, label: Optional[_LGBM_LabelType]) -> "Dataset": if isinstance(label, pd_DataFrame): if len(label.columns) > 1: raise ValueError('DataFrame for label cannot have multiple columns') - _check_for_bad_pandas_dtypes(label.dtypes) - try: - # most common case (no nullable dtypes) - label = label.to_numpy(dtype=np.float32, copy=False) - except TypeError: - # 1.0 <= pd version < 1.1 and nullable dtypes, least common case - # raises error because array is casted to type(pd.NA) and there's no na_value argument - label = label.astype(np.float32, copy=False).values - except ValueError: - # data has nullable dtypes, but we can specify na_value argument and copy will be made - label = label.to_numpy(dtype=np.float32, na_value=np.nan) - label_array = np.ravel(label) + label_array = np.ravel(_pandas_to_numpy(label, target_dtype=np.float32)) elif _is_pyarrow_array(label): label_array = label else: