Skip to content

Commit

Permalink
Update Ray Dataset with prediction example (#287)
Browse files Browse the repository at this point in the history

Prediction with Ray Data doesn't work with distributed loading (see https://discuss.ray.io/t/raytaskerror-typeerror/11486/2). This PR adds a simple example on how to do batch inference with Ray Data.

Ideally we can convert this automatically to Ray Data-based batch inference in predict() in the short term. In the long term, we should discontinue all non-ray-data APIs and move data source support to ray data instead (except for petastorm, most of them should be supported anyways).
  • Loading branch information
krfricke authored Aug 4, 2023
1 parent 6c038a2 commit 909848a
Showing 1 changed file with 9 additions and 10 deletions.
19 changes: 9 additions & 10 deletions xgboost_ray/examples/simple_ray_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
import numpy as np
import pandas as pd
import ray
from xgboost import DMatrix

from xgboost_ray import RayDMatrix, RayParams, train


def main(cpus_per_actor, num_actors):
np.random.seed(1234)
# Generate dataset
x = np.repeat(range(8), 16).reshape((32, 4))
# Even numbers --> 0, odd numbers --> 1
Expand All @@ -22,16 +24,7 @@ def main(cpus_per_actor, num_actors):
data.columns = [str(c) for c in data.columns]
data["label"] = y

# There was recent API change - the first clause covers the new
# and current Ray master API
if hasattr(ray.data, "from_pandas_refs"):
# Generate Ray dataset from 4 partitions
ray_ds = ray.data.from_pandas(data).repartition(num_actors)
else:
# Split into 4 partitions
partitions = [ray.put(part) for part in np.split(data, num_actors)]
ray_ds = ray.data.from_pandas(partitions)

ray_ds = ray.data.from_pandas(data)
train_set = RayDMatrix(ray_ds, "label")

evals_result = {}
Expand Down Expand Up @@ -62,6 +55,12 @@ def main(cpus_per_actor, num_actors):
bst.save_model(model_path)
print("Final training error: {:.4f}".format(evals_result["train"]["error"][-1]))

# Distributed prediction
scored = ray_ds.drop_columns(["label"]).map_batches(
lambda batch: {"pred": bst.predict(DMatrix(batch))}, batch_format="pandas"
)
print(scored.to_pandas())


if __name__ == "__main__":
parser = argparse.ArgumentParser()
Expand Down

0 comments on commit 909848a

Please sign in to comment.