Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Performance] [Speculative decoding] Support draft model on different tensor-parallel size than target model #4933

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

GeauxEric
Copy link
Contributor

FIX #4632

@GeauxEric GeauxEric marked this pull request as draft May 21, 2024 05:03
@rkooo567 rkooo567 requested a review from cadedaniel May 21, 2024 10:40
@cadedaniel
Copy link
Collaborator

Thanks! Ping when when this PR is ready!

@wooyeonlee0
Copy link
Contributor

Great! I'm looking forward to this feature :)

@GeauxEric
Copy link
Contributor Author

GeauxEric commented May 27, 2024

Not very familiar with distributed training and inference, so I spent some time reading the code base.

@cadedaniel, I got two questions about the expected behavior.

First question. For the proposal model, does its TP_size also need to conform world size == TP_size * PP_size?, if that is the case, when TP_size == 1, we will have PP_size > 0, which is not supported yet. A temporary solution is to let each worker have its own copy of the small draft model.

Second question. If we do need to perform distributed inference of the draft model, since there are two models (scoring and proposal) now, that means the function initialize_model_parallel needs to be invoked twice for each worker? One for scoring model and one for proposal model:

def initialize_model_parallel(
    tensor_model_parallel_size: int = 1,
    pipeline_model_parallel_size: int = 1,
    backend: Optional[str] = None,
) -> None:
    """
    Initialize model parallel groups.

@GeauxEric GeauxEric marked this pull request as ready for review May 27, 2024 22:16
@cadedaniel
Copy link
Collaborator

cadedaniel commented Jun 3, 2024

First question. For the proposal model, does its TP_size also need to conform world size == TP_size * PP_size?, if that is the case, when TP_size == 1, we will have PP_size > 0, which is not supported yet. A temporary solution is to let each worker have its own copy of the small draft model.

Let's leave out the PP case for now. In the future we can add more configurations that benefit PP latency. You can assume PP size is always 1.

Second question. If we do need to perform distributed inference of the draft model, since there are two models (scoring and proposal) now, that means the function initialize_model_parallel needs to be invoked twice for each worker? One for scoring model and one for proposal model:

Good question.. in my internal fork we had an ability to skip initialization the second time. See should_init_distributed_env=False in the following code.

    def init_model(self):
        """Initialize the model on all ranks.
        This also creates a single-rank process group containing only the
        self process.
        """
        world_rank = torch.distributed.get_rank()
        self._single_tp_group = torch.distributed.new_group([world_rank])

        with patch_tensor_parallel_group(self._single_tp_group):
            self._worker.init_model(should_init_distributed_env=False)

Then in the spec decode worker we initialize the larger model first.

# The scorer worker model is initialized first in case the proposer
# model has a smaller TP degree than the target worker.
self.scorer_worker.init_device()
self.proposer_worker.init_device()

@GeauxEric GeauxEric marked this pull request as draft June 4, 2024 20:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Performance] [Speculative decoding]: Support draft model on different tensor-parallel size than target model
4 participants