diff --git a/benchmark/src/main/scala/benchmark/wrapper/ldbc/Insert.scala b/benchmark/src/main/scala/benchmark/wrapper/ldbc/Insert.scala index 14f4357ac..66a89f590 100644 --- a/benchmark/src/main/scala/benchmark/wrapper/ldbc/Insert.scala +++ b/benchmark/src/main/scala/benchmark/wrapper/ldbc/Insert.scala @@ -36,7 +36,7 @@ class Insert: var query: Table[Test] = uninitialized @volatile - var queryRecords: List[(Int, String)] = List.empty + var records: NonEmptyList[(Int, String)] = uninitialized @volatile var dslRecords: SQL = uninitialized @@ -54,14 +54,11 @@ class Insert: connection = Resource.make(datasource.getConnection)(_.close()) - queryRecords = (1 to len).map(num => (num, s"record$num")).toList - dslRecords = comma( - NonEmptyList.fromListUnsafe((1 to len).map(num => parentheses(p"$num, ${ "record" + num }")).toList) - ) + records = NonEmptyList.fromListUnsafe((1 to len).map(num => (num, s"record$num")).toList) query = Table[Test]("ldbc_wrapper_query_test") - @Param(Array("10", "100", "1000", "2000")) + @Param(Array("10")) var len: Int = uninitialized @Benchmark @@ -70,7 +67,7 @@ class Insert: .use { conn => query .insertInto(test => (test.c1, test.c2)) - .values(queryRecords) + .values(records.toList) .update .commit(conn) } @@ -80,7 +77,7 @@ class Insert: def dslInsertN: Unit = connection .use { conn => - (sql"INSERT INTO `ldbc_wrapper_dsl_test` (`c1`, `c2`) VALUES" ++ dslRecords).update + (sql"INSERT INTO `ldbc_wrapper_dsl_test` (`c1`, `c2`) " ++ values(records)).update .commit(conn) } .unsafeRunSync() diff --git a/module/ldbc-dsl/src/main/scala/ldbc/dsl/package.scala b/module/ldbc-dsl/src/main/scala/ldbc/dsl/package.scala index 4cd6bc0f5..aa8bd10af 100644 --- a/module/ldbc-dsl/src/main/scala/ldbc/dsl/package.scala +++ b/module/ldbc-dsl/src/main/scala/ldbc/dsl/package.scala @@ -6,6 +6,8 @@ package ldbc +import scala.deriving.Mirror + import cats.{ Foldable, Functor, Reducible } import cats.data.NonEmptyList import cats.syntax.all.* @@ -16,7 +18,7 @@ import ldbc.dsl.syntax.* package object dsl: - private[ldbc] trait SyncSyntax[F[_]: Sync] extends StringContextSyntax[F]: + private[ldbc] trait SyncSyntax[F[_]: Temporal] extends StringContextSyntax[F]: /** * Function for setting parameters to be used as static strings. @@ -35,9 +37,20 @@ package object dsl: def values[M[_]: Reducible, T](vs: M[T])(using Parameter[T]): Mysql[F] = sql"VALUES" ++ comma(vs.toNonEmptyList.map(v => parentheses(p"$v"))) - /** Returns `VALUES (s0, s1, ...)`. */ - def values[M[_]: Reducible](vs: M[SQL]): Mysql[F] = - sql"VALUES" ++ parentheses(comma(vs.toNonEmptyList)) + /** Returns `VALUES (v0, v1), (v2, v3), ...`. */ + inline def values[M[_]: Reducible, T <: Product](vs: M[T])(using Mirror.ProductOf[T]): Mysql[F] = + sql"VALUES" ++ comma(vs.toNonEmptyList.map(v => parentheses(values(v)))) + + private inline def values[T <: Product](v: T)(using mirror: Mirror.ProductOf[T]): Mysql[F] = + val tuple = Parameter.fold[mirror.MirroredElemTypes] + val params = tuple.toList + Mysql[F]( + List.fill(params.size)("?").mkString(","), + (Tuple.fromProduct(v).toList zip params).map { + case (value, param) => + Parameter.DynamicBinder(value.asInstanceOf[Any])(using param.asInstanceOf[Parameter[Any]]) + } + ) /** Returns `(sql IN (v0, v1, ...))`. */ def in[T](s: SQL, v0: T, v1: T, vs: T*)(using Parameter[T]): Mysql[F] = diff --git a/module/ldbc-dsl/src/test/scala/ldbc/dsl/HelperFunctionTest.scala b/module/ldbc-dsl/src/test/scala/ldbc/dsl/HelperFunctionTest.scala index 2ad6045d7..d92777ec7 100644 --- a/module/ldbc-dsl/src/test/scala/ldbc/dsl/HelperFunctionTest.scala +++ b/module/ldbc-dsl/src/test/scala/ldbc/dsl/HelperFunctionTest.scala @@ -17,15 +17,22 @@ class HelperFunctionTest extends munit.CatsEffectSuite: test( "The statement that constructs VALUES with multiple values of the same type will be the same as the string specified." ) { - val sql = sql"INSERT INTO `table` (`column1`, `column2`) " ++ values(NonEmptyList.of(1, 2)) - assertEquals(sql.statement, "INSERT INTO `table` (`column1`, `column2`) VALUES(?),(?)") + val sql = sql"INSERT INTO `table` (`column1`, `column2`) " ++ values(NonEmptyList.of((1, 2), (3, 4), (5, 6))) + assertEquals(sql.statement, "INSERT INTO `table` (`column1`, `column2`) VALUES(?,?),(?,?),(?,?)") + } + + test( + "Statements that comprise VALUES with a single value of the same type will be the same as the specified string." + ) { + val sql = sql"INSERT INTO `table` (`column1`) " ++ values(NonEmptyList.of(1, 2, 3, 4)) + assertEquals(sql.statement, "INSERT INTO `table` (`column1`) VALUES(?),(?),(?),(?)") } test("A statement that constructs VALUES in multiple sql is the same as the specified string.") { case class Value(c1: Int, c2: String) - val vs: NonEmptyList[Value] = NonEmptyList.of(Value(1, "value1"), Value(2, "value2")) - val sql = - sql"INSERT INTO `table` (`column1`, `column2`) VALUES" ++ comma(vs.map(v => parentheses(p"${ v.c1 },${ v.c2 }"))) + val sql = sql"INSERT INTO `table` (`column1`, `column2`) " ++ values( + NonEmptyList.of(Value(1, "value1"), Value(2, "value2")) + ) assertEquals(sql.statement, "INSERT INTO `table` (`column1`, `column2`) VALUES(?,?),(?,?)") }