Skip to content

Commit

Permalink
Merge pull request #2023 from tpolecat/better_column_typecheck
Browse files Browse the repository at this point in the history
More precise typechecking using vendor types
  • Loading branch information
jatcwang authored Apr 21, 2024
2 parents ba69656 + 86c28d6 commit a785dc0
Show file tree
Hide file tree
Showing 27 changed files with 447 additions and 185 deletions.
24 changes: 21 additions & 3 deletions docker-compose.yml
Original file line number Diff line number Diff line change
@@ -1,14 +1,32 @@
version: '3.1'

services:

postgres:
image: postgis/postgis:11-3.3
image: postgis/postgis:16-3.4
environment:
POSTGRES_USER: postgres
POSTGRES_PASSWORD: password
POSTGRES_DB: world
ports:
- 5432:5432
volumes:
- ./init/:/docker-entrypoint-initdb.d/
- ./init/postgres/:/docker-entrypoint-initdb.d/
deploy:
resources:
limits:
memory: 500M


mysql:
image: mysql:8.0-debian
environment:
MYSQL_ROOT_PASSWORD: password
MYSQL_DATABASE: world
ports:
- 3306:3306
volumes:
- ./init/mysql/:/docker-entrypoint-initdb.d/
deploy:
resources:
limits:
memory: 500M
11 changes: 11 additions & 0 deletions init/mysql/test-table.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@

CREATE TABLE IF NOT EXISTS test (
c_integer INTEGER NOT NULL,
c_varchar VARCHAR(1024) NOT NULL,
c_date DATE NOT NULL,
c_datetime DATETIME(6) NOT NULL,
c_time TIME(6) NOT NULL,
c_timestamp TIMESTAMP(6) NOT NULL
);
INSERT INTO test(c_integer, c_varchar, c_date, c_datetime, c_time, c_timestamp)
VALUES (123, 'str', '2019-02-13', '2019-02-13 22:03:21.051', '22:03:21.051', '2019-02-13 22:03:21.051');
File renamed without changes.
51 changes: 29 additions & 22 deletions modules/core/src/main/scala/doobie/hi/connection.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,28 @@

package doobie.hi

import doobie.util.compat.propertiesToScala
import cats.Foldable
import cats.data.Ior
import cats.effect.kernel.syntax.monadCancel._
import cats.syntax.all._
import doobie.enumerated.AutoGeneratedKeys
import doobie.enumerated.Holdability
import doobie.enumerated.ResultSetType
import doobie.enumerated.Nullability
import doobie.enumerated.ResultSetConcurrency
import doobie.enumerated.ResultSetType
import doobie.enumerated.TransactionIsolation
import doobie.enumerated.AutoGeneratedKeys
import doobie.util.{ Read, Write }
import doobie.util.analysis.Analysis
import doobie.util.analysis.ColumnMeta
import doobie.util.analysis.ParameterMeta
import doobie.util.compat.propertiesToScala
import doobie.util.stream.repeatEvalChunks
import doobie.util.{ Get, Put, Read, Write }
import fs2.Stream
import fs2.Stream.{ eval, bracket }

import java.sql.{ Savepoint, PreparedStatement, ResultSet }

import scala.collection.immutable.Map

import cats.Foldable
import cats.syntax.all._
import cats.effect.kernel.syntax.monadCancel._
import fs2.Stream
import fs2.Stream.{ eval, bracket }

/**
* Module of high-level constructors for `ConnectionIO` actions.
* @group Modules
Expand Down Expand Up @@ -92,24 +94,29 @@ object connection {
* readable resultset row type `B`.
*/
def prepareQueryAnalysis[A: Write, B: Read](sql: String): ConnectionIO[Analysis] =
prepareStatement(sql) {
(HPS.getParameterMappings[A], HPS.getColumnMappings[B]) mapN (Analysis(sql, _, _))
}
prepareAnalysis(sql, HPS.getParameterMappings[A], HPS.getColumnMappings[B])

def prepareQueryAnalysis0[B: Read](sql: String): ConnectionIO[Analysis] =
prepareStatement(sql) {
HPS.getColumnMappings[B] map (cm => Analysis(sql, Nil, cm))
}
prepareAnalysis(sql, FPS.pure(Nil), HPS.getColumnMappings[B])

def prepareUpdateAnalysis[A: Write](sql: String): ConnectionIO[Analysis] =
prepareStatement(sql) {
HPS.getParameterMappings[A] map (pm => Analysis(sql, pm, Nil))
}
prepareAnalysis(sql, HPS.getParameterMappings[A], FPS.pure(Nil))

def prepareUpdateAnalysis0(sql: String): ConnectionIO[Analysis] =
prepareStatement(sql) {
Analysis(sql, Nil, Nil).pure[PreparedStatementIO]
prepareAnalysis(sql, FPS.pure(Nil), FPS.pure(Nil))

private def prepareAnalysis(
sql: String,
params: PreparedStatementIO[List[(Put[_], Nullability.NullabilityKnown) Ior ParameterMeta]],
columns: PreparedStatementIO[List[(Get[_], Nullability.NullabilityKnown) Ior ColumnMeta]],
) = {
val mappings = prepareStatement(sql) {
(params, columns).tupled
}
(HC.getMetaData(FDMD.getDriverName), mappings).mapN { case (driver, (p, c)) =>
Analysis(driver, sql, p, c)
}
}


/** @group Statements */
Expand Down
16 changes: 16 additions & 0 deletions modules/core/src/main/scala/doobie/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
// This software is licensed under the MIT License (MIT).
// For more information see LICENSE or https://opensource.org/licenses/MIT

import doobie.util.meta.{LegacyMeta, TimeMetaInstances}
// Copyright (c) 2013-2020 Rob Norris and Contributors
// This software is licensed under the MIT License (MIT).
// For more information see LICENSE or https://opensource.org/licenses/MIT

/**
* Top-level import, providing aliases for the most commonly used types and modules from
* doobie-free and doobie-core. A typical starting set of imports would be something like this.
Expand All @@ -21,13 +26,24 @@ package object doobie
object implicits
extends free.Instances
with generic.AutoDerivation
with LegacyMeta
with syntax.AllSyntax {

// re-export these instances so `Meta` takes priority, must be in the object
implicit def metaProjectionGet[A](implicit m: Meta[A]): Get[A] = Get.metaProjection
implicit def metaProjectionPut[A](implicit m: Meta[A]): Put[A] = Put.metaProjectionWrite
implicit def fromGetRead[A](implicit G: Get[A]): Read[A] = Read.fromGet
implicit def fromPutWrite[A](implicit P: Put[A]): Write[A] = Write.fromPut

/**
* Only use this import if:
* 1. You're NOT using one of the database doobie has direct java.time isntances for
* (PostgreSQL / MySQL). (They have more accurate column type checks)
* 2. Your driver natively supports java.time.* types
*
* If your driver doesn't support java.time.* types, use [[doobie.implicits.legacy.instant/localdate]] instead
*/
object javatimedrivernative extends TimeMetaInstances
}

}
53 changes: 22 additions & 31 deletions modules/core/src/main/scala/doobie/util/analysis.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,28 +20,9 @@ object analysis {

/** Metadata for the JDBC end of a column/parameter mapping. */
final case class ColumnMeta(jdbcType: JdbcType, vendorTypeName: String, nullability: Nullability, name: String)
object ColumnMeta {
def apply(jdbcType: JdbcType, vendorTypeName: String, nullability: Nullability, name: String): ColumnMeta = {
new ColumnMeta(tweakJdbcType(jdbcType, vendorTypeName), vendorTypeName, nullability, name)
}
}

/** Metadata for the JDBC end of a column/parameter mapping. */
final case class ParameterMeta(jdbcType: JdbcType, vendorTypeName: String, nullability: Nullability, mode: ParameterMode)
object ParameterMeta {
def apply(jdbcType: JdbcType, vendorTypeName: String, nullability: Nullability, mode: ParameterMode): ParameterMeta = {
new ParameterMeta(tweakJdbcType(jdbcType, vendorTypeName), vendorTypeName, nullability, mode)
}
}

private def tweakJdbcType(jdbcType: JdbcType, vendorTypeName: String) = jdbcType match {
// the Postgres driver does not return *WithTimezone types but they are pretty much required for proper analysis
// https://github.com/pgjdbc/pgjdbc/issues/2485
// https://github.com/pgjdbc/pgjdbc/issues/1766
case JdbcType.Time if vendorTypeName.compareToIgnoreCase("timetz") == 0 => JdbcType.TimeWithTimezone
case JdbcType.Timestamp if vendorTypeName.compareToIgnoreCase("timestamptz") == 0 => JdbcType.TimestampWithTimezone
case t => t
}

sealed trait AlignmentError extends Product with Serializable {
def tag: String
Expand Down Expand Up @@ -100,10 +81,10 @@ object analysis {
override val tag = "C"
override def msg =
s"""|${schema.jdbcType.show.toUpperCase} (${schema.vendorTypeName}) is not
|coercible to ${typeName(get.typeStack.last, n)} according to the JDBC specification or any defined
|coercible to ${typeName(get.typeStack.last, n)} (${get.vendorTypeNames.mkString(",")}) according to the JDBC specification or any defined
|mapping.
|Fix this by changing the schema type to
|${get.jdbcSources.toList.map(_.show.toUpperCase).toList.mkString(" or ") }; or the
|${get.jdbcSources.toList.map(_.show.toUpperCase).mkString(" or ") }; or the
|Scala type to an appropriate ${if (schema.jdbcType === JdbcType.Array) "array" else "object"}
|type.
|""".stripMargin.linesIterator.mkString(" ")
Expand All @@ -122,34 +103,46 @@ object analysis {

/** Compatibility analysis for the given statement and aligned mappings. */
final case class Analysis(
driver: String,
sql: String,
parameterAlignment: List[(Put[_], NullabilityKnown) Ior ParameterMeta],
columnAlignment: List[(Get[_], NullabilityKnown) Ior ColumnMeta]) {
columnAlignment: List[(Get[_], NullabilityKnown) Ior ColumnMeta]
) {

def parameterMisalignments: List[ParameterMisalignment] =
parameterAlignment.zipWithIndex.collect {
case (Ior.Left(_), n) => ParameterMisalignment(n + 1, None)
case (Ior.Right(p), n) => ParameterMisalignment(n + 1, Some(p))
}

private def hasParameterTypeErrors[A](put: Put[A], paramMeta: ParameterMeta): Boolean = {
val jdbcTypeMatches = put.jdbcTargets.contains_(paramMeta.jdbcType)
val vendorTypeMatches = put.vendorTypeNames.isEmpty || put.vendorTypeNames.contains_(paramMeta.vendorTypeName)

!jdbcTypeMatches || !vendorTypeMatches
}

def parameterTypeErrors: List[ParameterTypeError] =
parameterAlignment.zipWithIndex.collect {
case (Ior.Both((j, n1), p), n) if !j.jdbcTargets.contains_(p.jdbcType) =>
ParameterTypeError(n + 1, j, n1, p.jdbcType, p.vendorTypeName)
case (Ior.Both((put, n1), paramMeta), n) if hasParameterTypeErrors(put, paramMeta)=>
ParameterTypeError(n + 1, put, n1, paramMeta.jdbcType, paramMeta.vendorTypeName)
}

def columnMisalignments: List[ColumnMisalignment] =
columnAlignment.zipWithIndex.collect {
case (Ior.Left(j), n) => ColumnMisalignment(n + 1, Left(j))
case (Ior.Right(p), n) => ColumnMisalignment(n + 1, Right(p))
}


private def hasColumnTypeError[A](get: Get[A], columnMeta: ColumnMeta): Boolean = {
val jdbcTypeMatches = (get.jdbcSources.toList ++ get.jdbcSourceSecondary).contains_(columnMeta.jdbcType)
val vendorTypeMatches = get.vendorTypeNames.isEmpty || get.vendorTypeNames.contains_(columnMeta.vendorTypeName)
!jdbcTypeMatches || !vendorTypeMatches
}
def columnTypeErrors: List[ColumnTypeError] =
columnAlignment.zipWithIndex.collect {
case (Ior.Both((j, n1), p), n) if !(j.jdbcSources.toList ++ j.jdbcSourceSecondary).contains_(p.jdbcType) =>
ColumnTypeError(n + 1, j, n1, p)
case (Ior.Both((j, n1), p), n) if (p.jdbcType === JdbcType.JavaObject || p.jdbcType === JdbcType.Other) && !j.schemaTypes.headOption.contains_(p.vendorTypeName) =>
ColumnTypeError(n + 1, j, n1, p)
case (Ior.Both((get, n1), p), n) if hasColumnTypeError(get, p) =>
ColumnTypeError(n + 1, get, n1, p)
}

def columnTypeWarnings: List[ColumnTypeWarning] =
Expand Down Expand Up @@ -224,6 +217,4 @@ object analysis {
case Nullable => "NULL"
case NullableUnknown => "NULL?"
}


}
Loading

0 comments on commit a785dc0

Please sign in to comment.