diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py
index 94ae05b8905..fd3453ad56a 100644
--- a/distributed/tests/test_worker.py
+++ b/distributed/tests/test_worker.py
@@ -46,6 +46,8 @@
 from distributed.protocol import pickle
 from distributed.scheduler import Scheduler
 from distributed.utils_test import (
+    BlockedGatherDep,
+    BlockedGetData,
     TaskStateMetadataPlugin,
     _LockedCommPool,
     assert_story,
@@ -2618,7 +2620,7 @@ def sink(a, b, *args):
         if peer_addr == a.address and msg["op"] == "get_data":
             break
 
-    # Provoke an "impossible transision exception"
+    # Provoke an "impossible transition exception"
     # By choosing a state which doesn't exist we're not running into validation
     # errors and the state machine should raise if we want to transition from
     # fetch to memory
@@ -3137,30 +3139,6 @@ async def test_task_flight_compute_oserror(c, s, a, b):
     assert_story(sum_story, expected_sum_story, strict=True)
 
 
-class BlockedGatherDep(Worker):
-    def __init__(self, *args, **kwargs):
-        self.in_gather_dep = asyncio.Event()
-        self.block_gather_dep = asyncio.Event()
-        super().__init__(*args, **kwargs)
-
-    async def gather_dep(self, *args, **kwargs):
-        self.in_gather_dep.set()
-        await self.block_gather_dep.wait()
-        return await super().gather_dep(*args, **kwargs)
-
-
-class BlockedGetData(Worker):
-    def __init__(self, *args, **kwargs):
-        self.in_get_data = asyncio.Event()
-        self.block_get_data = asyncio.Event()
-        super().__init__(*args, **kwargs)
-
-    async def get_data(self, comm, *args, **kwargs):
-        self.in_get_data.set()
-        await self.block_get_data.wait()
-        return await super().get_data(comm, *args, **kwargs)
-
-
 @gen_cluster(client=True, nthreads=[])
 async def test_gather_dep_cancelled_rescheduled(c, s):
     """At time of writing, the gather_dep implementation filtered tasks again
diff --git a/distributed/tests/test_worker_state_machine.py b/distributed/tests/test_worker_state_machine.py
index 1b6f2e3f80b..d852a42499a 100644
--- a/distributed/tests/test_worker_state_machine.py
+++ b/distributed/tests/test_worker_state_machine.py
@@ -3,9 +3,17 @@
 
 import pytest
 
+from distributed import Worker, wait
 from distributed.protocol.serialize import Serialize
 from distributed.utils import recursive_to_dict
-from distributed.utils_test import _LockedCommPool, assert_story, gen_cluster, inc
+from distributed.utils_test import (
+    BlockedGetData,
+    _LockedCommPool,
+    assert_story,
+    freeze_data_fetching,
+    gen_cluster,
+    inc,
+)
 from distributed.worker_state_machine import (
     ExecuteFailureEvent,
     ExecuteSuccessEvent,
@@ -16,12 +24,13 @@
     SendMessageToScheduler,
     StateMachineEvent,
     TaskState,
+    TaskStateState,
     UniqueTaskHeap,
     merge_recs_instructions,
 )
 
 
-async def wait_for_state(key, state, dask_worker):
+async def wait_for_state(key: str, state: TaskStateState, dask_worker: Worker) -> None:
     while key not in dask_worker.tasks or dask_worker.tasks[key].state != state:
         await asyncio.sleep(0.005)
 
@@ -247,26 +256,17 @@ def test_executefailure_to_dict():
 
 @gen_cluster(client=True)
 async def test_fetch_to_compute(c, s, a, b):
-    # Block ensure_communicating to ensure we indeed know that the task is in
-    # fetch and doesn't leave it accidentally
-    old_out_connections, b.total_out_connections = b.total_out_connections, 0
-    old_comm_threshold, b.comm_threshold_bytes = b.comm_threshold_bytes, 0
-
-    f1 = c.submit(inc, 1, workers=[a.address], key="f1", allow_other_workers=True)
-    f2 = c.submit(inc, f1, workers=[b.address], key="f2")
-
-    await wait_for_state(f1.key, "fetch", b)
-    await a.close()
-
-    b.total_out_connections = old_out_connections
-    b.comm_threshold_bytes = old_comm_threshold
+    with freeze_data_fetching(b):
+        f1 = c.submit(inc, 1, workers=[a.address], key="f1", allow_other_workers=True)
+        f2 = c.submit(inc, f1, workers=[b.address], key="f2")
+        await wait_for_state(f1.key, "fetch", b)
+        await a.close()
 
     await f2
 
     assert_story(
         b.log,
-        # FIXME: This log should be replaced with an
-        # StateMachineEvent/Instruction log
+        # FIXME: This log should be replaced with a StateMachineEvent log
         [
             (f2.key, "compute-task", "released"),
             # This is a "please fetch" request. We don't have anything like
@@ -285,23 +285,180 @@ async def test_fetch_to_compute(c, s, a, b):
 
 @gen_cluster(client=True)
 async def test_fetch_via_amm_to_compute(c, s, a, b):
-    # Block ensure_communicating to ensure we indeed know that the task is in
-    # fetch and doesn't leave it accidentally
-    old_out_connections, b.total_out_connections = b.total_out_connections, 0
-    old_comm_threshold, b.comm_threshold_bytes = b.comm_threshold_bytes, 0
-
-    f1 = c.submit(inc, 1, workers=[a.address], key="f1", allow_other_workers=True)
+    with freeze_data_fetching(b):
+        f1 = c.submit(inc, 1, workers=[a.address], key="f1", allow_other_workers=True)
+        await f1
+        s.request_acquire_replicas(b.address, [f1.key], stimulus_id="test")
+        await wait_for_state(f1.key, "fetch", b)
+        await a.close()
 
     await f1
-    s.request_acquire_replicas(b.address, [f1.key], stimulus_id="test")
 
-    await wait_for_state(f1.key, "fetch", b)
-    await a.close()
+    assert_story(
+        b.log,
+        # FIXME: This log should be replaced with a StateMachineEvent log
+        [
+            (f1.key, "ensure-task-exists", "released"),
+            (f1.key, "released", "fetch", "fetch", {}),
+            (f1.key, "compute-task", "fetch"),
+            (f1.key, "put-in-memory"),
+        ],
+    )
 
-    b.total_out_connections = old_out_connections
-    b.comm_threshold_bytes = old_comm_threshold
 
-    await f1
+@pytest.mark.parametrize("as_deps", [False, True])
+@gen_cluster(client=True, nthreads=[("", 1)] * 3)
+async def test_lose_replica_during_fetch(c, s, w1, w2, w3, as_deps):
+    """
+    as_deps=True
+        0. task x is a dependency of y1 and y2
+        1. scheduler calls handle_compute("y1", who_has={"x": [w2, w3]}) on w1
+        2. x transitions released -> fetch
+        3. the network stack is busy, so x does not transition to flight yet.
+        4. scheduler calls handle_compute("y2", who_has={"x": [w3]}) on w1
+        5. when x finally reaches the top of the data_needed heap, w1 will not try
+           contacting w2
+
+    as_deps=False
+        1. scheduler calls handle_acquire_replicas(who_has={"x": [w2, w3]}) on w1
+        2. x transitions released -> fetch
+        3. the network stack is busy, so x does not transition to flight yet.
+        4. scheduler calls handle_acquire_replicas(who_has={"x": [w3]}) on w1
+        5. when x finally reaches the top of the data_needed heap, w1 will not try
+           contacting w2
+    """
+    x = (await c.scatter({"x": 1}, workers=[w2.address, w3.address], broadcast=True))[
+        "x"
+    ]
+
+    # Make sure find_missing is not involved
+    w1.periodic_callbacks["find-missing"].stop()
+
+    with freeze_data_fetching(w1, jump_start=True):
+        if as_deps:
+            y1 = c.submit(inc, x, key="y1", workers=[w1.address])
+        else:
+            s.request_acquire_replicas(w1.address, ["x"], stimulus_id="test")
+
+        await wait_for_state("x", "fetch", w1)
+        assert w1.tasks["x"].who_has == {w2.address, w3.address}
+
+        assert len(s.tasks["x"].who_has) == 2
+        await w2.close()
+        while len(s.tasks["x"].who_has) > 1:
+            await asyncio.sleep(0.01)
+
+        if as_deps:
+            y2 = c.submit(inc, x, key="y2", workers=[w1.address])
+        else:
+            s.request_acquire_replicas(w1.address, ["x"], stimulus_id="test")
+
+        while w1.tasks["x"].who_has != {w3.address}:
+            await asyncio.sleep(0.01)
+
+    await wait_for_state("x", "memory", w1)
+    assert_story(
+        w1.story("request-dep"),
+        [("request-dep", w3.address, {"x"})],
+        # This tests that there has been no attempt to contact w2.
+        # If the assumption being tested breaks, this will fail 50% of the times.
+        strict=True,
+    )
+
+
+@gen_cluster(client=True, nthreads=[("", 1)] * 2)
+async def test_fetch_to_missing(c, s, a, b):
+    """
+    1. task x is a dependency of y
+    2. scheduler calls handle_compute("y", who_has={"x": [b]}) on a
+    3. x transitions released -> fetch -> flight; a connects to b
+    4. b responds it's busy. x transitions flight -> fetch
+    5. The busy state triggers an RPC call to Scheduler.who_has
+    6. the scheduler responds {"x": []}, because w1 in the meantime has lost the key.
+    7. x is transitioned fetch -> missing
+    """
+    x = await c.scatter({"x": 1}, workers=[b.address])
+    b.total_in_connections = 0
+    # Crucially, unlike with `c.submit(inc, x, workers=[a.address])`, the scheduler
+    # doesn't keep track of acquire-replicas requests, so it won't proactively inform a
+    # when we call remove_worker later on
+    s.request_acquire_replicas(a.address, ["x"], stimulus_id="test")
+
+    # state will flip-flop between fetch and flight every 150ms, which is the retry
+    # period for busy workers.
+    await wait_for_state("x", "fetch", a)
+    assert b.address in a.busy_workers
+
+    # Sever connection between b and s, but not between b and a.
+    # If a tries fetching from b after this, b will keep responding {status: busy}.
+    b.periodic_callbacks["heartbeat"].stop()
+    await s.remove_worker(b.address, close=False, stimulus_id="test")
+
+    await wait_for_state("x", "missing", a)
+
+    assert_story(
+        a.story("x"),
+        [
+            ("x", "ensure-task-exists", "released"),
+            ("x", "released", "fetch", "fetch", {}),
+            ("gather-dependencies", b.address, {"x"}),
+            ("x", "fetch", "flight", "flight", {}),
+            ("request-dep", b.address, {"x"}),
+            ("busy-gather", b.address, {"x"}),
+            ("x", "flight", "fetch", "fetch", {}),
+            ("x", "fetch", "missing", "missing", {}),
+        ],
+        # There may be a round of find_missing() after this.
+        # Due to timings, there also may be multiple attempts to connect from a to b.
+        strict=False,
+    )
+
+
+@pytest.mark.skip(reason="https://github.com/dask/distributed/issues/6446")
+@gen_cluster(client=True)
+async def test_new_replica_while_all_workers_in_flight(c, s, w1, w2):
+    """A task is stuck in 'fetch' state because all workers that hold a replica are in
+    flight. While in this state, a new replica appears on a different worker and the
+    scheduler informs the waiting worker through a new acquire-replicas or
+    compute-task op.
+
+    In real life, this will typically happen when the Active Memory Manager replicates a
+    key to multiple workers and some workers are much faster than others to acquire it,
+    due to unrelated tasks being in flight, so 2 seconds later the AMM reiterates the
+    request, passing a larger who_has.
+
+    Test that, when this happens, the task is immediately acquired from the new worker,
+    without waiting for the original replica holders to get out of flight.
+    """
+    # Make sure find_missing is not involved
+    w1.periodic_callbacks["find-missing"].stop()
+
+    async with BlockedGetData(s.address) as w3:
+        x = c.submit(inc, 1, key="x", workers=[w3.address])
+        y = c.submit(inc, 2, key="y", workers=[w3.address])
+        await wait([x, y])
+        s.request_acquire_replicas(w1.address, ["x"], stimulus_id="test")
+        await w3.in_get_data.wait()
+        assert w1.tasks["x"].state == "flight"
+        s.request_acquire_replicas(w1.address, ["y"], stimulus_id="test")
+        # This cannot progress beyond fetch because w3 is already in flight
+        await wait_for_state("y", "fetch", w1)
+
+        # Simulate that the AMM also requires that w2 acquires a replica of x.
+        # The replica lands on w2 soon afterwards, while w3->w1 comms remain blocked by
+        # unrelated transfers (x in our case).
+        w2.update_data({"y": 3}, report=True)
+        ws2 = s.workers[w2.address]
+        while ws2 not in s.tasks["y"].who_has:
+            await asyncio.sleep(0.01)
+
+        # 2 seconds later, the AMM reiterates that w1 should acquire a replica of y
+        s.request_acquire_replicas(w1.address, ["y"], stimulus_id="test")
+        await wait_for_state("y", "memory", w1)
+
+        # Finally let the other worker to get out of flight
+        w3.block_get_data.set()
+        await wait_for_state("x", "memory", w1)
 
 
 @gen_cluster(client=True)
diff --git a/distributed/utils_test.py b/distributed/utils_test.py
index edd74bef17e..a2780f15f1f 100644
--- a/distributed/utils_test.py
+++ b/distributed/utils_test.py
@@ -2261,3 +2261,89 @@ def wait_for_log_line(
         if match in line:
             return line
         i += 1
+
+
+class BlockedGatherDep(Worker):
+    """A Worker that sets event `in_gather_dep` the first time it enters the gather_dep
+    method and then does not initiate any comms, thus leaving the task(s) in flight
+    indefinitely, until the test sets `block_gather_dep`
+
+    Example
+    -------
+    .. code-block:: python
+
+        @gen_test()
+        async def test1(s, a, b):
+            async with BlockedGatherDep(s.address) as x:
+                # [do something to cause x to fetch data from a or b]
+                await x.in_gather_dep.wait()
+                # [do something that must happen while the tasks are in flight]
+                x.block_gather_dep.set()
+                # [from this moment on, x is a regular worker]
+
+    See also
+    --------
+    BlockedGetData
+    """
+
+    def __init__(self, *args, **kwargs):
+        self.in_gather_dep = asyncio.Event()
+        self.block_gather_dep = asyncio.Event()
+        super().__init__(*args, **kwargs)
+
+    async def gather_dep(self, *args, **kwargs):
+        self.in_gather_dep.set()
+        await self.block_gather_dep.wait()
+        return await super().gather_dep(*args, **kwargs)
+
+
+class BlockedGetData(Worker):
+    """A Worker that sets event `in_get_data` the first time it enters the get_data
+    method and then does not answer the comms, thus leaving the task(s) in flight
+    indefinitely, until the test sets `block_get_data`
+
+    See also
+    --------
+    BlockedGatherDep
+    """
+
+    def __init__(self, *args, **kwargs):
+        self.in_get_data = asyncio.Event()
+        self.block_get_data = asyncio.Event()
+        super().__init__(*args, **kwargs)
+
+    async def get_data(self, comm, *args, **kwargs):
+        self.in_get_data.set()
+        await self.block_get_data.wait()
+        return await super().get_data(comm, *args, **kwargs)
+
+
+@contextmanager
+def freeze_data_fetching(w: Worker, *, jump_start: bool = False):
+    """Prevent any task from transitioning from fetch to flight on the worker while
+    inside the context, simulating a situation where the worker's network comms are
+    saturated.
+
+    This is not the same as setting the worker to Status=paused, which would also
+    inform the Scheduler and prevent further tasks to be enqueued on the worker.
+
+    Parameters
+    ----------
+    w: Worker
+        The Worker on which tasks will not transition from fetch to flight
+    jump_start: bool
+        If False, tasks will remain in fetch state after exiting the context, until
+        something else triggers ensure_communicating.
+        If True, trigger ensure_communicating on exit; this simulates e.g. an unrelated
+        worker moving out of in_flight_workers.
+    """
+    old_out_connections = w.total_out_connections
+    old_comm_threshold = w.comm_threshold_bytes
+    w.total_out_connections = 0
+    w.comm_threshold_bytes = 0
+    yield
+    w.total_out_connections = old_out_connections
+    w.comm_threshold_bytes = old_comm_threshold
+    if jump_start:
+        w.status = Status.paused
+        w.status = Status.running
diff --git a/distributed/worker.py b/distributed/worker.py
index c146cb56fff..0ad15cd3fb0 100644
--- a/distributed/worker.py
+++ b/distributed/worker.py
@@ -1890,7 +1890,7 @@ def handle_acquire_replicas(
             if ts.state != "memory":
                 recommendations[ts] = "fetch"
 
-        self.update_who_has(who_has)
+        self._update_who_has(who_has)
         self.transitions(recommendations, stimulus_id=stimulus_id)
 
         if self.validate:
@@ -1994,7 +1994,7 @@ def handle_compute_task(
             for dep_key, value in nbytes.items():
                 self.tasks[dep_key].nbytes = value
 
-            self.update_who_has(who_has)
+            self._update_who_has(who_has)
         else:  # pragma: nocover
             raise RuntimeError(f"Unexpected task state encountered {ts} {stimulus_id}")
 
@@ -2101,8 +2101,7 @@ def transition_released_waiting(
             if dep_ts.state != "memory":
                 ts.waiting_for_data.add(dep_ts)
                 dep_ts.waiters.add(ts)
-                if dep_ts.state not in {"fetch", "flight"}:
-                    recommendations[dep_ts] = "fetch"
+                recommendations[dep_ts] = "fetch"
 
         if ts.waiting_for_data:
             self.waiting_for_data_count += 1
@@ -2689,7 +2688,7 @@ def _transition(
             assert not args
             finish, *args = finish  # type: ignore
 
-        if ts is None or ts.state == finish:
+        if ts.state == finish:
             return {}, []
 
         start = ts.state
@@ -2992,8 +2991,11 @@ def _ensure_communicating(self, *, stimulus_id: str) -> RecsInstrs:
             if ts.state != "fetch" or ts.key in all_keys_to_gather:
                 continue
 
+            if not ts.who_has:
+                recommendations[ts] = "missing"
+                continue
+
             if self.validate:
-                assert ts.who_has
                 assert self.address not in ts.who_has
 
             workers = [
@@ -3146,8 +3148,14 @@ def _select_keys_for_gather(
 
         while tasks:
             ts = tasks.peek()
-            if ts.state != "fetch" or ts.key in all_keys_to_gather:
+            if (
+                ts.state != "fetch"
                 # Do not acquire the same key twice if multiple workers holds replicas
+                or ts.key in all_keys_to_gather
+                # A replica is still available (otherwise status would not be 'fetch'
+                # anymore), but not on this worker. See _update_who_has().
+                or worker not in ts.who_has
+            ):
                 tasks.pop()
                 continue
             if total_bytes + ts.get_nbytes() > self.target_message_size:
@@ -3373,7 +3381,7 @@ def done_event():
                 who_has = await retry_operation(
                     self.scheduler.who_has, keys=refresh_who_has
                 )
-                self.update_who_has(who_has)
+                self._update_who_has(who_has)
 
     @log_errors
     def _readd_busy_worker(self, worker: str) -> None:
@@ -3397,7 +3405,7 @@ async def find_missing(self) -> None:
                 self.scheduler.who_has,
                 keys=[ts.key for ts in self._missing_dep_flight],
             )
-            self.update_who_has(who_has)
+            self._update_who_has(who_has)
             recommendations: Recs = {}
             for ts in self._missing_dep_flight:
                 if ts.who_has:
@@ -3411,34 +3419,46 @@ async def find_missing(self) -> None:
                 "find-missing"
             ].callback_time = self.periodic_callbacks["heartbeat"].callback_time
 
-    def update_who_has(self, who_has: dict[str, Collection[str]]) -> None:
-        try:
-            for dep, workers in who_has.items():
-                if not workers:
-                    continue
+    def _update_who_has(self, who_has: Mapping[str, Collection[str]]) -> None:
+        for key, workers in who_has.items():
+            ts = self.tasks.get(key)
+            if not ts:
+                # The worker sent a refresh-who-has request to the scheduler but, by the
+                # time the answer comes back, some of the keys have been forgotten.
+                continue
+            workers = set(workers)
+
+            if self.address in workers:
+                workers.remove(self.address)
+                # This can only happen if rebalance() recently asked to release a key,
+                # but the RPC call hasn't returned yet. rebalance() is flagged as not
+                # being safe to run while the cluster is not at rest and has already
+                # been penned in to be redesigned on top of the AMM.
+                # It is not necessary to send a message back to the
+                # scheduler here, because it is guaranteed that there's already a
+                # release-worker-data message in transit to it.
+                if ts.state != "memory":
+                    logger.debug(  # pragma: nocover
+                        "Scheduler claims worker %s holds data for task %s, "
+                        "which is not true.",
+                        self.address,
+                        ts,
+                    )
 
-                if dep in self.tasks:
-                    dep_ts = self.tasks[dep]
-                    if self.address in workers and self.tasks[dep].state != "memory":
-                        logger.debug(
-                            "Scheduler claims worker %s holds data for task %s which is not true.",
-                            self.name,
-                            dep,
-                        )
-                        # Do not mutate the input dict. That's rude
-                        workers = set(workers) - {self.address}
-                    dep_ts.who_has.update(workers)
+            if ts.who_has == workers:
+                continue
 
-                    for worker in workers:
-                        self.has_what[worker].add(dep)
-                        self.data_needed_per_worker[worker].push(dep_ts)
-        except Exception as e:  # pragma: no cover
-            logger.exception(e)
-            if LOG_PDB:
-                import pdb
+            for worker in ts.who_has - workers:
+                self.has_what[worker].discard(key)
+                # Can't remove from self.data_needed_per_worker; there is logic
+                # in _select_keys_for_gather to deal with this
 
-                pdb.set_trace()
-            raise
+            for worker in workers - ts.who_has:
+                self.has_what[worker].add(key)
+                if ts.state == "fetch":
+                    self.data_needed_per_worker[worker].push(ts)
+
+            ts.who_has = workers
 
     def handle_steal_request(self, key: str, stimulus_id: str) -> None:
         # There may be a race condition between stealing and releasing a task.
@@ -4170,8 +4190,8 @@ def validate_task_fetch(self, ts):
         assert self.address not in ts.who_has
         assert not ts.done
         assert ts in self.data_needed
-        assert ts.who_has
-
+        # Note: ts.who_has may be have been emptied by _update_who_has, but the task
+        # won't transition to missing until it reaches the top of the data_needed heap.
         for w in ts.who_has:
             assert ts.key in self.has_what[w]
             assert ts in self.data_needed_per_worker[w]
@@ -4262,6 +4282,7 @@ def validate_state(self):
                 assert ts.state is not None
                 # check that worker has task
                 for worker in ts.who_has:
+                    assert worker != self.address
                     assert ts.key in self.has_what[worker]
                 # check that deps have a set state and that dependency<->dependent links
                 # are there
@@ -4286,6 +4307,7 @@ def validate_state(self):
             # FIXME https://github.com/dask/distributed/issues/6319
             # assert self.waiting_for_data_count == waiting_for_data_count
             for worker, keys in self.has_what.items():
+                assert worker != self.address
                 for k in keys:
                     assert worker in self.tasks[k].who_has