diff --git a/src/python/pants/bsp/testutil.py b/src/python/pants/bsp/testutil.py index 935a330950e..3f39acd0dc0 100644 --- a/src/python/pants/bsp/testutil.py +++ b/src/python/pants/bsp/testutil.py @@ -16,7 +16,7 @@ from pants.bsp.context import BSPContext from pants.bsp.protocol import BSPConnection from pants.bsp.rules import rules as bsp_rules -from pants.engine.internals import native_engine +from pants.engine.internals.native_engine import PyThreadLocals from pants.testutil.rule_runner import RuleRunner @@ -93,7 +93,7 @@ def setup_bsp_server( ): rule_runner = rule_runner or RuleRunner(rules=bsp_rules()) notification_names = notification_names or set() - stdio_destination = native_engine.stdio_thread_get_destination() + thread_locals = PyThreadLocals.get_for_current_thread() with setup_pipes() as pipes: context = BSPContext() @@ -107,7 +107,7 @@ def setup_bsp_server( ) def run_bsp_server(): - native_engine.stdio_thread_set_destination(stdio_destination) + thread_locals.set_for_current_thread() conn.run() bsp_thread = Thread(target=run_bsp_server) diff --git a/src/python/pants/engine/internals/native_engine.pyi b/src/python/pants/engine/internals/native_engine.pyi index 2905c61994c..2baa3657f6a 100644 --- a/src/python/pants/engine/internals/native_engine.pyi +++ b/src/python/pants/engine/internals/native_engine.pyi @@ -386,5 +386,10 @@ class PyTypes: class PyStdioDestination: pass +class PyThreadLocals: + @classmethod + def get_for_current_thread(cls) -> PyThreadLocals: ... + def set_for_current_thread(self) -> None: ... + class PollTimeout(Exception): pass diff --git a/src/python/pants/engine/streaming_workunit_handler.py b/src/python/pants/engine/streaming_workunit_handler.py index 354021008c7..e7b00c320fe 100644 --- a/src/python/pants/engine/streaming_workunit_handler.py +++ b/src/python/pants/engine/streaming_workunit_handler.py @@ -12,7 +12,7 @@ from pants.base.specs import Specs from pants.engine.addresses import Addresses from pants.engine.fs import Digest, DigestContents, FileDigest, Snapshot -from pants.engine.internals import native_engine +from pants.engine.internals.native_engine import PyThreadLocals from pants.engine.internals.scheduler import SchedulerSession, Workunit from pants.engine.internals.selectors import Params from pants.engine.rules import Get, MultiGet, QueryRule, collect_rules, rule @@ -30,6 +30,24 @@ # ----------------------------------------------------------------------------------------------- +def thread_locals_get_for_current_thread() -> PyThreadLocals: + """Gets the engine's thread local state for the current thread. + + In order to safely use StreamingWorkunitContext methods from additional threads, + StreamingWorkunit plugins should propagate thread local state from the threads that they are + initialized on to any additional threads that they spawn. + """ + return PyThreadLocals.get_for_current_thread() + + +def thread_locals_set_for_current_thread(thread_locals: PyThreadLocals) -> None: + """Sets the engine's thread local state for the current thread. + + See `thread_locals_get`. + """ + thread_locals.set_for_current_thread() + + @dataclass(frozen=True) class TargetInfo: filename: str @@ -246,9 +264,9 @@ def __init__( self.block_until_complete = not allow_async_completion or any( callback.can_finish_async is False for callback in self.callbacks ) - # Get the parent thread's logging destination. Note that this thread has not yet started + # Get the parent thread's thread locals. Note that this thread has not yet started # as we are only in the constructor. - self.logging_destination = native_engine.stdio_thread_get_destination() + self.thread_locals = PyThreadLocals.get_for_current_thread() def poll_workunits(self, *, finished: bool) -> None: workunits = self.scheduler.poll_workunits(self.max_workunit_verbosity) @@ -261,8 +279,9 @@ def poll_workunits(self, *, finished: bool) -> None: ) def run(self) -> None: - # First, set the thread's logging destination to the parent thread's, meaning the console. - native_engine.stdio_thread_set_destination(self.logging_destination) + # First, set the thread's thread locals to the parent thread's in order to propagate the + # console, workunit stores, etc. + self.thread_locals.set_for_current_thread() while not self.stop_request.isSet(): self.poll_workunits(finished=False) self.stop_request.wait(timeout=self.report_interval) diff --git a/src/rust/engine/src/externs/interface.rs b/src/rust/engine/src/externs/interface.rs index 279df62d6fd..cbbcab44b38 100644 --- a/src/rust/engine/src/externs/interface.rs +++ b/src/rust/engine/src/externs/interface.rs @@ -42,6 +42,7 @@ use rule_graph::{self, RuleGraph}; use task_executor::Executor; use workunit_store::{ ArtifactOutput, ObservationMetric, UserMetadataItem, Workunit, WorkunitState, WorkunitStore, + WorkunitStoreHandle, }; use crate::externs::fs::{todo_possible_store_missing_digest, PyFileDigest}; @@ -73,6 +74,7 @@ fn native_engine(py: Python, m: &PyModule) -> PyO3Result<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; @@ -225,7 +227,7 @@ impl PyTypes { struct PyScheduler(Scheduler); #[pyclass] -struct PyStdioDestination(Arc); +struct PyStdioDestination(PyThreadLocals); /// Represents configuration related to process execution strategies. /// @@ -500,6 +502,30 @@ fn py_result_from_root(py: Python, result: Result) -> PyResult { } } +#[pyclass] +struct PyThreadLocals(Arc, Option); + +impl PyThreadLocals { + fn get() -> Self { + let stdio_dest = stdio::get_destination(); + let workunit_store_handle = workunit_store::get_workunit_store_handle(); + Self(stdio_dest, workunit_store_handle) + } +} + +#[pymethods] +impl PyThreadLocals { + #[classmethod] + fn get_for_current_thread(_cls: &PyType) -> Self { + Self::get() + } + + fn set_for_current_thread(&self) { + stdio::set_thread_destination(self.0.clone()); + workunit_store::set_thread_workunit_store_handle(self.1.clone()); + } +} + #[pyfunction] fn nailgun_server_create( py_executor: &externs::scheduler::PyExecutor, @@ -1641,15 +1667,18 @@ fn stdio_thread_console_clear() { stdio::get_destination().console_clear(); } +// TODO: Deprecated, but without easy access to the decorator. Use +// `PyThreadLocals::get_for_current_thread` instead. Remove in Pants 2.17.0.dev0. #[pyfunction] fn stdio_thread_get_destination() -> PyStdioDestination { - let dest = stdio::get_destination(); - PyStdioDestination(dest) + PyStdioDestination(PyThreadLocals::get()) } +// TODO: Deprecated, but without easy access to the decorator. Use +// `PyThreadLocals::set_for_current_thread` instead. Remove in Pants 2.17.0.dev0. #[pyfunction] fn stdio_thread_set_destination(stdio_destination: &PyStdioDestination) { - stdio::set_thread_destination(stdio_destination.0.clone()); + stdio_destination.0.set_for_current_thread(); } // TODO: Needs to be thread-local / associated with the Console.