Skip to content

Commit

Permalink
feat: Install Bigframes sklearn dependencies automatically
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 571490689
  • Loading branch information
matthew29tang authored and copybara-github committed Oct 7, 2023
1 parent 9b75259 commit 7aaffe5
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 15 deletions.
16 changes: 11 additions & 5 deletions vertexai/preview/_workflow/serialization_engine/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1138,14 +1138,20 @@ def serialize(
if not _is_valid_gcs_path(gcs_path):
raise ValueError(f"Invalid gcs path: {gcs_path}")

BigframeSerializer._metadata.dependencies = (
supported_frameworks._get_bigframe_deps()
)

# Record the framework in metadata for deserialization
detected_framework = kwargs.get("framework")
BigframeSerializer._metadata.framework = detected_framework
if detected_framework == "torch":

# Reset dependencies and custom_commands in case the framework is different
BigframeSerializer._metadata.dependencies = []
BigframeSerializer._metadata.custom_commands = []

# Add dependencies based on framework
if detected_framework == "sklearn":
sklearn_deps = supported_frameworks._get_pandas_deps()
sklearn_deps += supported_frameworks._get_pyarrow_deps()
BigframeSerializer._metadata.dependencies += sklearn_deps
elif detected_framework == "torch":
# Install using custom_commands to avoid numpy dependency conflict
BigframeSerializer._metadata.custom_commands.append("pip install torchdata")
BigframeSerializer._metadata.custom_commands.append("pip install torcharrow")
Expand Down
10 changes: 0 additions & 10 deletions vertexai/preview/_workflow/shared/supported_frameworks.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,16 +276,6 @@ def _get_deps_if_pandas_dataframe(possible_dataframe: Any) -> List[str]:
return deps


def _get_bigframe_deps() -> List[str]:
deps = []
# Note: bigframe serialization can only occur locally so bigframes
# should not be installed remotely. Pandas and pyarrow are required
# to deserialize for sklearn bigframes though.
deps += _get_pandas_deps()
deps += _get_pyarrow_deps()
return deps


def _get_pyarrow_deps() -> List[str]:
deps = []
try:
Expand Down

0 comments on commit 7aaffe5

Please sign in to comment.