From 74df895e1bc0bcf92c383cd1a734f4569ea3f272 Mon Sep 17 00:00:00 2001 From: adamw Date: Thu, 7 Mar 2024 11:30:42 +0100 Subject: [PATCH] Suspend the evaluation of the BackendStub.send() effect in the target monad --- .../client4/testing/AbstractBackendStub.scala | 3 +- .../client4/testing/BackendStubTests.scala | 38 ++++++++++++++++++- 2 files changed, 39 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/sttp/client4/testing/AbstractBackendStub.scala b/core/src/main/scala/sttp/client4/testing/AbstractBackendStub.scala index 7e11ebd24e..2ddbedcd17 100644 --- a/core/src/main/scala/sttp/client4/testing/AbstractBackendStub.scala +++ b/core/src/main/scala/sttp/client4/testing/AbstractBackendStub.scala @@ -48,7 +48,7 @@ abstract class AbstractBackendStub[F[_], P]( withMatchers(matchers.orElse(wrappedPartial)) } - override def send[T](request: GenericRequest[T, P with Effect[F]]): F[Response[T]] = + override def send[T](request: GenericRequest[T, P with Effect[F]]): F[Response[T]] = monad.suspend { Try(matchers.lift(request)) match { case Success(Some(response)) => adjustExceptions(request)(tryAdjustResponseType(request.response, response.asInstanceOf[F[Response[T]]])(monad)) @@ -59,6 +59,7 @@ abstract class AbstractBackendStub[F[_], P]( } case Failure(e) => adjustExceptions(request)(monad.error(e)) } + } private def adjustExceptions[T](request: GenericRequest[_, _])(t: => F[T]): F[T] = SttpClientException.adjustExceptions(monad)(t)( diff --git a/core/src/test/scala/sttp/client4/testing/BackendStubTests.scala b/core/src/test/scala/sttp/client4/testing/BackendStubTests.scala index 34fb66788a..cbf7cba692 100644 --- a/core/src/test/scala/sttp/client4/testing/BackendStubTests.scala +++ b/core/src/test/scala/sttp/client4/testing/BackendStubTests.scala @@ -9,12 +9,13 @@ import sttp.client4.internal._ import sttp.client4.monad.IdMonad import sttp.client4.ws.async._ import sttp.model._ -import sttp.monad.{FutureMonad, TryMonad} +import sttp.monad.{FutureMonad, MonadError, TryMonad} import sttp.ws.WebSocketFrame import sttp.ws.testing.WebSocketStub import java.io.ByteArrayInputStream import java.util.concurrent.TimeoutException +import java.util.concurrent.atomic.AtomicInteger import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.duration._ import scala.util.{Failure, Success, Try} @@ -359,6 +360,41 @@ class BackendStubTests extends AnyFlatSpec with Matchers with ScalaFutures { result shouldBe Success(Right(1: Byte)) } + it should "evaluate side effects on each request" in { + // given + type Lazy[T] = () => T + object LazyMonad extends MonadError[Lazy] { + override def unit[T](t: T): Lazy[T] = () => t + override def map[T, T2](fa: Lazy[T])(f: T => T2): Lazy[T2] = () => f(fa()) + override def flatMap[T, T2](fa: Lazy[T])(f: T => Lazy[T2]): Lazy[T2] = () => f(fa())() + override def error[T](t: Throwable): Lazy[T] = () => throw t + override protected def handleWrappedError[T](rt: Lazy[T])(h: PartialFunction[Throwable, Lazy[T]]): Lazy[T] = + () => + try rt() + catch { case e if h.isDefinedAt(e) => h(e)() } + override def ensure[T](f: Lazy[T], e: => Lazy[Unit]): Lazy[T] = () => + try f() + finally e() + } + + val counter = new AtomicInteger(0) + val backend: Backend[Lazy] = BackendStub(LazyMonad).whenRequestMatchesPartial { case _ => + counter.getAndIncrement() + Response.ok("ok") + } + + // creating the "send effect" once ... + val result = basicRequest.get(uri"http://example.org").send(backend) + + // when + // ... and then using it twice + result().body shouldBe Right("ok") + result().body shouldBe Right("ok") + + // then + counter.get() shouldBe 2 + } + private val testingStubWithFallback = SyncBackendStub .withFallback(testingStub) .whenRequestMatches(_.uri.path.startsWith(List("c")))