Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support column split in approx tree method #8847

Merged
merged 3 commits into from
Mar 1, 2023

Conversation

rongou
Copy link
Contributor

@rongou rongou commented Feb 25, 2023

Since each worker has a distinct set of columns/features, and we already build histograms locally (#8811) and partition rows collaboratively (#8828), the only change remaining is to handle split finding correctly. We first find the local best splits using the existing approach, then do a round of allgather to collect best splits from all the workers, and update splits to be globally best splits.

@rongou
Copy link
Contributor Author

rongou commented Feb 25, 2023

@trivialfis @hcho3

@trivialfis
Copy link
Member

the only change remaining is to handle split finding correctly.

That's exciting! Do you have a complete example that I can run? Would love to try it out.

@rongou
Copy link
Contributor Author

rongou commented Feb 26, 2023

At least on a small dataset it seems to produce identical results as row split:

import filecmp
import multiprocessing

import xgboost as xgb
from xgboost import RabitTracker


def train(split, rank):
    dtrain = xgb.DMatrix('demo/data/agaricus.txt.train', data_split_mode=split)
    dtest = xgb.DMatrix('demo/data/agaricus.txt.test', data_split_mode=split)
    param = {"max_depth": 2, "eta": 1, "objective": "binary:logistic"}
    watchlist = [(dtest, "eval"), (dtrain, "train")]
    num_round = 2
    bst = xgb.train(param, dtrain, num_boost_round=num_round, evals=watchlist)
    if rank == 0:
        bst.save_model(f'agaricus.model.{split}.json')


def run_worker(rabit_env, rank):
    with xgb.collective.CommunicatorContext(**rabit_env):
        print("Training with row split")
        train(0, rank)
        print("Training with column split")
        train(1, rank)


def main():
    world_size = 2
    tracker = RabitTracker(host_ip='127.0.0.1', n_workers=world_size)
    tracker.start(world_size)

    workers = []
    for rank in range(world_size):
        worker = multiprocessing.Process(target=run_worker, args=(tracker.worker_envs(), rank))
        workers.append(worker)
        worker.start()
    for worker in workers:
        worker.join()
        assert worker.exitcode == 0

    result = filecmp.cmp('agaricus.model.0.json', 'agaricus.model.1.json', shallow=False)
    print(f'Two models are equal: {result}')


if __name__ == "__main__":
    main()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants