diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index da0a12259866..b097bd0748fe 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -34,6 +34,10 @@ "ctypes._Pointer[ctypes.c_int32]", "ctypes._Pointer[ctypes.c_int64]" ] +_ctypes_int_array = Union[ + "ctypes.Array[ctypes._Pointer[ctypes.c_int32]]", + "ctypes.Array[ctypes._Pointer[ctypes.c_int64]]" +] _ctypes_float_ptr = Union[ "ctypes._Pointer[ctypes.c_float]", "ctypes._Pointer[ctypes.c_double]" @@ -589,13 +593,16 @@ def _convert_from_sliced_object(data: np.ndarray) -> np.ndarray: return data -def _c_float_array(data): +def _c_float_array( + data: np.ndarray +) -> Tuple[_ctypes_float_ptr, int, np.ndarray]: """Get pointer of float numpy array / list.""" if _is_1d_list(data): data = np.array(data, copy=False) if _is_numpy_1d_array(data): data = _convert_from_sliced_object(data) assert data.flags.c_contiguous + ptr_data: _ctypes_float_ptr if data.dtype == np.float32: ptr_data = data.ctypes.data_as(ctypes.POINTER(ctypes.c_float)) type_data = _C_API_DTYPE_FLOAT32 @@ -609,13 +616,16 @@ def _c_float_array(data): return (ptr_data, type_data, data) # return `data` to avoid the temporary copy is freed -def _c_int_array(data): +def _c_int_array( + data: np.ndarray +) -> Tuple[_ctypes_int_ptr, int, np.ndarray]: """Get pointer of int numpy array / list.""" if _is_1d_list(data): data = np.array(data, copy=False) if _is_numpy_1d_array(data): data = _convert_from_sliced_object(data) assert data.flags.c_contiguous + ptr_data: _ctypes_int_ptr if data.dtype == np.int32: ptr_data = data.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)) type_data = _C_API_DTYPE_INT32 @@ -1624,10 +1634,10 @@ def _init_from_sample( # c type: double** # each double* element points to start of each column of sample data. - sample_col_ptr = (ctypes.POINTER(ctypes.c_double) * ncol)() + sample_col_ptr: _ctypes_float_array = (ctypes.POINTER(ctypes.c_double) * ncol)() # c type int** # each int* points to start of indices for each column - indices_col_ptr = (ctypes.POINTER(ctypes.c_int32) * ncol)() + indices_col_ptr: _ctypes_int_array = (ctypes.POINTER(ctypes.c_int32) * ncol)() for i in range(ncol): sample_col_ptr[i] = _c_float_array(sample_data[i])[0] indices_col_ptr[i] = _c_int_array(sample_indices[i])[0] @@ -2374,6 +2384,7 @@ def set_field( dtype = np.int32 if field_name == 'group' else np.float32 data = _list_to_1d_numpy(data, dtype, name=field_name) + ptr_data: Union[_ctypes_float_ptr, _ctypes_int_ptr] if data.dtype == np.float32 or data.dtype == np.float64: ptr_data, type_data, _ = _c_float_array(data) elif data.dtype == np.int32: