Skip to content

Commit

Permalink
use direct import for simpler code
Browse files Browse the repository at this point in the history
  • Loading branch information
ANazaret committed Nov 4, 2024
1 parent cf938d2 commit f990214
Showing 1 changed file with 5 additions and 7 deletions.
12 changes: 5 additions & 7 deletions src/treeffuser/samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from typing import Union

import numpy as np
import sklearn
from jaxtyping import Float
from sklearn.neighbors import KernelDensity
from tqdm import tqdm


Expand Down Expand Up @@ -118,7 +118,7 @@ def sample_kde(
self,
bandwidth: Union[float, Literal["scott", "silverman"]] = 1.0,
verbose: bool = False,
) -> List[sklearn.neighbors.KernelDensity]:
) -> List[KernelDensity]:
"""
Compute the Kernel Density Estimate (KDE) for each `x`.
Estimate: `KDE[Y | X = x]` for each `x` using Gaussian kernels from `sklearn.neighbors`.
Expand All @@ -135,8 +135,8 @@ def sample_kde(
Returns
-------
kdes : list of sklearn.neighbors.KernelDensity
A list of `sklearn.neighbors.KernelDensity` objects, one for each `x`.
kdes : list of KernelDensity
A list of `KernelDensity` objects, one for each `x`.
"""
kdes = []
for i in tqdm(
Expand All @@ -148,9 +148,7 @@ def sample_kde(
y_i = self._samples[:, i, None]
else:
y_i = self._samples[:, i, :]
kde = sklearn.neighbors.KernelDensity(
bandwidth=bandwidth, algorithm="auto", kernel="gaussian"
)
kde = KernelDensity(bandwidth=bandwidth, algorithm="auto", kernel="gaussian")
kde.fit(y_i)
kdes.append(kde)

Expand Down

0 comments on commit f990214

Please sign in to comment.