From d42a5e0f741c1d00786820f7b64c9d6d7a3805cd Mon Sep 17 00:00:00 2001 From: Vincent Vanlaer Date: Sun, 19 Nov 2023 00:27:56 +0100 Subject: [PATCH] Add trio.CapacityLimiter.wait_no_borrows This is the equivalent of trio.testing.wait_all_tasks_blocked but for anything wrapped by a CapacityLimiter. This is useful when writing tests that use to_thread --- newsfragments/2880.feature.rst | 1 + src/trio/_sync.py | 23 +++++++++++++++++++++++ src/trio/_tests/test_sync.py | 32 ++++++++++++++++++++++++++++++++ 3 files changed, 56 insertions(+) create mode 100644 newsfragments/2880.feature.rst diff --git a/newsfragments/2880.feature.rst b/newsfragments/2880.feature.rst new file mode 100644 index 0000000000..b3296f7bad --- /dev/null +++ b/newsfragments/2880.feature.rst @@ -0,0 +1 @@ +Add `trio.CapacityLimiter.wait_no_borrowers`, which blocks until no tasks are in the capacity limiter. This is intended to be used in the same way as `trio.testing.wait_all_tasks_blocked`. diff --git a/src/trio/_sync.py b/src/trio/_sync.py index 951ff892ea..deff5f65fe 100644 --- a/src/trio/_sync.py +++ b/src/trio/_sync.py @@ -220,6 +220,7 @@ def __init__(self, total_tokens: int | float): # noqa: PYI041 self._borrowers: set[Task | object] = set() # Maps tasks attempting to acquire -> borrower, to handle on-behalf-of self._pending_borrowers: dict[Task, Task | object] = {} + self._event_no_borrowers: Event = Event() # invoke the property setter for validation self.total_tokens: int | float = total_tokens assert self._total_tokens == total_tokens @@ -258,6 +259,10 @@ def _wake_waiters(self) -> None: for woken in self._lot.unpark(count=available): self._borrowers.add(self._pending_borrowers.pop(woken)) + if len(self._borrowers) == 0: + self._event_no_borrowers.set() + self._event_no_borrowers = Event() + @property def borrowed_tokens(self) -> int: """The amount of capacity that's currently in use.""" @@ -401,6 +406,24 @@ def statistics(self) -> CapacityLimiterStatistics: tasks_waiting=len(self._lot), ) + async def wait_no_borrowers(self) -> None: + """Wait until all tokens are free. + + This could be useful when testing code with trio.to_thread to make + sure no tasks are still making progress in a thread. The following + code shows how this could be used:: + + async def wait_all_settled(): + capacity_limiter = trio.to_thread.current_default_thread_limiter() + while True: + await capacity_limiter.wait_no_borrowers() + await trio.testing.wait_all_tasks_blocked() + if capacity_limiter.borrowed_tokens == 0: + break + """ + while self._borrowers: + await self._event_no_borrowers.wait() + @final class Semaphore(AsyncContextManagerMixin): diff --git a/src/trio/_tests/test_sync.py b/src/trio/_tests/test_sync.py index 9179c8a5ae..31ba38d1f4 100644 --- a/src/trio/_tests/test_sync.py +++ b/src/trio/_tests/test_sync.py @@ -165,6 +165,38 @@ async def test_CapacityLimiter_change_total_tokens() -> None: assert c.statistics().tasks_waiting == 0 +async def test_CapacityLimiter_wait_no_borrowers() -> None: + c = CapacityLimiter(3) + no_borrowers_left = False + e1 = Event() + e2 = Event() + + async def wait_event(e: Event) -> None: + async with c: + await e.wait() + + async def wait_capacity_limiter_no_borrowers() -> None: + nonlocal no_borrowers_left + await c.wait_no_borrowers() + no_borrowers_left = True + + async with _core.open_nursery() as nursery: + nursery.start_soon(wait_event, e1) + nursery.start_soon(wait_event, e2) + await wait_all_tasks_blocked() + nursery.start_soon(wait_capacity_limiter_no_borrowers) + await wait_all_tasks_blocked() + assert not no_borrowers_left + + e1.set() + await wait_all_tasks_blocked() + assert not no_borrowers_left + + e2.set() + await wait_all_tasks_blocked() + assert no_borrowers_left + + # regression test for issue #548 async def test_CapacityLimiter_memleak_548() -> None: limiter = CapacityLimiter(total_tokens=1)