From 634161fc36bf8bfa011f9e4ed380433628a0e9b2 Mon Sep 17 00:00:00 2001 From: Jacob Wang Date: Sun, 1 Sep 2024 10:07:19 +0100 Subject: [PATCH] Query cancellation for streaming queries `Stream.bracket` does not allow cancellation in the `acquire` step which is why we're using `bracketFull` instead --- .../src/main/scala/doobie/hi/connection.scala | 8 +++-- .../doobie/HikariQueryCancellationSuite.scala | 33 ++++++++++++++++--- 2 files changed, 33 insertions(+), 8 deletions(-) diff --git a/modules/core/src/main/scala/doobie/hi/connection.scala b/modules/core/src/main/scala/doobie/hi/connection.scala index 885eabddf..7eac3b276 100644 --- a/modules/core/src/main/scala/doobie/hi/connection.scala +++ b/modules/core/src/main/scala/doobie/hi/connection.scala @@ -218,9 +218,11 @@ object connection { for { ps <- Stream.bracket(runPreExecWithLogging(create, loggingInfo))(IFC.embed(_, IFPS.close)) _ <- Stream.eval(runPreExecWithLogging(IFC.embed(ps, IFPS.setFetchSize(chunkSize) *> prep), loggingInfo)) - resultSet <- Stream.bracket( - IFC.embed(ps, execLogged) - )(rs => IFC.embed(rs, IFRS.close)) + resultSet <- Stream.bracketFull[ConnectionIO, ResultSet](poll => + poll(WeakAsyncConnectionIO.cancelable( + IFC.embed(ps, execLogged), + IFC.embed(ps, IFPS.close) + )))((rs, _) => IFC.embed(rs, IFRS.close)) ele <- repeatEvalChunks(IFC.embed(resultSet, resultset.getNextChunk[A](chunkSize))) } yield ele } diff --git a/modules/hikari/src/test/scala/doobie/HikariQueryCancellationSuite.scala b/modules/hikari/src/test/scala/doobie/HikariQueryCancellationSuite.scala index 7c2a6af7c..5ab66f03e 100644 --- a/modules/hikari/src/test/scala/doobie/HikariQueryCancellationSuite.scala +++ b/modules/hikari/src/test/scala/doobie/HikariQueryCancellationSuite.scala @@ -10,6 +10,7 @@ import com.zaxxer.hikari.HikariConfig import doobie.hikari.HikariTransactor import doobie.implicits.* import doobie.util.transactor +import fs2.Stream import scala.concurrent.duration.DurationInt @@ -33,10 +34,10 @@ class HikariQueryCancellationSuite extends munit.FunSuite { test("Query cancel with Hikari") { val insert = for { - _ <- sql"CREATE TABLE if not exists blah (i text)".update.run - _ <- sql"truncate table blah".update.run - _ <- sql"INSERT INTO blah values ('1')".update.run - _ <- sql"INSERT INTO blah select concat(2, pg_sleep(1))".update.run + _ <- sql"CREATE TABLE if not exists query_cancel_test (i text)".update.run + _ <- sql"truncate table query_cancel_test".update.run + _ <- sql"INSERT INTO query_cancel_test values ('1')".update.run + _ <- sql"INSERT INTO query_cancel_test select concat(2, pg_sleep(1))".update.run } yield () val scenario = transactorRes.use { xa => for { @@ -44,7 +45,7 @@ class HikariQueryCancellationSuite extends munit.FunSuite { _ <- IO.sleep(200.millis) *> fiber.cancel _ <- IO.sleep(3.second) _ <- fiber.join.attempt - result <- sql"select * from blah order by i".query[String].to[List].transact(xa) + result <- sql"select * from query_cancel_test order by i".query[String].to[List].transact(xa) } yield { assertEquals(result, List("1")) } @@ -53,4 +54,26 @@ class HikariQueryCancellationSuite extends munit.FunSuite { scenario.unsafeRunSync() } + test("Stream query cancel with Hikari") { + val insert = for { + _ <- Stream.eval(sql"CREATE TABLE if not exists stream_cancel_test (i text)".update.run) + _ <- Stream.eval(sql"truncate table stream_cancel_test".update.run) + _ <- Stream.eval(sql"INSERT INTO stream_cancel_test values ('1')".update.run) + _ <- sql"INSERT INTO stream_cancel_test select concat(2, pg_sleep(1))".update.withGeneratedKeys[Int]("i") + } yield () + + val scenario = transactorRes.use { xa => + for { + fiber <- insert.transact(xa).compile.drain.start + _ <- IO.sleep(200.millis) *> fiber.cancel + _ <- IO.sleep(3.second) + _ <- fiber.join.attempt + result <- sql"select * from stream_cancel_test order by i".query[String].to[List].transact(xa) + } yield { + assertEquals(result, List("1")) + } + } + + scenario.unsafeRunSync() + } }