Skip to content

Commit

Permalink
migrate to latest API
Browse files Browse the repository at this point in the history
  • Loading branch information
pang-wu committed Dec 8, 2024
1 parent 19b010b commit b442a3a
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 4 deletions.
2 changes: 1 addition & 1 deletion python/raydp/spark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,4 @@
"spark_dataframe_to_ray_dataset",
"ray_dataset_to_spark_dataframe",
"from_spark_recoverable"
]
]
2 changes: 1 addition & 1 deletion python/raydp/spark/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,4 +254,4 @@ def ray_dataset_to_spark_dataframe(spark: sql.SparkSession,
elif isinstance(sample, pa.Table):
return _convert_by_udf(spark, blocks, locations, schema)
else:
raise RuntimeError("ray.to_spark only supports arrow type blocks")
raise RuntimeError("ray.to_spark only supports arrow type blocks")
6 changes: 4 additions & 2 deletions python/raydp/tf/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,10 @@ def fit(self,
label_cols = self._label_columns
if not isinstance(label_cols, list):
label_cols = [label_cols]
preprocessor = Concatenator(output_column_name="features",
exclude=label_cols)
preprocessor = Concatenator(
columns=self._feature_columns - label_cols,
output_column_name="features",
)
train_loop_config["feature_columns"] = "features"
train_ds = preprocessor.transform(train_ds)
if evaluate_ds is not None:
Expand Down

0 comments on commit b442a3a

Please sign in to comment.