Skip to content

Commit

Permalink
checkpointing: use CheckpointTransport abstraction
Browse files Browse the repository at this point in the history
  • Loading branch information
d4l3k committed Jan 25, 2025
1 parent beb94f0 commit 0e29ef9
Show file tree
Hide file tree
Showing 9 changed files with 743 additions and 296 deletions.
28 changes: 15 additions & 13 deletions proto/torchft.proto
Original file line number Diff line number Diff line change
Expand Up @@ -72,30 +72,32 @@ service LighthouseService {
message ManagerQuorumRequest {
int64 rank = 1;
int64 step = 2;
string checkpoint_server_addr = 3;
string checkpoint_metadata = 3;
bool shrink_only = 4;
}

message ManagerQuorumResponse {
int64 quorum_id = 1;
string address = 2;
string store_address = 3;
string recover_manager_address = 2;
optional int64 recover_rank = 3;
repeated int64 recovering_ranks = 4;
string store_address = 5;
// These are information for the replicas which are at the max step.
int64 max_step = 4;
optional int64 max_rank = 5;
int64 max_world_size = 6;
int64 max_step = 6;
optional int64 max_rank = 7;
int64 max_world_size = 8;
// These are information for all replicas including behind replicas.
int64 replica_rank = 7;
int64 replica_world_size = 8;
bool heal = 9;
int64 replica_rank = 9;
int64 replica_world_size = 10;
bool heal = 11;
}

message CheckpointAddressRequest {
message CheckpointMetadataRequest {
int64 rank = 1;
}

message CheckpointAddressResponse {
string checkpoint_server_address = 1;
message CheckpointMetadataResponse {
string checkpoint_metadata = 1;
}

message ShouldCommitRequest {
Expand All @@ -114,7 +116,7 @@ message KillResponse {}

service ManagerService {
rpc Quorum (ManagerQuorumRequest) returns (ManagerQuorumResponse);
rpc CheckpointAddress(CheckpointAddressRequest) returns (CheckpointAddressResponse);
rpc CheckpointMetadata(CheckpointMetadataRequest) returns (CheckpointMetadataResponse);
rpc ShouldCommit(ShouldCommitRequest) returns (ShouldCommitResponse);
rpc Kill(KillRequest) returns (KillResponse);
}
76 changes: 57 additions & 19 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ pub mod torchftpb {
}

use crate::torchftpb::manager_service_client::ManagerServiceClient;
use crate::torchftpb::{CheckpointAddressRequest, ManagerQuorumRequest, ShouldCommitRequest};
use crate::torchftpb::{CheckpointMetadataRequest, ManagerQuorumRequest, ShouldCommitRequest};
use pyo3::prelude::*;

#[pyclass]
Expand Down Expand Up @@ -113,15 +113,15 @@ impl ManagerClient {
py: Python<'_>,
rank: i64,
step: i64,
checkpoint_server_addr: String,
checkpoint_metadata: String,
shrink_only: bool,
timeout: Duration,
) -> Result<(i64, i64, i64, String, String, i64, Option<i64>, i64, bool), StatusError> {
) -> Result<QuorumResult, StatusError> {
py.allow_threads(move || {
let mut request = tonic::Request::new(ManagerQuorumRequest {
rank: rank,
step: step,
checkpoint_server_addr: checkpoint_server_addr,
checkpoint_metadata: checkpoint_metadata,
shrink_only: shrink_only,
});

Expand All @@ -131,38 +131,40 @@ impl ManagerClient {

let response = self.runtime.block_on(self.client.clone().quorum(request))?;
let resp = response.into_inner();
Ok((
resp.quorum_id,
resp.replica_rank,
resp.replica_world_size,
resp.address,
resp.store_address,
resp.max_step,
resp.max_rank,
resp.max_world_size,
resp.heal,
))
Ok(QuorumResult {
quorum_id: resp.quorum_id,
replica_rank: resp.replica_rank,
replica_world_size: resp.replica_world_size,
recover_manager_address: resp.recover_manager_address,
recover_rank: resp.recover_rank,
recovering_ranks: resp.recovering_ranks,
store_address: resp.store_address,
max_step: resp.max_step,
max_rank: resp.max_rank,
max_world_size: resp.max_world_size,
heal: resp.heal,
})
})
}

fn checkpoint_address(
fn checkpoint_metadata(
&self,
py: Python<'_>,
rank: i64,
timeout: Duration,
) -> Result<String, StatusError> {
py.allow_threads(move || {
let mut request = tonic::Request::new(CheckpointAddressRequest { rank: rank });
let mut request = tonic::Request::new(CheckpointMetadataRequest { rank: rank });

// This timeout is processed on the server side so we also enable
// keep alives to detect server health.
request.set_timeout(timeout);

let response = self
.runtime
.block_on(self.client.clone().checkpoint_address(request))?;
.block_on(self.client.clone().checkpoint_metadata(request))?;
let resp = response.into_inner();
Ok(resp.checkpoint_server_address)
Ok(resp.checkpoint_metadata)
})
}

Expand Down Expand Up @@ -194,6 +196,41 @@ impl ManagerClient {
}
}

#[pyclass(get_all, set_all)]
struct QuorumResult {
quorum_id: i64,
replica_rank: i64,
replica_world_size: i64,
recover_manager_address: String,
recover_rank: Option<i64>,
recovering_ranks: Vec<i64>,
store_address: String,
max_step: i64,
max_rank: Option<i64>,
max_world_size: i64,
heal: bool,
}

#[pymethods]
impl QuorumResult {
#[new]
fn new() -> Self {
Self {
quorum_id: 0,
replica_rank: 0,
replica_world_size: 1,
recover_manager_address: "".to_string(),
recover_rank: None,
recovering_ranks: Vec::new(),
store_address: "".to_string(),
max_step: 0,
max_rank: None,
max_world_size: 1,
heal: false,
}
}
}

fn reset_python_signals(py: Python<'_>) -> PyResult<()> {
// clear python signal handlers
// signal.signal(signal.SIGINT, signal.SIG_DFL)
Expand Down Expand Up @@ -319,6 +356,7 @@ fn torchft(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<Manager>()?;
m.add_class::<ManagerClient>()?;
m.add_class::<Lighthouse>()?;
m.add_class::<QuorumResult>()?;
m.add_function(wrap_pyfunction!(lighthouse_main, m)?)?;

Ok(())
Expand Down
Loading

0 comments on commit 0e29ef9

Please sign in to comment.