Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prevent infinite transition loops; more aggressive validate_state() #6318

Merged
merged 4 commits into from
May 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions distributed/distributed-schema.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1027,6 +1027,13 @@ properties:
type: boolean
description: Enter Python Debugger on scheduling error

transition-counter-max:
oneOf:
- enum: [false]
- type: integer
description: Cause the scheduler or workers to break if they reach this
number of transitions

system-monitor:
type: object
description: |
Expand Down
4 changes: 4 additions & 0 deletions distributed/distributed.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,10 @@ distributed:
log-length: 10000 # default length of logs to keep in memory
log-format: '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
pdb-on-err: False # enter debug mode on scheduling error
# Cause scheduler and workers to break if they reach this many transitions.
# Used to debug infinite transition loops.
# Note: setting this will cause healthy long-running services to eventually break.
transition-counter-max: False
Comment on lines +280 to +283
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should try to not clutter this file with stuff we only use for internal tests. Apart from tests I don't see the usefulness of having a global limit on the number of transitions

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, this is for tests only.

system-monitor:
interval: 500ms
event-loop: tornado
Expand Down
29 changes: 22 additions & 7 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1262,6 +1262,7 @@ class SchedulerState:
"validate",
"workers",
"transition_counter",
"transition_counter_max",
"plugins",
"UNKNOWN_TASK_DURATION",
"MEMORY_RECENT_TO_OLD_TIME",
Expand Down Expand Up @@ -1354,6 +1355,9 @@ def __init__(
/ 2.0
)
self.transition_counter = 0
self.transition_counter_max = dask.config.get(
"distributed.admin.transition-counter-max"
)

@property
def memory(self) -> MemoryState:
Expand Down Expand Up @@ -1430,16 +1434,24 @@ def _transition(self, key, finish: str, stimulus_id: str, *args, **kwargs):
Scheduler.transitions : transitive version of this function
"""
try:
recommendations = {} # type: ignore
worker_msgs = {} # type: ignore
client_msgs = {} # type: ignore

ts: TaskState = self.tasks.get(key) # type: ignore
if ts is None:
return recommendations, client_msgs, worker_msgs
return {}, {}, {}
start = ts._state
if start == finish:
return recommendations, client_msgs, worker_msgs
return {}, {}, {}

# Notes:
# - in case of transition through released, this counter is incremented by 2
# - this increase happens before the actual transitions, so that it can
# catch potential infinite recursions
self.transition_counter += 1
if self.transition_counter_max:
assert self.transition_counter < self.transition_counter_max
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be nice if this raised something like a TransitionCounterMaxExceeded error, to be consistent with workers

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, but scheduler doesn't have anything like InvalidTransitionError on the worker


recommendations = {} # type: ignore
worker_msgs = {} # type: ignore
client_msgs = {} # type: ignore

if self.plugins:
dependents = set(ts.dependents)
Expand All @@ -1451,7 +1463,7 @@ def _transition(self, key, finish: str, stimulus_id: str, *args, **kwargs):
recommendations, client_msgs, worker_msgs = func(
key, stimulus_id, *args, **kwargs
) # type: ignore
self.transition_counter += 1

elif "released" not in start_finish:
assert not args and not kwargs, (args, kwargs, start_finish)
a_recs: dict
Expand Down Expand Up @@ -3173,6 +3185,7 @@ def _to_dict(self, *, exclude: Container[str] = ()) -> dict:
info = super()._to_dict(exclude=exclude)
extra = {
"transition_log": self.transition_log,
"transition_counter": self.transition_counter,
"log": self.log,
"tasks": self.tasks,
"task_groups": self.task_groups,
Expand Down Expand Up @@ -4496,6 +4509,8 @@ def validate_state(self, allow_overlap: bool = False) -> None:
actual_total_occupancy,
self.total_occupancy,
)
if self.transition_counter_max:
assert self.transition_counter < self.transition_counter_max

###################
# Manage Messages #
Expand Down
62 changes: 60 additions & 2 deletions distributed/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from dask.utils import apply, parse_timedelta, stringify, tmpfile, typename

from distributed import (
CancelledError,
Client,
Event,
Lock,
Expand Down Expand Up @@ -3215,11 +3216,67 @@ async def test_computations_futures(c, s, a, b):
assert "inc" in str(computation.groups)


@gen_cluster(client=True)
async def test_transition_counter(c, s, a, b):
@gen_cluster(client=True, nthreads=[("", 1)])
async def test_transition_counter(c, s, a):
assert s.transition_counter == 0
assert a.transition_counter == 0
await c.submit(inc, 1)
assert s.transition_counter > 1
assert a.transition_counter > 1


@pytest.mark.slow
@gen_cluster(client=True)
async def test_transition_counter_max_scheduler(c, s, a, b):
# This is set by @gen_cluster; it's False in production
assert s.transition_counter_max > 0
s.transition_counter_max = 1
with captured_logger("distributed.scheduler") as logger:
with pytest.raises(CancelledError):
await c.submit(inc, 2)
assert s.transition_counter > 1
with pytest.raises(AssertionError):
s.validate_state()
assert "transition_counter_max" in logger.getvalue()
# Scheduler state is corrupted. Avoid test failure on gen_cluster teardown.
s.validate = False


@gen_cluster(client=True, nthreads=[("", 1)])
async def test_transition_counter_max_worker(c, s, a):
# This is set by @gen_cluster; it's False in production
assert s.transition_counter_max > 0
a.transition_counter_max = 1
with captured_logger("distributed.core") as logger:
fut = c.submit(inc, 2)
while True:
try:
a.validate_state()
except AssertionError:
break
await asyncio.sleep(0.01)

assert "TransitionCounterMaxExceeded" in logger.getvalue()
# Worker state is corrupted. Avoid test failure on gen_cluster teardown.
a.validate = False


@gen_cluster(
client=True,
nthreads=[("", 1)],
config={"distributed.admin.transition-counter-max": False},
)
async def test_disable_transition_counter_max(c, s, a, b):
"""Test that the cluster can run indefinitely if transition_counter_max is disabled.
This is the default outside of @gen_cluster.
"""
assert s.transition_counter_max is False
assert a.transition_counter_max is False
assert await c.submit(inc, 1) == 2
assert s.transition_counter > 1
assert a.transition_counter > 1
s.validate_state()
a.validate_state()


@gen_cluster(
Expand Down Expand Up @@ -3339,6 +3396,7 @@ async def test_Scheduler__to_dict(c, s, a):
"status",
"thread_id",
"transition_log",
"transition_counter",
"log",
"memory",
"tasks",
Expand Down
14 changes: 12 additions & 2 deletions distributed/tests/test_stress.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,12 @@ async def test_stress_steal(c, s, *workers):


@pytest.mark.slow
@gen_cluster(nthreads=[("127.0.0.1", 1)] * 10, client=True, timeout=180)
@gen_cluster(
nthreads=[("", 1)] * 10,
client=True,
timeout=180,
config={"distributed.admin.transition-counter-max": 500_000},
)
async def test_close_connections(c, s, *workers):
da = pytest.importorskip("dask.array")
x = da.random.random(size=(1000, 1000), chunks=(1000, 1))
Expand Down Expand Up @@ -291,7 +296,12 @@ async def test_no_delay_during_large_transfer(c, s, w):


@pytest.mark.slow
@gen_cluster(client=True, Worker=Nanny, nthreads=[("127.0.0.1", 2)] * 6)
@gen_cluster(
client=True,
Worker=Nanny,
nthreads=[("", 2)] * 6,
config={"distributed.admin.transition-counter-max": 500_000},
)
async def test_chaos_rechunk(c, s, *workers):
s.allowed_failures = 10000

Expand Down
1 change: 1 addition & 0 deletions distributed/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3434,6 +3434,7 @@ async def test_Worker__to_dict(c, s, a):
"busy_workers",
"log",
"stimulus_log",
"transition_counter",
"tasks",
"logs",
"config",
Expand Down
36 changes: 22 additions & 14 deletions distributed/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,6 @@
from time import sleep
from typing import Any, Generator, Literal

from distributed.compatibility import MACOS
from distributed.scheduler import Scheduler

try:
import ssl
except ImportError:
ssl = None # type: ignore

import pytest
import yaml
from tlz import assoc, memoize, merge
Expand All @@ -43,12 +35,12 @@

import dask

from distributed import system
from distributed import Scheduler, system
from distributed import versions as version_module
from distributed.client import Client, _global_clients, default_client
from distributed.comm import Comm
from distributed.comm.tcp import TCP
from distributed.compatibility import WINDOWS
from distributed.compatibility import MACOS, WINDOWS
from distributed.config import initialize_logging
from distributed.core import (
CommClosedError,
Expand Down Expand Up @@ -79,6 +71,11 @@
)
from distributed.worker import WORKER_ANY_RUNNING, InvalidTransition, Worker

try:
import ssl
except ImportError:
ssl = None # type: ignore

try:
import dask.array # register config
except ImportError:
Expand Down Expand Up @@ -447,8 +444,6 @@ async def background_read():

def run_scheduler(q, nputs, config, port=0, **kwargs):
with dask.config.set(config):
from distributed import Scheduler

# On Python 2.7 and Unix, fork() is used to spawn child processes,
# so avoid inheriting the parent's IO loop.
with pristine_loop() as loop:
Expand Down Expand Up @@ -999,6 +994,7 @@ async def test_foo(scheduler, worker1, worker2, pytest_fixture_a, pytest_fixture
worker_kwargs = merge(
{"memory_limit": system.MEMORY_LIMIT, "death_timeout": 15}, worker_kwargs
)
config = merge({"distributed.admin.transition-counter-max": 50_000}, config)

def _(func):
if not iscoroutinefunction(func):
Expand Down Expand Up @@ -1052,8 +1048,7 @@ async def coro():
task = asyncio.create_task(coro)
coro2 = asyncio.wait_for(asyncio.shield(task), timeout)
result = await coro2
if s.validate:
s.validate_state()
validate_state(s, *workers)

except asyncio.TimeoutError:
assert task
Expand All @@ -1073,6 +1068,10 @@ async def coro():
while not task.cancelled():
await asyncio.sleep(0.01)

# Hopefully, the hang has been caused by inconsistent state,
# which should be much more meaningful than the timeout
validate_state(s, *workers)

# Remove as much of the traceback as possible; it's
# uninteresting boilerplate from utils_test and asyncio and
# not from the code being tested.
Expand Down Expand Up @@ -1205,6 +1204,15 @@ async def dump_cluster_state(
print(f"Dumped cluster state to {fname}")


def validate_state(*servers: Scheduler | Worker | Nanny) -> None:
"""Run validate_state() on the Scheduler and all the Workers of the cluster.
Excludes workers wrapped by Nannies and workers manually started by the test.
"""
for s in servers:
if s.validate and hasattr(s, "validate_state"):
s.validate_state() # type: ignore


def raises(func, exc=Exception):
try:
func()
Expand Down
Loading