Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve infrastructure for experimental dispatching of non existing methods in cuML #6148

Merged
merged 13 commits into from
Dec 12, 2024
Merged
22 changes: 22 additions & 0 deletions python/cuml/cuml/internals/base.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -792,3 +792,25 @@ class UniversalBase(Base):
"""

return False

def _check_cpu_model(self):
dantegd marked this conversation as resolved.
Show resolved Hide resolved
"""
Checks if an estimator already has created a _cpu_model,
and creates one if necessary.
"""
if not hasattr(self, "_cpu_model"):
self.import_cpu_model()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that the call to the import_cpu_model function is not necessary here as we already checked the presence of the _cpu_model_class attribute earlier.

self.build_cpu_model()
self.gpu_to_cpu()

def __getattr__(self, attr):
try:
super().__getattr__(attr)

except AttributeError as ex:
if GlobalSettings().accelerator_active:
if hasattr(self._cpu_model_class, attr):
self._check_cpu_model()
return getattr(self._cpu_model, attr)
wphicks marked this conversation as resolved.
Show resolved Hide resolved

raise ex
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ def test_nearest_neighbors_n_jobs(synthetic_data, n_jobs):
assert True, f"NearestNeighbors ran successfully with n_jobs={n_jobs}"


@pytest.mark.xfail(reason="cuML doesn't have radius neighbors method")
def test_nearest_neighbors_radius(synthetic_data):
X, _ = synthetic_data
radius = 1.0
Expand Down Expand Up @@ -143,7 +142,6 @@ def test_nearest_neighbors_kneighbors_graph(synthetic_data):
), f"Each sample should have {n_neighbors} neighbors in the graph"


@pytest.mark.xfail(reason="cuML doesn't have radius neighbors graph method")
def test_nearest_neighbors_radius_neighbors_graph(synthetic_data):
X, _ = synthetic_data
radius = 1.0
Expand Down
Loading