From 881c8ad311e09035344b4d00a4daf8f5069987a5 Mon Sep 17 00:00:00 2001 From: Adam Hearn <22334119+hearnadam@users.noreply.github.com> Date: Mon, 2 Jun 2025 21:36:49 -0700 Subject: [PATCH] Enqueue+Dequeue: add `shutdownCause` method - implement method for `Hub` and `Queue` --- .../shared/src/test/scala/zio/HubSpec.scala | 125 ++++++++++++++++++ .../shared/src/test/scala/zio/QueueSpec.scala | 64 +++++++++ .../shared/src/test/scala/zio/ZPoolSpec.scala | 2 +- core/shared/src/main/scala/zio/Dequeue.scala | 6 + core/shared/src/main/scala/zio/Enqueue.scala | 6 + core/shared/src/main/scala/zio/Hub.scala | 113 ++++++++++------ core/shared/src/main/scala/zio/Queue.scala | 75 ++++++----- core/shared/src/main/scala/zio/ZIO.scala | 4 +- .../src/main/scala/zio/stream/platform.scala | 2 +- .../src/main/scala/zio/stream/ZStream.scala | 2 + 10 files changed, 322 insertions(+), 77 deletions(-) diff --git a/core-tests/shared/src/test/scala/zio/HubSpec.scala b/core-tests/shared/src/test/scala/zio/HubSpec.scala index 2c9291d7bdc9..2294b0366b23 100644 --- a/core-tests/shared/src/test/scala/zio/HubSpec.scala +++ b/core-tests/shared/src/test/scala/zio/HubSpec.scala @@ -60,6 +60,131 @@ object HubSpec extends ZIOBaseSpec { } } ), + suite("shutdown")( + test("shutdown with take fiber") { + for { + selfId <- ZIO.fiberId + hub <- Hub.bounded[Int](3) + f <- ZIO.scoped(hub.subscribe.flatMap(_.take)).fork + _ <- hub.shutdown + res <- f.join.sandbox.either + } yield assert(res.left.map(_.untraced))(isLeft(equalTo(Cause.interrupt(selfId)))) + }, + test("shutdown with publish fiber") { + for { + selfId <- ZIO.fiberId + hub <- Hub.bounded[Int](2) + _ <- hub.publish(1) + _ <- hub.publish(1) + f <- hub.publish(1).fork + _ <- hub.shutdown + res <- f.join.sandbox.either + } yield assert(res)(isLeft(equalTo(Cause.interrupt(selfId)))) + }, + test("shutdown with publish") { + for { + selfId <- ZIO.fiberId + hub <- Hub.bounded[Int](1) + _ <- hub.shutdown + res <- hub.publish(1).sandbox.either + } yield assert(res)(isLeft(equalTo(Cause.interrupt(selfId)))) + }, + test("shutdown with publishAll") { + for { + selfId <- ZIO.fiberId + hub <- Hub.bounded[Int](1) + _ <- hub.shutdown + res <- hub.publishAll(List(1)).sandbox.either + } yield assert(res)(isLeft(equalTo(Cause.interrupt(selfId)))) + }, + test("shutdown with size") { + for { + selfId <- ZIO.fiberId + hub <- Hub.bounded[Int](1) + _ <- hub.shutdown + res <- hub.size.sandbox.either + } yield assert(res)(isLeft(equalTo(Cause.interrupt(selfId)))) + } + ), + suite("shutdownCause")( + test("shutdown with take fiber using Cause.die") { + for { + hub <- Hub.bounded[Int](3) + f <- ZIO.scoped(hub.subscribe.flatMap(_.take)).fork + cause = Cause.die(new RuntimeException("test")) + _ <- hub.shutdownCause(cause) + res <- f.join.sandbox.either + } yield assert(res.left.map(_.untraced))(isLeft(equalTo(cause))) + }, + test("shutdown with publish fiber using Cause.die") { + for { + hub <- Hub.bounded[Int](2) + _ <- hub.publish(1) + _ <- hub.publish(1) + f <- hub.publish(1).fork + cause = Cause.die(new RuntimeException("test")) + _ <- hub.shutdownCause(cause) + res <- f.join.sandbox.either + } yield assert(res.left.map(_.untraced))(isLeft(equalTo(cause))) + }, + test("shutdown with publish using Cause.die") { + for { + hub <- Hub.bounded[Int](1) + cause = Cause.die(new RuntimeException("test")) + _ <- hub.shutdownCause(cause) + res <- hub.publish(1).sandbox.either + } yield assert(res.left.map(_.untraced))(isLeft(equalTo(cause))) + }, + test("shutdown with publishAll using Cause.die") { + for { + hub <- Hub.bounded[Int](1) + cause = Cause.die(new RuntimeException("test")) + _ <- hub.shutdownCause(cause) + res <- hub.publishAll(List(1)).sandbox.either + } yield assert(res.left.map(_.untraced))(isLeft(equalTo(cause))) + }, + test("shutdown with size using Cause.die") { + for { + hub <- Hub.bounded[Int](1) + cause = Cause.die(new RuntimeException("test")) + _ <- hub.shutdownCause(cause) + res <- hub.size.sandbox.either + } yield assert(res.left.map(_.untraced))(isLeft(equalTo(cause))) + } + ), + suite("awaitShutdown")( + test("single") { + for { + hub <- Hub.bounded[Int](3) + p <- Promise.make[Nothing, Boolean] + _ <- (hub.awaitShutdown *> p.succeed(true)).fork + _ <- hub.shutdown + res <- p.await + } yield assert(res)(isTrue) + }, + test("multiple") { + for { + hub <- Hub.bounded[Int](3) + p1 <- Promise.make[Nothing, Boolean] + p2 <- Promise.make[Nothing, Boolean] + _ <- (hub.awaitShutdown *> p1.succeed(true)).fork + _ <- (hub.awaitShutdown *> p2.succeed(true)).fork + _ <- hub.shutdown + res1 <- p1.await + res2 <- p2.await + } yield assert(res1)(isTrue) && + assert(res2)(isTrue) + }, + test("already shutdown") { + for { + hub <- Hub.bounded[Int](3) + _ <- hub.shutdown + p <- Promise.make[Nothing, Boolean] + _ <- (hub.awaitShutdown *> p.succeed(true)).fork + res <- p.await + } yield assert(res)(isTrue) + } + ), suite("concurrent publishers and subscribers")( test("one to one") { check(smallInt, Gen.listOf(smallInt)) { (n, as) => diff --git a/core-tests/shared/src/test/scala/zio/QueueSpec.scala b/core-tests/shared/src/test/scala/zio/QueueSpec.scala index d5e6987b86f0..430dea93e918 100644 --- a/core-tests/shared/src/test/scala/zio/QueueSpec.scala +++ b/core-tests/shared/src/test/scala/zio/QueueSpec.scala @@ -755,6 +755,70 @@ object QueueSpec extends ZIOBaseSpec { _ <- f.await } yield assertCompletes } @@ exceptJS(nonFlaky), + suite("shutdownCause")( + test("shutdown with take fiber using Cause.die") { + for { + queue <- Queue.bounded[Int](3) + f <- queue.take.fork + _ <- waitForSize(queue, -1) + cause = Cause.die(new RuntimeException("test")) + _ <- queue.shutdownCause(cause) + res <- f.join.sandbox.either + } yield assert(res.left.map(_.untraced))(isLeft(equalTo(cause))) + }, + test("shutdown with offer fiber using Cause.die") { + for { + queue <- Queue.bounded[Int](2) + _ <- queue.offer(1) + _ <- queue.offer(1) + f <- queue.offer(1).fork + _ <- waitForSize(queue, 3) + cause = Cause.die(new RuntimeException("test")) + _ <- queue.shutdownCause(cause) + res <- f.join.sandbox.either + } yield assert(res.left.map(_.untraced))(isLeft(equalTo(cause))) + }, + test("shutdown with offer using Cause.die") { + for { + queue <- Queue.bounded[Int](1) + cause = Cause.die(new RuntimeException("test")) + _ <- queue.shutdownCause(cause) + res <- queue.offer(1).sandbox.either + } yield assert(res.left.map(_.untraced))(isLeft(equalTo(cause))) + }, + test("shutdown with take using Cause.die") { + for { + queue <- Queue.bounded[Int](1) + cause = Cause.die(new RuntimeException("test")) + _ <- queue.shutdownCause(cause) + res <- queue.take.sandbox.either + } yield assert(res.left.map(_.untraced))(isLeft(equalTo(cause))) + }, + test("shutdown with takeAll using Cause.die") { + for { + queue <- Queue.bounded[Int](1) + cause = Cause.die(new RuntimeException("test")) + _ <- queue.shutdownCause(cause) + res <- queue.takeAll.sandbox.either + } yield assert(res.left.map(_.untraced))(isLeft(equalTo(cause))) + }, + test("shutdown with takeUpTo using Cause.die") { + for { + queue <- Queue.bounded[Int](1) + cause = Cause.die(new RuntimeException("test")) + _ <- queue.shutdownCause(cause) + res <- queue.takeUpTo(1).sandbox.either + } yield assert(res.left.map(_.untraced))(isLeft(equalTo(cause))) + }, + test("shutdown with size using Cause.die") { + for { + queue <- Queue.bounded[Int](1) + cause = Cause.die(new RuntimeException("test")) + _ <- queue.shutdownCause(cause) + res <- queue.size.sandbox.either + } yield assert(res.left.map(_.untraced))(isLeft(equalTo(cause))) + } + ), suite("back-pressured bounded queue stress testing") { val genChunk = Gen.chunkOfBounded(20, 100)(smallInt) List( diff --git a/core-tests/shared/src/test/scala/zio/ZPoolSpec.scala b/core-tests/shared/src/test/scala/zio/ZPoolSpec.scala index a2c8a69a28c8..a3fb9ae14ae3 100644 --- a/core-tests/shared/src/test/scala/zio/ZPoolSpec.scala +++ b/core-tests/shared/src/test/scala/zio/ZPoolSpec.scala @@ -140,7 +140,7 @@ object ZPoolSpec extends ZIOBaseSpec { scope <- Scope.make pool <- scope.extend(ZPool.make(get, 10)) _ <- ZIO.scoped(pool.get).fork.repeatN(99) - _ <- scope.close(Exit.succeed(())) + _ <- scope.close(Exit.unit) _ <- count.get.repeatUntil(_ == 0) } yield assertCompletes } @@ exceptJS(nonFlaky) + diff --git a/core/shared/src/main/scala/zio/Dequeue.scala b/core/shared/src/main/scala/zio/Dequeue.scala index c4b1bf8451ac..38215cb50545 100644 --- a/core/shared/src/main/scala/zio/Dequeue.scala +++ b/core/shared/src/main/scala/zio/Dequeue.scala @@ -44,6 +44,12 @@ sealed trait Dequeue[+A] extends Serializable { */ def shutdown(implicit trace: Trace): UIO[Unit] + /** + * Shuts down the Dequeue with a specific Cause, either `Die` or `Interrupt`. + * Future calls to `take*` fail immediately. + */ + def shutdownCause(cause: Cause[Nothing])(implicit trace: Trace): UIO[Unit] = shutdown(trace) + /** * Retrieves the size of the queue. This may be negative if fibers are * suspended waiting for elements to be added to the queue or greater than the diff --git a/core/shared/src/main/scala/zio/Enqueue.scala b/core/shared/src/main/scala/zio/Enqueue.scala index 17733d8a9947..c8e422e10aae 100644 --- a/core/shared/src/main/scala/zio/Enqueue.scala +++ b/core/shared/src/main/scala/zio/Enqueue.scala @@ -68,6 +68,12 @@ sealed trait Enqueue[-A] extends Serializable { */ def shutdown(implicit trace: Trace): UIO[Unit] + /** + * Shuts down the queue with a specific Cause, either `Die` or `Interrupt`. + * Future calls to `offer*` and `take*` fail immediately. + */ + def shutdownCause(cause: Cause[Nothing])(implicit trace: Trace): UIO[Unit] = shutdown(trace) + /** * Retrieves the size of the queue. This may be negative if fibers are * suspended waiting for elements to be added to the queue or greater than the diff --git a/core/shared/src/main/scala/zio/Hub.scala b/core/shared/src/main/scala/zio/Hub.scala index 76de187de686..df46f8bdfc10 100644 --- a/core/shared/src/main/scala/zio/Hub.scala +++ b/core/shared/src/main/scala/zio/Hub.scala @@ -127,15 +127,17 @@ object Hub { strategy: Strategy[A] ): Hub[A] = new Hub[A] { + private def interrupted(implicit trace: Trace): UIO[Nothing] = shutdownHook.await *> ZIO.interrupt + def awaitShutdown(implicit trace: Trace): UIO[Unit] = - shutdownHook.await + shutdownHook.await.foldCauseZIO(ZIO.unitZIOFn, ZIO.unitZIOFn) val capacity: Int = hub.capacity def isShutdown(implicit trace: Trace): UIO[Boolean] = ZIO.succeed(shutdownFlag.get) def publish(a: A)(implicit trace: Trace): UIO[Boolean] = ZIO.suspendSucceed { - if (shutdownFlag.get) ZIO.interrupt + if (shutdownFlag.get) interrupted else if (hub.publish(a)) { strategy.unsafeCompleteSubscribers(hub, subscribers) ZIO.succeed(true) @@ -145,7 +147,7 @@ object Hub { } def publishAll[A1 <: A](as: Iterable[A1])(implicit trace: Trace): UIO[Chunk[A1]] = ZIO.suspendSucceed { - if (shutdownFlag.get) ZIO.interrupt + if (shutdownFlag.get) interrupted else { val surplus = unsafePublishAll(hub, as) strategy.unsafeCompleteSubscribers(hub, subscribers) @@ -156,17 +158,21 @@ object Hub { } } } + private def shutdownUnsafe(cause: Cause[Nothing])(implicit trace: Trace, unsafe: Unsafe): UIO[Unit] = + if (shutdownFlag.compareAndSet(false, true) && shutdownHook.unsafe.failCause(cause)) { + scope.close(Exit.fail(cause)) *> strategy.shutdown + } else { + Exit.unit + } + + override def shutdownCause(cause: Cause[Nothing])(implicit trace: Trace): UIO[Unit] = + ZIO.suspendSucceed(shutdownUnsafe(cause)(trace, Unsafe)).uninterruptible def shutdown(implicit trace: Trace): UIO[Unit] = - ZIO.fiberIdWith { fiberId => - shutdownFlag.set(true) - ZIO - .whenZIODiscard(shutdownHook.succeedUnit) { - scope.close(Exit.interrupt(fiberId)) *> strategy.shutdown - } - }.uninterruptible + ZIO.fiberIdWith(fiberId => shutdownUnsafe(Cause.interrupt(fiberId))(trace, Unsafe)).uninterruptible + def size(implicit trace: Trace): UIO[Int] = ZIO.suspendSucceed { - if (shutdownFlag.get) ZIO.interrupt + if (shutdownFlag.get) interrupted else ZIO.succeed(hub.size()) } def subscribe(implicit trace: Trace): ZIO[Scope, Nothing, Dequeue[A]] = @@ -215,7 +221,7 @@ object Hub { ): Dequeue[A] = new Dequeue.Internal[A] { self => def awaitShutdown(implicit trace: Trace): UIO[Unit] = - shutdownHook.await + shutdownHook.await.foldCauseZIO(ZIO.unitZIOFn, ZIO.unitZIOFn) val capacity: Int = hub.capacity def isShutdown(implicit trace: Trace): UIO[Boolean] = @@ -226,59 +232,68 @@ object Hub { ZIO.succeed(Chunk.fromIterable(as)) def shutdown(implicit trace: Trace): UIO[Unit] = ZIO.fiberIdWith { fiberId => - shutdownFlag.set(true) - ZIO - .whenZIODiscard(shutdownHook.succeedUnit) { - ZIO.foreachParDiscard(unsafePollAll(pollers))(_.interruptAs(fiberId)) *> - ZIO.succeed { - subscribers.remove(subscription -> pollers) - subscription.unsubscribe() - strategy.unsafeOnHubEmptySpace(hub, subscribers) - } - } + shutdownUnsafe(Cause.interrupt(fiberId))(trace, Unsafe) + Exit.unit }.uninterruptible + + override def shutdownCause(cause: Cause[Nothing])(implicit trace: Trace): UIO[Unit] = + ZIO.succeed(shutdownUnsafe(cause)(trace, Unsafe)).uninterruptible + + private def interrupted(implicit trace: Trace): UIO[Nothing] = shutdownHook.await *> ZIO.interrupt + + private def shutdownUnsafe(cause: Cause[Nothing])(implicit trace: Trace, unsafe: Unsafe): Unit = + if (shutdownFlag.compareAndSet(false, true)) { + if (shutdownHook.unsafe.failCause(cause)) { + val _pollers = unsafePollAll(pollers) + _pollers.foreach(_.unsafe.failCause(cause)) + subscribers.remove(subscription -> pollers) + subscription.unsubscribe() + strategy.unsafeOnHubEmptySpace(hub, subscribers) + } + } + def size(implicit trace: Trace): UIO[Int] = ZIO.suspendSucceed { - if (shutdownFlag.get) ZIO.interrupt - else ZIO.succeed(subscription.size()) + if (shutdownFlag.get) interrupted + else Exit.succeed(subscription.size()) } def take(implicit trace: Trace): UIO[A] = ZIO.fiberIdWith { fiberId => - if (shutdownFlag.get) ZIO.interrupt + if (shutdownFlag.get) interrupted else { val empty = null.asInstanceOf[A] val message = if (pollers.isEmpty()) subscription.poll(empty) else empty message match { case null => - val promise = Promise.unsafe.make[Nothing, A](fiberId)(Unsafe.unsafe) + val promise = Promise.unsafe.make[Nothing, A](fiberId)(Unsafe) ZIO.suspendSucceed { pollers.offer(promise) subscribers.add(subscription -> pollers) strategy.unsafeCompletePollers(hub, subscribers, subscription, pollers) - if (shutdownFlag.get) ZIO.interrupt else promise.await + if (shutdownFlag.get) interrupted else promise.await }.onInterrupt(ZIO.succeed(unsafeRemove(pollers, promise))) case a => strategy.unsafeOnHubEmptySpace(hub, subscribers) - ZIO.succeed(a) + Exit.succeed(a) } } } def takeAll(implicit trace: Trace): ZIO[Any, Nothing, Chunk[A]] = ZIO.suspendSucceed { - if (shutdownFlag.get) ZIO.interrupt + if (shutdownFlag.get) interrupted else { val as = if (pollers.isEmpty()) unsafePollAll(subscription) else Chunk.empty strategy.unsafeOnHubEmptySpace(hub, subscribers) - ZIO.succeed(as) + Exit.succeed(as) } } def takeUpTo(max: Int)(implicit trace: Trace): ZIO[Any, Nothing, Chunk[A]] = ZIO.suspendSucceed { - if (shutdownFlag.get) ZIO.interrupt + if (shutdownFlag.get) interrupted else { val as = if (pollers.isEmpty()) unsafePollN(subscription, max) else Chunk.empty strategy.unsafeOnHubEmptySpace(hub, subscribers) - ZIO.succeed(as) + Exit.succeed(as) } } } @@ -305,6 +320,11 @@ object Hub { */ def shutdown(implicit trace: Trace): UIO[Unit] + /** + * Describes any finalization logic associated with this strategy. + */ + def shutdownCause(cause: Cause[Nothing])(implicit trace: Trace): UIO[Unit] = shutdown(trace) + /** * Describes how subscribers should signal to publishers waiting for space * to become available in the hub that space may be available. @@ -383,23 +403,30 @@ object Hub { isShutDown: AtomicBoolean )(implicit trace: Trace): UIO[Boolean] = ZIO.fiberIdWith { fiberId => - val promise = Promise.unsafe.make[Nothing, Boolean](fiberId)(Unsafe.unsafe) + val promise = Promise.unsafe.make[Nothing, Boolean](fiberId)(Unsafe) ZIO.suspendSucceed { unsafeOffer(as, promise) unsafeOnHubEmptySpace(hub, subscribers) unsafeCompleteSubscribers(hub, subscribers) - if (isShutDown.get) ZIO.interrupt else promise.await + if (isShutDown.get) promise.await *> ZIO.interrupt else promise.await }.onInterrupt(ZIO.succeed(unsafeRemove(promise))) } def shutdown(implicit trace: Trace): UIO[Unit] = - for { - fiberId <- ZIO.fiberId - publishers <- ZIO.succeed(unsafePollAll(publishers)) - _ <- ZIO.foreachParDiscard(publishers) { case (_, promise, last) => - if (last) promise.interruptAs(fiberId) else ZIO.unit - } - } yield () + ZIO.fiberIdWith { fiberId => + shutdownUnsafe(Cause.interrupt(fiberId))(trace, Unsafe) + Exit.unit + }.uninterruptible + + override def shutdownCause(cause: Cause[Nothing])(implicit trace: Trace): UIO[Unit] = + ZIO.succeed(shutdownUnsafe(cause)(trace, Unsafe)).uninterruptible + + private def shutdownUnsafe(cause: Cause[Nothing])(implicit trace: Trace, unsafe: Unsafe): Unit = { + val _publishers = unsafePollAll(publishers) + _publishers.foreach { case (_, promise, last) => + if (last) promise.unsafe.failCause(cause) + } + } def unsafeOnHubEmptySpace( hub: internal.Hub[A], @@ -463,6 +490,8 @@ object Hub { def shutdown(implicit trace: Trace): UIO[Unit] = ZIO.unit + override def shutdownCause(cause: Cause[Nothing])(implicit trace: Trace): UIO[Unit] = ZIO.unit + def unsafeOnHubEmptySpace( hub: internal.Hub[A], subscribers: Set[(internal.Hub.Subscription[A], MutableConcurrentQueue[Promise[Nothing, A]])] @@ -511,6 +540,8 @@ object Hub { def shutdown(implicit trace: Trace): UIO[Unit] = ZIO.unit + override def shutdownCause(cause: Cause[Nothing])(implicit trace: Trace): UIO[Unit] = ZIO.unit + def unsafeOnHubEmptySpace( hub: internal.Hub[A], subscribers: Set[(internal.Hub.Subscription[A], MutableConcurrentQueue[Promise[Nothing, A]])] diff --git a/core/shared/src/main/scala/zio/Queue.scala b/core/shared/src/main/scala/zio/Queue.scala index 0e7341eb2e32..fbf0c58f0c3a 100644 --- a/core/shared/src/main/scala/zio/Queue.scala +++ b/core/shared/src/main/scala/zio/Queue.scala @@ -39,6 +39,12 @@ sealed abstract class Queue[A] extends Dequeue.Internal[A] with Enqueue.Internal */ override final def isFull(implicit trace: Trace): UIO[Boolean] = size.map(_ >= capacity) + + /** + * Shuts down the queue with a specific Cause, either `Die` or `Interrupt`. + * Future calls to `offer*` and `take*` fail immediately. + */ + override def shutdownCause(cause: Cause[Nothing])(implicit trace: Trace): UIO[Unit] = shutdown(trace) } object Queue extends QueuePlatformSpecific { @@ -62,7 +68,7 @@ object Queue extends QueuePlatformSpecific { * `UIO[Queue[A]]` */ def bounded[A](requestedCapacity: => Int)(implicit trace: Trace): UIO[Queue[A]] = - ZIO.fiberId.map(unsafe.bounded(requestedCapacity, _)(Unsafe.unsafe)) + ZIO.fiberIdWith(id => Exit.succeed(unsafe.bounded(requestedCapacity, id)(Unsafe.unsafe))) /** * Makes a new bounded queue with the dropping strategy. When the capacity of @@ -81,7 +87,7 @@ object Queue extends QueuePlatformSpecific { * `UIO[Queue[A]]` */ def dropping[A](requestedCapacity: => Int)(implicit trace: Trace): UIO[Queue[A]] = - ZIO.fiberId.map(unsafe.dropping(requestedCapacity, _)(Unsafe.unsafe)) + ZIO.fiberIdWith(id => Exit.succeed(unsafe.dropping(requestedCapacity, id)(Unsafe.unsafe))) /** * Makes a new bounded queue with sliding strategy. When the capacity of the @@ -101,7 +107,7 @@ object Queue extends QueuePlatformSpecific { * `UIO[Queue[A]]` */ def sliding[A](requestedCapacity: => Int)(implicit trace: Trace): UIO[Queue[A]] = - ZIO.fiberId.map(unsafe.sliding(requestedCapacity, _)(Unsafe.unsafe)) + ZIO.fiberIdWith(id => Exit.succeed(unsafe.sliding(requestedCapacity, id)(Unsafe.unsafe))) /** * Makes a new unbounded queue. @@ -112,7 +118,7 @@ object Queue extends QueuePlatformSpecific { * `UIO[Queue[A]]` */ def unbounded[A](implicit trace: Trace): UIO[Queue[A]] = - ZIO.fiberId.map(unsafe.unbounded(_)(Unsafe.unsafe)) + ZIO.fiberIdWith(id => Exit.succeed(unsafe.unbounded(id)(Unsafe.unsafe))) object unsafe { @@ -161,6 +167,8 @@ object Queue extends QueuePlatformSpecific { strategy: Strategy[A] ) extends Queue[A] { + private def interrupted(implicit trace: Trace): UIO[Nothing] = shutdownHook.await *> ZIO.interrupt + private def removeTaker(taker: Promise[Nothing, A])(implicit trace: Trace): UIO[Unit] = ZIO.succeed(takers.remove(taker)) @@ -168,7 +176,7 @@ object Queue extends QueuePlatformSpecific { override def offer(a: A)(implicit trace: Trace): UIO[Boolean] = ZIO.suspendSucceed { - if (shutdownFlag.get) ZIO.interrupt + if (shutdownFlag.get) interrupted else { val noRemaining = if (queue.isEmpty()) { @@ -197,7 +205,7 @@ object Queue extends QueuePlatformSpecific { override def offerAll[A1 <: A](as: Iterable[A1])(implicit trace: Trace): UIO[Chunk[A1]] = ZIO.suspendSucceed { - if (shutdownFlag.get) ZIO.interrupt + if (shutdownFlag.get) interrupted else { val pTakers = if (queue.isEmpty()) unsafePollN(takers, as.size) else Chunk.empty val (forTakers, remaining) = as.splitAt(pTakers.size) @@ -221,35 +229,40 @@ object Queue extends QueuePlatformSpecific { } } - override def awaitShutdown(implicit trace: Trace): UIO[Unit] = shutdownHook.await + override def awaitShutdown(implicit trace: Trace): UIO[Unit] = + shutdownHook.await.foldCauseZIO(ZIO.unitZIOFn, ZIO.unitZIOFn) override def size(implicit trace: Trace): UIO[Int] = ZIO.suspendSucceed { - if (shutdownFlag.get) - ZIO.interrupt + if (shutdownFlag.get) interrupted else Exit.succeed(queue.size() - takers.size() + strategy.surplusSize) } override def shutdown(implicit trace: Trace): UIO[Unit] = ZIO.fiberIdWith { fiberId => - if (shutdownFlag.compareAndSet(false, true)) { - implicit val unsafe: Unsafe = Unsafe - shutdownHook.unsafe.succeedUnit - val it = unsafePollAll(takers).iterator - while (it.hasNext) { - it.next().unsafe.interruptAs(fiberId) - } - strategy.shutdown(fiberId) - } + shutdownUnsafe(Cause.interrupt(fiberId))(trace, Unsafe) Exit.unit }.uninterruptible + override def shutdownCause(cause: Cause[Nothing])(implicit trace: Trace): UIO[Unit] = + ZIO.succeed(shutdownUnsafe(cause)(trace, Unsafe)).uninterruptible + + private def shutdownUnsafe(cause: Cause[Nothing])(implicit trace: Trace, unsafe: Unsafe): Unit = + if (shutdownFlag.compareAndSet(false, true)) { + shutdownHook.unsafe.failCause(cause) + val it = unsafePollAll(takers).iterator + while (it.hasNext) { + it.next().unsafe.failCause(cause) + } + strategy.shutdown(cause) + } + override def isShutdown(implicit trace: Trace): UIO[Boolean] = ZIO.succeed(shutdownFlag.get) override def take(implicit trace: Trace): UIO[A] = ZIO.fiberIdWith { fiberId => - if (shutdownFlag.get) ZIO.interrupt + if (shutdownFlag.get) interrupted else { queue.poll(null.asInstanceOf[A]) match { case null => @@ -262,7 +275,7 @@ object Queue extends QueuePlatformSpecific { ZIO.suspendSucceed { takers.offer(p) strategy.unsafeCompleteTakers(queue, takers) - if (shutdownFlag.get) ZIO.interrupt else p.await + if (shutdownFlag.get) interrupted else p.await }.onInterrupt(removeTaker(p)) case item => @@ -274,8 +287,7 @@ object Queue extends QueuePlatformSpecific { override def takeAll(implicit trace: Trace): UIO[Chunk[A]] = ZIO.suspendSucceed { - if (shutdownFlag.get) - ZIO.interrupt + if (shutdownFlag.get) interrupted else { val as = unsafePollAll(queue) if (!as.isEmpty) { @@ -289,8 +301,7 @@ object Queue extends QueuePlatformSpecific { override def takeUpTo(max: Int)(implicit trace: Trace): UIO[Chunk[A]] = ZIO.suspendSucceed { - if (shutdownFlag.get) - ZIO.interrupt + if (shutdownFlag.get) interrupted else { val as = unsafePollN(queue, max) if (!as.isEmpty) { @@ -304,8 +315,7 @@ object Queue extends QueuePlatformSpecific { override def poll(implicit trace: Trace): UIO[Option[A]] = ZIO.suspendSucceed { - if (shutdownFlag.get) - ZIO.interrupt + if (shutdownFlag.get) interrupted else { queue.poll(null.asInstanceOf[A]) match { case null => Exit.none @@ -334,7 +344,8 @@ object Queue extends QueuePlatformSpecific { def surplusSize: Int - def shutdown(fiberId: FiberId)(implicit trace: Trace, unsafe: Unsafe): Unit + def shutdown(fiberId: FiberId)(implicit trace: Trace, unsafe: Unsafe): Unit = shutdown(Cause.interrupt(fiberId)) + def shutdown(cause: Cause[Nothing])(implicit trace: Trace, unsafe: Unsafe): Unit @tailrec final def unsafeCompleteTakers( @@ -398,7 +409,7 @@ object Queue extends QueuePlatformSpecific { unsafeOffer(as, p) unsafeOnQueueEmptySpace(queue, takers) unsafeCompleteTakers(queue, takers) - if (isShutdown.get) ZIO.interrupt else p.await + if (isShutdown.get) p.await *> ZIO.interrupt else p.await }.onInterrupt(ZIO.succeed(unsafeRemove(p))) } @@ -452,11 +463,11 @@ object Queue extends QueuePlatformSpecific { def surplusSize: Int = putters.size() - def shutdown(fiberId: FiberId)(implicit trace: Trace, unsafe: Unsafe): Unit = { + def shutdown(cause: Cause[Nothing])(implicit trace: Trace, unsafe: Unsafe): Unit = { var next = putters.poll() while (next ne null) { val (_, promise, isLast) = next - if (isLast) promise.unsafe.interruptAs(fiberId) + if (isLast) promise.unsafe.failCause(cause) next = putters.poll() } } @@ -478,7 +489,7 @@ object Queue extends QueuePlatformSpecific { def surplusSize: Int = 0 - def shutdown(fiberId: FiberId)(implicit trace: Trace, unsafe: Unsafe): Unit = () + def shutdown(cause: Cause[Nothing])(implicit trace: Trace, unsafe: Unsafe): Unit = () } final case class Sliding[A]() extends Strategy[A] { @@ -519,7 +530,7 @@ object Queue extends QueuePlatformSpecific { def surplusSize: Int = 0 - def shutdown(fiberId: FiberId)(implicit trace: Trace, unsafe: Unsafe): Unit = () + def shutdown(cause: Cause[Nothing])(implicit trace: Trace, unsafe: Unsafe): Unit = () } } diff --git a/core/shared/src/main/scala/zio/ZIO.scala b/core/shared/src/main/scala/zio/ZIO.scala index 17d9064b75f4..0c63358b5ca6 100644 --- a/core/shared/src/main/scala/zio/ZIO.scala +++ b/core/shared/src/main/scala/zio/ZIO.scala @@ -5428,8 +5428,8 @@ object ZIO extends ZIOCompanionPlatformSpecific with ZIOCompanionVersionSpecific ) } - private[zio] val unitFn: Any => Unit = (_: Any) => () - private val unitZIOFn: Any => UIO[Unit] = (_: Any) => Exit.unit + private[zio] val unitFn: Any => Unit = (_: Any) => () + private[zio] val unitZIOFn: Any => UIO[Unit] = (_: Any) => Exit.unit implicit final class ZIOAutoCloseableOps[R, E, A <: AutoCloseable](private val io: ZIO[R, E, A]) extends AnyVal { diff --git a/streams/jvm/src/main/scala/zio/stream/platform.scala b/streams/jvm/src/main/scala/zio/stream/platform.scala index 04db7ce85ae9..4bdecef8ed99 100644 --- a/streams/jvm/src/main/scala/zio/stream/platform.scala +++ b/streams/jvm/src/main/scala/zio/stream/platform.scala @@ -85,7 +85,7 @@ private[stream] trait ZStreamPlatformSpecificConstructors { } } yield { eitherStream match { - case Right(value) => ZStream.unwrap(output.shutdown as value) + case Right(value) => ZStream.unwrap(output.shutdown.as(value)) case Left(canceler) => lazy val loop: ZChannel[Any, Any, Any, Any, E, Chunk[A], Unit] = ZChannel.unwrap( diff --git a/streams/shared/src/main/scala/zio/stream/ZStream.scala b/streams/shared/src/main/scala/zio/stream/ZStream.scala index e02aab0f1762..53f04bc45d0a 100644 --- a/streams/shared/src/main/scala/zio/stream/ZStream.scala +++ b/streams/shared/src/main/scala/zio/stream/ZStream.scala @@ -6200,6 +6200,8 @@ object ZStream extends ZStreamPlatformSpecificConstructors { dequeue.isShutdown def shutdown(implicit trace: Trace): UIO[Unit] = dequeue.shutdown + override def shutdownCause(cause: Cause[Nothing])(implicit trace: Trace): UIO[Unit] = + dequeue.shutdownCause(cause) def size(implicit trace: Trace): UIO[Int] = dequeue.size def take(implicit trace: Trace): UIO[B] =