diff --git a/python-package/xgboost/spark/core.py b/python-package/xgboost/spark/core.py index 3a01ba4f39f1..eb59ff81a6bb 100644 --- a/python-package/xgboost/spark/core.py +++ b/python-package/xgboost/spark/core.py @@ -86,6 +86,7 @@ CommunicatorContext, _get_default_params_from_func, _get_gpu_id, + _get_host_ip, _get_max_num_concurrent_tasks, _get_rabit_args, _get_spark_session, @@ -121,6 +122,9 @@ "repartition_random_shuffle", "pred_contrib_col", "use_gpu", + "launch_tracker_on_driver", + "tracker_host", + "tracker_port", ] _non_booster_params = ["missing", "n_estimators", "feature_types", "feature_weights"] @@ -246,6 +250,27 @@ class _SparkXGBParams( "A list of str to specify feature names.", TypeConverters.toList, ) + launch_tracker_on_driver = Param( + Params._dummy(), + "launch_tracker_on_driver", + "A boolean variable. Set launch_tracker_on_driver to true if you want the tracker to be " + "launched on the driver side; otherwise, it will be launched on the executor side.", + TypeConverters.toBoolean, + ) + tracker_host = Param( + Params._dummy(), + "tracker_host", + "A string variable. The tracker host IP address. To set tracker host ip, you need to " + "enable launch_tracker_on_driver to be true first", + TypeConverters.toString, + ) + tracker_port = Param( + Params._dummy(), + "tracker_port", + "A string variable. The port number tracker listens on. To set tracker host port, you need " + "to enable launch_tracker_on_driver first", + TypeConverters.toInt, + ) def set_device(self, value: str) -> "_SparkXGBParams": """Set device, optional value: cpu, cuda, gpu""" @@ -617,6 +642,7 @@ def __init__(self) -> None: feature_names=None, feature_types=None, arbitrary_params_dict={}, + launch_tracker_on_driver=True, ) self.logger = get_logger(self.__class__.__name__) @@ -996,6 +1022,33 @@ def _try_stage_level_scheduling(self, rdd: RDD) -> RDD: ) return rdd.withResources(rp) + def _get_tracker_args(self) -> Tuple[bool, Dict[str, Any]]: + """Start the tracker and return the tracker envs on the driver side""" + launch_tracker_on_driver = self.getOrDefault(self.launch_tracker_on_driver) + rabit_args = {} + if launch_tracker_on_driver: + tracker_host: Optional[str] = None + if self.isDefined(self.tracker_host): + tracker_host = self.getOrDefault(self.tracker_host) + else: + tracker_host = ( + _get_spark_session().sparkContext.getConf().get("spark.driver.host") + ) + assert tracker_host is not None + tracker_port = 0 + if self.isDefined(self.tracker_port): + tracker_port = self.getOrDefault(self.tracker_port) + + num_workers = self.getOrDefault(self.num_workers) + rabit_args.update(_get_rabit_args(tracker_host, num_workers, tracker_port)) + else: + if self.isDefined(self.tracker_host) or self.isDefined(self.tracker_port): + raise ValueError( + "You must enable launch_tracker_on_driver to use " + "tracker_host and tracker_port" + ) + return launch_tracker_on_driver, rabit_args + def _fit(self, dataset: DataFrame) -> "_SparkXGBModel": # pylint: disable=too-many-statements, too-many-locals self._validate_params() @@ -1014,6 +1067,8 @@ def _fit(self, dataset: DataFrame) -> "_SparkXGBModel": num_workers = self.getOrDefault(self.num_workers) + launch_tracker_on_driver, rabit_args = self._get_tracker_args() + log_level = get_logger_level(_LOG_TAG) def _train_booster( @@ -1053,21 +1108,25 @@ def _train_booster( if use_qdm and (booster_params.get("max_bin", None) is not None): dmatrix_kwargs["max_bin"] = booster_params["max_bin"] - _rabit_args = {} + _rabit_args = rabit_args if context.partitionId() == 0: - _rabit_args = _get_rabit_args(context, num_workers) + if not launch_tracker_on_driver: + _rabit_args = _get_rabit_args(_get_host_ip(context), num_workers) get_logger(_LOG_TAG, log_level).info(msg) - worker_message = { - "rabit_msg": _rabit_args, + worker_message: Dict[str, Any] = { "use_qdm": use_qdm, } + if not launch_tracker_on_driver: + worker_message["rabit_msg"] = _rabit_args + messages = context.allGather(message=json.dumps(worker_message)) if len(set(json.loads(x)["use_qdm"] for x in messages)) != 1: raise RuntimeError("The workers' cudf environments are in-consistent ") - _rabit_args = json.loads(messages[0])["rabit_msg"] + if not launch_tracker_on_driver: + _rabit_args = json.loads(messages[0])["rabit_msg"] evals_result: Dict[str, Any] = {} with CommunicatorContext(context, **_rabit_args): diff --git a/python-package/xgboost/spark/estimator.py b/python-package/xgboost/spark/estimator.py index 51e2e946f8a5..8a4840846ac2 100644 --- a/python-package/xgboost/spark/estimator.py +++ b/python-package/xgboost/spark/estimator.py @@ -161,6 +161,9 @@ class SparkXGBRegressor(_SparkXGBEstimator): Boolean value to specify if enabling sparse data optimization, if True, Xgboost DMatrix object will be constructed from sparse matrix instead of dense matrix. + launch_tracker_on_driver: + Boolean value to indicate whether the tracker should be launched on the driver side or + the executor side. kwargs: A dictionary of xgboost parameters, please refer to @@ -215,6 +218,7 @@ def __init__( # pylint:disable=too-many-arguments force_repartition: bool = False, repartition_random_shuffle: bool = False, enable_sparse_data_optim: bool = False, + launch_tracker_on_driver: bool = True, **kwargs: Any, ) -> None: super().__init__() @@ -341,6 +345,9 @@ class SparkXGBClassifier(_SparkXGBEstimator, HasProbabilityCol, HasRawPrediction Boolean value to specify if enabling sparse data optimization, if True, Xgboost DMatrix object will be constructed from sparse matrix instead of dense matrix. + launch_tracker_on_driver: + Boolean value to indicate whether the tracker should be launched on the driver side or + the executor side. kwargs: A dictionary of xgboost parameters, please refer to @@ -395,6 +402,7 @@ def __init__( # pylint:disable=too-many-arguments force_repartition: bool = False, repartition_random_shuffle: bool = False, enable_sparse_data_optim: bool = False, + launch_tracker_on_driver: bool = True, **kwargs: Any, ) -> None: super().__init__() @@ -524,6 +532,9 @@ class SparkXGBRanker(_SparkXGBEstimator): Boolean value to specify if enabling sparse data optimization, if True, Xgboost DMatrix object will be constructed from sparse matrix instead of dense matrix. + launch_tracker_on_driver: + Boolean value to indicate whether the tracker should be launched on the driver side or + the executor side. kwargs: A dictionary of xgboost parameters, please refer to @@ -584,6 +595,7 @@ def __init__( # pylint:disable=too-many-arguments force_repartition: bool = False, repartition_random_shuffle: bool = False, enable_sparse_data_optim: bool = False, + launch_tracker_on_driver: bool = True, **kwargs: Any, ) -> None: super().__init__() diff --git a/python-package/xgboost/spark/utils.py b/python-package/xgboost/spark/utils.py index 0a421031ecd4..f03770059564 100644 --- a/python-package/xgboost/spark/utils.py +++ b/python-package/xgboost/spark/utils.py @@ -51,11 +51,10 @@ def __init__(self, context: BarrierTaskContext, **args: Any) -> None: super().__init__(**args) -def _start_tracker(context: BarrierTaskContext, n_workers: int) -> Dict[str, Any]: +def _start_tracker(host: str, n_workers: int, port: int = 0) -> Dict[str, Any]: """Start Rabit tracker with n_workers""" args: Dict[str, Any] = {"n_workers": n_workers} - host = _get_host_ip(context) - tracker = RabitTracker(n_workers=n_workers, host_ip=host, sortby="task") + tracker = RabitTracker(n_workers=n_workers, host_ip=host, sortby="task", port=port) tracker.start() thread = Thread(target=tracker.wait_for) thread.daemon = True @@ -64,9 +63,9 @@ def _start_tracker(context: BarrierTaskContext, n_workers: int) -> Dict[str, Any return args -def _get_rabit_args(context: BarrierTaskContext, n_workers: int) -> Dict[str, Any]: +def _get_rabit_args(host: str, n_workers: int, port: int = 0) -> Dict[str, Any]: """Get rabit context arguments to send to each worker.""" - env = _start_tracker(context, n_workers) + env = _start_tracker(host, n_workers, port) return env diff --git a/tests/test_distributed/test_with_spark/test_spark_local.py b/tests/test_distributed/test_with_spark/test_spark_local.py index feb7b18bc035..523e1e64bbc8 100644 --- a/tests/test_distributed/test_with_spark/test_spark_local.py +++ b/tests/test_distributed/test_with_spark/test_spark_local.py @@ -1630,6 +1630,34 @@ def test_unsupported_params(self): with pytest.raises(ValueError, match="evals_result"): SparkXGBClassifier(evals_result={}) + def test_tracker(self): + classifier = SparkXGBClassifier( + launch_tracker_on_driver=True, + tracker_host="192.168.1.32", + tracker_port=59981, + ) + with pytest.raises(Exception, match="Failed to bind socket"): + classifier._get_tracker_args() + + classifier = SparkXGBClassifier( + launch_tracker_on_driver=False, tracker_host="127.0.0.1", tracker_port=58892 + ) + with pytest.raises( + ValueError, match="You must enable launch_tracker_on_driver" + ): + classifier._get_tracker_args() + + classifier = SparkXGBClassifier( + launch_tracker_on_driver=True, + tracker_host="127.0.0.1", + tracker_port=58892, + num_workers=2, + ) + launch_tracker_on_driver, rabit_envs = classifier._get_tracker_args() + assert launch_tracker_on_driver == True + assert rabit_envs["n_workers"] == 2 + assert rabit_envs["dmlc_tracker_uri"] == "127.0.0.1" + LTRData = namedtuple("LTRData", ("df_train", "df_test", "df_train_1"))