Skip to content

Stop actor meshes #537

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions hyperactor/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ declare_attrs! {

/// Maximum buffer size for split port messages
pub attr SPLIT_MAX_BUFFER_SIZE: usize = 5;

/// Timeout used by proc mesh for stopping an actor.
pub attr STOP_ACTOR_TIMEOUT: Duration = Duration::from_secs(1);
}

/// Load configuration from environment variables
Expand Down
46 changes: 29 additions & 17 deletions hyperactor/src/proc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,7 @@ impl Proc {
/// Call `abort` on the `JoinHandle` associated with the given
/// root actor. If successful return `Some(root.clone())` else
/// `None`.
pub fn abort_root_actor(&mut self, root: &ActorId) -> Option<ActorId> {
pub fn abort_root_actor(&self, root: &ActorId) -> Option<ActorId> {
self.state()
.ledger
.roots
Expand All @@ -511,17 +511,12 @@ impl Proc {
.next()
}

// Iterating over a proc's root actors signaling each to stop.
// Return the root actor IDs and status observers.
async fn destroy(
&mut self,
) -> Result<HashMap<ActorId, watch::Receiver<ActorStatus>>, anyhow::Error> {
tracing::debug!("{}: proc stopping", self.proc_id());

let mut statuses = HashMap::new();
for entry in self.state().ledger.roots.iter() {
/// Signals to a root actor to stop,
/// returning a status observer if successful.
pub fn stop_actor(&self, actor_id: &ActorId) -> Option<watch::Receiver<ActorStatus>> {
if let Some(entry) = self.state().ledger.roots.get(actor_id) {
match entry.value().upgrade() {
None => (), // the root's cell has been dropped
None => None, // the root's cell has been dropped
Some(cell) => {
tracing::info!("sending stop signal to {}", cell.actor_id());
if let Err(err) = cell.signal(Signal::DrainAndStop) {
Expand All @@ -531,15 +526,16 @@ impl Proc {
cell.pid(),
err
);
continue;
None
} else {
Some(cell.status().clone())
}
statuses.insert(cell.actor_id().clone(), cell.status().clone());
}
}
} else {
tracing::error!("no actor {} found in {} roots", actor_id, self.proc_id());
None
}

tracing::debug!("{}: proc stopped", self.proc_id());
Ok(statuses)
}

/// Stop the proc. Returns a pair of:
Expand All @@ -553,7 +549,23 @@ impl Proc {
timeout: Duration,
skip_waiting: Option<&ActorId>,
) -> Result<(Vec<ActorId>, Vec<ActorId>), anyhow::Error> {
let mut statuses = self.destroy().await?;
tracing::debug!("{}: proc stopping", self.proc_id());

let mut statuses = HashMap::new();
for actor_id in self
.state()
.ledger
.roots
.iter()
.map(|entry| entry.key().clone())
.collect::<Vec<_>>()
{
if let Some(status) = self.stop_actor(&actor_id) {
statuses.insert(actor_id, status);
}
}
tracing::debug!("{}: proc stopped", self.proc_id());

let waits: Vec<_> = statuses
.iter_mut()
.filter(|(actor_id, _)| Some(*actor_id) != skip_waiting)
Expand Down
1 change: 1 addition & 0 deletions hyperactor_mesh/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ tracing-subscriber = { version = "0.3.19", features = ["chrono", "env-filter", "
[dev-dependencies]
maplit = "1.0"
timed_test = { version = "0.0.0", path = "../timed_test" }
tracing-test = { version = "0.2.3", features = ["no-env-filter"] }

[lints]
rust = { unexpected_cfgs = { check-cfg = ["cfg(fbcode_build)"], level = "warn" } }
70 changes: 70 additions & 0 deletions hyperactor_mesh/src/actor_mesh.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,13 @@ pub trait ActorMesh: Mesh<Id = ActorMeshId> {
self.shape().slice().iter().map(move |rank| gang.rank(rank))
}

fn stop(&self) -> impl std::future::Future<Output = Result<(), anyhow::Error>> + Send
where
Self: Sync,
{
async { self.proc_mesh().stop_actor_by_name(self.name()).await }
}

/// Get a serializeable reference to this mesh similar to ActorHandle::bind
fn bind(&self) -> ActorMeshRef<Self::Actor> {
ActorMeshRef::attest(
Expand Down Expand Up @@ -1023,6 +1030,69 @@ mod tests {
);
assert!(events.next().await.is_none());
}

#[tracing_test::traced_test]
#[tokio::test]
async fn test_stop_actor_mesh() {
use hyperactor::test_utils::pingpong::PingPongActor;
use hyperactor::test_utils::pingpong::PingPongActorParams;
use hyperactor::test_utils::pingpong::PingPongMessage;

let config = hyperactor::config::global::lock();
let _guard = config.override_key(
hyperactor::config::MESSAGE_DELIVERY_TIMEOUT,
tokio::time::Duration::from_secs(1),
);

let alloc = LocalAllocator
.allocate(AllocSpec {
shape: shape! { replica = 2 },
constraints: Default::default(),
})
.await
.unwrap();
let mesh = ProcMesh::allocate(alloc).await.unwrap();

let ping_pong_actor_params = PingPongActorParams::new(
PortRef::attest_message_port(mesh.client().actor_id()),
None,
);
let mesh_one: RootActorMesh<PingPongActor> = mesh
.spawn::<PingPongActor>("mesh_one", &ping_pong_actor_params)
.await
.unwrap();

let mesh_two: RootActorMesh<PingPongActor> = mesh
.spawn::<PingPongActor>("mesh_two", &ping_pong_actor_params)
.await
.unwrap();

mesh_two.stop().await.unwrap();

let ping_two: ActorRef<PingPongActor> = mesh_two.get(0).unwrap();
let pong_two: ActorRef<PingPongActor> = mesh_two.get(1).unwrap();

assert!(logs_contain(&format!(
"stopped actor {}",
ping_two.actor_id()
)));
assert!(logs_contain(&format!(
"stopped actor {}",
pong_two.actor_id()
)));

// Other actor meshes on this proc mesh should still be up and running
let ping_one: ActorRef<PingPongActor> = mesh_one.get(0).unwrap();
let pong_one: ActorRef<PingPongActor> = mesh_one.get(1).unwrap();
let (done_tx, done_rx) = mesh.client().open_once_port();
pong_one
.send(
mesh.client(),
PingPongMessage(1, ping_one.clone(), done_tx.bind()),
)
.unwrap();
assert!(done_rx.recv().await.is_ok());
}
} // mod local

mod process {
Expand Down
37 changes: 37 additions & 0 deletions hyperactor_mesh/src/proc_mesh.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ use std::sync::Arc;

use async_trait::async_trait;
use dashmap::DashMap;
use futures::future::join_all;
use hyperactor::Actor;
use hyperactor::ActorId;
use hyperactor::ActorRef;
use hyperactor::Mailbox;
use hyperactor::Named;
Expand Down Expand Up @@ -56,6 +58,7 @@ use crate::comm::CommActorMode;
use crate::proc_mesh::mesh_agent::GspawnResult;
use crate::proc_mesh::mesh_agent::MeshAgent;
use crate::proc_mesh::mesh_agent::MeshAgentMessageClient;
use crate::proc_mesh::mesh_agent::StopActorResult;
use crate::reference::ProcMeshId;

pub mod mesh_agent;
Expand Down Expand Up @@ -449,6 +452,40 @@ impl ProcMesh {
pub fn shape(&self) -> &Shape {
&self.shape
}

/// Send stop actors message to all mesh agents for a specific mesh name
pub async fn stop_actor_by_name(&self, mesh_name: &str) -> Result<(), anyhow::Error> {
let timeout = hyperactor::config::global::get(hyperactor::config::STOP_ACTOR_TIMEOUT);
let results = join_all(self.agents().map(|agent| async move {
let actor_id = ActorId(agent.actor_id().proc_id().clone(), mesh_name.to_string(), 0);
(
actor_id.clone(),
agent
.clone()
.stop_actor(&self.client, actor_id, timeout.as_millis() as u64)
.await,
)
}))
.await;

for (actor_id, result) in results {
match result {
Ok(StopActorResult::Timeout) => {
tracing::error!("timed out while stopping actor {}", actor_id);
}
Ok(StopActorResult::NotFound) => {
tracing::error!("no actor {} on proc {}", actor_id, actor_id.proc_id());
}
Ok(StopActorResult::Success) => {
tracing::info!("stopped actor {}", actor_id);
}
Err(e) => {
tracing::error!("error stopping actor {}: {}", actor_id, e);
}
}
}
Ok(())
}
}

/// Proc lifecycle events.
Expand Down
47 changes: 46 additions & 1 deletion hyperactor_mesh/src/proc_mesh/mesh_agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,17 @@ use hyperactor::HandleClient;
use hyperactor::Handler;
use hyperactor::Instance;
use hyperactor::Named;
use hyperactor::OncePortRef;
use hyperactor::PortHandle;
use hyperactor::PortRef;
use hyperactor::ProcId;
use hyperactor::RefClient;
use hyperactor::actor::ActorStatus;
use hyperactor::actor::remote::Remote;
use hyperactor::channel;
use hyperactor::channel::ChannelAddr;
use hyperactor::clock::Clock;
use hyperactor::clock::RealClock;
use hyperactor::mailbox::BoxedMailboxSender;
use hyperactor::mailbox::DeliveryError;
use hyperactor::mailbox::DialMailboxRouter;
Expand All @@ -52,11 +56,17 @@ pub enum GspawnResult {
Error(String),
}

#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Named)]
pub enum StopActorResult {
Success,
Timeout,
NotFound,
}

#[derive(
Debug,
Clone,
PartialEq,
Eq,
Serialize,
Deserialize,
Handler,
Expand Down Expand Up @@ -91,6 +101,17 @@ pub(crate) enum MeshAgentMessage {
/// reply port; the proc should send its rank to indicated a spawned actor
status_port: PortRef<GspawnResult>,
},

/// Stop actors of a specific mesh name
StopActor {
/// The actor to stop
actor_id: ActorId,
/// The timeout for waiting for the actor to stop
timeout_ms: u64,
/// The result when trying to stop the actor
#[reply]
stopped: OncePortRef<StopActorResult>,
},
}

/// A mesh agent is responsible for managing procs in a [`ProcMesh`].
Expand Down Expand Up @@ -224,6 +245,30 @@ impl MeshAgentMessageHandler for MeshAgent {
status_port.send(cx, GspawnResult::Success { rank, actor_id })?;
Ok(())
}

async fn stop_actor(
&mut self,
_cx: &Context<Self>,
actor_id: ActorId,
timeout_ms: u64,
) -> Result<StopActorResult, anyhow::Error> {
tracing::info!("Stopping actor: {}", actor_id);

if let Some(mut status) = self.proc.stop_actor(&actor_id) {
match RealClock
.timeout(
tokio::time::Duration::from_millis(timeout_ms),
status.wait_for(|state: &ActorStatus| matches!(*state, ActorStatus::Stopped)),
)
.await
{
Ok(_) => Ok(StopActorResult::Success),
Err(_) => Ok(StopActorResult::Timeout),
}
} else {
Ok(StopActorResult::NotFound)
}
}
}

#[async_trait]
Expand Down
19 changes: 19 additions & 0 deletions monarch_hyperactor/src/actor_mesh.rs
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,25 @@ impl PythonActorMesh {
fn __reduce_ex__(&self, _proto: u8) -> PyResult<()> {
Err(self.pickling_err())
}

fn stop<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
let actor_mesh = self.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let actor_mesh = actor_mesh
.take()
.await
.map_err(|_| PyRuntimeError::new_err("`ActorMesh` has already been stopped"))?;
actor_mesh.stop().await.map_err(|err| {
PyException::new_err(format!("Failed to stop actor mesh: {}", err))
})?;
Ok(())
})
}

#[getter]
fn stopped(&self) -> PyResult<bool> {
Ok(self.inner.borrow().is_err())
}
}

#[pyclass(
Expand Down
14 changes: 14 additions & 0 deletions python/monarch/_rust_bindings/monarch_hyperactor/actor_mesh.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,20 @@ class PythonActorMesh:
"""
...

async def stop(self) -> None:
"""
Stop all actors that are part of this mesh.
Using this mesh after stop() is called will raise an Exception.
"""
...

@property
def stopped(self) -> bool:
"""
If the mesh has been stopped.
"""
...

@final
class ActorMeshMonitor:
def __aiter__(self) -> AsyncIterator["ActorSupervisionEvent"]:
Expand Down
Loading
Loading