Skip to content

Commit

Permalink
Update comment
Browse files Browse the repository at this point in the history
  • Loading branch information
kingychiu committed Dec 22, 2023
1 parent 0d34aa8 commit 5077101
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions target_permutation_importances/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ def compute(
Args:
model_cls: The constructor/class of the model.
model_cls_params: The parameters to pass to the model constructor.
model_fit_params: A function that return parameters to pass to the model fit method.
model_fit_params: A Dict or A function that return parameters to pass to the model fit method.
X: The input data.
y: The target vector.
num_actual_runs: Number of actual runs. Defaults to 2.
Expand Down Expand Up @@ -379,7 +379,7 @@ def compute(
model_cls_params={ # The parameters to pass to the model constructor. Update this based on your needs.
"n_jobs": -1,
},
model_fit_params=lambda _: {}, # The parameters to pass to the model fit method. Update this based on your needs.
model_fit_params={}, # The parameters to pass to the model fit method. Update this based on your needs.
X=Xpd, # pd.DataFrame, np.ndarray
y=data.target, # pd.Series, np.ndarray
num_actual_runs=2,
Expand Down Expand Up @@ -424,9 +424,10 @@ def _model_builder(is_random_run: bool, run_idx: int) -> Any:
return model_cls(**_model_cls_params)

def _model_fitter(model: Any, X: XType, y: YType) -> Any:
if isinstance(model_fit_params, dict):
if isinstance(model_fit_params, dict): # pragma: no cover
_model_fit_params = model_fit_params.copy()
else:
# Assume it is a function
_model_fit_params = model_fit_params(
list(X.columns) if isinstance(X, pd.DataFrame) else None,
)
Expand Down

0 comments on commit 5077101

Please sign in to comment.