Skip to content

Commit

Permalink
[pyspark] Add tracker configuration (#10281)
Browse files Browse the repository at this point in the history
  • Loading branch information
wbo4958 authored Nov 5, 2024
1 parent ccc5f05 commit 197c0ae
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 10 deletions.
69 changes: 64 additions & 5 deletions python-package/xgboost/spark/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
CommunicatorContext,
_get_default_params_from_func,
_get_gpu_id,
_get_host_ip,
_get_max_num_concurrent_tasks,
_get_rabit_args,
_get_spark_session,
Expand Down Expand Up @@ -121,6 +122,9 @@
"repartition_random_shuffle",
"pred_contrib_col",
"use_gpu",
"launch_tracker_on_driver",
"tracker_host",
"tracker_port",
]

_non_booster_params = ["missing", "n_estimators", "feature_types", "feature_weights"]
Expand Down Expand Up @@ -246,6 +250,27 @@ class _SparkXGBParams(
"A list of str to specify feature names.",
TypeConverters.toList,
)
launch_tracker_on_driver = Param(
Params._dummy(),
"launch_tracker_on_driver",
"A boolean variable. Set launch_tracker_on_driver to true if you want the tracker to be "
"launched on the driver side; otherwise, it will be launched on the executor side.",
TypeConverters.toBoolean,
)
tracker_host = Param(
Params._dummy(),
"tracker_host",
"A string variable. The tracker host IP address. To set tracker host ip, you need to "
"enable launch_tracker_on_driver to be true first",
TypeConverters.toString,
)
tracker_port = Param(
Params._dummy(),
"tracker_port",
"A string variable. The port number tracker listens on. To set tracker host port, you need "
"to enable launch_tracker_on_driver first",
TypeConverters.toInt,
)

def set_device(self, value: str) -> "_SparkXGBParams":
"""Set device, optional value: cpu, cuda, gpu"""
Expand Down Expand Up @@ -617,6 +642,7 @@ def __init__(self) -> None:
feature_names=None,
feature_types=None,
arbitrary_params_dict={},
launch_tracker_on_driver=True,
)

self.logger = get_logger(self.__class__.__name__)
Expand Down Expand Up @@ -996,6 +1022,33 @@ def _try_stage_level_scheduling(self, rdd: RDD) -> RDD:
)
return rdd.withResources(rp)

def _get_tracker_args(self) -> Tuple[bool, Dict[str, Any]]:
"""Start the tracker and return the tracker envs on the driver side"""
launch_tracker_on_driver = self.getOrDefault(self.launch_tracker_on_driver)
rabit_args = {}
if launch_tracker_on_driver:
tracker_host: Optional[str] = None
if self.isDefined(self.tracker_host):
tracker_host = self.getOrDefault(self.tracker_host)
else:
tracker_host = (
_get_spark_session().sparkContext.getConf().get("spark.driver.host")
)
assert tracker_host is not None
tracker_port = 0
if self.isDefined(self.tracker_port):
tracker_port = self.getOrDefault(self.tracker_port)

num_workers = self.getOrDefault(self.num_workers)
rabit_args.update(_get_rabit_args(tracker_host, num_workers, tracker_port))
else:
if self.isDefined(self.tracker_host) or self.isDefined(self.tracker_port):
raise ValueError(
"You must enable launch_tracker_on_driver to use "
"tracker_host and tracker_port"
)
return launch_tracker_on_driver, rabit_args

def _fit(self, dataset: DataFrame) -> "_SparkXGBModel":
# pylint: disable=too-many-statements, too-many-locals
self._validate_params()
Expand All @@ -1014,6 +1067,8 @@ def _fit(self, dataset: DataFrame) -> "_SparkXGBModel":

num_workers = self.getOrDefault(self.num_workers)

launch_tracker_on_driver, rabit_args = self._get_tracker_args()

log_level = get_logger_level(_LOG_TAG)

def _train_booster(
Expand Down Expand Up @@ -1053,21 +1108,25 @@ def _train_booster(
if use_qdm and (booster_params.get("max_bin", None) is not None):
dmatrix_kwargs["max_bin"] = booster_params["max_bin"]

_rabit_args = {}
_rabit_args = rabit_args
if context.partitionId() == 0:
_rabit_args = _get_rabit_args(context, num_workers)
if not launch_tracker_on_driver:
_rabit_args = _get_rabit_args(_get_host_ip(context), num_workers)
get_logger(_LOG_TAG, log_level).info(msg)

worker_message = {
"rabit_msg": _rabit_args,
worker_message: Dict[str, Any] = {
"use_qdm": use_qdm,
}

if not launch_tracker_on_driver:
worker_message["rabit_msg"] = _rabit_args

messages = context.allGather(message=json.dumps(worker_message))
if len(set(json.loads(x)["use_qdm"] for x in messages)) != 1:
raise RuntimeError("The workers' cudf environments are in-consistent ")

_rabit_args = json.loads(messages[0])["rabit_msg"]
if not launch_tracker_on_driver:
_rabit_args = json.loads(messages[0])["rabit_msg"]

evals_result: Dict[str, Any] = {}
with CommunicatorContext(context, **_rabit_args):
Expand Down
12 changes: 12 additions & 0 deletions python-package/xgboost/spark/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,9 @@ class SparkXGBRegressor(_SparkXGBEstimator):
Boolean value to specify if enabling sparse data optimization, if True,
Xgboost DMatrix object will be constructed from sparse matrix instead of
dense matrix.
launch_tracker_on_driver:
Boolean value to indicate whether the tracker should be launched on the driver side or
the executor side.
kwargs:
A dictionary of xgboost parameters, please refer to
Expand Down Expand Up @@ -215,6 +218,7 @@ def __init__( # pylint:disable=too-many-arguments
force_repartition: bool = False,
repartition_random_shuffle: bool = False,
enable_sparse_data_optim: bool = False,
launch_tracker_on_driver: bool = True,
**kwargs: Any,
) -> None:
super().__init__()
Expand Down Expand Up @@ -341,6 +345,9 @@ class SparkXGBClassifier(_SparkXGBEstimator, HasProbabilityCol, HasRawPrediction
Boolean value to specify if enabling sparse data optimization, if True,
Xgboost DMatrix object will be constructed from sparse matrix instead of
dense matrix.
launch_tracker_on_driver:
Boolean value to indicate whether the tracker should be launched on the driver side or
the executor side.
kwargs:
A dictionary of xgboost parameters, please refer to
Expand Down Expand Up @@ -395,6 +402,7 @@ def __init__( # pylint:disable=too-many-arguments
force_repartition: bool = False,
repartition_random_shuffle: bool = False,
enable_sparse_data_optim: bool = False,
launch_tracker_on_driver: bool = True,
**kwargs: Any,
) -> None:
super().__init__()
Expand Down Expand Up @@ -524,6 +532,9 @@ class SparkXGBRanker(_SparkXGBEstimator):
Boolean value to specify if enabling sparse data optimization, if True,
Xgboost DMatrix object will be constructed from sparse matrix instead of
dense matrix.
launch_tracker_on_driver:
Boolean value to indicate whether the tracker should be launched on the driver side or
the executor side.
kwargs:
A dictionary of xgboost parameters, please refer to
Expand Down Expand Up @@ -584,6 +595,7 @@ def __init__( # pylint:disable=too-many-arguments
force_repartition: bool = False,
repartition_random_shuffle: bool = False,
enable_sparse_data_optim: bool = False,
launch_tracker_on_driver: bool = True,
**kwargs: Any,
) -> None:
super().__init__()
Expand Down
9 changes: 4 additions & 5 deletions python-package/xgboost/spark/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,10 @@ def __init__(self, context: BarrierTaskContext, **args: Any) -> None:
super().__init__(**args)


def _start_tracker(context: BarrierTaskContext, n_workers: int) -> Dict[str, Any]:
def _start_tracker(host: str, n_workers: int, port: int = 0) -> Dict[str, Any]:
"""Start Rabit tracker with n_workers"""
args: Dict[str, Any] = {"n_workers": n_workers}
host = _get_host_ip(context)
tracker = RabitTracker(n_workers=n_workers, host_ip=host, sortby="task")
tracker = RabitTracker(n_workers=n_workers, host_ip=host, sortby="task", port=port)
tracker.start()
thread = Thread(target=tracker.wait_for)
thread.daemon = True
Expand All @@ -64,9 +63,9 @@ def _start_tracker(context: BarrierTaskContext, n_workers: int) -> Dict[str, Any
return args


def _get_rabit_args(context: BarrierTaskContext, n_workers: int) -> Dict[str, Any]:
def _get_rabit_args(host: str, n_workers: int, port: int = 0) -> Dict[str, Any]:
"""Get rabit context arguments to send to each worker."""
env = _start_tracker(context, n_workers)
env = _start_tracker(host, n_workers, port)
return env


Expand Down
28 changes: 28 additions & 0 deletions tests/test_distributed/test_with_spark/test_spark_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -1630,6 +1630,34 @@ def test_unsupported_params(self):
with pytest.raises(ValueError, match="evals_result"):
SparkXGBClassifier(evals_result={})

def test_tracker(self):
classifier = SparkXGBClassifier(
launch_tracker_on_driver=True,
tracker_host="192.168.1.32",
tracker_port=59981,
)
with pytest.raises(Exception, match="Failed to bind socket"):
classifier._get_tracker_args()

classifier = SparkXGBClassifier(
launch_tracker_on_driver=False, tracker_host="127.0.0.1", tracker_port=58892
)
with pytest.raises(
ValueError, match="You must enable launch_tracker_on_driver"
):
classifier._get_tracker_args()

classifier = SparkXGBClassifier(
launch_tracker_on_driver=True,
tracker_host="127.0.0.1",
tracker_port=58892,
num_workers=2,
)
launch_tracker_on_driver, rabit_envs = classifier._get_tracker_args()
assert launch_tracker_on_driver == True
assert rabit_envs["n_workers"] == 2
assert rabit_envs["dmlc_tracker_uri"] == "127.0.0.1"


LTRData = namedtuple("LTRData", ("df_train", "df_test", "df_train_1"))

Expand Down

0 comments on commit 197c0ae

Please sign in to comment.