Skip to content

Commit b845c1f

Browse files
committed
Conditionally import databricks libraries
1 parent d7964e0 commit b845c1f

File tree

1 file changed

+18
-9
lines changed

1 file changed

+18
-9
lines changed

crewai_tools/tools/databricks_query_tool/databricks_query_tool.py

+18-9
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import os
2-
from typing import Any, Dict, List, Optional, Type, Union
2+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union
33

44
from crewai.tools import BaseTool
5-
from databricks.sdk import WorkspaceClient
6-
from pydantic import BaseModel, Field, model_validator
5+
from pydantic import BaseModel, Field, model_validator, ConfigDict, PrivateAttr
76

7+
if TYPE_CHECKING:
8+
from databricks.sdk import WorkspaceClient
89

910
class DatabricksQueryToolSchema(BaseModel):
1011
"""Input schema for DatabricksQueryTool."""
@@ -55,6 +56,10 @@ class DatabricksQueryTool(BaseTool):
5556
>>> results = tool.run(query="SELECT * FROM my_table LIMIT 10")
5657
"""
5758

59+
model_config = ConfigDict(
60+
arbitrary_types_allowed=True, validate_assignment=True, frozen=False
61+
)
62+
5863
name: str = "Databricks SQL Query"
5964
description: str = (
6065
"Execute SQL queries against Databricks workspace tables and return the results."
@@ -67,7 +72,7 @@ class DatabricksQueryTool(BaseTool):
6772
default_schema: Optional[str] = None
6873
default_warehouse_id: Optional[str] = None
6974

70-
_workspace_client: Optional[WorkspaceClient] = None
75+
_workspace_client: Optional["WorkspaceClient"] = PrivateAttr(None)
7176

7277
def __init__(
7378
self,
@@ -89,8 +94,6 @@ def __init__(
8994
self.default_catalog = default_catalog
9095
self.default_schema = default_schema
9196
self.default_warehouse_id = default_warehouse_id
92-
93-
# Validate that Databricks credentials are available
9497
self._validate_credentials()
9598

9699
def _validate_credentials(self) -> None:
@@ -105,10 +108,16 @@ def _validate_credentials(self) -> None:
105108
)
106109

107110
@property
108-
def workspace_client(self) -> WorkspaceClient:
111+
def workspace_client(self) -> "WorkspaceClient":
109112
"""Get or create a Databricks WorkspaceClient instance."""
110113
if self._workspace_client is None:
111-
self._workspace_client = WorkspaceClient()
114+
try:
115+
from databricks.sdk import WorkspaceClient
116+
self._workspace_client = WorkspaceClient()
117+
except ImportError:
118+
raise ImportError(
119+
"`databricks-sdk` package not found, please run `uv add databricks-sdk`"
120+
)
112121
return self._workspace_client
113122

114123
def _format_results(self, results: List[Dict[str, Any]]) -> str:
@@ -733,4 +742,4 @@ def _run(
733742
# Include more details in the error message to help with debugging
734743
import traceback
735744
error_details = traceback.format_exc()
736-
return f"Error executing Databricks query: {str(e)}\n\nDetails:\n{error_details}"
745+
return f"Error executing Databricks query: {str(e)}\n\nDetails:\n{error_details}"

0 commit comments

Comments
 (0)