Skip to content

Commit

Permalink
Introduce a plugin API to provide all thread local state, and depreca…
Browse files Browse the repository at this point in the history
…te stdio-specific methods (Cherry-pick of #15890) (#15916)

As described in #15887: `StreamingWorkunit` plugins have never been able to set thread-local `WorkunitStore` state, but that apparently didn't matter until #11331 made it possible for the `StreamingWorkunitContext` file-fetching methods to encounter data which had not yet been fetched (and thus needed to create a workunit for the fetching).

This change updates and "deprecates" the existing `stdio_thread_[gs]et_destination` methods (although it doesn't have access to a decorator to do that), and introduces generic thread-local state methods which include all thread-local state required by engine APIs.

Fixes #15887.

[ci skip-build-wheels]
  • Loading branch information
stuhood authored Jun 24, 2022
1 parent 69a10c7 commit d7e7a1a
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 12 deletions.
6 changes: 3 additions & 3 deletions src/python/pants/bsp/testutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from pants.bsp.context import BSPContext
from pants.bsp.protocol import BSPConnection
from pants.bsp.rules import rules as bsp_rules
from pants.engine.internals import native_engine
from pants.engine.internals.native_engine import PyThreadLocals
from pants.testutil.rule_runner import RuleRunner


Expand Down Expand Up @@ -93,7 +93,7 @@ def setup_bsp_server(
):
rule_runner = rule_runner or RuleRunner(rules=bsp_rules())
notification_names = notification_names or set()
stdio_destination = native_engine.stdio_thread_get_destination()
thread_locals = PyThreadLocals.get_for_current_thread()

with setup_pipes() as pipes:
context = BSPContext()
Expand All @@ -107,7 +107,7 @@ def setup_bsp_server(
)

def run_bsp_server():
native_engine.stdio_thread_set_destination(stdio_destination)
thread_locals.set_for_current_thread()
conn.run()

bsp_thread = Thread(target=run_bsp_server)
Expand Down
5 changes: 5 additions & 0 deletions src/python/pants/engine/internals/native_engine.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -386,5 +386,10 @@ class PyTypes:
class PyStdioDestination:
pass

class PyThreadLocals:
@classmethod
def get_for_current_thread(cls) -> PyThreadLocals: ...
def set_for_current_thread(self) -> None: ...

class PollTimeout(Exception):
pass
29 changes: 24 additions & 5 deletions src/python/pants/engine/streaming_workunit_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from pants.base.specs import Specs
from pants.engine.addresses import Addresses
from pants.engine.fs import Digest, DigestContents, FileDigest, Snapshot
from pants.engine.internals import native_engine
from pants.engine.internals.native_engine import PyThreadLocals
from pants.engine.internals.scheduler import SchedulerSession, Workunit
from pants.engine.internals.selectors import Params
from pants.engine.rules import Get, MultiGet, QueryRule, collect_rules, rule
Expand All @@ -30,6 +30,24 @@
# -----------------------------------------------------------------------------------------------


def thread_locals_get_for_current_thread() -> PyThreadLocals:
"""Gets the engine's thread local state for the current thread.
In order to safely use StreamingWorkunitContext methods from additional threads,
StreamingWorkunit plugins should propagate thread local state from the threads that they are
initialized on to any additional threads that they spawn.
"""
return PyThreadLocals.get_for_current_thread()


def thread_locals_set_for_current_thread(thread_locals: PyThreadLocals) -> None:
"""Sets the engine's thread local state for the current thread.
See `thread_locals_get`.
"""
thread_locals.set_for_current_thread()


@dataclass(frozen=True)
class TargetInfo:
filename: str
Expand Down Expand Up @@ -246,9 +264,9 @@ def __init__(
self.block_until_complete = not allow_async_completion or any(
callback.can_finish_async is False for callback in self.callbacks
)
# Get the parent thread's logging destination. Note that this thread has not yet started
# Get the parent thread's thread locals. Note that this thread has not yet started
# as we are only in the constructor.
self.logging_destination = native_engine.stdio_thread_get_destination()
self.thread_locals = PyThreadLocals.get_for_current_thread()

def poll_workunits(self, *, finished: bool) -> None:
workunits = self.scheduler.poll_workunits(self.max_workunit_verbosity)
Expand All @@ -261,8 +279,9 @@ def poll_workunits(self, *, finished: bool) -> None:
)

def run(self) -> None:
# First, set the thread's logging destination to the parent thread's, meaning the console.
native_engine.stdio_thread_set_destination(self.logging_destination)
# First, set the thread's thread locals to the parent thread's in order to propagate the
# console, workunit stores, etc.
self.thread_locals.set_for_current_thread()
while not self.stop_request.isSet():
self.poll_workunits(finished=False)
self.stop_request.wait(timeout=self.report_interval)
Expand Down
37 changes: 33 additions & 4 deletions src/rust/engine/src/externs/interface.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ use rule_graph::{self, RuleGraph};
use task_executor::Executor;
use workunit_store::{
ArtifactOutput, ObservationMetric, UserMetadataItem, Workunit, WorkunitState, WorkunitStore,
WorkunitStoreHandle,
};

use crate::externs::fs::{todo_possible_store_missing_digest, PyFileDigest};
Expand Down Expand Up @@ -73,6 +74,7 @@ fn native_engine(py: Python, m: &PyModule) -> PyO3Result<()> {
m.add_class::<PySessionCancellationLatch>()?;
m.add_class::<PyStdioDestination>()?;
m.add_class::<PyTasks>()?;
m.add_class::<PyThreadLocals>()?;
m.add_class::<PyTypes>()?;

m.add_class::<externs::PyGeneratorResponseBreak>()?;
Expand Down Expand Up @@ -225,7 +227,7 @@ impl PyTypes {
struct PyScheduler(Scheduler);

#[pyclass]
struct PyStdioDestination(Arc<stdio::Destination>);
struct PyStdioDestination(PyThreadLocals);

/// Represents configuration related to process execution strategies.
///
Expand Down Expand Up @@ -500,6 +502,30 @@ fn py_result_from_root(py: Python, result: Result<Value, Failure>) -> PyResult {
}
}

#[pyclass]
struct PyThreadLocals(Arc<stdio::Destination>, Option<WorkunitStoreHandle>);

impl PyThreadLocals {
fn get() -> Self {
let stdio_dest = stdio::get_destination();
let workunit_store_handle = workunit_store::get_workunit_store_handle();
Self(stdio_dest, workunit_store_handle)
}
}

#[pymethods]
impl PyThreadLocals {
#[classmethod]
fn get_for_current_thread(_cls: &PyType) -> Self {
Self::get()
}

fn set_for_current_thread(&self) {
stdio::set_thread_destination(self.0.clone());
workunit_store::set_thread_workunit_store_handle(self.1.clone());
}
}

#[pyfunction]
fn nailgun_server_create(
py_executor: &externs::scheduler::PyExecutor,
Expand Down Expand Up @@ -1641,15 +1667,18 @@ fn stdio_thread_console_clear() {
stdio::get_destination().console_clear();
}

// TODO: Deprecated, but without easy access to the decorator. Use
// `PyThreadLocals::get_for_current_thread` instead. Remove in Pants 2.17.0.dev0.
#[pyfunction]
fn stdio_thread_get_destination() -> PyStdioDestination {
let dest = stdio::get_destination();
PyStdioDestination(dest)
PyStdioDestination(PyThreadLocals::get())
}

// TODO: Deprecated, but without easy access to the decorator. Use
// `PyThreadLocals::set_for_current_thread` instead. Remove in Pants 2.17.0.dev0.
#[pyfunction]
fn stdio_thread_set_destination(stdio_destination: &PyStdioDestination) {
stdio::set_thread_destination(stdio_destination.0.clone());
stdio_destination.0.set_for_current_thread();
}

// TODO: Needs to be thread-local / associated with the Console.
Expand Down

0 comments on commit d7e7a1a

Please sign in to comment.