Description
Proactive Error Propagation
This issue tracks the implementation of proactive error propagation in torchFT, coming from discussions with @d4l3k and @guodong.
Would love to hear feedback as to which parts are most necessary and should be implemented first!
Overview
Currently, torchFT relies on timeout mechanism to detect errors from neighboring nodes during a forward pass. This would be prohibitively expensive in scenarios of high failure. One potential solution to this is to introduce proactive error propagation mechanisms that complement timeouts.
TorchFT Relies on Timeout Mechanism for Failure Detection
Currently in torchFT, the failure response mechanism is the following (visualized below):
Upon failure
-
Horizontal DDP
a. For the horizontal (DDP)train.py
:all_reduce
timeout
b.should_commit
RPC returnsfalse
for all ranks in the corresponding replica group -
Vertical MP
a. Noshould_commit
RPC, soshould_commit
RPC hangs indefinitely
b. Manager does not join the next quorum
c. Cluster manager (e.g., Slurm/Kubernetes) detects the failure through routine health checks and reschedules the job on another node

With this mechanism, each failure incurs the overhead of the time taken until the first gradient allreduce to be issued by the DDP peers, plus all_reduce_timeout
for the horizontal peers to register the failure after the allreduce is issued.
This overhead could potentially be decreased by decreasing allreduce timeout. However, this leads to increased probability of false timeouts in the presence of network jittering, leading also to lost iterations. Further, the timeout cannot be reduced below the time it takes for allreduce to normally complete.
A Solution: Proactive Error Propagation
A potential way to decrease the failure recovery overhead is to introduce proactive timeout mechanisms. Whenever an error happens, the error is proactively propagated to neighbors. With the error propagated, one could recover from it. See the figure below:

Implementation
There are two subproblems to solve:
- Propagation of Failure to Each Rank
- Interrupt the training thread to signal error
- Recover from error
1. Propagation of Failure to Each Rank
This could be implemented in various ways. With torchFT's infrastructure, utilizing the current ManagerServer and LighthouseServer makes most sense to me. The error propagation flow would then be the following. Note that each ManagerClient corresponds to a training process.
ManagerClient x (fails)
|
[heartbeat]
v
+---------------+
| ManagerServer |
+---------------+
| \
| \ [push]
| \
v [heartbeat]
v
+---------------+
| Lighthouse |
| Server |
+---------------+
|
[push]
v
+----------------------+
| Other ManagerServers |
+----------------------+
|
[push]
v
+-----------------------------------------+
| Other Intra‑ReplicaGroup ManagerClients |
+-----------------------------------------+
This can be split into the following subproblems:
Northbound:
a) From ManagerClient
to ManagerServer
TODO: Implement heartbeat mechanism from ManagerClient
to ManagerServer
.
b) From ManagerServer
to LighthouseServer
* This could be embedded into the current heartbeat mechanism, or use another mechanism. Embedding into heartbeat is easier. Although using another mechanism may make sense to include more reconfiguration metadata in the future.
Southbound:
c) From LighthouseServer
to ManagerServer
TODO: Implement Push Mechanism. A gRPC
stream seems to make most sense here (as @d4l3k suggested).
d) From ManagerServer
to ManagerClient
Similar to c)
2. Interrupt the training thread to signal error
- First, there would need to be a listening thread monitoring the push notifications.
- Upon push notifications, an
on_error_notification
function needs to be implemented. - The
on_error_notification
would need to communicate with the training thread in some way for it to reconfigure.
This could be done in various levels of complexity. A natural way of dealing with this is through an interrupt mechanism. However, interrupts in Python cannot interrupt blocking calls in C. So blocking calls with the TCPStore
or waiting for futures wouldn't be able to be interrupted. This would most likely be the same with other blocking calls on the CUDA
level, e.g., if the CUDA
queue fills up. Thus, we have to consider how to deal with these blocking calls.
The main blocking call that the main thread makes is waiting for collective communication futures.
In Manager.py:
for work in self._pending_work:
work.wait()
The future could be interrupted through pg.abort()
, which would cancel all futures associated with the process group, along with the group itself.
One could also—at the expense of additional complexity—call future.set_exception()
to avoid aborting the current process group, but only specific futures.
This may be beneficial in cases when the process group doesn't have to be reconfigured as no machines involved in the pg
have failed. Though this is not too useful in HSDP, it could be useful in more flexible parallelism reconfiguration strategies, like [Oobleck](https://dl.acm.org/doi/abs/10.1145/3600006.3613152)).
3. Recover from error
A simple implementation of error recovery would
- Reconfigure the process group given the error information.
- Issue a new quorum.
- For efficiency, have a default Fast Quorum that assumes no other error happens except for the current error.
- Continue with the next iteration.
Additional nice-to-have features would include changing the local batch size (#186), and also the ability to reuse the work done for the previous iteration for the next iteration.
One additional implementation complexity here is to consider how to deal with the situation of having an additional push error whilst reconfiguring. Ideally we would also want to interrupt the reconfiguration process safely and efficiently. This would require adding an abort/interrupt functionality to TCPStore during the Rendezvous process for process group reconfiguration.
Benchmark
To test the efficacy of the implementation, we should benchmark the time until the next optimizer.step()
when we inject different kinds of failures at random points in the training loop:
There are three distinct types of failures to test:
b) ManagerClient Failure
c) ManagerServer Failure
c) InterReplicaGroup Communication Failure (from e.g. Network Partitioning)