Skip to content

Commit 934235d

Browse files
authored
Conditionally import Databricks library (#243)
Databricks is an optional dependency, but the tool package is imported by default, leading to ImportError exceptions. Related: crewAIInc/crewAI#2390
1 parent d7964e0 commit 934235d

File tree

1 file changed

+13
-8
lines changed

1 file changed

+13
-8
lines changed

crewai_tools/tools/databricks_query_tool/databricks_query_tool.py

+13-8
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import os
2-
from typing import Any, Dict, List, Optional, Type, Union
2+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union
33

44
from crewai.tools import BaseTool
5-
from databricks.sdk import WorkspaceClient
65
from pydantic import BaseModel, Field, model_validator
76

7+
if TYPE_CHECKING:
8+
from databricks.sdk import WorkspaceClient
89

910
class DatabricksQueryToolSchema(BaseModel):
1011
"""Input schema for DatabricksQueryTool."""
@@ -67,7 +68,7 @@ class DatabricksQueryTool(BaseTool):
6768
default_schema: Optional[str] = None
6869
default_warehouse_id: Optional[str] = None
6970

70-
_workspace_client: Optional[WorkspaceClient] = None
71+
_workspace_client: Optional["WorkspaceClient"] = None
7172

7273
def __init__(
7374
self,
@@ -89,8 +90,6 @@ def __init__(
8990
self.default_catalog = default_catalog
9091
self.default_schema = default_schema
9192
self.default_warehouse_id = default_warehouse_id
92-
93-
# Validate that Databricks credentials are available
9493
self._validate_credentials()
9594

9695
def _validate_credentials(self) -> None:
@@ -105,10 +104,16 @@ def _validate_credentials(self) -> None:
105104
)
106105

107106
@property
108-
def workspace_client(self) -> WorkspaceClient:
107+
def workspace_client(self) -> "WorkspaceClient":
109108
"""Get or create a Databricks WorkspaceClient instance."""
110109
if self._workspace_client is None:
111-
self._workspace_client = WorkspaceClient()
110+
try:
111+
from databricks.sdk import WorkspaceClient
112+
self._workspace_client = WorkspaceClient()
113+
except ImportError:
114+
raise ImportError(
115+
"`databricks-sdk` package not found, please run `uv add databricks-sdk`"
116+
)
112117
return self._workspace_client
113118

114119
def _format_results(self, results: List[Dict[str, Any]]) -> str:
@@ -733,4 +738,4 @@ def _run(
733738
# Include more details in the error message to help with debugging
734739
import traceback
735740
error_details = traceback.format_exc()
736-
return f"Error executing Databricks query: {str(e)}\n\nDetails:\n{error_details}"
741+
return f"Error executing Databricks query: {str(e)}\n\nDetails:\n{error_details}"

0 commit comments

Comments
 (0)