diff --git a/.github/workflows/ci-kotlin.yml b/.github/workflows/ci-kotlin.yml new file mode 100644 index 0000000000..f864c2fc7e --- /dev/null +++ b/.github/workflows/ci-kotlin.yml @@ -0,0 +1,36 @@ +name: sqlc kotlin test suite +on: [push, pull_request] +jobs: + + build: + name: Build And Test + runs-on: ubuntu-latest + + services: + postgres: + image: postgres:11 + env: + POSTGRES_USER: postgres + POSTGRES_PASSWORD: postgres + POSTGRES_DB: postgres + ports: + - 5432:5432 + # needed because the postgres container does not provide a healthcheck + options: --health-cmd pg_isready --health-interval 10s --health-timeout 5s --health-retries 5 + + steps: + - uses: actions/checkout@master + - uses: actions/setup-java@v1 + with: + java-version: 11 + - uses: eskatos/gradle-command-action@v1 + env: + PG_USER: postgres + PG_HOST: localhost + PG_DATABASE: postgres + PG_PASSWORD: postgres + PG_PORT: ${{ job.services.postgres.ports['5432'] }} + with: + build-root-directory: examples/kotlin + wrapper-directory: examples/kotlin + arguments: test --scan diff --git a/examples/kotlin/.gitignore b/examples/kotlin/.gitignore new file mode 100644 index 0000000000..fbb16c8de7 --- /dev/null +++ b/examples/kotlin/.gitignore @@ -0,0 +1,4 @@ +/.gradle/ +/.idea/ +/build/ +/out/ diff --git a/examples/kotlin/README.md b/examples/kotlin/README.md new file mode 100644 index 0000000000..ed910aa966 --- /dev/null +++ b/examples/kotlin/README.md @@ -0,0 +1,17 @@ +# Kotlin examples + +This is a Kotlin gradle project configured to compile and test all examples. Currently tests have only been written for the `authors` example. + +To run tests: + +```shell script +./gradlew clean test +``` + +The project can be easily imported into Intellij. + +1. Install Java if you don't already have it +1. Download Intellij IDEA Community Edition +1. In the "Welcome" modal, click "Import Project" +1. Open the `build.gradle` file adjacent to this README file +1. Wait for Intellij to sync the gradle modules and complete indexing diff --git a/examples/kotlin/build.gradle b/examples/kotlin/build.gradle new file mode 100644 index 0000000000..fd331077ae --- /dev/null +++ b/examples/kotlin/build.gradle @@ -0,0 +1,33 @@ +plugins { + id 'org.jetbrains.kotlin.jvm' version '1.3.60' +} + +group 'com.example' +version '1.0-SNAPSHOT' + +repositories { + mavenCentral() +} + +dependencies { + implementation 'org.postgresql:postgresql:42.2.9' + implementation "org.jetbrains.kotlin:kotlin-stdlib-jdk8" + testImplementation 'org.junit.jupiter:junit-jupiter-api:5.3.1' + testRuntimeOnly 'org.junit.jupiter:junit-jupiter-engine:5.3.1' +} + +test { + useJUnitPlatform() +} + +compileKotlin { + kotlinOptions.jvmTarget = "1.8" +} +compileTestKotlin { + kotlinOptions.jvmTarget = "1.8" +} + +buildScan { + termsOfServiceUrl = "https://gradle.com/terms-of-service" + termsOfServiceAgree = "yes" +} diff --git a/examples/kotlin/gradle.properties b/examples/kotlin/gradle.properties new file mode 100644 index 0000000000..29e08e8ca8 --- /dev/null +++ b/examples/kotlin/gradle.properties @@ -0,0 +1 @@ +kotlin.code.style=official \ No newline at end of file diff --git a/examples/kotlin/gradle/wrapper/gradle-wrapper.jar b/examples/kotlin/gradle/wrapper/gradle-wrapper.jar new file mode 100644 index 0000000000..87b738cbd0 Binary files /dev/null and b/examples/kotlin/gradle/wrapper/gradle-wrapper.jar differ diff --git a/examples/kotlin/gradle/wrapper/gradle-wrapper.properties b/examples/kotlin/gradle/wrapper/gradle-wrapper.properties new file mode 100644 index 0000000000..b5354905d6 --- /dev/null +++ b/examples/kotlin/gradle/wrapper/gradle-wrapper.properties @@ -0,0 +1,6 @@ +#Sat Jan 25 10:45:34 EST 2020 +distributionUrl=https\://services.gradle.org/distributions/gradle-6.1.1-all.zip +distributionBase=GRADLE_USER_HOME +distributionPath=wrapper/dists +zipStorePath=wrapper/dists +zipStoreBase=GRADLE_USER_HOME diff --git a/examples/kotlin/gradlew b/examples/kotlin/gradlew new file mode 100755 index 0000000000..af6708ff22 --- /dev/null +++ b/examples/kotlin/gradlew @@ -0,0 +1,172 @@ +#!/usr/bin/env sh + +############################################################################## +## +## Gradle start up script for UN*X +## +############################################################################## + +# Attempt to set APP_HOME +# Resolve links: $0 may be a link +PRG="$0" +# Need this for relative symlinks. +while [ -h "$PRG" ] ; do + ls=`ls -ld "$PRG"` + link=`expr "$ls" : '.*-> \(.*\)$'` + if expr "$link" : '/.*' > /dev/null; then + PRG="$link" + else + PRG=`dirname "$PRG"`"/$link" + fi +done +SAVED="`pwd`" +cd "`dirname \"$PRG\"`/" >/dev/null +APP_HOME="`pwd -P`" +cd "$SAVED" >/dev/null + +APP_NAME="Gradle" +APP_BASE_NAME=`basename "$0"` + +# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +DEFAULT_JVM_OPTS='"-Xmx64m"' + +# Use the maximum available, or set MAX_FD != -1 to use that value. +MAX_FD="maximum" + +warn () { + echo "$*" +} + +die () { + echo + echo "$*" + echo + exit 1 +} + +# OS specific support (must be 'true' or 'false'). +cygwin=false +msys=false +darwin=false +nonstop=false +case "`uname`" in + CYGWIN* ) + cygwin=true + ;; + Darwin* ) + darwin=true + ;; + MINGW* ) + msys=true + ;; + NONSTOP* ) + nonstop=true + ;; +esac + +CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar + +# Determine the Java command to use to start the JVM. +if [ -n "$JAVA_HOME" ] ; then + if [ -x "$JAVA_HOME/jre/sh/java" ] ; then + # IBM's JDK on AIX uses strange locations for the executables + JAVACMD="$JAVA_HOME/jre/sh/java" + else + JAVACMD="$JAVA_HOME/bin/java" + fi + if [ ! -x "$JAVACMD" ] ; then + die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME + +Please set the JAVA_HOME variable in your environment to match the +location of your Java installation." + fi +else + JAVACMD="java" + which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. + +Please set the JAVA_HOME variable in your environment to match the +location of your Java installation." +fi + +# Increase the maximum file descriptors if we can. +if [ "$cygwin" = "false" -a "$darwin" = "false" -a "$nonstop" = "false" ] ; then + MAX_FD_LIMIT=`ulimit -H -n` + if [ $? -eq 0 ] ; then + if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then + MAX_FD="$MAX_FD_LIMIT" + fi + ulimit -n $MAX_FD + if [ $? -ne 0 ] ; then + warn "Could not set maximum file descriptor limit: $MAX_FD" + fi + else + warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT" + fi +fi + +# For Darwin, add options to specify how the application appears in the dock +if $darwin; then + GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\"" +fi + +# For Cygwin, switch paths to Windows format before running java +if $cygwin ; then + APP_HOME=`cygpath --path --mixed "$APP_HOME"` + CLASSPATH=`cygpath --path --mixed "$CLASSPATH"` + JAVACMD=`cygpath --unix "$JAVACMD"` + + # We build the pattern for arguments to be converted via cygpath + ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null` + SEP="" + for dir in $ROOTDIRSRAW ; do + ROOTDIRS="$ROOTDIRS$SEP$dir" + SEP="|" + done + OURCYGPATTERN="(^($ROOTDIRS))" + # Add a user-defined pattern to the cygpath arguments + if [ "$GRADLE_CYGPATTERN" != "" ] ; then + OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)" + fi + # Now convert the arguments - kludge to limit ourselves to /bin/sh + i=0 + for arg in "$@" ; do + CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -` + CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option + + if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition + eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"` + else + eval `echo args$i`="\"$arg\"" + fi + i=$((i+1)) + done + case $i in + (0) set -- ;; + (1) set -- "$args0" ;; + (2) set -- "$args0" "$args1" ;; + (3) set -- "$args0" "$args1" "$args2" ;; + (4) set -- "$args0" "$args1" "$args2" "$args3" ;; + (5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;; + (6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;; + (7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;; + (8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;; + (9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;; + esac +fi + +# Escape application args +save () { + for i do printf %s\\n "$i" | sed "s/'/'\\\\''/g;1s/^/'/;\$s/\$/' \\\\/" ; done + echo " " +} +APP_ARGS=$(save "$@") + +# Collect all arguments for the java command, following the shell quoting and substitution rules +eval set -- $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS "\"-Dorg.gradle.appname=$APP_BASE_NAME\"" -classpath "\"$CLASSPATH\"" org.gradle.wrapper.GradleWrapperMain "$APP_ARGS" + +# by default we should be in the correct project dir, but when run from Finder on Mac, the cwd is wrong +if [ "$(uname)" = "Darwin" ] && [ "$HOME" = "$PWD" ]; then + cd "$(dirname "$0")" +fi + +exec "$JAVACMD" "$@" diff --git a/examples/kotlin/gradlew.bat b/examples/kotlin/gradlew.bat new file mode 100644 index 0000000000..6d57edc706 --- /dev/null +++ b/examples/kotlin/gradlew.bat @@ -0,0 +1,84 @@ +@if "%DEBUG%" == "" @echo off +@rem ########################################################################## +@rem +@rem Gradle startup script for Windows +@rem +@rem ########################################################################## + +@rem Set local scope for the variables with windows NT shell +if "%OS%"=="Windows_NT" setlocal + +set DIRNAME=%~dp0 +if "%DIRNAME%" == "" set DIRNAME=. +set APP_BASE_NAME=%~n0 +set APP_HOME=%DIRNAME% + +@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +set DEFAULT_JVM_OPTS="-Xmx64m" + +@rem Find java.exe +if defined JAVA_HOME goto findJavaFromJavaHome + +set JAVA_EXE=java.exe +%JAVA_EXE% -version >NUL 2>&1 +if "%ERRORLEVEL%" == "0" goto init + +echo. +echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. +echo. +echo Please set the JAVA_HOME variable in your environment to match the +echo location of your Java installation. + +goto fail + +:findJavaFromJavaHome +set JAVA_HOME=%JAVA_HOME:"=% +set JAVA_EXE=%JAVA_HOME%/bin/java.exe + +if exist "%JAVA_EXE%" goto init + +echo. +echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% +echo. +echo Please set the JAVA_HOME variable in your environment to match the +echo location of your Java installation. + +goto fail + +:init +@rem Get command-line arguments, handling Windows variants + +if not "%OS%" == "Windows_NT" goto win9xME_args + +:win9xME_args +@rem Slurp the command line arguments. +set CMD_LINE_ARGS= +set _SKIP=2 + +:win9xME_args_slurp +if "x%~1" == "x" goto execute + +set CMD_LINE_ARGS=%* + +:execute +@rem Setup the command line + +set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar + +@rem Execute Gradle +"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %CMD_LINE_ARGS% + +:end +@rem End local scope for the variables with windows NT shell +if "%ERRORLEVEL%"=="0" goto mainEnd + +:fail +rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of +rem the _cmd.exe /c_ return code! +if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1 +exit /b 1 + +:mainEnd +if "%OS%"=="Windows_NT" endlocal + +:omega diff --git a/examples/kotlin/settings.gradle b/examples/kotlin/settings.gradle new file mode 100644 index 0000000000..5a094655c6 --- /dev/null +++ b/examples/kotlin/settings.gradle @@ -0,0 +1,5 @@ +plugins { + id("com.gradle.enterprise").version("3.1.1") +} + +rootProject.name = 'dbtest' diff --git a/examples/kotlin/src/main/kotlin/com/example/authors/Models.kt b/examples/kotlin/src/main/kotlin/com/example/authors/Models.kt new file mode 100644 index 0000000000..916a88d7c8 --- /dev/null +++ b/examples/kotlin/src/main/kotlin/com/example/authors/Models.kt @@ -0,0 +1,10 @@ +// Code generated by sqlc. DO NOT EDIT. + +package com.example.authors + +data class Author ( + val id: Long, + val name: String, + val bio: String? +) + diff --git a/examples/kotlin/src/main/kotlin/com/example/authors/QueriesImpl.kt b/examples/kotlin/src/main/kotlin/com/example/authors/QueriesImpl.kt new file mode 100644 index 0000000000..9ee42a6abf --- /dev/null +++ b/examples/kotlin/src/main/kotlin/com/example/authors/QueriesImpl.kt @@ -0,0 +1,104 @@ +// Code generated by sqlc. DO NOT EDIT. + +package com.example.authors + +import java.sql.Connection +import java.sql.SQLException + +const val createAuthor = """-- name: createAuthor :one +INSERT INTO authors ( + name, bio +) VALUES ( + ?, ? +) +RETURNING id, name, bio +""" + +const val deleteAuthor = """-- name: deleteAuthor :exec +DELETE FROM authors +WHERE id = ? +""" + +const val getAuthor = """-- name: getAuthor :one +SELECT id, name, bio FROM authors +WHERE id = ? LIMIT 1 +""" + +const val listAuthors = """-- name: listAuthors :many +SELECT id, name, bio FROM authors +ORDER BY name +""" + +class QueriesImpl(private val conn: Connection) { + + @Throws(SQLException::class) + fun createAuthor(name: String, bio: String?): Author { + return conn.prepareStatement(createAuthor).use { stmt -> + stmt.setString(1, name) + stmt.setString(2, bio) + + val results = stmt.executeQuery() + if (!results.next()) { + throw SQLException("no rows in result set") + } + val ret = Author( + results.getLong(1), + results.getString(2), + results.getString(3) + ) + if (results.next()) { + throw SQLException("expected one row in result set, but got many") + } + ret + } + } + + @Throws(SQLException::class) + fun deleteAuthor(id: Long) { + conn.prepareStatement(deleteAuthor).use { stmt -> + stmt.setLong(1, id) + + stmt.execute() + } + } + + @Throws(SQLException::class) + fun getAuthor(id: Long): Author { + return conn.prepareStatement(getAuthor).use { stmt -> + stmt.setLong(1, id) + + val results = stmt.executeQuery() + if (!results.next()) { + throw SQLException("no rows in result set") + } + val ret = Author( + results.getLong(1), + results.getString(2), + results.getString(3) + ) + if (results.next()) { + throw SQLException("expected one row in result set, but got many") + } + ret + } + } + + @Throws(SQLException::class) + fun listAuthors(): List { + return conn.prepareStatement(listAuthors).use { stmt -> + + val results = stmt.executeQuery() + val ret = mutableListOf() + while (results.next()) { + ret.add(Author( + results.getLong(1), + results.getString(2), + results.getString(3) + )) + } + ret + } + } + +} + diff --git a/examples/kotlin/src/main/kotlin/com/example/booktest/postgresql/Models.kt b/examples/kotlin/src/main/kotlin/com/example/booktest/postgresql/Models.kt new file mode 100644 index 0000000000..066918ee6b --- /dev/null +++ b/examples/kotlin/src/main/kotlin/com/example/booktest/postgresql/Models.kt @@ -0,0 +1,32 @@ +// Code generated by sqlc. DO NOT EDIT. + +package com.example.booktest.postgresql + +import java.time.OffsetDateTime + +enum class BookType(val value: String) { + FICTION("FICTION"), + NONFICTION("NONFICTION"); + + companion object { + private val map = BookType.values().associateBy(BookType::value) + fun lookup(value: String) = map[value] + } +} + +data class Author ( + val authorId: Int, + val name: String +) + +data class Book ( + val bookId: Int, + val authorId: Int, + val isbn: String, + val booktype: BookType, + val title: String, + val year: Int, + val available: OffsetDateTime, + val tags: List +) + diff --git a/examples/kotlin/src/main/kotlin/com/example/booktest/postgresql/QueriesImpl.kt b/examples/kotlin/src/main/kotlin/com/example/booktest/postgresql/QueriesImpl.kt new file mode 100644 index 0000000000..f6e13ab8db --- /dev/null +++ b/examples/kotlin/src/main/kotlin/com/example/booktest/postgresql/QueriesImpl.kt @@ -0,0 +1,279 @@ +// Code generated by sqlc. DO NOT EDIT. + +package com.example.booktest.postgresql + +import java.sql.Connection +import java.sql.SQLException +import java.sql.Types +import java.time.OffsetDateTime + +const val booksByTags = """-- name: booksByTags :many +SELECT + book_id, + title, + name, + isbn, + tags +FROM books +LEFT JOIN authors ON books.author_id = authors.author_id +WHERE tags && ?::varchar[] +""" + +data class BooksByTagsRow ( + val bookId: Int, + val title: String, + val name: String, + val isbn: String, + val tags: List +) + +const val booksByTitleYear = """-- name: booksByTitleYear :many +SELECT book_id, author_id, isbn, booktype, title, year, available, tags FROM books +WHERE title = ? AND year = ? +""" + +const val createAuthor = """-- name: createAuthor :one +INSERT INTO authors (name) VALUES (?) +RETURNING author_id, name +""" + +const val createBook = """-- name: createBook :one +INSERT INTO books ( + author_id, + isbn, + booktype, + title, + year, + available, + tags +) VALUES ( + ?, + ?, + ?, + ?, + ?, + ?, + ? +) +RETURNING book_id, author_id, isbn, booktype, title, year, available, tags +""" + +const val deleteBook = """-- name: deleteBook :exec +DELETE FROM books +WHERE book_id = ? +""" + +const val getAuthor = """-- name: getAuthor :one +SELECT author_id, name FROM authors +WHERE author_id = ? +""" + +const val getBook = """-- name: getBook :one +SELECT book_id, author_id, isbn, booktype, title, year, available, tags FROM books +WHERE book_id = ? +""" + +const val updateBook = """-- name: updateBook :exec +UPDATE books +SET title = ?, tags = ? +WHERE book_id = ? +""" + +const val updateBookISBN = """-- name: updateBookISBN :exec +UPDATE books +SET title = ?, tags = ?, isbn = ? +WHERE book_id = ? +""" + +class QueriesImpl(private val conn: Connection) { + + @Throws(SQLException::class) + fun booksByTags(dollar1: List): List { + return conn.prepareStatement(booksByTags).use { stmt -> + stmt.setArray(1, conn.createArrayOf("pg_catalog.varchar", dollar1.toTypedArray())) + + val results = stmt.executeQuery() + val ret = mutableListOf() + while (results.next()) { + ret.add(BooksByTagsRow( + results.getInt(1), + results.getString(2), + results.getString(3), + results.getString(4), + (results.getArray(5).array as Array).toList() + )) + } + ret + } + } + + @Throws(SQLException::class) + fun booksByTitleYear(title: String, year: Int): List { + return conn.prepareStatement(booksByTitleYear).use { stmt -> + stmt.setString(1, title) + stmt.setInt(2, year) + + val results = stmt.executeQuery() + val ret = mutableListOf() + while (results.next()) { + ret.add(Book( + results.getInt(1), + results.getInt(2), + results.getString(3), + BookType.lookup(results.getString(4))!!, + results.getString(5), + results.getInt(6), + results.getObject(7, OffsetDateTime::class.java), + (results.getArray(8).array as Array).toList() + )) + } + ret + } + } + + @Throws(SQLException::class) + fun createAuthor(name: String): Author { + return conn.prepareStatement(createAuthor).use { stmt -> + stmt.setString(1, name) + + val results = stmt.executeQuery() + if (!results.next()) { + throw SQLException("no rows in result set") + } + val ret = Author( + results.getInt(1), + results.getString(2) + ) + if (results.next()) { + throw SQLException("expected one row in result set, but got many") + } + ret + } + } + + @Throws(SQLException::class) + fun createBook( + authorId: Int, + isbn: String, + booktype: BookType, + title: String, + year: Int, + available: OffsetDateTime, + tags: List): Book { + return conn.prepareStatement(createBook).use { stmt -> + stmt.setInt(1, authorId) + stmt.setString(2, isbn) + stmt.setObject(3, booktype.value, Types.OTHER) + stmt.setString(4, title) + stmt.setInt(5, year) + stmt.setObject(6, available) + stmt.setArray(7, conn.createArrayOf("pg_catalog.varchar", tags.toTypedArray())) + + val results = stmt.executeQuery() + if (!results.next()) { + throw SQLException("no rows in result set") + } + val ret = Book( + results.getInt(1), + results.getInt(2), + results.getString(3), + BookType.lookup(results.getString(4))!!, + results.getString(5), + results.getInt(6), + results.getObject(7, OffsetDateTime::class.java), + (results.getArray(8).array as Array).toList() + ) + if (results.next()) { + throw SQLException("expected one row in result set, but got many") + } + ret + } + } + + @Throws(SQLException::class) + fun deleteBook(bookId: Int) { + conn.prepareStatement(deleteBook).use { stmt -> + stmt.setInt(1, bookId) + + stmt.execute() + } + } + + @Throws(SQLException::class) + fun getAuthor(authorId: Int): Author { + return conn.prepareStatement(getAuthor).use { stmt -> + stmt.setInt(1, authorId) + + val results = stmt.executeQuery() + if (!results.next()) { + throw SQLException("no rows in result set") + } + val ret = Author( + results.getInt(1), + results.getString(2) + ) + if (results.next()) { + throw SQLException("expected one row in result set, but got many") + } + ret + } + } + + @Throws(SQLException::class) + fun getBook(bookId: Int): Book { + return conn.prepareStatement(getBook).use { stmt -> + stmt.setInt(1, bookId) + + val results = stmt.executeQuery() + if (!results.next()) { + throw SQLException("no rows in result set") + } + val ret = Book( + results.getInt(1), + results.getInt(2), + results.getString(3), + BookType.lookup(results.getString(4))!!, + results.getString(5), + results.getInt(6), + results.getObject(7, OffsetDateTime::class.java), + (results.getArray(8).array as Array).toList() + ) + if (results.next()) { + throw SQLException("expected one row in result set, but got many") + } + ret + } + } + + @Throws(SQLException::class) + fun updateBook( + title: String, + tags: List, + bookId: Int) { + conn.prepareStatement(updateBook).use { stmt -> + stmt.setString(1, title) + stmt.setArray(2, conn.createArrayOf("pg_catalog.varchar", tags.toTypedArray())) + stmt.setInt(3, bookId) + + stmt.execute() + } + } + + @Throws(SQLException::class) + fun updateBookISBN( + title: String, + tags: List, + isbn: String, + bookId: Int) { + conn.prepareStatement(updateBookISBN).use { stmt -> + stmt.setString(1, title) + stmt.setArray(2, conn.createArrayOf("pg_catalog.varchar", tags.toTypedArray())) + stmt.setString(3, isbn) + stmt.setInt(4, bookId) + + stmt.execute() + } + } + +} + diff --git a/examples/kotlin/src/main/kotlin/com/example/jets/Models.kt b/examples/kotlin/src/main/kotlin/com/example/jets/Models.kt new file mode 100644 index 0000000000..d2bced3778 --- /dev/null +++ b/examples/kotlin/src/main/kotlin/com/example/jets/Models.kt @@ -0,0 +1,27 @@ +// Code generated by sqlc. DO NOT EDIT. + +package com.example.jets + +data class Jet ( + val id: Int, + val pilotId: Int, + val age: Int, + val name: String, + val color: String +) + +data class Language ( + val id: Int, + val language: String +) + +data class Pilot ( + val id: Int, + val name: String +) + +data class PilotLanguage ( + val pilotId: Int, + val languageId: Int +) + diff --git a/examples/kotlin/src/main/kotlin/com/example/jets/QueriesImpl.kt b/examples/kotlin/src/main/kotlin/com/example/jets/QueriesImpl.kt new file mode 100644 index 0000000000..04ab619fca --- /dev/null +++ b/examples/kotlin/src/main/kotlin/com/example/jets/QueriesImpl.kt @@ -0,0 +1,64 @@ +// Code generated by sqlc. DO NOT EDIT. + +package com.example.jets + +import java.sql.Connection +import java.sql.SQLException + +const val countPilots = """-- name: countPilots :one +SELECT COUNT(*) FROM pilots +""" + +const val deletePilot = """-- name: deletePilot :exec +DELETE FROM pilots WHERE id = ? +""" + +const val listPilots = """-- name: listPilots :many +SELECT id, name FROM pilots LIMIT 5 +""" + +class QueriesImpl(private val conn: Connection) { + + @Throws(SQLException::class) + fun countPilots(): Long { + return conn.prepareStatement(countPilots).use { stmt -> + + val results = stmt.executeQuery() + if (!results.next()) { + throw SQLException("no rows in result set") + } + val ret = results.getLong(1) + if (results.next()) { + throw SQLException("expected one row in result set, but got many") + } + ret + } + } + + @Throws(SQLException::class) + fun deletePilot(id: Int) { + conn.prepareStatement(deletePilot).use { stmt -> + stmt.setInt(1, id) + + stmt.execute() + } + } + + @Throws(SQLException::class) + fun listPilots(): List { + return conn.prepareStatement(listPilots).use { stmt -> + + val results = stmt.executeQuery() + val ret = mutableListOf() + while (results.next()) { + ret.add(Pilot( + results.getInt(1), + results.getString(2) + )) + } + ret + } + } + +} + diff --git a/examples/kotlin/src/main/kotlin/com/example/ondeck/Models.kt b/examples/kotlin/src/main/kotlin/com/example/ondeck/Models.kt new file mode 100644 index 0000000000..e4dd8e7db7 --- /dev/null +++ b/examples/kotlin/src/main/kotlin/com/example/ondeck/Models.kt @@ -0,0 +1,37 @@ +// Code generated by sqlc. DO NOT EDIT. + +package com.example.ondeck + +import java.time.LocalDateTime + +// Venues can be either open or closed +enum class Status(val value: String) { + OPEN("op!en"), + CLOSED("clo@sed"); + + companion object { + private val map = Status.values().associateBy(Status::value) + fun lookup(value: String) = map[value] + } +} + +data class City ( + val slug: String, + val name: String +) + +// Venues are places where muisc happens +data class Venue ( + val id: Int, + val status: Status, + val statuses: List, + // This value appears in public URLs + val slug: String, + val name: String, + val city: String, + val spotifyPlaylist: String, + val songkickId: String?, + val tags: List, + val createdAt: LocalDateTime +) + diff --git a/examples/kotlin/src/main/kotlin/com/example/ondeck/Queries.kt b/examples/kotlin/src/main/kotlin/com/example/ondeck/Queries.kt new file mode 100644 index 0000000000..69debb00a7 --- /dev/null +++ b/examples/kotlin/src/main/kotlin/com/example/ondeck/Queries.kt @@ -0,0 +1,49 @@ +// Code generated by sqlc. DO NOT EDIT. + +package com.example.ondeck + +import java.sql.Connection +import java.sql.SQLException +import java.sql.Types +import java.time.LocalDateTime + +interface Queries { + @Throws(SQLException::class) + fun createCity(name: String, slug: String): City + + @Throws(SQLException::class) + fun createVenue( + slug: String, + name: String, + city: String, + spotifyPlaylist: String, + status: Status, + statuses: List, + tags: List): Int + + @Throws(SQLException::class) + fun deleteVenue(slug: String) + + @Throws(SQLException::class) + fun getCity(slug: String): City + + @Throws(SQLException::class) + fun getVenue(slug: String, city: String): Venue + + @Throws(SQLException::class) + fun listCities(): List + + @Throws(SQLException::class) + fun listVenues(city: String): List + + @Throws(SQLException::class) + fun updateCityName(name: String, slug: String) + + @Throws(SQLException::class) + fun updateVenueName(name: String, slug: String): Int + + @Throws(SQLException::class) + fun venueCountByCity(): List + +} + diff --git a/examples/kotlin/src/main/kotlin/com/example/ondeck/QueriesImpl.kt b/examples/kotlin/src/main/kotlin/com/example/ondeck/QueriesImpl.kt new file mode 100644 index 0000000000..774b67cb4d --- /dev/null +++ b/examples/kotlin/src/main/kotlin/com/example/ondeck/QueriesImpl.kt @@ -0,0 +1,301 @@ +// Code generated by sqlc. DO NOT EDIT. + +package com.example.ondeck + +import java.sql.Connection +import java.sql.SQLException +import java.sql.Types +import java.time.LocalDateTime + +const val createCity = """-- name: createCity :one +INSERT INTO city ( + name, + slug +) VALUES ( + ?, + ? +) RETURNING slug, name +""" + +const val createVenue = """-- name: createVenue :one +INSERT INTO venue ( + slug, + name, + city, + created_at, + spotify_playlist, + status, + statuses, + tags +) VALUES ( + ?, + ?, + ?, + NOW(), + ?, + ?, + ?, + ? +) RETURNING id +""" + +const val deleteVenue = """-- name: deleteVenue :exec +DELETE FROM venue +WHERE slug = ? AND slug = ? +""" + +const val getCity = """-- name: getCity :one +SELECT slug, name +FROM city +WHERE slug = ? +""" + +const val getVenue = """-- name: getVenue :one +SELECT id, status, statuses, slug, name, city, spotify_playlist, songkick_id, tags, created_at +FROM venue +WHERE slug = ? AND city = ? +""" + +const val listCities = """-- name: listCities :many +SELECT slug, name +FROM city +ORDER BY name +""" + +const val listVenues = """-- name: listVenues :many +SELECT id, status, statuses, slug, name, city, spotify_playlist, songkick_id, tags, created_at +FROM venue +WHERE city = ? +ORDER BY name +""" + +const val updateCityName = """-- name: updateCityName :exec +UPDATE city +SET name = ? +WHERE slug = ? +""" + +const val updateVenueName = """-- name: updateVenueName :one +UPDATE venue +SET name = ? +WHERE slug = ? +RETURNING id +""" + +const val venueCountByCity = """-- name: venueCountByCity :many +SELECT + city, + count(*) +FROM venue +GROUP BY 1 +ORDER BY 1 +""" + +data class VenueCountByCityRow ( + val city: String, + val count: Long +) + +class QueriesImpl(private val conn: Connection) : Queries { + +// Create a new city. The slug must be unique. +// This is the second line of the comment +// This is the third line + + @Throws(SQLException::class) + override fun createCity(name: String, slug: String): City { + return conn.prepareStatement(createCity).use { stmt -> + stmt.setString(1, name) + stmt.setString(2, slug) + + val results = stmt.executeQuery() + if (!results.next()) { + throw SQLException("no rows in result set") + } + val ret = City( + results.getString(1), + results.getString(2) + ) + if (results.next()) { + throw SQLException("expected one row in result set, but got many") + } + ret + } + } + + @Throws(SQLException::class) + override fun createVenue( + slug: String, + name: String, + city: String, + spotifyPlaylist: String, + status: Status, + statuses: List, + tags: List): Int { + return conn.prepareStatement(createVenue).use { stmt -> + stmt.setString(1, slug) + stmt.setString(2, name) + stmt.setString(3, city) + stmt.setString(4, spotifyPlaylist) + stmt.setObject(5, status.value, Types.OTHER) + stmt.setArray(6, conn.createArrayOf("status", statuses.map { v -> v.value }.toTypedArray())) + stmt.setArray(7, conn.createArrayOf("text", tags.toTypedArray())) + + val results = stmt.executeQuery() + if (!results.next()) { + throw SQLException("no rows in result set") + } + val ret = results.getInt(1) + if (results.next()) { + throw SQLException("expected one row in result set, but got many") + } + ret + } + } + + @Throws(SQLException::class) + override fun deleteVenue(slug: String) { + conn.prepareStatement(deleteVenue).use { stmt -> + stmt.setString(1, slug) + stmt.setString(2, slug) + + stmt.execute() + } + } + + @Throws(SQLException::class) + override fun getCity(slug: String): City { + return conn.prepareStatement(getCity).use { stmt -> + stmt.setString(1, slug) + + val results = stmt.executeQuery() + if (!results.next()) { + throw SQLException("no rows in result set") + } + val ret = City( + results.getString(1), + results.getString(2) + ) + if (results.next()) { + throw SQLException("expected one row in result set, but got many") + } + ret + } + } + + @Throws(SQLException::class) + override fun getVenue(slug: String, city: String): Venue { + return conn.prepareStatement(getVenue).use { stmt -> + stmt.setString(1, slug) + stmt.setString(2, city) + + val results = stmt.executeQuery() + if (!results.next()) { + throw SQLException("no rows in result set") + } + val ret = Venue( + results.getInt(1), + Status.lookup(results.getString(2))!!, + (results.getArray(3).array as Array).map { v -> Status.lookup(v)!! }.toList(), + results.getString(4), + results.getString(5), + results.getString(6), + results.getString(7), + results.getString(8), + (results.getArray(9).array as Array).toList(), + results.getObject(10, LocalDateTime::class.java) + ) + if (results.next()) { + throw SQLException("expected one row in result set, but got many") + } + ret + } + } + + @Throws(SQLException::class) + override fun listCities(): List { + return conn.prepareStatement(listCities).use { stmt -> + + val results = stmt.executeQuery() + val ret = mutableListOf() + while (results.next()) { + ret.add(City( + results.getString(1), + results.getString(2) + )) + } + ret + } + } + + @Throws(SQLException::class) + override fun listVenues(city: String): List { + return conn.prepareStatement(listVenues).use { stmt -> + stmt.setString(1, city) + + val results = stmt.executeQuery() + val ret = mutableListOf() + while (results.next()) { + ret.add(Venue( + results.getInt(1), + Status.lookup(results.getString(2))!!, + (results.getArray(3).array as Array).map { v -> Status.lookup(v)!! }.toList(), + results.getString(4), + results.getString(5), + results.getString(6), + results.getString(7), + results.getString(8), + (results.getArray(9).array as Array).toList(), + results.getObject(10, LocalDateTime::class.java) + )) + } + ret + } + } + + @Throws(SQLException::class) + override fun updateCityName(name: String, slug: String) { + conn.prepareStatement(updateCityName).use { stmt -> + stmt.setString(1, name) + stmt.setString(2, slug) + + stmt.execute() + } + } + + @Throws(SQLException::class) + override fun updateVenueName(name: String, slug: String): Int { + return conn.prepareStatement(updateVenueName).use { stmt -> + stmt.setString(1, name) + stmt.setString(2, slug) + + val results = stmt.executeQuery() + if (!results.next()) { + throw SQLException("no rows in result set") + } + val ret = results.getInt(1) + if (results.next()) { + throw SQLException("expected one row in result set, but got many") + } + ret + } + } + + @Throws(SQLException::class) + override fun venueCountByCity(): List { + return conn.prepareStatement(venueCountByCity).use { stmt -> + + val results = stmt.executeQuery() + val ret = mutableListOf() + while (results.next()) { + ret.add(VenueCountByCityRow( + results.getString(1), + results.getLong(2) + )) + } + ret + } + } + +} + diff --git a/examples/kotlin/src/main/resources/authors/query.sql b/examples/kotlin/src/main/resources/authors/query.sql new file mode 100644 index 0000000000..75e38b2caf --- /dev/null +++ b/examples/kotlin/src/main/resources/authors/query.sql @@ -0,0 +1,19 @@ +-- name: GetAuthor :one +SELECT * FROM authors +WHERE id = $1 LIMIT 1; + +-- name: ListAuthors :many +SELECT * FROM authors +ORDER BY name; + +-- name: CreateAuthor :one +INSERT INTO authors ( + name, bio +) VALUES ( + $1, $2 +) +RETURNING *; + +-- name: DeleteAuthor :exec +DELETE FROM authors +WHERE id = $1; diff --git a/examples/kotlin/src/main/resources/authors/schema.sql b/examples/kotlin/src/main/resources/authors/schema.sql new file mode 100644 index 0000000000..b4fad78497 --- /dev/null +++ b/examples/kotlin/src/main/resources/authors/schema.sql @@ -0,0 +1,5 @@ +CREATE TABLE authors ( + id BIGSERIAL PRIMARY KEY, + name text NOT NULL, + bio text +); diff --git a/examples/kotlin/src/main/resources/booktest/postgresql/query.sql b/examples/kotlin/src/main/resources/booktest/postgresql/query.sql new file mode 100644 index 0000000000..f4537c603e --- /dev/null +++ b/examples/kotlin/src/main/resources/booktest/postgresql/query.sql @@ -0,0 +1,60 @@ +-- name: GetAuthor :one +SELECT * FROM authors +WHERE author_id = $1; + +-- name: GetBook :one +SELECT * FROM books +WHERE book_id = $1; + +-- name: DeleteBook :exec +DELETE FROM books +WHERE book_id = $1; + +-- name: BooksByTitleYear :many +SELECT * FROM books +WHERE title = $1 AND year = $2; + +-- name: BooksByTags :many +SELECT + book_id, + title, + name, + isbn, + tags +FROM books +LEFT JOIN authors ON books.author_id = authors.author_id +WHERE tags && $1::varchar[]; + +-- name: CreateAuthor :one +INSERT INTO authors (name) VALUES ($1) +RETURNING *; + +-- name: CreateBook :one +INSERT INTO books ( + author_id, + isbn, + booktype, + title, + year, + available, + tags +) VALUES ( + $1, + $2, + $3, + $4, + $5, + $6, + $7 +) +RETURNING *; + +-- name: UpdateBook :exec +UPDATE books +SET title = $1, tags = $2 +WHERE book_id = $3; + +-- name: UpdateBookISBN :exec +UPDATE books +SET title = $1, tags = $2, isbn = $4 +WHERE book_id = $3; diff --git a/examples/kotlin/src/main/resources/booktest/postgresql/schema.sql b/examples/kotlin/src/main/resources/booktest/postgresql/schema.sql new file mode 100644 index 0000000000..0816931a81 --- /dev/null +++ b/examples/kotlin/src/main/resources/booktest/postgresql/schema.sql @@ -0,0 +1,37 @@ +DROP TABLE IF EXISTS books CASCADE; +DROP TYPE IF EXISTS book_type CASCADE; +DROP TABLE IF EXISTS authors CASCADE; +DROP FUNCTION IF EXISTS say_hello(text) CASCADE; + +CREATE TABLE authors ( + author_id SERIAL PRIMARY KEY, + name text NOT NULL DEFAULT '' +); + +CREATE INDEX authors_name_idx ON authors(name); + +CREATE TYPE book_type AS ENUM ( + 'FICTION', + 'NONFICTION' +); + +CREATE TABLE books ( + book_id SERIAL PRIMARY KEY, + author_id integer NOT NULL REFERENCES authors(author_id), + isbn text NOT NULL DEFAULT '' UNIQUE, + booktype book_type NOT NULL DEFAULT 'FICTION', + title text NOT NULL DEFAULT '', + year integer NOT NULL DEFAULT 2000, + available timestamp with time zone NOT NULL DEFAULT 'NOW()', + tags varchar[] NOT NULL DEFAULT '{}' +); + +CREATE INDEX books_title_idx ON books(title, year); + +CREATE FUNCTION say_hello(text) RETURNS text AS $$ +BEGIN + RETURN CONCAT('hello ', $1); +END; +$$ LANGUAGE plpgsql; + +CREATE INDEX books_title_lower_idx ON books(title); diff --git a/examples/kotlin/src/main/resources/ondeck/query/city.sql b/examples/kotlin/src/main/resources/ondeck/query/city.sql new file mode 100644 index 0000000000..f34dc9961e --- /dev/null +++ b/examples/kotlin/src/main/resources/ondeck/query/city.sql @@ -0,0 +1,26 @@ +-- name: ListCities :many +SELECT * +FROM city +ORDER BY name; + +-- name: GetCity :one +SELECT * +FROM city +WHERE slug = $1; + +-- name: CreateCity :one +-- Create a new city. The slug must be unique. +-- This is the second line of the comment +-- This is the third line +INSERT INTO city ( + name, + slug +) VALUES ( + $1, + $2 +) RETURNING *; + +-- name: UpdateCityName :exec +UPDATE city +SET name = $2 +WHERE slug = $1; diff --git a/examples/kotlin/src/main/resources/ondeck/query/venue.sql b/examples/kotlin/src/main/resources/ondeck/query/venue.sql new file mode 100644 index 0000000000..8c6bd02664 --- /dev/null +++ b/examples/kotlin/src/main/resources/ondeck/query/venue.sql @@ -0,0 +1,49 @@ +-- name: ListVenues :many +SELECT * +FROM venue +WHERE city = $1 +ORDER BY name; + +-- name: DeleteVenue :exec +DELETE FROM venue +WHERE slug = $1 AND slug = $1; + +-- name: GetVenue :one +SELECT * +FROM venue +WHERE slug = $1 AND city = $2; + +-- name: CreateVenue :one +INSERT INTO venue ( + slug, + name, + city, + created_at, + spotify_playlist, + status, + statuses, + tags +) VALUES ( + $1, + $2, + $3, + NOW(), + $4, + $5, + $6, + $7 +) RETURNING id; + +-- name: UpdateVenueName :one +UPDATE venue +SET name = $2 +WHERE slug = $1 +RETURNING id; + +-- name: VenueCountByCity :many +SELECT + city, + count(*) +FROM venue +GROUP BY 1 +ORDER BY 1; diff --git a/examples/kotlin/src/main/resources/ondeck/schema/0001_city.sql b/examples/kotlin/src/main/resources/ondeck/schema/0001_city.sql new file mode 100644 index 0000000000..af38f16bb5 --- /dev/null +++ b/examples/kotlin/src/main/resources/ondeck/schema/0001_city.sql @@ -0,0 +1,4 @@ +CREATE TABLE city ( + slug text PRIMARY KEY, + name text NOT NULL +) diff --git a/examples/kotlin/src/main/resources/ondeck/schema/0002_venue.sql b/examples/kotlin/src/main/resources/ondeck/schema/0002_venue.sql new file mode 100644 index 0000000000..940de7a5a8 --- /dev/null +++ b/examples/kotlin/src/main/resources/ondeck/schema/0002_venue.sql @@ -0,0 +1,18 @@ +CREATE TYPE status AS ENUM ('op!en', 'clo@sed'); +COMMENT ON TYPE status IS 'Venues can be either open or closed'; + +CREATE TABLE venues ( + id SERIAL primary key, + dropped text, + status status not null, + statuses status[], + slug text not null, + name varchar(255) not null, + city text not null references city(slug), + spotify_playlist varchar not null, + songkick_id text, + tags text[] +); +COMMENT ON TABLE venues IS 'Venues are places where muisc happens'; +COMMENT ON COLUMN venues.slug IS 'This value appears in public URLs'; + diff --git a/examples/kotlin/src/main/resources/ondeck/schema/0003_add_column.sql b/examples/kotlin/src/main/resources/ondeck/schema/0003_add_column.sql new file mode 100644 index 0000000000..9b334bccce --- /dev/null +++ b/examples/kotlin/src/main/resources/ondeck/schema/0003_add_column.sql @@ -0,0 +1,3 @@ +ALTER TABLE venues RENAME TO venue; +ALTER TABLE venue ADD COLUMN created_at TIMESTAMP NOT NULL DEFAULT NOW(); +ALTER TABLE venue DROP COLUMN dropped; diff --git a/examples/kotlin/src/test/kotlin/com/example/authors/QueriesImplTest.kt b/examples/kotlin/src/test/kotlin/com/example/authors/QueriesImplTest.kt new file mode 100644 index 0000000000..93608a1dfc --- /dev/null +++ b/examples/kotlin/src/test/kotlin/com/example/authors/QueriesImplTest.kt @@ -0,0 +1,60 @@ +package com.example.authors + +import com.example.dbtest.DbTestExtension +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.extension.RegisterExtension + +class QueriesImplTest() { + + companion object { + @JvmField + @RegisterExtension + val dbtest = DbTestExtension("src/main/resources/authors/schema.sql") + } + + @Test + fun testCreateAuthor() { + val db = QueriesImpl(dbtest.getConnection()) + + val initialAuthors = db.listAuthors() + assert(initialAuthors.isEmpty()) + + val name = "Brian Kernighan" + val bio = "Co-author of The C Programming Language and The Go Programming Language" + val insertedAuthor = db.createAuthor( + name = name, + bio = bio + ) + val expectedAuthor = Author(insertedAuthor.id, name, bio) + assertEquals(expectedAuthor, insertedAuthor) + + val fetchedAuthor = db.getAuthor(insertedAuthor.id) + assertEquals(expectedAuthor, fetchedAuthor) + + val listedAuthors = db.listAuthors() + assertEquals(1, listedAuthors.size) + assertEquals(expectedAuthor, listedAuthors[0]) + } + + @Test + fun testNull() { + val db = QueriesImpl(dbtest.getConnection()) + + val initialAuthors = db.listAuthors() + assert(initialAuthors.isEmpty()) + + val name = "Brian Kernighan" + val bio = null + val insertedAuthor = db.createAuthor(name, bio) + val expectedAuthor = Author(insertedAuthor.id, name, bio) + assertEquals(expectedAuthor, insertedAuthor) + + val fetchedAuthor = db.getAuthor(insertedAuthor.id) + assertEquals(expectedAuthor, fetchedAuthor) + + val listedAuthors = db.listAuthors() + assertEquals(1, listedAuthors.size) + assertEquals(expectedAuthor, listedAuthors[0]) + } +} diff --git a/examples/kotlin/src/test/kotlin/com/example/booktest/postgresql/QueriesImplTest.kt b/examples/kotlin/src/test/kotlin/com/example/booktest/postgresql/QueriesImplTest.kt new file mode 100644 index 0000000000..d85747431e --- /dev/null +++ b/examples/kotlin/src/test/kotlin/com/example/booktest/postgresql/QueriesImplTest.kt @@ -0,0 +1,97 @@ +package com.example.booktest.postgresql + +import com.example.dbtest.DbTestExtension +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.extension.RegisterExtension +import java.time.OffsetDateTime +import java.time.format.DateTimeFormatter + +class QueriesImplTest { + companion object { + @JvmField @RegisterExtension val dbtest = DbTestExtension("src/main/resources/booktest/postgresql/schema.sql") + } + + @Test + fun testQueries() { + val conn = dbtest.getConnection() + val db = QueriesImpl(conn) + val author = db.createAuthor("Unknown Master") + + // Start a transaction + conn.autoCommit = false + db.createBook( + authorId = author.authorId, + isbn = "1", + title = "my book title", + booktype = BookType.NONFICTION, + year = 2016, + available = OffsetDateTime.now(), + tags = listOf() + ) + + val b1 = db.createBook( + authorId = author.authorId, + isbn = "2", + title = "the second book", + booktype = BookType.NONFICTION, + year = 2016, + available = OffsetDateTime.now(), + tags = listOf("cool", "unique") + ) + + db.updateBook( + bookId = b1.bookId, + title = "changed second title", + tags = listOf("cool", "disastor") + ) + + val b3 = db.createBook( + authorId = author.authorId, + isbn = "3", + title = "the third book", + booktype = BookType.NONFICTION, + year = 2001, + available = OffsetDateTime.now(), + tags = listOf("cool") + ) + + db.createBook( + authorId = author.authorId, + isbn = "4", + title = "4th place finisher", + booktype = BookType.NONFICTION, + year = 2011, + available = OffsetDateTime.now(), + tags = listOf("other") + ) + + // Commit transaction + conn.commit() + conn.autoCommit = true + + // ISBN update fails because parameters are not in sequential order. After changing $N to ?, ordering is lost, + // and the parameters are filled into the wrong slots. + db.updateBookISBN( + bookId = b3.bookId, + isbn = "NEW ISBN", + title = "never ever gonna finish, a quatrain", + tags = listOf("someother") + ) + + val books0 = db.booksByTitleYear("my book title", 2016) + + val formatter = DateTimeFormatter.ISO_DATE_TIME + for (book in books0) { + println("Book ${book.bookId} (${book.booktype}): ${book.title} available: ${book.available.format(formatter)}") + val author = db.getAuthor(book.authorId) + println("Book ${book.bookId} author: ${author.name}") + } + + // find a book with either "cool" or "other" tag + println("---------\\nTag search results:\\n") + val res = db.booksByTags(listOf("cool", "other", "someother")) + for (ab in res) { + println("Book ${ab.bookId}: '${ab.title}', Author: '${ab.name}', ISBN: '${ab.isbn}' Tags: '${ab.tags.toList()}'") + } + } +} \ No newline at end of file diff --git a/examples/kotlin/src/test/kotlin/com/example/dbtest/DbTestExtension.kt b/examples/kotlin/src/test/kotlin/com/example/dbtest/DbTestExtension.kt new file mode 100644 index 0000000000..66f831477e --- /dev/null +++ b/examples/kotlin/src/test/kotlin/com/example/dbtest/DbTestExtension.kt @@ -0,0 +1,49 @@ +package com.example.dbtest + +import org.junit.jupiter.api.extension.AfterEachCallback +import org.junit.jupiter.api.extension.BeforeEachCallback +import org.junit.jupiter.api.extension.ExtensionContext +import java.nio.file.Files +import java.nio.file.Paths +import java.sql.Connection +import java.sql.DriverManager +import kotlin.streams.toList + +const val schema = "dinosql_test" + +class DbTestExtension(private val migrationsPath: String) : BeforeEachCallback, AfterEachCallback { + private val schemaConn: Connection + private val url: String + + init { + val user = System.getenv("PG_USER") ?: "postgres" + val pass = System.getenv("PG_PASSWORD") ?: "mysecretpassword" + val host = System.getenv("PG_HOST") ?: "127.0.0.1" + val port = System.getenv("PG_PORT") ?: "5432" + val db = System.getenv("PG_DATABASE") ?: "dinotest" + url = "jdbc:postgresql://$host:$port/$db?user=$user&password=$pass&sslmode=disable" + + schemaConn = DriverManager.getConnection(url) + } + + override fun beforeEach(context: ExtensionContext) { + schemaConn.createStatement().execute("CREATE SCHEMA $schema") + val path = Paths.get(migrationsPath) + val migrations = if (Files.isDirectory(path)) { + Files.list(path).filter{ it.toString().endsWith(".sql")}.sorted().map { Files.readString(it) }.toList() + } else { + listOf(Files.readString(path)) + } + migrations.forEach { + getConnection().createStatement().execute(it) + } + } + + override fun afterEach(context: ExtensionContext) { + schemaConn.createStatement().execute("DROP SCHEMA $schema CASCADE") + } + + fun getConnection(): Connection { + return DriverManager.getConnection("$url¤tSchema=$schema") + } +} \ No newline at end of file diff --git a/examples/kotlin/src/test/kotlin/com/example/ondeck/QueriesImplTest.kt b/examples/kotlin/src/test/kotlin/com/example/ondeck/QueriesImplTest.kt new file mode 100644 index 0000000000..2ab1ba3ac5 --- /dev/null +++ b/examples/kotlin/src/test/kotlin/com/example/ondeck/QueriesImplTest.kt @@ -0,0 +1,46 @@ +package com.example.ondeck + +import com.example.dbtest.DbTestExtension +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.extension.RegisterExtension + +class QueriesImplTest { + companion object { + @JvmField @RegisterExtension val dbtest = DbTestExtension("src/main/resources/ondeck/schema") + } + + @Test + fun testQueries() { + val q = QueriesImpl(dbtest.getConnection()) + val city = q.createCity( + slug = "san-francisco", + name = "San Francisco" + ) + val venueId = q.createVenue( + slug = "the-fillmore", + name = "The Fillmore", + city = city.slug, + spotifyPlaylist = "spotify=uri", + status = Status.OPEN, + statuses = listOf(Status.OPEN, Status.CLOSED), + tags = listOf("rock", "punk") + ) + val venue = q.getVenue( + slug = "the-fillmore", + city = city.slug + ) + assertEquals(venueId, venue.id) + + assertEquals(city, q.getCity(city.slug)) + assertEquals(listOf(VenueCountByCityRow(city.slug, 1)), q.venueCountByCity()) + assertEquals(listOf(city), q.listCities()) + assertEquals(listOf(venue), q.listVenues(city.slug)) + + q.updateCityName(slug = city.slug, name = "SF") + val id = q.updateVenueName(slug = venue.slug, name = "Fillmore") + assertEquals(venue.id, id) + + q.deleteVenue(venue.slug) + } +} \ No newline at end of file diff --git a/examples/sqlc.json b/examples/sqlc.json index 9662784e93..f892a0bb11 100644 --- a/examples/sqlc.json +++ b/examples/sqlc.json @@ -35,6 +35,39 @@ "schema": "booktest/mysql/schema.sql", "queries": "booktest/mysql/query.sql", "engine": "mysql" + }, + { + "name": "com.example.authors", + "path": "kotlin/src/main/kotlin/com/example/authors", + "schema": "kotlin/src/main/resources/authors/schema.sql", + "queries": "kotlin/src/main/resources/authors/query.sql", + "engine": "postgresql", + "language": "kotlin" + }, + { + "name": "com.example.ondeck", + "path": "kotlin/src/main/kotlin/com/example/ondeck", + "schema": "ondeck/schema", + "queries": "ondeck/query", + "engine": "postgresql", + "emit_interface": true, + "language": "kotlin" + }, + { + "name": "com.example.jets", + "path": "kotlin/src/main/kotlin/com/example/jets", + "schema": "jets/schema.sql", + "queries": "jets/query-building.sql", + "engine": "postgresql", + "language": "kotlin" + }, + { + "name": "com.example.booktest.postgresql", + "path": "kotlin/src/main/kotlin/com/example/booktest/postgresql", + "schema": "kotlin/src/main/resources/booktest/postgresql/schema.sql", + "queries": "kotlin/src/main/resources/booktest/postgresql/query.sql", + "engine": "postgresql", + "language": "kotlin" } ] } diff --git a/internal/cmd/cmd.go b/internal/cmd/cmd.go index 487deb07ef..3e4f7fa301 100644 --- a/internal/cmd/cmd.go +++ b/internal/cmd/cmd.go @@ -9,11 +9,11 @@ import ( "os/exec" "path/filepath" - "github.com/kyleconroy/sqlc/internal/dinosql" - "github.com/davecgh/go-spew/spew" pg "github.com/lfittl/pg_query_go" "github.com/spf13/cobra" + + "github.com/kyleconroy/sqlc/internal/dinosql" ) // Do runs the command logic. diff --git a/internal/cmd/generate.go b/internal/cmd/generate.go index da4cfa8c4b..1b8e592c4e 100644 --- a/internal/cmd/generate.go +++ b/internal/cmd/generate.go @@ -8,6 +8,7 @@ import ( "path/filepath" "github.com/kyleconroy/sqlc/internal/dinosql" + "github.com/kyleconroy/sqlc/internal/dinosql/kotlin" "github.com/kyleconroy/sqlc/internal/mysql" ) @@ -106,11 +107,22 @@ func Generate(dir string, stderr io.Writer) (map[string]string, error) { errored = true continue } - result = q + result = &kotlin.Result{Result: q} } - files, err := dinosql.Generate(result, combo) + var files map[string]string + switch pkg.Language { + case dinosql.LanguageGo: + files, err = dinosql.Generate(result, combo) + case dinosql.LanguageKotlin: + ktRes, ok := result.(kotlin.KtGenerateable) + if !ok { + err = fmt.Errorf("kotlin not supported for engine %s", pkg.Engine) + break + } + files, err = kotlin.KtGenerate(ktRes, combo) + } if err != nil { fmt.Fprintf(stderr, "# package %s\n", name) fmt.Fprintf(stderr, "error generating code: %s\n", err) diff --git a/internal/dinosql/config.go b/internal/dinosql/config.go index e14dc9a2f0..8ae8c6b76f 100644 --- a/internal/dinosql/config.go +++ b/internal/dinosql/config.go @@ -41,9 +41,17 @@ const ( EnginePostgreSQL Engine = "postgresql" ) +type Language string + +const ( + LanguageGo Language = "go" + LanguageKotlin Language = "kotlin" +) + type PackageSettings struct { Name string `json:"name"` Engine Engine `json:"engine,omitempty"` + Language Language `json:"language,omitempty"` Path string `json:"path"` Schema string `json:"schema"` Queries string `json:"queries"` @@ -51,6 +59,8 @@ type PackageSettings struct { EmitJSONTags bool `json:"emit_json_tags"` EmitPreparedQueries bool `json:"emit_prepared_queries"` Overrides []Override `json:"overrides"` + // HACK: this is only set in tests, only here till Kotlin support can be merged. + rewriteParams bool } type Override struct { @@ -196,6 +206,11 @@ func ParseConfig(rd io.Reader) (GenerateSettings, error) { if config.Packages[j].Engine == "" { config.Packages[j].Engine = EnginePostgreSQL } + if config.Packages[j].Language == "" { + config.Packages[j].Language = LanguageGo + } else if config.Packages[j].Language == "kotlin" { + config.Packages[j].rewriteParams = true + } } return config, nil } diff --git a/internal/dinosql/gen.go b/internal/dinosql/gen.go index 325eed5dbe..4a20999022 100644 --- a/internal/dinosql/gen.go +++ b/internal/dinosql/gen.go @@ -719,6 +719,11 @@ func (r Result) goInnerType(col core.Column, settings CombinedSettings) string { } } +type goColumn struct { + id int + core.Column +} + // It's possible that this method will generate duplicate JSON tag values // // Columns: count, count, count_2 @@ -726,21 +731,31 @@ func (r Result) goInnerType(col core.Column, settings CombinedSettings) string { // JSON tags: count, count_2, count_2 // // This is unlikely to happen, so don't fix it yet -func (r Result) columnsToStruct(name string, columns []core.Column, settings CombinedSettings) *GoStruct { +func (r Result) columnsToStruct(name string, columns []goColumn, settings CombinedSettings) *GoStruct { gs := GoStruct{ Name: name, } seen := map[string]int{} + suffixes := map[int]int{} for i, c := range columns { tagName := c.Name - fieldName := StructName(columnName(c, i), settings) - if v := seen[c.Name]; v > 0 { - tagName = fmt.Sprintf("%s_%d", tagName, v+1) - fieldName = fmt.Sprintf("%s_%d", fieldName, v+1) + fieldName := StructName(columnName(c.Column, i), settings) + // Track suffixes by the ID of the column, so that columns referring to the same numbered parameter can be + // reused. + suffix := 0 + if o, ok := suffixes[c.id]; ok { + suffix = o + } else if v := seen[c.Name]; v > 0 { + suffix = v+1 + } + suffixes[c.id] = suffix + if suffix > 0 { + tagName = fmt.Sprintf("%s_%d", tagName, suffix) + fieldName = fmt.Sprintf("%s_%d", fieldName, suffix) } gs.Fields = append(gs.Fields, GoField{ Name: fieldName, - Type: r.goType(c, settings), + Type: r.goType(c.Column, settings), Tags: map[string]string{"json:": tagName}, }) seen[c.Name]++ @@ -815,9 +830,12 @@ func (r Result) GoQueries(settings CombinedSettings) []GoQuery { Typ: r.goType(p.Column, settings), } } else if len(query.Params) > 1 { - var cols []core.Column + var cols []goColumn for _, p := range query.Params { - cols = append(cols, p.Column) + cols = append(cols, goColumn{ + id: p.Number, + Column: p.Column, + }) } gq.Arg = GoQueryValue{ Emit: true, @@ -858,7 +876,14 @@ func (r Result) GoQueries(settings CombinedSettings) []GoQuery { } if gs == nil { - gs = r.columnsToStruct(gq.MethodName+"Row", query.Columns, settings) + var columns []goColumn + for i, c := range query.Columns { + columns = append(columns, goColumn{ + id: i, + Column: c, + }) + } + gs = r.columnsToStruct(gq.MethodName+"Row", columns, settings) emit = true } gq.Ret = GoQueryValue{ diff --git a/internal/dinosql/kotlin/gen.go b/internal/dinosql/kotlin/gen.go new file mode 100644 index 0000000000..f30b43ad5f --- /dev/null +++ b/internal/dinosql/kotlin/gen.go @@ -0,0 +1,1030 @@ +package kotlin + +import ( + "bufio" + "bytes" + "fmt" + "log" + "regexp" + "sort" + "strings" + "text/template" + + "github.com/kyleconroy/sqlc/internal/dinosql" + core "github.com/kyleconroy/sqlc/internal/pg" + + "github.com/jinzhu/inflection" +) + +var ktIdentPattern = regexp.MustCompile("[^a-zA-Z0-9_]+") + +type KtConstant struct { + Name string + Type string + Value string +} + +type KtEnum struct { + Name string + Comment string + Constants []KtConstant +} + +type KtField struct { + Name string + Type ktType + Comment string +} + +type KtStruct struct { + Table core.FQN + Name string + Fields []KtField + JDBCParamBindings []KtField + Comment string +} + +type KtQueryValue struct { + Emit bool + Name string + Struct *KtStruct + Typ ktType + JDBCParamBindCount int +} + +func (v KtQueryValue) EmitStruct() bool { + return v.Emit +} + +func (v KtQueryValue) IsStruct() bool { + return v.Struct != nil +} + +func (v KtQueryValue) isEmpty() bool { + return v.Typ == (ktType{}) && v.Name == "" && v.Struct == nil +} + +func (v KtQueryValue) Type() string { + if v.Typ != (ktType{}) { + return v.Typ.String() + } + if v.Struct != nil { + return v.Struct.Name + } + panic("no type for KtQueryValue: " + v.Name) +} + +func jdbcSet(t ktType, idx int, name string) string { + if t.IsEnum && t.IsArray { + return fmt.Sprintf(`stmt.setArray(%d, conn.createArrayOf("%s", %s.map { v -> v.value }.toTypedArray()))`, idx, t.DataType, name) + } + if t.IsEnum { + return fmt.Sprintf("stmt.setObject(%d, %s.value, %s)", idx, name, "Types.OTHER") + } + if t.IsArray { + return fmt.Sprintf(`stmt.setArray(%d, conn.createArrayOf("%s", %s.toTypedArray()))`, idx, t.DataType, name) + } + if t.IsTime() { + return fmt.Sprintf("stmt.setObject(%d, %s)", idx, name) + } + return fmt.Sprintf("stmt.set%s(%d, %s)", t.Name, idx, name) +} + +type KtParams struct { + Struct *KtStruct +} + +func (v KtParams) isEmpty() bool { + return len(v.Struct.Fields) == 0 +} + +func (v KtParams) Args() string { + if v.isEmpty() { + return "" + } + var out []string + for _, f := range v.Struct.Fields { + out = append(out, f.Name+": "+f.Type.String()) + } + if len(out) < 3 { + return strings.Join(out, ", ") + } + return "\n" + indent(strings.Join(out, ",\n"), 6, -1) +} + +func (v KtParams) Bindings() string { + if v.isEmpty() { + return "" + } + var out []string + for i, f := range v.Struct.JDBCParamBindings { + out = append(out, jdbcSet(f.Type, i+1, f.Name)) + } + return indent(strings.Join(out, "\n"), 6, 0) +} + +func jdbcGet(t ktType, idx int) string { + if t.IsEnum && t.IsArray { + return fmt.Sprintf(`(results.getArray(%d).array as Array).map { v -> %s.lookup(v)!! }.toList()`, idx, t.Name) + } + if t.IsEnum { + return fmt.Sprintf("%s.lookup(results.getString(%d))!!", t.Name, idx) + } + if t.IsArray { + return fmt.Sprintf(`(results.getArray(%d).array as Array<%s>).toList()`, idx, t.Name) + } + if t.IsTime() { + return fmt.Sprintf(`results.getObject(%d, %s::class.java)`, idx, t.Name) + } + return fmt.Sprintf(`results.get%s(%d)`, t.Name, idx) +} + +func (v KtQueryValue) ResultSet() string { + var out []string + if v.Struct == nil { + return jdbcGet(v.Typ, 1) + } + for i, f := range v.Struct.Fields { + out = append(out, jdbcGet(f.Type, i+1)) + } + ret := indent(strings.Join(out, ",\n"), 4, -1) + ret = indent(v.Struct.Name+"(\n"+ret+"\n)", 8, 0) + return ret +} + +func indent(s string, n int, firstIndent int) string { + lines := strings.Split(s, "\n") + buf := bytes.NewBuffer(nil) + for i, l := range lines { + indent := n + if i == 0 && firstIndent != -1 { + indent = firstIndent + } + if i != 0 { + buf.WriteRune('\n') + } + for i := 0; i < indent; i++ { + buf.WriteRune(' ') + } + buf.WriteString(l) + } + return buf.String() +} + +// A struct used to generate methods and fields on the Queries struct +type KtQuery struct { + ClassName string + Cmd string + Comments []string + MethodName string + FieldName string + ConstantName string + SQL string + SourceName string + Ret KtQueryValue + Arg KtParams +} + +type KtGenerateable interface { + KtDataClasses(settings dinosql.CombinedSettings) []KtStruct + KtQueries(settings dinosql.CombinedSettings) []KtQuery + KtEnums(settings dinosql.CombinedSettings) []KtEnum +} + +func KtUsesType(r KtGenerateable, typ string, settings dinosql.CombinedSettings) bool { + for _, strct := range r.KtDataClasses(settings) { + for _, f := range strct.Fields { + if f.Type.Name == typ { + return true + } + } + } + return false +} + +func KtImports(r KtGenerateable, settings dinosql.CombinedSettings) func(string) [][]string { + return func(filename string) [][]string { + if filename == "Models.kt" { + return ModelKtImports(r, settings) + } + + if filename == "Querier.kt" { + return InterfaceKtImports(r, settings) + } + + return QueryKtImports(r, settings, filename) + } +} + +func InterfaceKtImports(r KtGenerateable, settings dinosql.CombinedSettings) [][]string { + gq := r.KtQueries(settings) + uses := func(name string) bool { + for _, q := range gq { + if !q.Ret.isEmpty() { + if strings.HasPrefix(q.Ret.Type(), name) { + return true + } + } + if !q.Arg.isEmpty() { + for _, f := range q.Arg.Struct.Fields { + if strings.HasPrefix(f.Type.Name, name) { + return true + } + } + } + } + return false + } + + std := map[string]struct{}{ + "java.sql.Connection": {}, + "java.sql.SQLException": {}, + } + if uses("LocalDate") { + std["java.time.LocalDate"] = struct{}{} + } + if uses("LocalTime") { + std["java.time.LocalTime"] = struct{}{} + } + if uses("LocalDateTime") { + std["java.time.LocalDateTime"] = struct{}{} + } + if uses("OffsetDateTime") { + std["java.time.OffsetDateTime"] = struct{}{} + } + + stds := make([]string, 0, len(std)) + for s, _ := range std { + stds = append(stds, s) + } + + sort.Strings(stds) + return [][]string{stds} +} + +func ModelKtImports(r KtGenerateable, settings dinosql.CombinedSettings) [][]string { + std := make(map[string]struct{}) + if KtUsesType(r, "LocalDate", settings) { + std["java.time.LocalDate"] = struct{}{} + } + if KtUsesType(r, "LocalTime", settings) { + std["java.time.LocalTime"] = struct{}{} + } + if KtUsesType(r, "LocalDateTime", settings) { + std["java.time.LocalDateTime"] = struct{}{} + } + if KtUsesType(r, "OffsetDateTime", settings) { + std["java.time.OffsetDateTime"] = struct{}{} + } + + stds := make([]string, 0, len(std)) + for s, _ := range std { + stds = append(stds, s) + } + + sort.Strings(stds) + return [][]string{stds} +} + +func QueryKtImports(r KtGenerateable, settings dinosql.CombinedSettings, filename string) [][]string { + // for _, strct := range r.KtDataClasses() { + // for _, f := range strct.Fields { + // if strings.HasPrefix(f.Type, "[]") { + // return true + // } + // } + // } + var gq []KtQuery + for _, query := range r.KtQueries(settings) { + gq = append(gq, query) + } + + uses := func(name string) bool { + for _, q := range gq { + if !q.Ret.isEmpty() { + if q.Ret.Struct != nil { + for _, f := range q.Ret.Struct.Fields { + if f.Type.Name == name { + return true + } + } + } + if q.Ret.Type() == name { + return true + } + } + if !q.Arg.isEmpty() { + for _, f := range q.Arg.Struct.Fields { + if f.Type.Name == name { + return true + } + } + } + } + return false + } + + hasEnum := func() bool { + for _, q := range gq { + if !q.Arg.isEmpty() { + for _, f := range q.Arg.Struct.Fields { + if f.Type.IsEnum { + return true + } + } + } + } + return false + } + + std := map[string]struct{}{ + "java.sql.Connection": {}, + "java.sql.SQLException": {}, + } + if uses("LocalDate") { + std["java.time.LocalDate"] = struct{}{} + } + if uses("LocalTime") { + std["java.time.LocalTime"] = struct{}{} + } + if uses("LocalDateTime") { + std["java.time.LocalDateTime"] = struct{}{} + } + if uses("OffsetDateTime") { + std["java.time.OffsetDateTime"] = struct{}{} + } + if hasEnum() { + std["java.sql.Types"] = struct{}{} + } + + pkg := make(map[string]struct{}) + + pkgs := make([]string, 0, len(pkg)) + for p, _ := range pkg { + pkgs = append(pkgs, p) + } + + stds := make([]string, 0, len(std)) + for s, _ := range std { + stds = append(stds, s) + } + + sort.Strings(stds) + sort.Strings(pkgs) + return [][]string{stds, pkgs} +} + +func ktEnumValueName(value string) string { + id := strings.Replace(value, "-", "_", -1) + id = strings.Replace(id, ":", "_", -1) + id = strings.Replace(id, "/", "_", -1) + id = ktIdentPattern.ReplaceAllString(id, "") + return strings.ToUpper(id) +} + +// Result is a wrapper around *dinosql.Result that extends it with Kotlin support. +// It can be used to generate both Go and Kotlin code. +// TODO: This is a temporary hack to ensure minimal chance of merge conflicts while Kotlin support is forked. +// Once it is merged upstream, we can factor split out Go support from the core dinosql.Result. +type Result struct { + *dinosql.Result +} + +func (r Result) KtEnums(settings dinosql.CombinedSettings) []KtEnum { + var enums []KtEnum + for name, schema := range r.Catalog.Schemas { + if name == "pg_catalog" { + continue + } + for _, enum := range schema.Enums { + var enumName string + if name == "public" { + enumName = enum.Name + } else { + enumName = name + "_" + enum.Name + } + e := KtEnum{ + Name: KtDataClassName(enumName, settings), + Comment: enum.Comment, + } + for _, v := range enum.Vals { + e.Constants = append(e.Constants, KtConstant{ + Name: ktEnumValueName(v), + Value: v, + Type: e.Name, + }) + } + enums = append(enums, e) + } + } + if len(enums) > 0 { + sort.Slice(enums, func(i, j int) bool { return enums[i].Name < enums[j].Name }) + } + return enums +} + +func KtDataClassName(name string, settings dinosql.CombinedSettings) string { + if rename := settings.Global.Rename[name]; rename != "" { + return rename + } + out := "" + for _, p := range strings.Split(name, "_") { + out += strings.Title(p) + } + return out +} + +func KtMemberName(name string, settings dinosql.CombinedSettings) string { + return dinosql.LowerTitle(KtDataClassName(name, settings)) +} + +func (r Result) KtDataClasses(settings dinosql.CombinedSettings) []KtStruct { + var structs []KtStruct + for name, schema := range r.Catalog.Schemas { + if name == "pg_catalog" { + continue + } + for _, table := range schema.Tables { + var tableName string + if name == "public" { + tableName = table.Name + } else { + tableName = name + "_" + table.Name + } + s := KtStruct{ + Table: core.FQN{Schema: name, Rel: table.Name}, + Name: inflection.Singular(KtDataClassName(tableName, settings)), + Comment: table.Comment, + } + for _, column := range table.Columns { + s.Fields = append(s.Fields, KtField{ + Name: KtMemberName(column.Name, settings), + Type: r.ktType(column, settings), + Comment: column.Comment, + }) + } + structs = append(structs, s) + } + } + if len(structs) > 0 { + sort.Slice(structs, func(i, j int) bool { return structs[i].Name < structs[j].Name }) + } + return structs +} + +type ktType struct { + Name string + IsEnum bool + IsArray bool + IsNull bool + DataType string +} + +func (t ktType) String() string { + v := t.Name + if t.IsArray { + v = fmt.Sprintf("List<%s>", v) + } else if t.IsNull { + v += "?" + } + return v +} + +func (t ktType) jdbcSetter() string { + return "set" + t.jdbcType() +} + +func (t ktType) jdbcType() string { + if t.IsArray { + return "Array" + } + if t.IsEnum || t.IsTime() { + return "Object" + } + return t.Name +} + +func (t ktType) IsTime() bool { + return t.Name == "LocalDate" || t.Name == "LocalDateTime" || t.Name == "LocalTime" || t.Name == "OffsetDateTime" +} + +func (r Result) ktType(col core.Column, settings dinosql.CombinedSettings) ktType { + typ, isEnum := r.ktInnerType(col, settings) + return ktType{ + Name: typ, + IsEnum: isEnum, + IsArray: col.IsArray, + IsNull: !col.NotNull, + DataType: col.DataType, + } +} + +func (r Result) ktInnerType(col core.Column, settings dinosql.CombinedSettings) (string, bool) { + columnType := col.DataType + + switch columnType { + case "serial", "pg_catalog.serial4": + return "Int", false + + case "bigserial", "pg_catalog.serial8": + return "Long", false + + case "smallserial", "pg_catalog.serial2": + return "Short", false + + case "integer", "int", "int4", "pg_catalog.int4": + return "Int", false + + case "bigint", "pg_catalog.int8": + return "Long", false + + case "smallint", "pg_catalog.int2": + return "Short", false + + case "float", "double precision", "pg_catalog.float8": + return "Double", false + + case "real", "pg_catalog.float4": + return "Float", false + + case "pg_catalog.numeric": + return "java.math.BigDecimal", false + + case "bool", "pg_catalog.bool": + return "Boolean", false + + case "jsonb": + // TODO: support json and byte types + return "String", false + + case "bytea", "blob", "pg_catalog.bytea": + return "String", false + + case "date": + // Date and time mappings from https://jdbc.postgresql.org/documentation/head/java8-date-time.html + return "LocalDate", false + + case "pg_catalog.time", "pg_catalog.timetz": + return "LocalTime", false + + case "pg_catalog.timestamp": + return "LocalDateTime", false + + case "pg_catalog.timestamptz", "timestamptz": + // TODO + return "OffsetDateTime", false + + case "text", "pg_catalog.varchar", "pg_catalog.bpchar", "string": + return "String", false + + case "uuid": + // TODO + return "uuid.UUID", false + + case "inet": + // TODO + return "net.IP", false + + case "void": + // TODO + // A void value always returns NULL. Since there is no built-in NULL + // value into the SQL package, we'll use sql.NullBool + return "sql.NullBool", false + + case "any": + // TODO + return "Any", false + + default: + for name, schema := range r.Catalog.Schemas { + if name == "pg_catalog" { + continue + } + for _, enum := range schema.Enums { + if columnType == enum.Name { + if name == "public" { + return KtDataClassName(enum.Name, settings), true + } + + return KtDataClassName(name+"_"+enum.Name, settings), true + } + } + } + log.Printf("unknown PostgreSQL type: %s\n", columnType) + return "interface{}", false + } +} + +type goColumn struct { + id int + core.Column +} + +func (r Result) ktColumnsToStruct(name string, columns []goColumn, settings dinosql.CombinedSettings, namer func(core.Column, int) string) *KtStruct { + gs := KtStruct{ + Name: name, + } + idSeen := map[int]KtField{} + nameSeen := map[string]int{} + for _, c := range columns { + if binding, ok := idSeen[c.id]; ok { + gs.JDBCParamBindings = append(gs.JDBCParamBindings, binding) + continue + } + fieldName := KtMemberName(namer(c.Column, c.id), settings) + if v := nameSeen[c.Name]; v > 0 { + fieldName = fmt.Sprintf("%s_%d", fieldName, v+1) + } + field := KtField{ + Name: fieldName, + Type: r.ktType(c.Column, settings), + } + gs.Fields = append(gs.Fields, field) + gs.JDBCParamBindings = append(gs.JDBCParamBindings, field) + nameSeen[c.Name]++ + idSeen[c.id] = field + } + return &gs +} + +func ktArgName(name string) string { + out := "" + for i, p := range strings.Split(name, "_") { + if i == 0 { + out += strings.ToLower(p) + } else { + out += strings.Title(p) + } + } + return out +} + +func ktParamName(c core.Column, number int) string { + if c.Name != "" { + return ktArgName(c.Name) + } + return fmt.Sprintf("dollar_%d", number) +} + + +func ktColumnName(c core.Column, pos int) string { + if c.Name != "" { + return c.Name + } + return fmt.Sprintf("column_%d", pos+1) +} + +var jdbcSQLRe = regexp.MustCompile(`\B\$\d+\b`) + +// HACK: jdbc doesn't support numbered parameters, so we need to transform them to question marks... +// But there's no access to the SQL parser here, so we just do a dumb regexp replace instead. This won't work if +// the literal strings contain matching values, but good enough for a prototype. +func jdbcSQL(s string) string { + return jdbcSQLRe.ReplaceAllString(s, "?") +} + +func (r Result) KtQueries(settings dinosql.CombinedSettings) []KtQuery { + structs := r.KtDataClasses(settings) + + qs := make([]KtQuery, 0, len(r.Queries)) + for _, query := range r.Queries { + if query.Name == "" { + continue + } + if query.Cmd == "" { + continue + } + + gq := KtQuery{ + Cmd: query.Cmd, + ClassName: strings.Title(query.Name), + ConstantName: dinosql.LowerTitle(query.Name), + FieldName: dinosql.LowerTitle(query.Name) + "Stmt", + MethodName: dinosql.LowerTitle(query.Name), + SourceName: query.Filename, + SQL: jdbcSQL(query.SQL), + Comments: query.Comments, + } + + var cols []goColumn + for _, p := range query.Params { + cols = append(cols, goColumn{ + id: p.Number, + Column: p.Column, + }) + } + params := r.ktColumnsToStruct(gq.ClassName+"Bindings", cols, settings, ktParamName) + gq.Arg = KtParams{ + Struct: params, + } + + if len(query.Columns) == 1 { + c := query.Columns[0] + gq.Ret = KtQueryValue{ + Name: "results", + Typ: r.ktType(c, settings), + } + } else if len(query.Columns) > 1 { + var gs *KtStruct + var emit bool + + for _, s := range structs { + if len(s.Fields) != len(query.Columns) { + continue + } + same := true + for i, f := range s.Fields { + c := query.Columns[i] + sameName := f.Name == KtMemberName(ktColumnName(c, i), settings) + sameType := f.Type == r.ktType(c, settings) + sameTable := s.Table.Catalog == c.Table.Catalog && s.Table.Schema == c.Table.Schema && s.Table.Rel == c.Table.Rel + + if !sameName || !sameType || !sameTable { + same = false + } + } + if same { + gs = &s + break + } + } + + if gs == nil { + var columns []goColumn + for i, c := range query.Columns { + columns = append(columns, goColumn{ + id: i, + Column: c, + }) + } + gs = r.ktColumnsToStruct(gq.ClassName+"Row", columns, settings, ktColumnName) + emit = true + } + gq.Ret = KtQueryValue{ + Emit: emit, + Name: "results", + Struct: gs, + } + } + + qs = append(qs, gq) + } + sort.Slice(qs, func(i, j int) bool { return qs[i].MethodName < qs[j].MethodName }) + return qs +} + +var ktIfaceTmpl = `// Code generated by sqlc. DO NOT EDIT. + +package {{.Package}} + +{{range imports .SourceName}} +{{range .}}import {{.}} +{{end}} +{{end}} + +interface Queries { + {{- range .KtQueries}} + @Throws(SQLException::class) + {{- if eq .Cmd ":one"}} + fun {{.MethodName}}({{.Arg.Args}}): {{.Ret.Type}} + {{- end}} + {{- if eq .Cmd ":many"}} + fun {{.MethodName}}({{.Arg.Args}}): List<{{.Ret.Type}}> + {{- end}} + {{- if eq .Cmd ":exec"}} + fun {{.MethodName}}({{.Arg.Args}}) + {{- end}} + {{- if eq .Cmd ":execrows"}} + fun {{.MethodName}}({{.Arg.Args}}): Int + {{- end}} + {{end}} +} +` + +var ktModelsTmpl = `// Code generated by sqlc. DO NOT EDIT. + +package {{.Package}} + +{{range imports .SourceName}} +{{range .}}import {{.}} +{{end}} +{{end}} + +{{range .Enums}} +{{if .Comment}}// {{.Comment}}{{end}} +enum class {{.Name}}(val value: String) { + {{- range $i, $e := .Constants}} + {{- if $i }},{{end}} + {{.Name}}("{{.Value}}") + {{- end}}; + + companion object { + private val map = {{.Name}}.values().associateBy({{.Name}}::value) + fun lookup(value: String) = map[value] + } +} +{{end}} + +{{range .KtDataClasses}} +{{if .Comment}}// {{.Comment}}{{end}} +data class {{.Name}} ( {{- range $i, $e := .Fields}} + {{- if $i }},{{end}} + {{- if .Comment}} + // {{.Comment}}{{else}} + {{- end}} + val {{.Name}}: {{.Type}} + {{- end}} +) +{{end}} +` + +var ktSqlTmpl = `// Code generated by sqlc. DO NOT EDIT. + +package {{.Package}} + +{{range imports .SourceName}} +{{range .}}import {{.}} +{{end}} +{{end}} + +{{range .KtQueries}} +const val {{.ConstantName}} = {{$.Q}}-- name: {{.MethodName}} {{.Cmd}} +{{.SQL}} +{{$.Q}} + +{{if .Ret.EmitStruct}} +data class {{.Ret.Type}} ( {{- range $i, $e := .Ret.Struct.Fields}} + {{- if $i }},{{end}} + val {{.Name}}: {{.Type}} + {{- end}} +) +{{end}} +{{end}} + +class QueriesImpl(private val conn: Connection){{ if .EmitInterface }} : Queries{{end}} { +{{range .KtQueries}} +{{if eq .Cmd ":one"}} +{{range .Comments}}//{{.}} +{{end}} + @Throws(SQLException::class) + {{ if $.EmitInterface }}override {{ end -}} + fun {{.MethodName}}({{.Arg.Args}}): {{.Ret.Type}} { + return conn.prepareStatement({{.ConstantName}}).use { stmt -> + {{.Arg.Bindings}} + + val results = stmt.executeQuery() + if (!results.next()) { + throw SQLException("no rows in result set") + } + val ret = {{.Ret.ResultSet}} + if (results.next()) { + throw SQLException("expected one row in result set, but got many") + } + ret + } + } +{{end}} + +{{if eq .Cmd ":many"}} +{{range .Comments}}//{{.}} +{{end}} + @Throws(SQLException::class) + {{ if $.EmitInterface }}override {{ end -}} + fun {{.MethodName}}({{.Arg.Args}}): List<{{.Ret.Type}}> { + return conn.prepareStatement({{.ConstantName}}).use { stmt -> + {{.Arg.Bindings}} + + val results = stmt.executeQuery() + val ret = mutableListOf<{{.Ret.Type}}>() + while (results.next()) { + ret.add({{.Ret.ResultSet}}) + } + ret + } + } +{{end}} + +{{if eq .Cmd ":exec"}} +{{range .Comments}}//{{.}} +{{end}} + @Throws(SQLException::class) + {{ if $.EmitInterface }}override {{ end -}} + fun {{.MethodName}}({{.Arg.Args}}) { + conn.prepareStatement({{.ConstantName}}).use { stmt -> + {{ .Arg.Bindings }} + + stmt.execute() + } + } +{{end}} + +{{if eq .Cmd ":execrows"}} +{{range .Comments}}//{{.}} +{{end}} + @Throws(SQLException::class) + {{ if $.EmitInterface }}override {{ end -}} + fun {{.MethodName}}({{.Arg.Args}}): Int { + return conn.prepareStatement({{.ConstantName}}).use { stmt -> + {{ .Arg.Bindings }} + + stmt.execute() + stmt.updateCount + } + } +{{end}} +{{end}} +} +` + +type ktTmplCtx struct { + Q string + Package string + Enums []KtEnum + KtDataClasses []KtStruct + KtQueries []KtQuery + Settings dinosql.GenerateSettings + + // TODO: Race conditions + SourceName string + + EmitJSONTags bool + EmitPreparedQueries bool + EmitInterface bool +} + +func Offset(v int) int { + return v + 1 +} + +func ktFormat(s string) string { + // TODO: do more than just skip multiple blank lines, like maybe run ktlint to format + skipNextSpace := false + var lines []string + for _, l := range strings.Split(s, "\n") { + isSpace := len(strings.TrimSpace(l)) == 0 + if !isSpace || !skipNextSpace { + lines = append(lines, l) + } + skipNextSpace = isSpace + } + o := strings.Join(lines, "\n") + o += "\n" + return o +} + +func KtGenerate(r KtGenerateable, settings dinosql.CombinedSettings) (map[string]string, error) { + funcMap := template.FuncMap{ + "lowerTitle": dinosql.LowerTitle, + "imports": KtImports(r, settings), + "offset": Offset, + } + + modelsFile := template.Must(template.New("table").Funcs(funcMap).Parse(ktModelsTmpl)) + sqlFile := template.Must(template.New("table").Funcs(funcMap).Parse(ktSqlTmpl)) + ifaceFile := template.Must(template.New("table").Funcs(funcMap).Parse(ktIfaceTmpl)) + + pkg := settings.Package + tctx := ktTmplCtx{ + Settings: settings.Global, + EmitInterface: pkg.EmitInterface, + EmitJSONTags: pkg.EmitJSONTags, + EmitPreparedQueries: pkg.EmitPreparedQueries, + Q: `"""`, + Package: pkg.Name, + KtQueries: r.KtQueries(settings), + Enums: r.KtEnums(settings), + KtDataClasses: r.KtDataClasses(settings), + } + + output := map[string]string{} + + execute := func(name string, t *template.Template) error { + var b bytes.Buffer + w := bufio.NewWriter(&b) + tctx.SourceName = name + err := t.Execute(w, tctx) + w.Flush() + if err != nil { + return err + } + if !strings.HasSuffix(name, ".kt") { + name += ".kt" + } + output[name] = ktFormat(b.String()) + return nil + } + + if err := execute("Models.kt", modelsFile); err != nil { + return nil, err + } + if pkg.EmitInterface { + if err := execute("Queries.kt", ifaceFile); err != nil { + return nil, err + } + } + if err := execute("QueriesImpl.kt", sqlFile); err != nil { + return nil, err + } + + return output, nil +} diff --git a/internal/dinosql/parser.go b/internal/dinosql/parser.go index 8154f1a5c6..1106fb4e38 100644 --- a/internal/dinosql/parser.go +++ b/internal/dinosql/parser.go @@ -229,7 +229,8 @@ func ParseQueries(c core.Catalog, pkg PackageSettings) (*Result, error) { continue } for _, stmt := range tree.Statements { - query, err := parseQuery(c, stmt, source) + rewriteParameters := pkg.rewriteParams + query, err := parseQuery(c, stmt, source, rewriteParameters) if err == errUnsupportedStatementType { continue } @@ -407,7 +408,7 @@ func validateCmd(n nodes.Node, name, cmd string) error { var errUnsupportedStatementType = errors.New("parseQuery: unsupported statement type") -func parseQuery(c core.Catalog, stmt nodes.Node, source string) (*Query, error) { +func parseQuery(c core.Catalog, stmt nodes.Node, source string, rewriteParameters bool) (*Query, error) { if err := validateParamRef(stmt); err != nil { return nil, err } @@ -443,6 +444,16 @@ func parseQuery(c core.Catalog, stmt nodes.Node, source string) (*Query, error) } rvs := rangeVars(raw.Stmt) refs := findParameters(raw.Stmt) + var edits []edit + if rewriteParameters { + edits, err = rewriteNumberedParameters(refs, raw, rawSQL) + if err != nil { + return nil, err + } + } else { + refs = uniqueParamRefs(refs) + sort.Slice(refs, func(i, j int) bool { return refs[i].ref.Number < refs[j].ref.Number }) + } params, err := resolveCatalogRefs(c, rvs, refs) if err != nil { return nil, err @@ -452,7 +463,13 @@ func parseQuery(c core.Catalog, stmt nodes.Node, source string) (*Query, error) if err != nil { return nil, err } - expanded, err := expand(c, raw, rawSQL) + expandEdits, err := expand(c, raw, rawSQL) + if err != nil { + return nil, err + } + edits = append(edits, expandEdits...) + + expanded, err := editQuery(rawSQL, edits) if err != nil { return nil, err } @@ -472,6 +489,18 @@ func parseQuery(c core.Catalog, stmt nodes.Node, source string) (*Query, error) }, nil } +func rewriteNumberedParameters(refs []paramRef, raw nodes.RawStmt, sql string) ([]edit, error) { + edits := make([]edit, len(refs)) + for i, ref := range refs { + edits[i] = edit{ + Location: ref.ref.Location - raw.StmtLocation, + Old: fmt.Sprintf("$%d", ref.ref.Number), + New: "?", + } + } + return edits, nil +} + func stripComments(sql string) (string, []string, error) { s := bufio.NewScanner(strings.NewReader(sql)) var lines, comments []string @@ -494,7 +523,7 @@ type edit struct { New string } -func expand(c core.Catalog, raw nodes.RawStmt, sql string) (string, error) { +func expand(c core.Catalog, raw nodes.RawStmt, sql string) ([]edit, error) { list := search(raw, func(node nodes.Node) bool { switch node.(type) { case nodes.DeleteStmt: @@ -507,17 +536,17 @@ func expand(c core.Catalog, raw nodes.RawStmt, sql string) (string, error) { return true }) if len(list.Items) == 0 { - return sql, nil + return nil, nil } var edits []edit for _, item := range list.Items { edit, err := expandStmt(c, raw, item) if err != nil { - return "", err + return nil, err } edits = append(edits, edit...) } - return editQuery(sql, edits) + return edits, nil } func expandStmt(c core.Catalog, raw nodes.RawStmt, node nodes.Node) ([]edit, error) { @@ -958,7 +987,8 @@ type paramRef struct { type paramSearch struct { parent nodes.Node rangeVar *nodes.RangeVar - refs map[int]paramRef + refs *[]paramRef + seen map[int]struct{} // XXX: Gross state hack for limit limitCount nodes.Node @@ -1005,7 +1035,8 @@ func (p paramSearch) Visit(node nodes.Node) Visitor { continue } // TODO: Out-of-bounds panic - p.refs[ref.Number] = paramRef{parent: n.Cols.Items[i], ref: ref, rv: p.rangeVar} + *p.refs = append(*p.refs, paramRef{parent: n.Cols.Items[i], ref: ref, rv: p.rangeVar}) + p.seen[ref.Location] = struct{}{} } for _, vl := range s.ValuesLists { for i, v := range vl { @@ -1014,7 +1045,8 @@ func (p paramSearch) Visit(node nodes.Node) Visitor { continue } // TODO: Out-of-bounds panic - p.refs[ref.Number] = paramRef{parent: n.Cols.Items[i], ref: ref, rv: p.rangeVar} + *p.refs = append(*p.refs, paramRef{parent: n.Cols.Items[i], ref: ref, rv: p.rangeVar}) + p.seen[ref.Location] = struct{}{} } } } @@ -1050,7 +1082,7 @@ func (p paramSearch) Visit(node nodes.Node) Visitor { parent = limitOffset{} } } - if _, found := p.refs[n.Number]; found { + if _, found := p.seen[n.Location]; found { break } @@ -1072,7 +1104,8 @@ func (p paramSearch) Visit(node nodes.Node) Visitor { } if set { - p.refs[n.Number] = paramRef{parent: parent, ref: n, rv: p.rangeVar} + *p.refs = append(*p.refs, paramRef{parent: parent, ref: n, rv: p.rangeVar}) + p.seen[n.Location] = struct{}{} } return nil } @@ -1080,13 +1113,9 @@ func (p paramSearch) Visit(node nodes.Node) Visitor { } func findParameters(root nodes.Node) []paramRef { - v := paramSearch{refs: map[int]paramRef{}} - Walk(v, root) refs := make([]paramRef, 0) - for _, r := range v.refs { - refs = append(refs, r) - } - sort.Slice(refs, func(i, j int) bool { return refs[i].ref.Number < refs[j].ref.Number }) + v := paramSearch{seen: make(map[int]struct{}), refs: &refs} + Walk(v, root) return refs } @@ -1348,3 +1377,15 @@ func resolveCatalogRefs(c core.Catalog, rvs []nodes.RangeVar, args []paramRef) ( } return a, nil } + +func uniqueParamRefs(in []paramRef) []paramRef { + m := make(map[int]struct{}, len(in)) + o := make([]paramRef, 0, len(in)) + for _, v := range in { + if _, ok := m[v.ref.Number]; !ok { + m[v.ref.Number] = struct{}{} + o = append(o, v) + } + } + return o +} diff --git a/internal/dinosql/parser_test.go b/internal/dinosql/parser_test.go index e53c59955c..d1741cc6be 100644 --- a/internal/dinosql/parser_test.go +++ b/internal/dinosql/parser_test.go @@ -3,6 +3,7 @@ package dinosql import ( "testing" + "github.com/google/go-cmp/cmp" pg "github.com/lfittl/pg_query_go" nodes "github.com/lfittl/pg_query_go/nodes" ) @@ -87,13 +88,14 @@ func TestLineColumn(t *testing.T) { func TestExtractArgs(t *testing.T) { queries := []struct { - query string - count int + query string + bindNumbers []int }{ - {"SELECT * FROM venue WHERE slug = $1 AND city = $2", 2}, - {"SELECT * FROM venue WHERE slug = $1", 1}, - {"SELECT * FROM venue LIMIT $1", 1}, - {"SELECT * FROM venue OFFSET $1", 1}, + {"SELECT * FROM venue WHERE slug = $1 AND city = $2", []int{1, 2}}, + {"SELECT * FROM venue WHERE slug = $1 AND region = $2 AND city = $3 AND country = $2", []int{1, 2, 3, 2}}, + {"SELECT * FROM venue WHERE slug = $1", []int{1}}, + {"SELECT * FROM venue LIMIT $1", []int{1}}, + {"SELECT * FROM venue OFFSET $1", []int{1}}, } for _, q := range queries { tree, err := pg.Parse(q.query) @@ -105,8 +107,46 @@ func TestExtractArgs(t *testing.T) { if err != nil { t.Error(err) } - if len(refs) != q.count { - t.Errorf("expected %d refs, got %d", q.count, len(refs)) + nums := make([]int, len(refs)) + for i, n := range refs { + nums[i] = n.ref.Number + } + if diff := cmp.Diff(q.bindNumbers, nums); diff != "" { + t.Errorf("expected bindings %v, got %v", q.bindNumbers, nums) + } + } + } +} + +func TestRewriteParameters(t *testing.T) { + queries := []struct { + orig string + new string + }{ + {"SELECT * FROM venue WHERE slug = $1 AND city = $3 AND bar = $2", "SELECT * FROM venue WHERE slug = ? AND city = ? AND bar = ?"}, + {"DELETE FROM venue WHERE slug = $1 AND slug = $1", "DELETE FROM venue WHERE slug = ? AND slug = ?"}, + {"SELECT * FROM venue LIMIT $1", "SELECT * FROM venue LIMIT ?"}, + } + for _, q := range queries { + tree, err := pg.Parse(q.orig) + if err != nil { + t.Fatal(err) + } + for _, stmt := range tree.Statements { + refs := findParameters(stmt) + if err != nil { + t.Error(err) + } + edits, err := rewriteNumberedParameters(refs, stmt.(nodes.RawStmt), q.orig) + if err != nil { + t.Error(err) + } + rewritten, err := editQuery(q.orig, edits) + if err != nil { + t.Error(err) + } + if rewritten != q.new { + t.Errorf("expected %q, got %q", q.new, rewritten) } } } diff --git a/internal/dinosql/query_test.go b/internal/dinosql/query_test.go index c80b1d85e6..8ee13ac451 100644 --- a/internal/dinosql/query_test.go +++ b/internal/dinosql/query_test.go @@ -21,7 +21,7 @@ func parseSQL(in string) (Query, error) { return Query{}, err } - q, err := parseQuery(c, tree.Statements[len(tree.Statements)-1], in) + q, err := parseQuery(c, tree.Statements[len(tree.Statements)-1], in, false) if q == nil { return Query{}, err } diff --git a/internal/endtoend/endtoend_test.go b/internal/endtoend/endtoend_test.go index ac69188b04..abccbc3cc1 100644 --- a/internal/endtoend/endtoend_test.go +++ b/internal/endtoend/endtoend_test.go @@ -58,10 +58,10 @@ func cmpDirectory(t *testing.T, dir string, actual map[string]string) { if file.IsDir() { return nil } - if !strings.HasSuffix(path, ".go") { + if !strings.HasSuffix(path, ".go") && !strings.HasSuffix(path, ".kt") { return nil } - if strings.HasSuffix(path, "_test.go") { + if strings.HasSuffix(path, "_test.go") || strings.Contains(path, "src/test/") { return nil } blob, err := ioutil.ReadFile(path) diff --git a/internal/sqltest/postgres.go b/internal/sqltest/postgres.go index be86df4e65..5a283ac33a 100644 --- a/internal/sqltest/postgres.go +++ b/internal/sqltest/postgres.go @@ -42,6 +42,10 @@ func PostgreSQL(t *testing.T, migrations string) (*sql.DB, func()) { pgUser = "postgres" } + if pgPass == "" { + pgPass = "mysecretpassword" + } + if pgPort == "" { pgPort = "5432" }