Skip to content

Implementing Proactive Error Propagation in torchFT #188

Open
@WarrenZhu050413

Description

@WarrenZhu050413

Proactive Error Propagation

This issue tracks the implementation of proactive error propagation in torchFT, coming from discussions with @d4l3k and @guodong.

Would love to hear feedback as to which parts are most necessary and should be implemented first!

Overview

Currently, torchFT relies on timeout mechanism to detect errors from neighboring nodes during a forward pass. This would be prohibitively expensive in scenarios of high failure. One potential solution to this is to introduce proactive error propagation mechanisms that complement timeouts.

TorchFT Relies on Timeout Mechanism for Failure Detection

Currently in torchFT, the failure response mechanism is the following (visualized below):

Upon failure

  1. Horizontal DDP
    a. For the horizontal (DDP) train.py: all_reduce timeout
    b. should_commit RPC returns false for all ranks in the corresponding replica group

  2. Vertical MP
    a. No should_commit RPC, so should_commit RPC hangs indefinitely
    b. Manager does not join the next quorum
    c. Cluster manager (e.g., Slurm/Kubernetes) detects the failure through routine health checks and reschedules the job on another node

torchFT Current Failure Recovery Flow

With this mechanism, each failure incurs the overhead of the time taken until the first gradient allreduce to be issued by the DDP peers, plus all_reduce_timeout for the horizontal peers to register the failure after the allreduce is issued.

This overhead could potentially be decreased by decreasing allreduce timeout. However, this leads to increased probability of false timeouts in the presence of network jittering, leading also to lost iterations. Further, the timeout cannot be reduced below the time it takes for allreduce to normally complete.

A Solution: Proactive Error Propagation

A potential way to decrease the failure recovery overhead is to introduce proactive timeout mechanisms. Whenever an error happens, the error is proactively propagated to neighbors. With the error propagated, one could recover from it. See the figure below:

Failure Recovery Flow with Proactive Error Propagation

Implementation

There are two subproblems to solve:

  1. Propagation of Failure to Each Rank
  2. Interrupt the training thread to signal error
  3. Recover from error

1. Propagation of Failure to Each Rank

This could be implemented in various ways. With torchFT's infrastructure, utilizing the current ManagerServer and LighthouseServer makes most sense to me. The error propagation flow would then be the following. Note that each ManagerClient corresponds to a training process.

ManagerClient x (fails)
      |
   [heartbeat]
      v
+---------------+
| ManagerServer |
+---------------+
   |         \
   |          \ [push]
   |           \
   v    [heartbeat]   
                v
+---------------+  
| Lighthouse    |  
| Server        |  
+---------------+  
    |
  [push]
    v
+----------------------+
| Other ManagerServers |
+----------------------+
    |
  [push]
    v
+-----------------------------------------+
| Other Intra‑ReplicaGroup ManagerClients |
+-----------------------------------------+

This can be split into the following subproblems:

Northbound:
a) From ManagerClient to ManagerServer
TODO: Implement heartbeat mechanism from ManagerClient to ManagerServer.
b) From ManagerServer to LighthouseServer
* This could be embedded into the current heartbeat mechanism, or use another mechanism. Embedding into heartbeat is easier. Although using another mechanism may make sense to include more reconfiguration metadata in the future.

Southbound:
c) From LighthouseServer to ManagerServer
TODO: Implement Push Mechanism. A gRPC stream seems to make most sense here (as @d4l3k suggested).
d) From ManagerServer to ManagerClient
Similar to c)

2. Interrupt the training thread to signal error

  1. First, there would need to be a listening thread monitoring the push notifications.
  2. Upon push notifications, an on_error_notification function needs to be implemented.
  3. The on_error_notification would need to communicate with the training thread in some way for it to reconfigure.

This could be done in various levels of complexity. A natural way of dealing with this is through an interrupt mechanism. However, interrupts in Python cannot interrupt blocking calls in C. So blocking calls with the TCPStore or waiting for futures wouldn't be able to be interrupted. This would most likely be the same with other blocking calls on the CUDA level, e.g., if the CUDA queue fills up. Thus, we have to consider how to deal with these blocking calls.

The main blocking call that the main thread makes is waiting for collective communication futures.

In Manager.py:

for work in self._pending_work:
    work.wait()

The future could be interrupted through pg.abort(), which would cancel all futures associated with the process group, along with the group itself.

One could also—at the expense of additional complexity—call future.set_exception() to avoid aborting the current process group, but only specific futures.

This may be beneficial in cases when the process group doesn't have to be reconfigured as no machines involved in the pg have failed. Though this is not too useful in HSDP, it could be useful in more flexible parallelism reconfiguration strategies, like [Oobleck](https://dl.acm.org/doi/abs/10.1145/3600006.3613152)).

3. Recover from error

A simple implementation of error recovery would

  1. Reconfigure the process group given the error information.
  2. Issue a new quorum.
  3. For efficiency, have a default Fast Quorum that assumes no other error happens except for the current error.
  4. Continue with the next iteration.

Additional nice-to-have features would include changing the local batch size (#186), and also the ability to reuse the work done for the previous iteration for the next iteration.

One additional implementation complexity here is to consider how to deal with the situation of having an additional push error whilst reconfiguring. Ideally we would also want to interrupt the reconfiguration process safely and efficiently. This would require adding an abort/interrupt functionality to TCPStore during the Rendezvous process for process group reconfiguration.

Benchmark

To test the efficacy of the implementation, we should benchmark the time until the next optimizer.step() when we inject different kinds of failures at random points in the training loop:

There are three distinct types of failures to test:

b) ManagerClient Failure
c) ManagerServer Failure
c) InterReplicaGroup Communication Failure (from e.g. Network Partitioning)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions