diff --git a/build.sbt b/build.sbt index 78733223906e..7276efbd78f2 100644 --- a/build.sbt +++ b/build.sbt @@ -42,7 +42,7 @@ addCommandAlias( ) addCommandAlias( "testJVM", - ";coreTestsJVM/test;stacktracerJVM/test;streamsTestsJVM/test;testTestsJVM/test;testMagnoliaTestsJVM/test;testRefinedJVM/test;testRunnerJVM/test;testRunnerJVM/Test/run;examplesJVM/Test/compile;benchmarks/Test/compile;macrosTestsJVM/test;testJunitRunnerTests/test;concurrentJVM/test;managedTestsJVM/test" + ";coreTestsJVM/test;stacktracerJVM/test;streamsTestsJVM/test;testTestsJVM/test;testMagnoliaTestsJVM/test;testRefinedJVM/test;testRunnerJVM/test;testRunnerJVM/Test/run;examplesJVM/Test/compile;benchmarks/Test/compile;macrosTestsJVM/test;concurrentJVM/test;managedTestsJVM/test;set ThisBuild/isSnapshot:=true;testJunitRunnerTests/test;testJunitEngineTests/test;reload" ) addCommandAlias( "testJVMNoBenchmarks", @@ -97,6 +97,8 @@ lazy val rootJVM213 = project scalafixTests, testJunitRunner, testJunitRunnerTests, + testJunitEngine, + testJunitEngineTests, testMagnolia.jvm, testMagnoliaTests.jvm, testRefined.jvm, @@ -111,7 +113,9 @@ lazy val rootJVM3 = project .aggregate( List[ProjectReference]( testJunitRunner, + testJunitEngine, // testJunitRunnerTests, TODO: fix test + testJunitEngineTests, testMagnolia.jvm, testMagnoliaTests.jvm, testRefined.jvm, @@ -161,7 +165,9 @@ lazy val root213 = project benchmarks, scalafixTests, testJunitRunner, - testJunitRunnerTests + testJunitEngine, + testJunitRunnerTests, + testJunitEngineTests )) * ) @@ -180,7 +186,9 @@ lazy val root3 = project ).flatMap(p => List[ProjectReference](p.jvm, p.js)) ++ List[ProjectReference]( testJunitRunner, - testJunitRunnerTests + testJunitEngine, + testJunitRunnerTests, + testJunitEngineTests )) * ) @@ -494,7 +502,7 @@ lazy val testScalaCheck = crossProject(JSPlatform, JVMPlatform, NativePlatform) .settings(crossProjectSettings) .settings( libraryDependencies ++= Seq( - ("org.scalacheck" %%% "scalacheck" % "1.18.0") + "org.scalacheck" %%% "scalacheck" % "1.18.0" ) ) .jsSettings(jsSettings) @@ -586,6 +594,62 @@ lazy val testJunitRunnerTests = project.module .value ) +lazy val testJunitEngine = project.module + .in(file("test-junit-engine")) + .settings(stdSettings("zio-test-junit-engine")) + .settings( + libraryDependencies ++= Seq( + "org.junit.platform" % "junit-platform-engine" % "1.11.0", + "org.scala-lang.modules" %% "scala-collection-compat" % "2.12.0" + ) + ) + .dependsOn(tests.jvm) + +lazy val testJunitEngineTests = project.module + .in(file("test-junit-engine-tests")) + .settings(stdSettings("test-junit-engine-tests")) + .settings(Test / fork := true) + .settings(Test / javaOptions ++= { + Seq( + s"-Dproject.dir=${baseDirectory.value}", + s"-Dproject.version=${version.value}", + s"-Dscala.version=${scalaVersion.value}", + s"-Dscala.compat.version=${scalaBinaryVersion.value}" + ) + }) + .settings(publish / skip := true) + .settings( + libraryDependencies ++= Seq( + "junit" % "junit" % "4.13.2" % Test, + "org.scala-lang.modules" %% "scala-xml" % "2.2.0" % Test, + // required to run embedded maven in the tests + "org.apache.maven" % "maven-embedder" % "3.9.6" % Test, + "org.apache.maven" % "maven-compat" % "3.9.6" % Test, + "com.google.inject" % "guice" % "4.0" % Test, + "org.eclipse.sisu" % "org.eclipse.sisu.inject" % "0.3.5" % Test, + "org.apache.maven.resolver" % "maven-resolver-connector-basic" % "1.9.18" % Test, + "org.apache.maven.resolver" % "maven-resolver-transport-http" % "1.9.18" % Test, + "org.codehaus.plexus" % "plexus-component-annotations" % "2.2.0" % Test, + "org.slf4j" % "slf4j-simple" % "1.7.36" % Test + ) + ) + .dependsOn( + tests.jvm, + testRunner.jvm + ) + // publish locally so embedded maven runs against locally compiled zio + .settings( + Test / Keys.test := + (Test / Keys.test) + .dependsOn(testJunitEngine / publishM2) + .dependsOn(tests.jvm / publishM2) + .dependsOn(core.jvm / publishM2) + .dependsOn(internalMacros.jvm / publishM2) + .dependsOn(streams.jvm / publishM2) + .dependsOn(stacktracer.jvm / publishM2) + .value + ) + lazy val concurrent = crossProject(JSPlatform, JVMPlatform, NativePlatform) .in(file("concurrent")) .dependsOn(core) @@ -881,6 +945,7 @@ lazy val docs = project.module concurrent.jvm, tests.jvm, testJunitRunner, + testJunitEngine, testMagnolia.jvm, testRefined.jvm, testScalaCheck.jvm, diff --git a/core-tests/shared/src/test/scala/zio/FiberRuntimeSpec.scala b/core-tests/shared/src/test/scala/zio/FiberRuntimeSpec.scala new file mode 100644 index 000000000000..be5d6972efcc --- /dev/null +++ b/core-tests/shared/src/test/scala/zio/FiberRuntimeSpec.scala @@ -0,0 +1,96 @@ +package zio + +import zio.internal.FiberScope +import zio.test._ + +import java.util.concurrent.atomic.AtomicInteger + +object FiberRuntimeSpec extends ZIOBaseSpec { + private implicit val unsafe: Unsafe = Unsafe.unsafe + + def spec = suite("FiberRuntimeSpec")( + suite("whileLoop")( + test("auto-yields every 10280 operations when no other yielding is performed") { + ZIO.suspendSucceed { + val nIters = 50000 + val nOpsPerYield = 1024 * 10 + val nOps = new AtomicInteger(0) + val latch = Promise.unsafe.make[Nothing, Unit](FiberId.None) + val supervisor = new YieldTrackingSupervisor(latch, nOps) + val f = ZIO.whileLoop(nOps.getAndIncrement() < nIters)(Exit.unit)(_ => ()) + ZIO + .withFiberRuntime[Any, Nothing, Unit] { (parentFib, status) => + val fiber = ZIO.unsafe.makeChildFiber(Trace.empty, f, parentFib, status.runtimeFlags, FiberScope.global) + fiber.setFiberRef(FiberRef.currentSupervisor, supervisor) + fiber.startConcurrently(f) + latch.await + } + .as { + val yieldedAt = supervisor.yieldedAt + assertTrue( + yieldedAt == List( + nIters + 1, + nOpsPerYield * 4 - 3, + nOpsPerYield * 3 - 2, + nOpsPerYield * 2 - 1, + nOpsPerYield + ) + ) + } + } + }, + test("doesn't auto-yield when effect itself yields") { + ZIO.suspendSucceed { + val nIters = 50000 + val nOps = new AtomicInteger(0) + val latch = Promise.unsafe.make[Nothing, Unit](FiberId.None) + val supervisor = new YieldTrackingSupervisor(latch, nOps) + val f = + ZIO.whileLoop(nOps.getAndIncrement() < nIters)(ZIO.when(nOps.get() % 10000 == 0)(ZIO.yieldNow))(_ => ()) + ZIO + .withFiberRuntime[Any, Nothing, Unit] { (parentFib, status) => + val fiber = ZIO.unsafe.makeChildFiber(Trace.empty, f, parentFib, status.runtimeFlags, FiberScope.global) + fiber.setFiberRef(FiberRef.currentSupervisor, supervisor) + fiber.startConcurrently(f) + latch.await + } + .as { + val yieldedAt = supervisor.yieldedAt + assertTrue( + yieldedAt == List(nIters + 1, 50000, 40000, 30000, 20000, 10000) + ) + } + } + } + ) + ) + + private final class YieldTrackingSupervisor( + latch: Promise[Nothing, Unit], + nOps: AtomicInteger + ) extends Supervisor[Unit] { + @volatile var yieldedAt = List.empty[Int] + @volatile private var onEndCalled = false + + def value(implicit trace: Trace): UIO[Unit] = ZIO.unit + + def onStart[R, E, A]( + environment: ZEnvironment[R], + effect: ZIO[R, E, A], + parent: Option[Fiber.Runtime[Any, Any]], + fiber: Fiber.Runtime[E, A] + )(implicit unsafe: Unsafe): Unit = () + + override def onEnd[R, E, A](value: Exit[E, A], fiber: Fiber.Runtime[E, A])(implicit unsafe: Unsafe): Unit = { + onEndCalled = true + () + } + + override def onSuspend[E, A](fiber: Fiber.Runtime[E, A])(implicit unsafe: Unsafe): Unit = { + yieldedAt ::= nOps.get() + if (onEndCalled) latch.unsafe.done(Exit.unit) // onEnd gets called before onSuspend + () + } + } + +} diff --git a/core-tests/shared/src/test/scala/zio/ZIOSpec.scala b/core-tests/shared/src/test/scala/zio/ZIOSpec.scala index 39e445b8a8d2..87f421b7732d 100644 --- a/core-tests/shared/src/test/scala/zio/ZIOSpec.scala +++ b/core-tests/shared/src/test/scala/zio/ZIOSpec.scala @@ -2,9 +2,9 @@ package zio import zio.Cause._ import zio.LatchOps._ -import zio.internal.Platform +import zio.internal.{FiberRuntime, Platform} import zio.test.Assertion._ -import zio.test.TestAspect.{exceptJS, flaky, forked, jvmOnly, nonFlaky, scala2Only} +import zio.test.TestAspect.{exceptJS, flaky, forked, jvmOnly, nonFlaky, scala2Only, timeout, withLiveClock} import zio.test._ import scala.annotation.tailrec @@ -1620,7 +1620,25 @@ object ZIOSpec extends ZIOBaseSpec { _ <- latch.await exit <- fiber.interrupt.map(_.mapErrorCauseExit((cause: Cause[Nothing]) => cause.untraced)) } yield assert(exit)(isInterrupted) - } @@ exceptJS(nonFlaky) + } @@ exceptJS(nonFlaky), + test("child fibers can be created on interrupted parents within unininterruptible regions") { + for { + latch1 <- Promise.make[Nothing, Unit] + isInterupted <- Promise.make[Nothing, Boolean] + parent <- (latch1.succeed(()) *> ZIO.never).onInterrupt { + for { + latch2 <- Promise.make[Nothing, Unit] + child <- latch2.await.fork + _ <- ZIO.sleep(5.millis) + _ <- isInterupted.done(Exit.succeed(child.asInstanceOf[FiberRuntime[?, ?]].isInterrupted())) + _ <- latch2.succeed(()) + } yield () + }.forkDaemon + _ <- latch1.await + _ <- parent.interrupt + res <- isInterupted.await + } yield assertTrue(!res) + } @@ withLiveClock @@ exceptJS(nonFlaky) @@ timeout(20.seconds) @@ zioTag(interruption) ), suite("negate")( test("on true returns false") { diff --git a/core-tests/shared/src/test/scala/zio/ZKeyedPoolSpec.scala b/core-tests/shared/src/test/scala/zio/ZKeyedPoolSpec.scala index 4b27b679af21..dd10417b1618 100644 --- a/core-tests/shared/src/test/scala/zio/ZKeyedPoolSpec.scala +++ b/core-tests/shared/src/test/scala/zio/ZKeyedPoolSpec.scala @@ -40,6 +40,24 @@ object ZKeyedPoolSpec extends ZIOBaseSpec { _ <- TestClock.adjust((15 * 400).millis) _ <- fiber.join } yield assertCompletes - } + }, + test("invalidate does not cause memory leaks (i9306)") { + ZKeyedPool + .make[String, Any, Nothing, Array[Int]]((_: String) => ZIO.succeed(Array.ofDim[Int](1000000)), size = 1) + .flatMap { pool => + ZIO + .foreachDiscard(1 to 10000)(_ => + ZIO.scoped { + for { + item1 <- pool.get("key0") + _ <- ZIO.foreachDiscard(1 to 5)(i => pool.get(s"key$i")) + _ <- pool.invalidate(item1) + } yield () + } + ) + } + .as(assertCompletes) + } @@ jvmOnly ) @@ exceptJS + } diff --git a/core-tests/shared/src/test/scala/zio/ZPoolSpec.scala b/core-tests/shared/src/test/scala/zio/ZPoolSpec.scala index 2ced091b5003..a2c8a69a28c8 100644 --- a/core-tests/shared/src/test/scala/zio/ZPoolSpec.scala +++ b/core-tests/shared/src/test/scala/zio/ZPoolSpec.scala @@ -206,6 +206,12 @@ object ZPoolSpec extends ZIOBaseSpec { _ <- ZIO.scoped(ZPool.make(incCounter <* ZIO.fail("oh no"), 10)) _ <- latch.await } yield assertCompletes - } @@ exceptJS(nonFlaky(1000)) + } @@ exceptJS(nonFlaky(1000)) + + test("calling invalidate with items not in the pool doesn't cause memory leaks") { + for { + pool <- ZPool.make(ZIO.succeed(Array.empty[Int]), 1) + _ <- ZIO.foreachDiscard(1 to 1000)(_ => pool.invalidate(Array.ofDim[Int](10000000))) + } yield assertCompletes + } @@ jvmOnly }.provideLayer(Scope.default) @@ timeout(30.seconds) } diff --git a/core/js/src/main/scala/zio/SystemPlatformSpecific.scala b/core/js/src/main/scala/zio/SystemPlatformSpecific.scala new file mode 100644 index 000000000000..045c39c05f09 --- /dev/null +++ b/core/js/src/main/scala/zio/SystemPlatformSpecific.scala @@ -0,0 +1,44 @@ +/* + * Copyright 2017-2024 John A. De Goes and the ZIO Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package zio + +import zio.internal.stacktracer.Tracer +import zio.stacktracer.TracingImplicits.disableAutoTrace + +import scala.collection.mutable +import scala.scalajs.js +import scala.scalajs.js.Dynamic.global + +private[zio] trait SystemPlatformSpecific { self: System.type => + + private[zio] val environmentProvider = new EnvironmentProvider { + private val envMap: mutable.Map[String, String] = { + if (js.typeOf(global.process) != "undefined" && js.typeOf(global.process.env) != "undefined") { + global.process.env.asInstanceOf[js.Dictionary[String]] + } else { + mutable.Map.empty + } + } + + override def env(variable: String): Option[String] = + envMap.get(variable) + + override def envs: Map[String, String] = + envMap.toMap + } + +} diff --git a/core/jvm-native/src/main/scala/zio/SystemPlatformSpecific.scala b/core/jvm-native/src/main/scala/zio/SystemPlatformSpecific.scala new file mode 100644 index 000000000000..9ac989aebf8d --- /dev/null +++ b/core/jvm-native/src/main/scala/zio/SystemPlatformSpecific.scala @@ -0,0 +1,37 @@ +/* + * Copyright 2017-2024 John A. De Goes and the ZIO Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package zio + +import zio.internal.stacktracer.Tracer +import zio.stacktracer.TracingImplicits.disableAutoTrace + +import java.lang.{System => JSystem} +import scala.annotation.nowarn +import scala.collection.JavaConverters._ + +private[zio] trait SystemPlatformSpecific { self: System.type => + + private[zio] val environmentProvider = new EnvironmentProvider { + override def env(variable: String): Option[String] = + Option(JSystem.getenv(variable)) + + @nowarn("msg=JavaConverters") + override def envs: Map[String, String] = + JSystem.getenv().asScala.toMap + } + +} diff --git a/core/shared/src/main/scala-2/zio/internal/macros/LayerMacroUtils.scala b/core/shared/src/main/scala-2/zio/internal/macros/LayerMacroUtils.scala index 755421ab643b..d7100cd451fb 100644 --- a/core/shared/src/main/scala-2/zio/internal/macros/LayerMacroUtils.scala +++ b/core/shared/src/main/scala-2/zio/internal/macros/LayerMacroUtils.scala @@ -43,40 +43,35 @@ private[zio] trait LayerMacroUtils { provideMethod: ProvideMethod ): Expr[ZLayer[R0, E, R]] = { verifyLayers(layers) - val remainderTypes = getRequirements[R0] - val targetTypes = getRequirements[R] - val debug = typeOf[ZLayer.Debug.type].termSymbol - var usesEnvironment = false - val trace = c.freshName(TermName("trace")) - val compose = c.freshName(TermName("compose")) - + val debug = typeOf[ZLayer.Debug.type].termSymbol val debugMap: PartialFunction[LayerExpr, ZLayer.Debug] = { case Expr(q"$prefix.tree") if prefix.symbol == debug => ZLayer.Debug.Tree case Expr(q"$prefix.mermaid") if prefix.symbol == debug => ZLayer.Debug.Mermaid } + var usesEnvironment = false + var usesCompose = false + def typeToNode(tpe: Type): Node[Type, LayerExpr] = { usesEnvironment = true - Node(Nil, List(tpe), c.Expr[ZLayer[_, _, _]](q"${reify(ZLayer)}.environment[$tpe]($trace)")) + Node(Nil, List(tpe), c.Expr(q"${reify(ZLayer)}.environment[$tpe]")) } def buildFinalTree(tree: LayerTree[LayerExpr]): LayerExpr = { - val memoList: List[(LayerExpr, LayerExpr)] = tree.toList.map { node => - val termName = c.freshName(TermName("layer")) - node -> c.Expr[ZLayer[_, _, _]](q"$termName") - } - + val compose = c.freshName(TermName("compose")) + val memoList: List[(LayerExpr, LayerExpr)] = + tree.toList.map(_ -> c.Expr[ZLayer[_, _, _]](q"${c.freshName(TermName("layer"))}")) val definitions = memoList.map { case (expr, memoizedNode) => q"val ${TermName(memoizedNode.tree.toString)} = $expr" } - var usesCompose = false - val memoMap = memoList.toMap - val layerSym = typeOf[ZLayer[_, _, _]].typeSymbol val layerExpr = tree.fold[LayerExpr]( z = reify(ZLayer.unit), - value = memoMap, - composeH = (lhs, rhs) => c.Expr(q"$lhs ++ $rhs"), + value = memoList.toMap, + composeH = { + case (lhs, Expr(rhs: Ident)) => c.Expr(q"$lhs ++ $rhs") + case (lhs, rhs) => c.Expr(q"$lhs +!+ $rhs") + }, composeV = (lhs, rhs) => { usesCompose = true c.Expr(q"$compose($lhs, $rhs)") @@ -84,20 +79,19 @@ private[zio] trait LayerMacroUtils { ) val traceVal = if (usesEnvironment || usesCompose) { - List(q"val $trace: ${typeOf[Trace]} = ${reify(Tracer)}.newTrace") + val trace = c.freshName(TermName("trace")) + List(q"implicit val $trace: ${typeOf[Trace]} = ${reify(Tracer)}.newTrace") } else { Nil } val composeDef = if (usesCompose) { - val R = c.freshName(TypeName("R")) - val E = c.freshName(TypeName("E")) - val O1 = c.freshName(TypeName("O1")) - val O2 = c.freshName(TypeName("O2")) - List(q""" - def $compose[$R, $E, $O1, $O2](lhs: $layerSym[$R, $E, $O1], rhs: $layerSym[$O1, $E, $O2]) = - lhs.to(rhs)($trace) - """) + val ZLayer = typeOf[ZLayer[_, _, _]].typeSymbol + val R = c.freshName(TypeName("R")) + val E = c.freshName(TypeName("E")) + val O1 = c.freshName(TypeName("O1")) + val O2 = c.freshName(TypeName("O2")) + List(q"def $compose[$R, $E, $O1, $O2](lhs: $ZLayer[$R, $E, $O1], rhs: $ZLayer[$O1, $E, $O2]) = lhs >>> rhs") } else { Nil } @@ -111,8 +105,8 @@ private[zio] trait LayerMacroUtils { } val builder = LayerBuilder[Type, LayerExpr]( - target0 = targetTypes, - remainder = remainderTypes, + target0 = getRequirements[R], + remainder = getRequirements[R0], providedLayers0 = layers.toList, layerToDebug = debugMap, sideEffectType = definitions.UnitTpe, diff --git a/core/shared/src/main/scala-3/zio/internal/macros/LayerMacroUtils.scala b/core/shared/src/main/scala-3/zio/internal/macros/LayerMacroUtils.scala index 4d84dd88ccc6..92af358887a6 100644 --- a/core/shared/src/main/scala-3/zio/internal/macros/LayerMacroUtils.scala +++ b/core/shared/src/main/scala-3/zio/internal/macros/LayerMacroUtils.scala @@ -42,31 +42,28 @@ private[zio] object LayerMacroUtils { loop(TypeRepr.of[T]) } - val targetTypes = getRequirements[R] - val remainderTypes = getRequirements[R0] val layerToDebug: PartialFunction[LayerExpr[E], ZLayer.Debug] = { case '{ ZLayer.Debug.tree } => ZLayer.Debug.Tree case '{ ZLayer.Debug.mermaid } => ZLayer.Debug.Mermaid } '{ - val trace: Trace = Tracer.newTrace + val trace = Tracer.newTrace + given Trace = trace ${ def typeToNode(tpe: TypeRepr): Node[TypeRepr, LayerExpr[E]] = - Node(Nil, List(tpe), tpe.asType match { case '[t] => '{ ZLayer.environment[t](trace) } }) + Node(Nil, List(tpe), tpe.asType match { case '[t] => '{ ZLayer.environment[t] } }) def composeH(lhs: LayerExpr[E], rhs: LayerExpr[E]): LayerExpr[E] = lhs match { case '{ $lhs: ZLayer[i, E, o] } => rhs match { case '{ $rhs: ZLayer[i2, E, o2] } => - val tag = Expr - .summon[EnvironmentTag[o2]] - .getOrElse( - report.errorAndAbort(s"Cannot find EnvironmentTag[${TypeRepr.of[o2].show}] in implicit scope") - ) - '{ $lhs.++($rhs)($tag) } + rhs.asTerm match { + case _: Ident => '{ $lhs.and($rhs)(summonInline) } + case _ => '{ $lhs +!+ $rhs } + } } } @@ -75,7 +72,7 @@ private[zio] object LayerMacroUtils { case '{ $lhs: ZLayer[i, E, o] } => rhs match { case '{ $rhs: ZLayer[`o`, E, o2] } => - '{ composeLayer($lhs, $rhs)(using trace) } + '{ composeLayer($lhs, $rhs) } } } @@ -90,8 +87,8 @@ private[zio] object LayerMacroUtils { } val builder = LayerBuilder[TypeRepr, LayerExpr[E]]( - target0 = targetTypes, - remainder = remainderTypes, + target0 = getRequirements[R], + remainder = getRequirements[R0], providedLayers0 = layers.toList, layerToDebug = layerToDebug, typeEquals = _ <:< _, diff --git a/core/shared/src/main/scala/zio/System.scala b/core/shared/src/main/scala/zio/System.scala index 78cf11179a48..809676a0b493 100644 --- a/core/shared/src/main/scala/zio/System.scala +++ b/core/shared/src/main/scala/zio/System.scala @@ -101,7 +101,7 @@ trait System extends Serializable { self => } } -object System extends Serializable { +object System extends SystemPlatformSpecific { val tag: Tag[System] = Tag[System] @@ -140,7 +140,7 @@ object System extends Serializable { @transient override val unsafe: UnsafeAPI = new UnsafeAPI { override def env(variable: String)(implicit unsafe: Unsafe): Option[String] = - Option(JSystem.getenv(variable)) + environmentProvider.env(variable) override def envOrElse(variable: String, alt: => String)(implicit unsafe: Unsafe): String = envOrElseWith(variable, alt)(env) @@ -148,9 +148,8 @@ object System extends Serializable { override def envOrOption(variable: String, alt: => Option[String])(implicit unsafe: Unsafe): Option[String] = envOrOptionWith(variable, alt)(env) - @nowarn("msg=JavaConverters") override def envs()(implicit unsafe: Unsafe): Map[String, String] = - JSystem.getenv.asScala.toMap + environmentProvider.envs override def lineSeparator()(implicit unsafe: Unsafe): String = JSystem.lineSeparator @@ -172,6 +171,11 @@ object System extends Serializable { } } + private[zio] trait EnvironmentProvider { + def env(variable: String): Option[String] + def envs: Map[String, String] + } + private[zio] def envOrElseWith(variable: String, alt: => String)(env: String => Option[String]): String = env(variable).getOrElse(alt) diff --git a/core/shared/src/main/scala/zio/VersionSpecific.scala b/core/shared/src/main/scala/zio/VersionSpecific.scala index a3d981cb21b9..ec5e18048158 100644 --- a/core/shared/src/main/scala/zio/VersionSpecific.scala +++ b/core/shared/src/main/scala/zio/VersionSpecific.scala @@ -16,12 +16,9 @@ package zio -import zio.internal.Platform import zio.stacktracer.TracingImplicits.disableAutoTrace -import izumi.reflect.macrortti.LightTypeTagRef -import java.util.{Map => JMap} -import scala.collection.mutable +import java.util.concurrent.ConcurrentHashMap private[zio] trait VersionSpecific { @@ -59,8 +56,18 @@ private[zio] trait VersionSpecific { type LightTypeTag = izumi.reflect.macrortti.LightTypeTag - private[zio] def taggedIsSubtype(left: LightTypeTag, right: LightTypeTag): Boolean = - taggedSubtypes.computeIfAbsent((left, right), taggedIsSubtypeFn).value + private[zio] def taggedIsSubtype(left: LightTypeTag, right: LightTypeTag): Boolean = { + // NOTE: Prefer get/putIfAbsent pattern as it offers better read performance at the cost of + // potentially computing `<:<` multiple times during app warmup + val k = (left, right) + taggedSubtypes.get(k) match { + case null => + val v = left <:< right + taggedSubtypes.putIfAbsent(k, v) + v + case v => v.booleanValue() + } + } private[zio] def taggedTagType[A](tagged: EnvironmentTag[A]): LightTypeTag = tagged.tag @@ -72,25 +79,25 @@ private[zio] trait VersionSpecific { * `Tag[A with B]` should produce `Set(Tag[A], Tag[B])` */ private[zio] def taggedGetServices[A](t: LightTypeTag): Set[LightTypeTag] = - taggedServices.computeIfAbsent(t, taggedServicesFn) - - private val taggedSubtypes: JMap[(LightTypeTag, LightTypeTag), BoxedBool] = - Platform.newConcurrentMap()(Unsafe.unsafe) - - private val taggedServices: JMap[LightTypeTag, Set[LightTypeTag]] = - Platform.newConcurrentMap()(Unsafe.unsafe) - - private[this] val taggedIsSubtypeFn = - new java.util.function.Function[(LightTypeTag, LightTypeTag), BoxedBool] { - override def apply(tags: (LightTypeTag, LightTypeTag)): BoxedBool = - if (tags._1 <:< tags._2) BoxedBool.True else BoxedBool.False + // NOTE: See `taggedIsSubtype` for implementation notes + taggedServices.get(t) match { + case null => + val v = t.decompose + taggedServices.putIfAbsent(t, v) + v + case v => v } - private[this] val taggedServicesFn = - new java.util.function.Function[LightTypeTag, Set[LightTypeTag]] { - override def apply(tag: LightTypeTag): Set[LightTypeTag] = - tag.decompose - } + private val taggedSubtypes: ConcurrentHashMap[(LightTypeTag, LightTypeTag), java.lang.Boolean] = { + /* + * '''NOTE''': Larger maps have lower chance of collision which offers better + * read performance and smaller chance of entering synchronized blocks during writes + */ + new ConcurrentHashMap[(LightTypeTag, LightTypeTag), java.lang.Boolean](1024) + } + + private val taggedServices: ConcurrentHashMap[LightTypeTag, Set[LightTypeTag]] = + new ConcurrentHashMap[LightTypeTag, Set[LightTypeTag]](256) private sealed trait BoxedBool { self => final def value: Boolean = self eq BoxedBool.True diff --git a/core/shared/src/main/scala/zio/ZEnvironment.scala b/core/shared/src/main/scala/zio/ZEnvironment.scala index d9a467d088a4..908c0ca0e78c 100644 --- a/core/shared/src/main/scala/zio/ZEnvironment.scala +++ b/core/shared/src/main/scala/zio/ZEnvironment.scala @@ -283,7 +283,7 @@ final class ZEnvironment[+R] private ( else value } - private[this] def getUnsafe[A](tag: LightTypeTag)(implicit unsafe: Unsafe): A = { + private[this] def getUnsafe[A](tag: LightTypeTag): A = { val fromCache = self.cache.get(tag) if (fromCache != null) fromCache.asInstanceOf[A] @@ -301,7 +301,7 @@ final class ZEnvironment[+R] private ( } } if (service != null) { - self.cache.put(tag, service) + self.cache.putIfAbsent(tag, service) } service } diff --git a/core/shared/src/main/scala/zio/ZPool.scala b/core/shared/src/main/scala/zio/ZPool.scala index 5a4f29c9ad98..6de09c3d0a55 100644 --- a/core/shared/src/main/scala/zio/ZPool.scala +++ b/core/shared/src/main/scala/zio/ZPool.scala @@ -16,6 +16,7 @@ package zio +import zio.ZIO.InterruptibilityRestorer import zio.stacktracer.TracingImplicits.disableAutoTrace /** @@ -115,7 +116,7 @@ object ZPool { down <- Ref.make(false) state <- Ref.make(State(0, 0)) items <- Queue.bounded[Attempted[E, A]](range.end) - inv <- Ref.make(Set.empty[A]) + alloc <- Ref.make(Set.empty[A]) initial <- strategy.initial pool = DefaultPool( get.provideSomeEnvironment[Scope](env.union[Scope](_)), @@ -123,7 +124,7 @@ object ZPool { down, state, items, - inv, + alloc, strategy.track(initial) ) _ <- restore(pool.initialize).foldCauseZIO( @@ -148,16 +149,27 @@ object ZPool { ZIO.done(result) } - private case class DefaultPool[R, E, A]( + private final case class DefaultPool[E, A]( creator: ZIO[Scope, E, A], range: Range, isShuttingDown: Ref[Boolean], state: Ref[State], items: Queue[Attempted[E, A]], - invalidated: Ref[Set[A]], + allocated: Ref[Set[A]], track: Exit[E, A] => UIO[Any] ) extends ZPool[E, A] { + private def allocate(implicit restore: InterruptibilityRestorer, trace: Trace): UIO[Any] = + for { + scope <- Scope.make + exit <- scope.extend(restore(creator)).exit + attempted <- ZIO.succeed(Attempted(exit, scope.close(exit))) + _ <- attempted.forEach(a => allocated.update(_ + a)) + _ <- items.offer(attempted) + _ <- track(attempted.result) + _ <- getAndShutdown.whenZIO(isShuttingDown.get) + } yield attempted + /** * Returns the number of items in the pool in excess of the minimum size. */ @@ -165,7 +177,7 @@ object ZPool { state.get.map { case State(free, size) => size - range.start min free } def get(implicit trace: Trace): ZIO[Scope, E, A] = - ZIO.InterruptibilityRestorer.make.flatMap { restore => + ZIO.InterruptibilityRestorer.make.flatMap { implicit restore => def acquire: UIO[Attempted[E, A]] = isShuttingDown.get.flatMap { down => if (down) ZIO.interrupt @@ -176,9 +188,9 @@ object ZPool { items.take.flatMap { attempted => attempted.result match { case Exit.Success(item) => - invalidated.get.flatMap { set => - if (set.contains(item)) finalizeInvalid(attempted) *> acquire - else Exit.succeed(attempted) + allocated.get.flatMap { set => + if (set.contains(item)) Exit.succeed(attempted) + else finalizeInvalid(attempted) *> acquire } case _ => state.modify { case State(size, free) => @@ -201,21 +213,20 @@ object ZPool { def release(attempted: Attempted[E, A]): UIO[Any] = attempted.result match { case Exit.Success(item) => - invalidated.get.flatMap { set => - if (set.contains(item)) finalizeInvalid(attempted) - else + allocated.get.flatMap { set => + if (set.contains(item)) state.update(state => state.copy(free = state.free + 1)) *> items.offer(attempted) *> track(attempted.result) *> getAndShutdown.whenZIO(isShuttingDown.get) + else finalizeInvalid(attempted) } case _ => Exit.unit // Handled during acquire } def finalizeInvalid(attempted: Attempted[E, A]): UIO[Any] = - attempted.forEach(a => invalidated.update(_ - a)) *> - attempted.finalizer *> + attempted.finalizer *> state.modify { case State(size, free) => if (size <= range.start || free < 0) allocate -> State(size, free + 1) @@ -223,38 +234,18 @@ object ZPool { ZIO.unit -> State(size - 1, free) }.flatten - def allocate: UIO[Any] = - for { - scope <- Scope.make - exit <- scope.extend(restore(creator)).exit - attempted <- ZIO.succeed(Attempted(exit, scope.close(exit))) - _ <- items.offer(attempted) - _ <- track(attempted.result) - _ <- getAndShutdown.whenZIO(isShuttingDown.get) - } yield attempted - ZIO.acquireRelease(acquire)(release).flatMap(_.result).disconnect } /** * Begins pre-allocating pool entries based on minimum pool size. */ - final def initialize(implicit trace: Trace): UIO[Unit] = - ZIO.uninterruptibleMask { restore => + def initialize(implicit trace: Trace): UIO[Unit] = + ZIO.uninterruptibleMask { implicit restore => ZIO.replicateZIODiscard(range.start) { state.modify { case State(size, free) => if (size < range.start && size >= 0) - ( - for { - scope <- Scope.make - exit <- scope.extend(restore(creator)).exit - attempted <- ZIO.succeed(Attempted(exit, scope.close(exit))) - _ <- items.offer(attempted) - _ <- track(attempted.result) - _ <- getAndShutdown.whenZIO(isShuttingDown.get) - } yield attempted, - State(size + 1, free + 1) - ) + allocate -> State(size + 1, free + 1) else ZIO.unit -> State(size, free) }.flatten @@ -262,7 +253,7 @@ object ZPool { } def invalidate(item: A)(implicit trace: zio.Trace): UIO[Unit] = - invalidated.update(_ + item) + allocated.update(_ - item) /** * Shrinks the pool down, but never to less than the minimum size. @@ -273,7 +264,7 @@ object ZPool { if (size > range.start && free > 0) ( items.take.flatMap { attempted => - attempted.forEach(a => invalidated.update(_ - a)) *> + attempted.forEach(a => allocated.update(_ - a)) *> attempted.finalizer *> state.update(state => state.copy(size = state.size - 1)) }, @@ -295,7 +286,7 @@ object ZPool { items.take.foldCauseZIO( _ => ZIO.unit, attempted => - attempted.forEach(a => invalidated.update(_ - a)) *> + attempted.forEach(a => allocated.update(_ - a)) *> attempted.finalizer *> state.update(state => state.copy(size = state.size - 1)) *> getAndShutdown diff --git a/core/shared/src/main/scala/zio/internal/FiberRuntime.scala b/core/shared/src/main/scala/zio/internal/FiberRuntime.scala index 9cb8e6f9afdb..ff38871a48aa 100644 --- a/core/shared/src/main/scala/zio/internal/FiberRuntime.scala +++ b/core/shared/src/main/scala/zio/internal/FiberRuntime.scala @@ -158,7 +158,7 @@ final class FiberRuntime[E, A](fiberId: FiberId.Runtime, fiberRefs0: FiberRefs, if (isAlive()) { getChildren().add(child) - if (isInterrupted()) + if (shouldInterrupt()) child.tellInterrupt(getInterruptedCause()) } else { child.tellInterrupt(getInterruptedCause()) @@ -170,7 +170,7 @@ final class FiberRuntime[E, A](fiberId: FiberId.Runtime, fiberRefs0: FiberRefs, if (isAlive()) { val childs = getChildren() - if (isInterrupted()) { + if (shouldInterrupt()) { val cause = getInterruptedCause() while (iter.hasNext) { val child = iter.next() @@ -410,7 +410,7 @@ final class FiberRuntime[E, A](fiberId: FiberId.Runtime, fiberRefs0: FiberRefs, } val exit = - runLoop(effect, 0, _stackSize, initialDepth).asInstanceOf[Exit[E, A]] + runLoop(effect, 0, _stackSize, initialDepth, 0).asInstanceOf[Exit[E, A]] if (null eq exit) { // Terminate this evaluation, async resumption will continue evaluation: @@ -962,13 +962,14 @@ final class FiberRuntime[E, A](fiberId: FiberId.Runtime, fiberRefs0: FiberRefs, effect: ZIO.Erased, minStackIndex: Int, startStackIndex: Int, - currentDepth: Int + currentDepth: Int, + currentOps: Int ): Exit[Any, Any] = { assert(DisableAssertions || running.get) // Note that assigning `cur` as the result of `try` or `if` can cause Scalac to box local variables. var cur = effect - var ops = 0 + var ops = currentOps var stackIndex = startStackIndex if (currentDepth >= FiberRuntime.MaxDepthBeforeTrampoline) { @@ -1064,7 +1065,8 @@ final class FiberRuntime[E, A](fiberId: FiberId.Runtime, fiberRefs0: FiberRefs, stackIndex = pushStackFrame(flatmap, stackIndex) - val result = runLoop(flatmap.first, stackIndex, stackIndex, currentDepth + 1) + val result = runLoop(flatmap.first, stackIndex, stackIndex, currentDepth + 1, ops) + ops += 1 if (null eq result) return null @@ -1094,7 +1096,9 @@ final class FiberRuntime[E, A](fiberId: FiberId.Runtime, fiberRefs0: FiberRefs, stackIndex = pushStackFrame(fold, stackIndex) - val result = runLoop(fold.first, stackIndex, stackIndex, currentDepth + 1) + val result = runLoop(fold.first, stackIndex, stackIndex, currentDepth + 1, ops) + ops += 1 + if (null eq result) return null else { @@ -1159,7 +1163,8 @@ final class FiberRuntime[E, A](fiberId: FiberId.Runtime, fiberRefs0: FiberRefs, stackIndex = pushStackFrame(k, stackIndex) - val exit = runLoop(update0.f(oldRuntimeFlags), stackIndex, stackIndex, currentDepth + 1) + val exit = runLoop(update0.f(oldRuntimeFlags), stackIndex, stackIndex, currentDepth + 1, ops) + ops += 1 if (null eq exit) return null @@ -1191,7 +1196,7 @@ final class FiberRuntime[E, A](fiberId: FiberId.Runtime, fiberRefs0: FiberRefs, cur = null while ((cur eq null) && check()) { - runLoop(iterate.body(), stackIndex, stackIndex, nextDepth) match { + runLoop(iterate.body(), stackIndex, stackIndex, nextDepth, ops) match { case s: Success[Any] => iterate.process(s.value) case null => @@ -1199,6 +1204,7 @@ final class FiberRuntime[E, A](fiberId: FiberId.Runtime, fiberRefs0: FiberRefs, case failure => cur = failure } + ops += 1 } stackIndex -= 1 diff --git a/docs/reference/service-pattern/service-pattern.md b/docs/reference/service-pattern/service-pattern.md index 2e38f06f7243..22ddeaedca1d 100644 --- a/docs/reference/service-pattern/service-pattern.md +++ b/docs/reference/service-pattern/service-pattern.md @@ -114,7 +114,7 @@ final class DocRepoLive( _ <- metadataRepo.put(id, metadata) } yield id - override def delete(id: String): ZIO[Any, Throwable, Unit] = (blobStorage.delete(id) <&> metadataRepo.delete(id)).unit + override def delete(id: String): ZIO[Any, Throwable, Unit] = blobStorage.delete(id) &> metadataRepo.delete(id).unit override def findByTitle(title: String): ZIO[Any, Throwable, List[Doc]] = for { diff --git a/project/MimaSettings.scala b/project/MimaSettings.scala index 0ebf3cacfbc7..43bdef4258e0 100644 --- a/project/MimaSettings.scala +++ b/project/MimaSettings.scala @@ -17,7 +17,9 @@ object MimaSettings { exclude[IncompatibleResultTypeProblem]("zio.stm.TRef.todo"), exclude[DirectMissingMethodProblem]("zio.stm.TRef.versioned_="), exclude[IncompatibleResultTypeProblem]("zio.stm.TRef.versioned"), - exclude[ReversedMissingMethodProblem]("zio.Fiber#Runtime#UnsafeAPI.zio$Fiber$Runtime$UnsafeAPI$$$outer") + exclude[ReversedMissingMethodProblem]("zio.Fiber#Runtime#UnsafeAPI.zio$Fiber$Runtime$UnsafeAPI$$$outer"), + exclude[FinalClassProblem]("zio.ZPool$DefaultPool"), + exclude[DirectMissingMethodProblem]("zio.ZPool#DefaultPool.invalidated") ), mimaFailOnProblem := failOnProblem ) diff --git a/streams-tests/shared/src/test/scala/zio/stream/ZStreamSpec.scala b/streams-tests/shared/src/test/scala/zio/stream/ZStreamSpec.scala index 2198b2993cf0..f27c3a7ff89b 100644 --- a/streams-tests/shared/src/test/scala/zio/stream/ZStreamSpec.scala +++ b/streams-tests/shared/src/test/scala/zio/stream/ZStreamSpec.scala @@ -2400,9 +2400,9 @@ object ZStreamSpec extends ZIOBaseSpec { .runDrain .fork _ <- requestQueue.offer("some message").forever.fork - _ <- counter.get.repeatUntil(_ >= 10) + _ <- (ZIO.yieldNow *> counter.get).repeatUntil(_ >= 10) } yield assertCompletes - } @@ exceptJS(nonFlaky) @@ TestAspect.timeout(10.seconds) + } @@ exceptJS(nonFlaky) @@ TestAspect.timeout(30.seconds) ), suite("interruptAfter")( test("interrupts after given duration") { @@ -2725,7 +2725,7 @@ object ZStreamSpec extends ZIOBaseSpec { count <- latch.count _ <- f.join } yield assertTrue(count == 0) - } @@ TestAspect.jvmOnly @@ nonFlaky(5), + } @@ TestAspect.jvmOnly @@ nonFlaky, test("accumulates parallel errors") { sealed abstract class DbError extends Product with Serializable case object Missing extends DbError diff --git a/streams/shared/src/main/scala/zio/stream/ZChannel.scala b/streams/shared/src/main/scala/zio/stream/ZChannel.scala index 64c74d566781..12c08b978edd 100644 --- a/streams/shared/src/main/scala/zio/stream/ZChannel.scala +++ b/streams/shared/src/main/scala/zio/stream/ZChannel.scala @@ -673,54 +673,52 @@ sealed trait ZChannel[-Env, -InErr, -InElem, -InDone, +OutErr, +OutElem, +OutDon permits <- Semaphore.make(n.toLong) failure <- Ref.make[Cause[OutErr1]](Cause.empty) pull <- (queueReader >>> self).toPullInAlt(scope) - _ <- ZIO.fiberIdWith { fiberId => - pull.flatMap { outElem => - val latch = Promise.unsafe.make[Nothing, Unit](fiberId)(Unsafe.unsafe) - for { - f <- permits - .withPermit( - latch.succeed(()) *> f(outElem) - .catchAllCause(cause => - failure.update(_ && cause).unless(cause.isInterrupted) *> - errorSignal.succeed(()) *> - ZChannel.failLeftUnit - ) - ) - .interruptible - .fork - _ <- latch.await - _ <- outgoing.offer(f) - } yield () - }.forever.interruptible - } - .catchAllCause(cause => - cause.failureOrCause match { - case Left(x: Left[OutErr, OutDone]) => - failure.update(_ && Cause.fail(x.value)) *> - outgoing.offer(Fiber.done(ZChannel.failLeftUnit)) *> - ZChannel.failUnit - case Left(x: Right[OutErr, OutDone]) => - permits.withPermits(n.toLong)(ZIO.unit).interruptible *> - outgoing.offer(Fiber.fail(x.asInstanceOf[Either[Unit, OutDone]])) - case Right(cause) => - failure.update(_ && cause).unless(cause.isInterrupted) *> - outgoing.offer(Fiber.done(ZChannel.failLeftUnit)) *> - ZChannel.failUnit - } - ) - .raceFirst(errorSignal.await.interruptible) - .forkIn(scope) + childScope <- scope.fork + fiberId <- ZIO.fiberId + _ <- + pull.flatMap { outElem => + val latch = Promise.unsafe.make[Nothing, Unit](fiberId)(Unsafe) + for { + f <- permits + .withPermit( + latch.succeed(()) *> f(outElem) + .catchAllCause(cause => + failure.update(_ && cause).unless(cause.isInterruptedOnly) *> + errorSignal.succeed(()) *> + ZChannel.failLeftUnit + ) + ) + .interruptible + .forkIn(childScope) + _ <- latch.await + _ <- outgoing.offer(f) + } yield () + }.forever.interruptible + .onError(_.failureOrCause match { + case Left(x: Left[OutErr, OutDone]) => + failure.update(_ && Cause.fail(x.value)) *> + outgoing.offer(Fiber.done(ZChannel.failLeftUnit)) + case Left(x: Right[OutErr, OutDone]) => + permits.withPermits(n.toLong)(ZIO.unit).interruptible *> + outgoing.offer(Fiber.fail(x.asInstanceOf[Either[Unit, OutDone]])) + case Right(cause) => + failure.update(_ && cause).unless(cause.isInterruptedOnly) *> + outgoing.offer(Fiber.done(ZChannel.failLeftUnit)) + }) + .raceFirst(ZChannel.awaitErrorSignal(childScope, fiberId)(errorSignal)) + .forkIn(scope) } yield { lazy val writer: ZChannel[Env1, Any, Any, Any, OutErr1, OutElem2, OutDone] = ZChannel.unwrap[Env1, Any, Any, Any, OutErr1, OutElem2, OutDone] { outgoing.take.flatMap(_.await).map { case s: Exit.Success[OutElem2] => ZChannel.write(s.value) *> writer case f: Exit.Failure[Either[Unit, OutDone]] => + def extractFailures = ZChannel.unwrap(failure.get.map(ZChannel.refailCause(_))) f.cause.failureOrCause match { - case Left(_: Left[Unit, OutDone]) => ZChannel.unwrap(failure.get.map(ZChannel.refailCause(_))) - case Left(x: Right[Unit, OutDone]) => ZChannel.succeedNow(x.value) - case Right(cause) if cause.isInterrupted => ZChannel.unwrap(failure.get.map(ZChannel.refailCause(_))) - case Right(cause) => ZChannel.refailCause(cause) + case Left(_: Left[Unit, OutDone]) => extractFailures + case Left(x: Right[Unit, OutDone]) => ZChannel.succeedNow(x.value) + case Right(c) if c.isInterruptedOnly => extractFailures + case Right(cause) => ZChannel.refailCause(cause) } } } @@ -752,54 +750,53 @@ sealed trait ZChannel[-Env, -InErr, -InElem, -InDone, +OutErr, +OutElem, +OutDon permits <- Semaphore.make(n.toLong) failure <- Ref.make[Cause[OutErr1]](Cause.empty) pull <- (queueReader >>> self).toPullInAlt(scope) - _ <- ZIO.fiberIdWith { fiberId => - pull.flatMap { outElem => - val latch = Promise.unsafe.make[Nothing, Unit](fiberId)(Unsafe.unsafe) - for { - _ <- permits - .withPermit( - latch.succeed(()) *> f(outElem) - .foldCauseZIO( - cause => - failure.update(_ && cause).unless(cause.isInterrupted) *> - errorSignal.succeed(()) *> - outgoing.offer(ZChannel.failLeftUnit), - elem => outgoing.offer(Exit.succeed(elem)) - ) - ) - .interruptible - .fork - _ <- latch.await - } yield () - }.forever.interruptible - } - .catchAllCause(cause => - cause.failureOrCause match { - case Left(x: Left[OutErr, OutDone]) => - failure.update(_ && Cause.fail(x.value)) *> - outgoing.offer(ZChannel.failLeftUnit) *> - ZChannel.failUnit - case Left(x: Right[OutErr, OutDone]) => - permits.withPermits(n.toLong)(ZIO.unit).interruptible *> - outgoing.offer(Exit.fail(x.asInstanceOf[Either[Unit, OutDone]])) - case Right(cause) => - failure.update(_ && cause).unless(cause.isInterrupted) *> - outgoing.offer(ZChannel.failLeftUnit) *> - ZChannel.failUnit - } - ) - .raceFirst(errorSignal.await.interruptible) - .forkIn(scope) + childScope <- scope.fork + fiberId <- ZIO.fiberId + _ <- + pull.flatMap { outElem => + val latch = Promise.unsafe.make[Nothing, Unit](fiberId)(Unsafe) + for { + _ <- permits + .withPermit( + latch.succeed(()) *> f(outElem) + .foldCauseZIO( + cause => + failure.update(_ && cause).unless(cause.isInterruptedOnly) *> + errorSignal.succeed(()) *> + outgoing.offer(ZChannel.failLeftUnit), + elem => outgoing.offer(Exit.succeed(elem)) + ) + ) + .interruptible + .forkIn(childScope) + _ <- latch.await + } yield () + }.forever.interruptible + .onError(_.failureOrCause match { + case Left(x: Left[OutErr, OutDone]) => + failure.update(_ && Cause.fail(x.value)) *> + outgoing.offer(ZChannel.failLeftUnit) + case Left(x: Right[OutErr, OutDone]) => + permits.withPermits(n.toLong)(ZIO.unit).interruptible *> + outgoing.offer(Exit.fail(x.asInstanceOf[Either[Unit, OutDone]])) + case Right(cause) => + failure.update(_ && cause).unless(cause.isInterruptedOnly) *> + outgoing.offer(ZChannel.failLeftUnit) + }) + .raceFirst(ZChannel.awaitErrorSignal(childScope, fiberId)(errorSignal)) + .forkIn(scope) } yield { lazy val writer: ZChannel[Env1, Any, Any, Any, OutErr1, OutElem2, OutDone] = ZChannel.unwrap[Env1, Any, Any, Any, OutErr1, OutElem2, OutDone] { outgoing.take.map { case s: Exit.Success[OutElem2] => ZChannel.write(s.value) *> writer case f: Exit.Failure[Either[Unit, OutDone]] => + def extractFailures = ZChannel.unwrap(failure.get.map(ZChannel.refailCause(_))) f.cause.failureOrCause match { - case Left(_: Left[Unit, OutDone]) => ZChannel.unwrap(failure.get.map(ZChannel.refailCause(_))) - case Left(x: Right[Unit, OutDone]) => ZChannel.succeedNow(x.value) - case Right(cause) => ZChannel.refailCause(cause) + case Left(_: Left[Unit, OutDone]) => extractFailures + case Left(x: Right[Unit, OutDone]) => ZChannel.succeedNow(x.value) + case Right(c) if c.isInterruptedOnly => extractFailures + case Right(cause) => ZChannel.refailCause(cause) } } } @@ -1684,6 +1681,17 @@ object ZChannel { .ensuringWith(ex => ref.get.flatMap(_.apply(ex))) } + private def awaitErrorSignal( + scope: Scope.Closeable, + fiberId: FiberId + )( + signal: Promise[Nothing, Unit] + )(implicit trace: Trace): UIO[Unit] = + signal.await.interruptible.onExit { + case _: Exit.Success[?] => scope.close(Exit.interrupt(fiberId)) + case _ => Exit.unit + } + /** * Creates a channel backed by a buffer. When the buffer is empty, the channel * will simply passthrough its input as output. However, when the buffer is @@ -1951,6 +1959,8 @@ object ZChannel { errorSignal <- Promise.make[Nothing, Unit] permits <- Semaphore.make(n.toLong) pull <- (incoming >>> channels).toPullInAlt(scope) + childScope <- scope.fork + fiberId <- ZIO.fiberId evaluatePull = (pull: ZIO[Env, Either[OutErr, OutDone], OutElem]) => pull .flatMap(outElem => outgoing.offer(Exit.succeed(outElem))) @@ -1983,7 +1993,7 @@ object ZChannel { _ <- permits .withPermit(latch.succeed(()) *> raceIOs) .interruptible - .fork + .forkIn(childScope) _ <- latch.await } yield () } @@ -2002,17 +2012,16 @@ object ZChannel { .toPullInAlt(scope) .flatMap(evaluatePull(_).race(canceler.await.interruptible)) } - childFiber <- permits - .withPermit(latch.succeed(()) *> raceIOs) - .interruptible - .fork + _ <- permits + .withPermit(latch.succeed(()) *> raceIOs) + .interruptible + .forkIn(childScope) _ <- latch.await } yield () } } - _ <- ZIO - .fiberIdWith(pullStrategy(_).forever) - .catchAllCause(cause => + _ <- pullStrategy(fiberId).forever + .onError(cause => cause.failureOrCause match { case Left(_: Left[OutErr, OutDone]) => outgoing.offer(Exit.failCause(cause)) case Left(x: Right[OutErr, OutDone]) => @@ -2024,7 +2033,7 @@ object ZChannel { case Right(cause) => outgoing.offer(Exit.failCause(cause.map(Left(_)))) } ) - .raceFirst(errorSignal.await.interruptible) + .raceFirst(awaitErrorSignal(childScope, fiberId)(errorSignal)) .forkIn(scope) } yield { lazy val consumer: ZChannel[Env, Any, Any, Any, OutErr, OutElem, OutDone] = diff --git a/test-junit-engine-tests/maven/pom.xml b/test-junit-engine-tests/maven/pom.xml new file mode 100644 index 000000000000..5c6718f8eecf --- /dev/null +++ b/test-junit-engine-tests/maven/pom.xml @@ -0,0 +1,123 @@ + + 4.0.0 + dev.zio + 1.0 + zio_test_junit_engine_test + pom + ${project.artifactId} + Testing ZIO Test Junit engine test project + 2024 + + + UTF-8 + 2.13.13 + 2.12 + 2.0.22 + + + + + org.junit.jupiter + junit-jupiter-engine + 5.11.0 + test + + + dev.zio + zio-test_${scala.compat.version} + ${zio.version} + test + + + dev.zio + zio_${scala.compat.version} + ${zio.version} + + + dev.zio + zio-test-junit-engine_${scala.compat.version} + ${zio.version} + test + + + org.apache.maven.plugins + maven-surefire-plugin + 3.5.0 + + + + + + scala2 + + + org.scala-lang + scala-library + ${scala.version} + + + + + scala3 + + + org.scala-lang + scala3-library_${scala.compat.version} + ${scala.version} + + + + + + + + + + + net.alchim31.maven + scala-maven-plugin + 4.8.1 + + + + compile + testCompile + + + + + -dependencyfile + ${project.build.directory}/.scala_dependencies + -Ywarn-value-discard + + + + + + + + org.apache.maven.plugins + maven-surefire-plugin + 3.2.5 + + test + methods + 10 + + **/*Spec.* + + + + + + test + + + + + + + + + + diff --git a/test-junit-engine-tests/maven/settings.xml b/test-junit-engine-tests/maven/settings.xml new file mode 100644 index 000000000000..f7b0df01d7ab --- /dev/null +++ b/test-junit-engine-tests/maven/settings.xml @@ -0,0 +1,5 @@ + + + + diff --git a/test-junit-engine-tests/maven/src/test/scala/zio/test/junit/maven/DefectSpec.scala b/test-junit-engine-tests/maven/src/test/scala/zio/test/junit/maven/DefectSpec.scala new file mode 100644 index 000000000000..975aa03697d8 --- /dev/null +++ b/test-junit-engine-tests/maven/src/test/scala/zio/test/junit/maven/DefectSpec.scala @@ -0,0 +1,38 @@ +package zio.test.junit.maven + +import zio.test.Assertion.equalTo +import zio.test.{Spec, TestEnvironment, assert, ZIOSpecDefault} +import zio.{Scope, Task, ZIO, ZLayer} + +trait Ops { + def targetHost: String +} + +object OpsTest extends Ops { + override def targetHost: String = null +} + +trait MyService { + def readData : Task[List[String]] +} + +class MyServiceTest(targetHostName: String) extends MyService { + + val url = s"https://${targetHostName.toLowerCase}/ws" // <- null pointer exception here + + override def readData: Task[List[String]] = { + ZIO.succeed(List("a","b")) + } +} + +object DefectSpec extends ZIOSpecDefault { + override def spec: Spec[TestEnvironment with Scope, Any] = suite("nul test")( + test("test with defect") { + for { + ms <- ZIO.service[MyService] + result <- ms.readData + } + yield assert(result.size)(equalTo(2)) + }.provideLayer(ZLayer.succeed(new MyServiceTest(OpsTest.targetHost))) + ) +} diff --git a/test-junit-engine-tests/maven/src/test/scala/zio/test/junit/maven/FailingSpec.scala b/test-junit-engine-tests/maven/src/test/scala/zio/test/junit/maven/FailingSpec.scala new file mode 100644 index 000000000000..9eb29f57487e --- /dev/null +++ b/test-junit-engine-tests/maven/src/test/scala/zio/test/junit/maven/FailingSpec.scala @@ -0,0 +1,19 @@ +package zio.test.junit.maven + +import zio.test.junit._ +import zio.test._ +import zio.test.Assertion._ + +object FailingSpec extends ZIOSpecDefault { + override def spec = suite("FailingSpec")( + test("should fail") { + assert(11)(equalTo(12)) + }, + test("should fail - isSome") { + assert(Some(11))(isSome(equalTo(12))) + }, + test("should succeed") { + assert(12)(equalTo(12)) + } + ) +} diff --git a/test-junit-engine-tests/src/test/scala/zio/test/junit/MavenJunitSpec.scala b/test-junit-engine-tests/src/test/scala/zio/test/junit/MavenJunitSpec.scala new file mode 100644 index 000000000000..f6daf2da0239 --- /dev/null +++ b/test-junit-engine-tests/src/test/scala/zio/test/junit/MavenJunitSpec.scala @@ -0,0 +1,121 @@ +package zio.test.junit + +import org.apache.maven.cli.MavenCli +import zio.test.Assertion._ +import zio.test.{ZIOSpecDefault, _} +import zio.{System => _, ZIO, Task} + +import java.io.File +import scala.collection.immutable +import scala.xml.XML + +/** + * when running from IDE run `sbt publishM2`, copy the snapshot version the + * artifacts were published under (something like: + * `1.0.2+0-37ee0765+20201006-1859-SNAPSHOT`) and put this into `VM Parameters`: + * `-Dproject.dir=\$PROJECT_DIR\$/test-junit-tests/jvm + * -Dproject.version=\$snapshotVersion` + */ +object MavenJunitSpec extends ZIOSpecDefault { + + def spec = suite("MavenJunitSpec")( + test("Spec results are properly reported") { + for { + mvn <- makeMaven + mvnResult <- mvn.clean() *> mvn.test() + report <- mvn.parseSurefireReport("zio.test.junit.maven.FailingSpec") + reportDefect <- mvn.parseSurefireReport("zio.test.junit.maven.DefectSpec") + } yield { + assert(mvnResult)(not(equalTo(0))) && + assert(report)( + containsFailure( + "should fail", + "11 was not equal to 12" + ) && + containsFailure( + "should fail - isSome", + "11 was not equal to 12" + ) && + containsSuccess("should succeed") + ) && + assertTrue(reportDefect.length == 1) // spec with defect is reported + } + } + ) @@ TestAspect.sequential /*@@ + // flaky: sometimes maven fails to download dependencies in CI + TestAspect.flaky(3)*/ + + def makeMaven: ZIO[Any, AssertionError, MavenDriver] = for { + projectDir <- + ZIO + .fromOption(sys.props.get("project.dir")) + .orElseFail( + new AssertionError( + "Missing project.dir system property\n" + + "when running from IDE put this into `VM Parameters`: `-Dproject.dir=$PROJECT_DIR$/test-junit-tests/jvm`" + ) + ) + projectVer <- + ZIO + .fromOption(sys.props.get("project.version")) + .orElseFail( + new AssertionError( + "Missing project.version system property\n" + + "when running from IDE put this into `VM Parameters`: `-Dproject.version=`" + ) + ) + scalaVersion = sys.props.get("scala.version").getOrElse("2.12.10") + scalaCompatVersion = sys.props.get("scala.compat.version").getOrElse("2.12") + } yield new MavenDriver(projectDir, projectVer, scalaVersion, scalaCompatVersion) + + class MavenDriver(projectDir: String, projectVersion: String, scalaVersion: String, scalaCompatVersion: String) { + val mvnRoot: String = new File(s"$projectDir/maven").getCanonicalPath + private val cli = new MavenCli + java.lang.System.setProperty("maven.multiModuleProjectDirectory", mvnRoot) + + def clean(): Task[Int] = run("clean") + + def test(): Task[Int] = run( + "test", + s"-Dzio.version=$projectVersion", + s"-Dscala.version=$scalaVersion", + s"-Dscala.compat.version=$scalaCompatVersion", + s"-Pscala${scalaVersion(0)}" + ) + def run(command: String*): Task[Int] = ZIO.attemptBlocking( + cli.doMain(command.toArray, mvnRoot, System.out, System.err) + ) + + def parseSurefireReport(testFQN: String): Task[immutable.Seq[TestCase]] = + ZIO + .attemptBlocking( + XML.load(scala.xml.Source.fromFile(new File(s"$mvnRoot/target/surefire-reports/TEST-$testFQN.xml"))) + ) + .map { report => + (report \ "testcase").map { tcNode => + TestCase( + tcNode \@ "name", + (tcNode \ "error").headOption + .map(error => TestError(error.text.linesIterator.map(_.trim).mkString("\n"), error \@ "type")), + (tcNode \ "failure").headOption + .map(error => TestError(error.text.linesIterator.map(_.trim).mkString("\n"), error \@ "type")) + ) + } + } + } + + def containsSuccess(label: String): Assertion[Iterable[TestCase]] = containsResult(label, error = None) + def containsFailure(label: String, error: String): Assertion[Iterable[TestCase]] = containsResult(label, Some(error)) + def containsResult(label: String, error: Option[String]): Assertion[Iterable[TestCase]] = + exists(assertion(s"check $label") { testCase => + testCase.name == label && + error + .map(err => testCase.errorAndFailure.exists(_.message.contains(err))) + .getOrElse(testCase.errorAndFailure.isEmpty) + }) + + case class TestCase(name: String, error: Option[TestError], failure: Option[TestError]) { + lazy val errorAndFailure: Seq[TestError] = (error ++ failure).toSeq + } + case class TestError(message: String, `type`: String) +} diff --git a/test-junit-engine/src/main/resources/META-INF/services/org.junit.platform.engine.TestEngine b/test-junit-engine/src/main/resources/META-INF/services/org.junit.platform.engine.TestEngine new file mode 100644 index 000000000000..aec86fcff1f7 --- /dev/null +++ b/test-junit-engine/src/main/resources/META-INF/services/org.junit.platform.engine.TestEngine @@ -0,0 +1 @@ +zio.test.junit.ZIOTestEngine \ No newline at end of file diff --git a/test-junit-engine/src/main/scala/zio/test/junit/ReflectionUtils.scala b/test-junit-engine/src/main/scala/zio/test/junit/ReflectionUtils.scala new file mode 100644 index 000000000000..27746300d016 --- /dev/null +++ b/test-junit-engine/src/main/scala/zio/test/junit/ReflectionUtils.scala @@ -0,0 +1,23 @@ +package zio.test.junit + +import scala.util.Try + +private[zio] object ReflectionUtils { + + /** + * Retrieves the companion object of the specified class, if it exists. Here + * we use plain java reflection as runtime reflection is not available for + * scala 3 + * + * @param klass + * The class for which to retrieve the companion object. + * @return + * the optional companion object. + */ + def getCompanionObject(klass: Class[_]): Option[Any] = + Try { + (if (klass.getName.endsWith("$")) klass else getClass.getClassLoader.loadClass(klass.getName + "$")) + .getDeclaredField("MODULE$") + .get(null) + }.toOption +} diff --git a/test-junit-engine/src/main/scala/zio/test/junit/TestFailed.scala b/test-junit-engine/src/main/scala/zio/test/junit/TestFailed.scala new file mode 100644 index 000000000000..ecc0a5f33477 --- /dev/null +++ b/test-junit-engine/src/main/scala/zio/test/junit/TestFailed.scala @@ -0,0 +1,30 @@ +/* + * Copyright 2024-2024 Vincent Raman and the ZIO Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package zio.test.junit + +import org.opentest4j.AssertionFailedError + +/** + * Represents a failure of a test assertion. It needs to extend + * AssertionFailedError for Junit5 platform to mark it as failure + * + * @param message + * A description of the failure. + * @param cause + * The underlying cause of the failure, if any. + */ +class TestFailed(message: String, cause: Throwable = null) extends AssertionFailedError(message, cause) diff --git a/test-junit-engine/src/main/scala/zio/test/junit/ZIOSuiteTestDescriptor.scala b/test-junit-engine/src/main/scala/zio/test/junit/ZIOSuiteTestDescriptor.scala new file mode 100644 index 000000000000..f715370e34d3 --- /dev/null +++ b/test-junit-engine/src/main/scala/zio/test/junit/ZIOSuiteTestDescriptor.scala @@ -0,0 +1,48 @@ +/* + * Copyright 2024-2024 Vincent Raman and the ZIO Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package zio.test.junit + +import org.junit.platform.engine.support.descriptor.AbstractTestDescriptor +import org.junit.platform.engine.{TestDescriptor, UniqueId} + +/** + * Describes a JUnit 5 test descriptor for a suite of ZIO tests. + * + * @constructor + * Creates an instance of ZIOSuiteTestDescriptor. + * @param parent + * The parent TestDescriptor. + * @param uniqueId + * The unique identifier for this test descriptor. + * @param label + * The display name of this test descriptor. + * @param testClass + * The class representing the suite of tests. + */ +class ZIOSuiteTestDescriptor( + parent: TestDescriptor, + uniqueId: UniqueId, + label: String, + testClass: Class[_] +) extends AbstractTestDescriptor(uniqueId, label, ZIOTestSource(testClass)) { + setParent(parent) + override def getType: TestDescriptor.Type = TestDescriptor.Type.CONTAINER +} + +object ZIOSuiteTestDescriptor { + val segmentType = "suite" +} diff --git a/test-junit-engine/src/main/scala/zio/test/junit/ZIOTestClassDescriptor.scala b/test-junit-engine/src/main/scala/zio/test/junit/ZIOTestClassDescriptor.scala new file mode 100644 index 000000000000..6cd216c50187 --- /dev/null +++ b/test-junit-engine/src/main/scala/zio/test/junit/ZIOTestClassDescriptor.scala @@ -0,0 +1,103 @@ +/* + * Copyright 2024-2024 Vincent Raman and the ZIO Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package zio.test.junit + +import org.junit.platform.engine.support.descriptor.AbstractTestDescriptor +import org.junit.platform.engine.{TestDescriptor, UniqueId} +import zio._ +import zio.test._ +import zio.test.junit.ReflectionUtils._ + +/** + * Represents a JUnit 5 test descriptor for a ZIO-based test class + * + * @param parent + * The parent TestDescriptor. + * @param uniqueId + * A unique identifier for this TestDescriptor. + * @param testClass + * The class representing the test. + */ +class ZIOTestClassDescriptor(parent: TestDescriptor, uniqueId: UniqueId, val testClass: Class[_]) + extends AbstractTestDescriptor(uniqueId, testClass.getName.stripSuffix("$"), ZIOTestSource(testClass)) { + + setParent(parent) + val className: String = testClass.getName + + // reflection to get the spec implementation + // by default, it would be an object, extending ZIOSpecDefault + // but this also support class that needs to be instantiated + val spec: ZIOSpecAbstract = getCompanionObject(testClass) + .getOrElse(testClass.getDeclaredConstructor().newInstance()) + .asInstanceOf[ZIOSpecAbstract] + + def traverse[R, E]( + spec: Spec[R, E], + description: TestDescriptor, + path: Vector[String] = Vector.empty + ): ZIO[R with Scope, Any, Unit] = + spec.caseValue match { + case Spec.ExecCase(_, spec: Spec[R, E]) => traverse(spec, description, path) + case Spec.LabeledCase(label, spec: Spec[R, E]) => + traverse(spec, description, path :+ label) + case Spec.ScopedCase(scoped) => scoped.flatMap((s: Spec[R, E]) => traverse(s, description, path)) + case Spec.MultipleCase(specs) => + val suiteDesc = new ZIOSuiteTestDescriptor( + description, + description.getUniqueId.append(ZIOSuiteTestDescriptor.segmentType, path.lastOption.getOrElse("")), + path.lastOption.getOrElse(""), + testClass + ) + ZIO.succeed(description.addChild(suiteDesc)) *> + ZIO.foreach(specs)((s: Spec[R, E]) => traverse(s, suiteDesc, path)).ignore + case Spec.TestCase(_, annotations) => + ZIO.succeed( + description.addChild( + new ZIOTestDescriptor( + description, + description.getUniqueId.append(ZIOTestDescriptor.segmentType, path.lastOption.getOrElse("")), + path.lastOption.getOrElse(""), + testClass, + annotations + ) + ) + ) + } + + lazy val scoped: ZIO[spec.Environment with TestEnvironment, Any, Unit] = + ZIO.scoped[spec.Environment with TestEnvironment]( + traverse(spec.spec, this) + ) + + Unsafe.unsafe { implicit unsafe => + Runtime.default.unsafe + .run( + scoped + .provide( + Scope.default >>> (liveEnvironment >>> TestEnvironment.live ++ ZLayer.environment[Scope]), + spec.bootstrap + ) + ) + .getOrThrowFiberFailure() + } + + override def getType: TestDescriptor.Type = TestDescriptor.Type.CONTAINER +} + +object ZIOTestClassDescriptor { + val segmentType = "class" +} diff --git a/test-junit-engine/src/main/scala/zio/test/junit/ZIOTestClassRunner.scala b/test-junit-engine/src/main/scala/zio/test/junit/ZIOTestClassRunner.scala new file mode 100644 index 000000000000..565994c6f508 --- /dev/null +++ b/test-junit-engine/src/main/scala/zio/test/junit/ZIOTestClassRunner.scala @@ -0,0 +1,152 @@ +/* + * Copyright 2024-2024 Vincent Raman and the ZIO Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package zio.test.junit + +import org.junit.platform.engine.{EngineExecutionListener, TestDescriptor, TestExecutionResult} +import zio.test.render.ConsoleRenderer +import zio.test.render.ExecutionResult.ResultType.Test +import zio.test.render.ExecutionResult.Status.Failed +import zio.test.render.LogLine.Message +import zio.test._ +import zio._ + +import scala.jdk.OptionConverters._ + +/** + * The `ZIOTestClassRunner` is responsible for running ZIO tests within a test + * class and reporting their results using the given JUnit 5 + * `EngineExecutionListener`. + * + * @param descriptor + * The ZIO test class JUnit 5 descriptor + */ +class ZIOTestClassRunner(descriptor: ZIOTestClassDescriptor) { + private val spec = descriptor.spec + + def run(notifier: EngineExecutionListener): IO[Any, Summary] = { + def instrumentedSpec[R, E](zspec: Spec[R, E]) = { + def loop( + spec: Spec[R, E], + description: TestDescriptor, + path: Vector[String] = Vector.empty + ): Spec.SpecCase[R, E, Spec[R, E]] = + spec.caseValue match { + case Spec.ExecCase(exec, spec: Spec[R, E]) => Spec.ExecCase(exec, Spec(loop(spec, description, path))) + case Spec.LabeledCase(label, spec) => + Spec.LabeledCase(label, Spec(loop(spec, description, path :+ label))) + case Spec.ScopedCase(scoped) => + Spec.ScopedCase[R, E, Spec[R, E]](scoped.map(spec => Spec(loop(spec, description, path)))) + case Spec.MultipleCase(specs) => + val uniqueId = + description.getUniqueId.append(ZIOSuiteTestDescriptor.segmentType, path.lastOption.getOrElse("")) + descriptor + .findByUniqueId(uniqueId) + .toScala + .map(suiteDescription => Spec.MultipleCase(specs.map(spec => Spec(loop(spec, suiteDescription, path))))) + .getOrElse( + // filtered out + Spec.MultipleCase(Chunk.empty) + ) + case Spec.TestCase(test, annotations) => + val uniqueId = description.getUniqueId.append(ZIOTestDescriptor.segmentType, path.lastOption.getOrElse("")) + descriptor + .findByUniqueId(uniqueId) + .toScala + .map(_ => Spec.TestCase(test, annotations)) + .getOrElse( + // filtered out + Spec.TestCase(ZIO.succeed(TestSuccess.Ignored()), annotations) + ) + } + + Spec(loop(zspec, descriptor)) + } + + def testDescriptorFromReversedLabel(labelsReversed: List[String]): Option[TestDescriptor] = { + val uniqueId = labelsReversed.reverse.zipWithIndex.foldLeft(descriptor.getUniqueId) { + case (uid, (label, labelIdx)) if labelIdx == labelsReversed.length - 1 => + uid.append(ZIOTestDescriptor.segmentType, label) + case (uid, (label, _)) => uid.append(ZIOSuiteTestDescriptor.segmentType, label) + } + descriptor.findByUniqueId(uniqueId).toScala + } + + def notifyTestFailure(testDescriptor: TestDescriptor, failure: TestFailure[_]): Unit = failure match { + case TestFailure.Assertion(result, _) => + notifier.executionFinished( + testDescriptor, + TestExecutionResult.failed( + new TestFailed(renderToString(renderFailureDetails(testDescriptor, result))) + ) + ) + case TestFailure.Runtime(cause, _) => + notifier.executionFinished( + testDescriptor, + TestExecutionResult.failed(cause.squashWith { + case t: Throwable => t + case _ => new TestFailed(renderToString(ConsoleRenderer.renderCause(cause, 0))) + }) + ) + } + + val instrumented: Spec[spec.Environment with TestEnvironment with Scope, Any] = instrumentedSpec(spec.spec) + + val eventHandler: ZTestEventHandler = { + case ExecutionEvent.TestStarted(labelsReversed, _, _, _, _) => + ZIO.succeed(testDescriptorFromReversedLabel(labelsReversed).foreach(notifier.executionStarted)) + case ExecutionEvent.Test(labelsReversed, test, _, _, _, _, _) => + ZIO.succeed(testDescriptorFromReversedLabel(labelsReversed).foreach { testDescriptor => + test match { + case Left(failure: TestFailure[_]) => + notifyTestFailure(testDescriptor, failure) + case Right(TestSuccess.Succeeded(_)) => + notifier.executionFinished(testDescriptor, TestExecutionResult.successful()) + case Right(TestSuccess.Ignored(_)) => + notifier.executionSkipped(testDescriptor, "Test skipped") + } + }) + case ExecutionEvent.RuntimeFailure(_, labelsReversed, failure, _) => + ZIO.succeed(testDescriptorFromReversedLabel(labelsReversed).foreach(notifyTestFailure(_, failure))) + case _ => ZIO.unit // unhandled events linked to the suite level + } + + spec + .runSpecAsApp(instrumented, TestArgs.empty, Console.ConsoleLive, eventHandler) + .provide( + Scope.default >>> (liveEnvironment >>> TestEnvironment.live ++ ZLayer.environment[Scope]), + spec.bootstrap + ) + } + + private def renderFailureDetails(descriptor: TestDescriptor, result: TestResult): Message = + Message( + ConsoleRenderer + .rendered( + Test, + descriptor.getDisplayName, + Failed, + 0, + ConsoleRenderer.renderAssertionResult(result.result, 0).lines: _* + ) + .streamingLines + ) + + private def renderToString(message: Message) = + message.lines + .map(_.fragments.map(_.text).fold("")(_ + _)) + .mkString("\n") +} diff --git a/test-junit-engine/src/main/scala/zio/test/junit/ZIOTestClassSelectorResolver.scala b/test-junit-engine/src/main/scala/zio/test/junit/ZIOTestClassSelectorResolver.scala new file mode 100644 index 000000000000..834f09165391 --- /dev/null +++ b/test-junit-engine/src/main/scala/zio/test/junit/ZIOTestClassSelectorResolver.scala @@ -0,0 +1,117 @@ +/* + * Copyright 2024-2024 Vincent Raman and the ZIO Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package zio.test.junit + +import org.junit.platform.commons.support.ReflectionSupport +import org.junit.platform.commons.util.ReflectionUtils.{isAbstract, isAssignableTo, isInnerClass, isPublic} +import org.junit.platform.engine.TestDescriptor +import org.junit.platform.engine.discovery.{ClassSelector, ClasspathRootSelector, ModuleSelector, PackageSelector} +import org.junit.platform.engine.support.discovery.SelectorResolver +import org.junit.platform.engine.support.discovery.SelectorResolver.{Match, Resolution} +import zio.test._ +import zio.test.junit.ReflectionUtils._ + +import java.util.Optional +import java.util.function.Predicate +import java.util.stream.Collectors +import scala.jdk.CollectionConverters._ + +/** + * JUnit 5 platform test class resolver for ZIO test implementation + */ +class ZIOTestClassSelectorResolver extends SelectorResolver { + private val isSuitePredicate: Predicate[Class[_]] = { (testClass: Class[_]) => + // valid test class are ones directly extending ZIOSpecAbstract or whose companion object + // extends ZIOSpecAbstract + isPublic(testClass) && !isAbstract(testClass) && !isInnerClass(testClass) && + (isAssignableTo(testClass, classOf[ZIOSpecAbstract]) || getCompanionObject(testClass).exists( + _.isInstanceOf[ZIOSpecAbstract] + )) + } + + private val alwaysTruePredicate: Predicate[String] = _ => true + + private def classDescriptorFunction( + testClass: Class[_] + ): java.util.function.Function[TestDescriptor, Optional[ZIOTestClassDescriptor]] = + (parentTestDescriptor: TestDescriptor) => { + val suiteUniqueId = + parentTestDescriptor.getUniqueId.append(ZIOTestClassDescriptor.segmentType, testClass.getName.stripSuffix("$")) + val newChild = parentTestDescriptor.getChildren.asScala.find(_.getUniqueId == suiteUniqueId) match { + case Some(_) => Optional.empty[ZIOTestClassDescriptor]() + case None => Optional.of(new ZIOTestClassDescriptor(parentTestDescriptor, suiteUniqueId, testClass)) + } + newChild + } + + private val toMatch: java.util.function.Function[TestDescriptor, java.util.stream.Stream[Match]] = + (td: TestDescriptor) => java.util.stream.Stream.of[Match](Match.exact(td)) + + private def addToParentFunction( + context: SelectorResolver.Context + ): java.util.function.Function[Class[_], java.util.stream.Stream[Match]] = (aClass: Class[_]) => { + context + .addToParent(classDescriptorFunction(aClass)) + .map[java.util.stream.Stream[Match]](toMatch) + .orElse(java.util.stream.Stream.empty()) + } + + override def resolve( + selector: ClasspathRootSelector, + context: SelectorResolver.Context + ): SelectorResolver.Resolution = { + val matches = + ReflectionSupport + .findAllClassesInClasspathRoot(selector.getClasspathRoot, isSuitePredicate, alwaysTruePredicate) + .stream() + .flatMap(addToParentFunction(context)) + .collect(Collectors.toSet()) + Resolution.matches(matches) + } + + override def resolve(selector: PackageSelector, context: SelectorResolver.Context): SelectorResolver.Resolution = { + val matches = + ReflectionSupport + .findAllClassesInPackage(selector.getPackageName, isSuitePredicate, alwaysTruePredicate) + .stream() + .flatMap(addToParentFunction(context)) + .collect(Collectors.toSet()) + Resolution.matches(matches) + } + + override def resolve(selector: ModuleSelector, context: SelectorResolver.Context): SelectorResolver.Resolution = { + val matches = + ReflectionSupport + .findAllClassesInModule(selector.getModuleName, isSuitePredicate, alwaysTruePredicate) + .stream() + .flatMap(addToParentFunction(context)) + .collect(Collectors.toSet()) + Resolution.matches(matches) + } + + override def resolve(selector: ClassSelector, context: SelectorResolver.Context): SelectorResolver.Resolution = { + val testClass = selector.getJavaClass + if (isSuitePredicate.test(testClass)) { + context + .addToParent(classDescriptorFunction(testClass)) + .map[Resolution]((td: TestDescriptor) => Resolution.`match`(Match.exact(td))) + .orElse(Resolution.unresolved()) + } else { + Resolution.unresolved() + } + } +} diff --git a/test-junit-engine/src/main/scala/zio/test/junit/ZIOTestDescriptor.scala b/test-junit-engine/src/main/scala/zio/test/junit/ZIOTestDescriptor.scala new file mode 100644 index 000000000000..4fd06ab334a4 --- /dev/null +++ b/test-junit-engine/src/main/scala/zio/test/junit/ZIOTestDescriptor.scala @@ -0,0 +1,52 @@ +/* + * Copyright 2024-2024 Vincent Raman and the ZIO Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package zio.test.junit + +import org.junit.platform.engine.{TestDescriptor, UniqueId} +import org.junit.platform.engine.support.descriptor.AbstractTestDescriptor +import zio.test.TestAnnotationMap + +/** + * Describes a JUnit 5 test descriptor for a single ZIO tests. + * + * @constructor + * Creates a new ZIOTestDescriptor instance. + * @param parent + * The parent descriptor of this test descriptor. + * @param uniqueId + * A unique identifier for this test descriptor. + * @param label + * The display name or label for this test descriptor. + * @param testClass + * The test class associated with this descriptor. + * @param annotations + * A map of annotations applied to this test descriptor. + */ +class ZIOTestDescriptor( + parent: TestDescriptor, + uniqueId: UniqueId, + label: String, + testClass: Class[_], + annotations: TestAnnotationMap +) extends AbstractTestDescriptor(uniqueId, label, ZIOTestSource(testClass, annotations)) { + setParent(parent) + override def getType: TestDescriptor.Type = TestDescriptor.Type.TEST +} + +object ZIOTestDescriptor { + val segmentType = "test" +} diff --git a/test-junit-engine/src/main/scala/zio/test/junit/ZIOTestEngine.scala b/test-junit-engine/src/main/scala/zio/test/junit/ZIOTestEngine.scala new file mode 100644 index 000000000000..3b0d3c95ec9f --- /dev/null +++ b/test-junit-engine/src/main/scala/zio/test/junit/ZIOTestEngine.scala @@ -0,0 +1,73 @@ +/* + * Copyright 2024-2024 Vincent Raman and the ZIO Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package zio.test.junit + +import org.junit.platform.engine.{ + EngineDiscoveryRequest, + ExecutionRequest, + TestDescriptor, + TestEngine, + TestExecutionResult, + UniqueId +} +import org.junit.platform.engine.support.descriptor.EngineDescriptor +import org.junit.platform.engine.support.discovery.EngineDiscoveryRequestResolver +import zio._ + +import scala.jdk.CollectionConverters._ + +/** + * A JUnit platform test engine implementation designed for running ZIO tests. + */ +class ZIOTestEngine extends TestEngine { + + override def getId: String = "zio" + + private lazy val discoverer = EngineDiscoveryRequestResolver + .builder[EngineDescriptor]() + .addSelectorResolver(new ZIOTestClassSelectorResolver) + .build() + + override def discover(discoveryRequest: EngineDiscoveryRequest, uniqueId: UniqueId): TestDescriptor = { + val engineDesc = new EngineDescriptor(uniqueId, "ZIO EngineDescriptor") + discoverer.resolve(discoveryRequest, engineDesc) + engineDesc + } + + override def execute(request: ExecutionRequest): Unit = + Unsafe.unsafe { implicit unsafe => + Runtime.default.unsafe.run { + val engineDesc = request.getRootTestDescriptor + val listener = request.getEngineExecutionListener + ZIO.logInfo("Start tests execution...") *> ZIO.foreachDiscard(engineDesc.getChildren.asScala) { + case clzDesc: ZIOTestClassDescriptor => + ZIO.logInfo(s"Start execution of test class ${clzDesc.className}...") *> ZIO.succeed( + listener.executionStarted(clzDesc) + ) *> new ZIOTestClassRunner(clzDesc).run(listener) *> ZIO.succeed( + listener.executionFinished(clzDesc, TestExecutionResult.successful()) + ) + case otherDesc => + // Do nothing for other descriptor, just log it. + ZIO.logWarning(s"Found test descriptor $otherDesc that is not supported, skipping.") + + } *> ZIO.succeed(listener.executionFinished(engineDesc, TestExecutionResult.successful())) *> ZIO.logInfo( + "Completed tests execution." + ) + } + .getOrThrowFiberFailure() + } +} diff --git a/test-junit-engine/src/main/scala/zio/test/junit/ZIOTestSource.scala b/test-junit-engine/src/main/scala/zio/test/junit/ZIOTestSource.scala new file mode 100644 index 000000000000..7b312b7654ae --- /dev/null +++ b/test-junit-engine/src/main/scala/zio/test/junit/ZIOTestSource.scala @@ -0,0 +1,42 @@ +/* + * Copyright 2024-2024 Vincent Raman and the ZIO Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package zio.test.junit + +import zio.test._ + +import java.io.File +import org.junit.platform.engine.TestSource +import org.junit.platform.engine.support.descriptor.{ClassSource, FilePosition, FileSource} + +/** + * ZIOTestSource is an object responsible for creating instances of JUnit + * TestSource based on provided test class and optional annotations. + */ +object ZIOTestSource { + def apply(testClass: Class[_], annotations: Option[TestAnnotationMap]): TestSource = + annotations + .map(_.get(TestAnnotation.trace)) + .collect { case location :: _ => + FileSource.from(new File(location.path), FilePosition.from(location.line)) + } + .getOrElse(ClassSource.from(testClass)) + + def apply(testClass: Class[_]): TestSource = ZIOTestSource(testClass, None) + + def apply(testClass: Class[_], annotations: TestAnnotationMap): TestSource = + ZIOTestSource(testClass, Some(annotations)) +} diff --git a/test-tests/shared/src/test/scala/zio/test/CheckSpec.scala b/test-tests/shared/src/test/scala/zio/test/CheckSpec.scala index 7ad0ed83bec8..edb0080a862a 100644 --- a/test-tests/shared/src/test/scala/zio/test/CheckSpec.scala +++ b/test-tests/shared/src/test/scala/zio/test/CheckSpec.scala @@ -1,7 +1,7 @@ package zio.test import zio.test.Assertion._ -import zio.test.TestAspect.failing +import zio.test.TestAspect._ import zio.{Chunk, Random, Ref, ZIO} object CheckSpec extends ZIOBaseSpec { @@ -88,6 +88,9 @@ object CheckSpec extends ZIOBaseSpec { checkAllPar(Gen.int, 2) { _ => assertZIO(ZIO.unit)(equalTo(())) } - } + }, + test("i9303") { + checkAllPar(Gen.int, 4)(_ => assertCompletes) + } @@ jvm(nonFlaky(1000)) ) } diff --git a/test-tests/shared/src/test/scala/zio/test/TestArrowSpec.scala b/test-tests/shared/src/test/scala/zio/test/TestArrowSpec.scala index f1e7eacdbf02..dd0ec89649f3 100644 --- a/test-tests/shared/src/test/scala/zio/test/TestArrowSpec.scala +++ b/test-tests/shared/src/test/scala/zio/test/TestArrowSpec.scala @@ -133,6 +133,32 @@ object TestArrowSpec extends ZIOBaseSpec { meta.meta(genFailureDetails = genFailureDetails2).asInstanceOf[Meta[Any, Nothing]].genFailureDetails assertTrue(res1.map(_.iterations == 1).getOrElse(false) && res2.map(_.iterations == 2).getOrElse(false)) } + ), + suite("Span substring bounds")( + test("correctly handles valid span bounds within string length") { + val span = Span(0, 3) + assertTrue(span.substring("foo bar baz") == "foo") + }, + test("clamps start when it is out of bounds") { + val span = Span(-3, 3) + assertTrue(span.substring("foo bar baz") == "foo") + }, + test("clamps end when it exceeds string length") { + val span = Span(0, 10) + assertTrue(span.substring("foo") == "foo") + }, + test("clamps both start and end when both are out of bounds") { + val span = Span(-5, 10) + assertTrue(span.substring("baz") == "baz") + }, + test("returns empty string when start equals end") { + val span = Span(3, 3) + assertTrue(span.substring("foo bar baz") == "") + }, + test("returns empty string when start is greater than end") { + val span = Span(4, 2) + assertTrue(span.substring("foo bar baz") == "") + } ) ) diff --git a/test/shared/src/main/scala/zio/test/TestArrow.scala b/test/shared/src/main/scala/zio/test/TestArrow.scala index 5c30f03806ee..f7f8359d750e 100644 --- a/test/shared/src/main/scala/zio/test/TestArrow.scala +++ b/test/shared/src/main/scala/zio/test/TestArrow.scala @@ -288,7 +288,11 @@ object TestArrow { } case class Span(start: Int, end: Int) { - def substring(str: String): String = str.substring(start, end) + def substring(str: String): String = { + val safeStart = math.max(0, math.min(start, str.length)) + val safeEnd = math.max(safeStart, math.min(end, str.length)) + str.substring(safeStart, safeEnd) + } } sealed case class Meta[-A, +B]( diff --git a/website/package.json b/website/package.json index e0ec129be30f..1bd109331d20 100644 --- a/website/package.json +++ b/website/package.json @@ -49,10 +49,10 @@ "@zio.dev/zio-interop-twitter": "2022.11.21-9f1a594d033d", "@zio.dev/zio-jdbc": "0.1.2", "@zio.dev/zio-json": "0.7.3", - "@zio.dev/zio-kafka": "2.8.3", + "@zio.dev/zio-kafka": "2.9.0", "@zio.dev/zio-keeper": "0.0.0--215-0a9dd0ea-SNAPSHOT", "@zio.dev/zio-lambda": "1.0.5", - "@zio.dev/zio-logging": "2.3.2", + "@zio.dev/zio-logging": "2.4.0", "@zio.dev/zio-memberlist": "0.0.0--40-a85dc5a1--20221121-1416-SNAPSHOT", "@zio.dev/zio-meta": "0.0.0--21-54bc2e8b-SNAPSHOT", "@zio.dev/zio-metrics-connectors": "2.3.1", @@ -60,7 +60,7 @@ "@zio.dev/zio-nio": "2.0.0", "@zio.dev/zio-optics": "0.2.2", "@zio.dev/zio-parser": "0.1.9", - "@zio.dev/zio-prelude": "1.0.0-RC31", + "@zio.dev/zio-prelude": "1.0.0-RC34", "@zio.dev/zio-process": "0.7.2", "@zio.dev/zio-profiling": "0.3.2", "@zio.dev/zio-query": "0.7.6", @@ -82,7 +82,7 @@ "blended-include-code-plugin": "0.1.2", "clsx": "2.0.0", "highlight.js": "11.8.0", - "postcss": "8.4.47", + "postcss": "8.4.49", "prism-react-renderer": "1.3.5", "prism-themes": "1.9.0", "prismjs": "^1.29.0", @@ -91,7 +91,7 @@ "react-icons": "5.3.0", "react-markdown": "9.0.1", "remark-kroki-plugin": "0.1.1", - "tailwindcss": "3.4.13", + "tailwindcss": "3.4.15", "tslib": "^2.4.0" }, "resolutions": {