Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#301 Make SQL generators extensible so that custom generators could be plugged in without Pramen recompilation #352

Merged
merged 5 commits into from
Feb 12, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,10 @@ is determined by the pipeline configuration.

# Specifies the maximum number of records to fetch. Good for testing purposes.
#limit.records = 100

# Optionally, you can specify a class for a custom SQL generator for your RDMS engine.
# The class whould extend 'za.co.absa.pramen.api.sql.SqlGenerator'
#sql.generator.class = "com.example.MySqlGenerator"
}
```

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
* limitations under the License.
*/

package za.co.absa.pramen.core.reader.model
package za.co.absa.pramen.api.sql

sealed trait QuotingPolicy {
def name: String
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
* limitations under the License.
*/

package za.co.absa.pramen.core.sql
package za.co.absa.pramen.api.sql

sealed trait SqlColumnType

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@
* limitations under the License.
*/

package za.co.absa.pramen.core.sql
package za.co.absa.pramen.api.sql

import za.co.absa.pramen.core.reader.model.QuotingPolicy
import com.typesafe.config.Config

case class SqlConfig(
infoDateColumn: String,
infoDateType: SqlColumnType,
dateFormatApp: String,
identifierQuotingPolicy: QuotingPolicy
identifierQuotingPolicy: QuotingPolicy,
sqlGeneratorClass: Option[String],
extraConfig: Config
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/*
* Copyright 2022 ABSA Group Limited
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package za.co.absa.pramen.api.sql

import java.sql.Connection
import java.time.LocalDate

trait SqlGenerator {
/**
* Returns wrapped query that can be passed as .option("dtable", here) to the Spark JDBC reader.
* For example, given "SELECT * FROM abc", returns "(SELECT * FROM abc) tbl"
*/
def getDtable(sql: String): String

/**
* Generates a query that returns the record count of a table that does not have the information date field.
*/
def getCountQuery(tableName: String): String

/**
* Generates a query that returns the record count of a table for the given period when the table does have the information date field.
*/
def getCountQuery(tableName: String, infoDateBegin: LocalDate, infoDateEnd: LocalDate): String

/**
* Generates a query that returns data of a table that does not have the information date field.
*/
def getDataQuery(tableName: String, columns: Seq[String], limit: Option[Int]): String

/**
* Generates a query that returns the data of a table for the given period when the table does have the information date field.
*/
def getDataQuery(tableName: String, infoDateBegin: LocalDate, infoDateEnd: LocalDate, columns: Seq[String], limit: Option[Int]): String

/**
* Returns WHERE condition for table that has the information date field given the time period.
*/
def getWhere(dateBegin: LocalDate, dateEnd: LocalDate): String

/** Returns the date literal for the dialect of the SQL. */
def getDateLiteral(date: LocalDate): String

/**
* Validates and escapes an identifier (table or column name) if needed.
* Escaping happens according to the quoting policy.
*/
def escape(identifier: String): String

/**
* Quotes an identifier name with characters specific to SQL dialects.
* If the identifier is already wrapped, it won't be double wrapped.
* It supports partially quoted identifiers. E.g. '"my_catalog".my table' will be quoted as '"my_catalog"."my table"'.
*/
def quote(identifier: String): String

/**
* Returns true if the SQL generator can only work if it has an active connection.
* This can be for database engines that does not support "SELECT * FROM table" and require explicit list of columns.
* The connection can be used to query the list of columns first.
*/
def requiresConnection: Boolean = false

/**
* Sets the connection for the the SQL generator can only work if it has an active connection.
*/
def setConnection(connection: Connection): Unit = {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,7 @@
* limitations under the License.
*/

package za.co.absa.pramen.core.sql

import za.co.absa.pramen.core.reader.model.QuotingPolicy
package za.co.absa.pramen.api.sql

import scala.collection.mutable.ListBuffer

Expand Down Expand Up @@ -51,8 +49,7 @@ abstract class SqlGeneratorBase(sqlConfig: SqlConfig) extends SqlGenerator {
splitComplexIdentifier(identifier).map(quoteSingleIdentifier).mkString(".")
}

/** This validates and escapes an identifier (table or column name) if needed. Escaping does not happen always to maintain backwards compatibility. */
override final def escape(identifier: String): String = {
override final def escape(identifier: String): String = {
if (needsEscaping(sqlConfig.identifierQuotingPolicy, identifier)) {
quote(identifier)
} else {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Copyright 2022 ABSA Group Limited
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package za.co.absa.pramen.api.sql

import org.scalatest.wordspec.AnyWordSpec

class QuotingPolicySuite extends AnyWordSpec {
"QuotingPolicy.fromString" should {
"return Never for 'never'" in {
assert(QuotingPolicy.fromString("never") == QuotingPolicy.Never)
assert(QuotingPolicy.fromString("never").name == "never")
}

"return Always for 'always'" in {
assert(QuotingPolicy.fromString("always") == QuotingPolicy.Always)
assert(QuotingPolicy.fromString("always").name == "always")
}

"return Auto for 'auto'" in {
assert(QuotingPolicy.fromString("auto") == QuotingPolicy.Auto)
assert(QuotingPolicy.fromString("auto").name == "auto")
}

"throw an exception for an unknown quoting policy" in {
assertThrows[IllegalArgumentException] {
QuotingPolicy.fromString("unknown")
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@ package za.co.absa.pramen.core.reader
import com.typesafe.config.Config
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.slf4j.LoggerFactory
import za.co.absa.pramen.api.sql.{SqlColumnType, SqlConfig}
import za.co.absa.pramen.api.{Query, TableReader}
import za.co.absa.pramen.core.config.Keys
import za.co.absa.pramen.core.reader.model.TableReaderJdbcConfig
import za.co.absa.pramen.core.sql.{SqlColumnType, SqlConfig, SqlGenerator}
import za.co.absa.pramen.core.sql.SqlGeneratorLoader
import za.co.absa.pramen.core.utils.JdbcNativeUtils.JDBC_WORDS_TO_REDACT
import za.co.absa.pramen.core.utils.{ConfigUtils, JdbcSparkUtils, TimeUtils}

Expand All @@ -46,9 +47,7 @@ class TableReaderJdbc(jdbcReaderConfig: TableReaderJdbcConfig,
logConfiguration()

private[core] lazy val sqlGen = {
val gen = SqlGenerator.fromDriverName(jdbcReaderConfig.jdbcConfig.driver,
getSqlConfig,
ConfigUtils.getExtraConfig(conf, "sql"))
val gen = SqlGeneratorLoader.getSqlGenerator(jdbcReaderConfig.jdbcConfig.driver, getSqlConfig)

if (gen.requiresConnection) {
val (connection, url) = jdbcUrlSelector.getWorkingConnection(jdbcRetries)
Expand Down Expand Up @@ -211,7 +210,9 @@ class TableReaderJdbc(jdbcReaderConfig: TableReaderJdbcConfig,
SqlConfig(jdbcReaderConfig.infoDateColumn,
infoDateType,
jdbcReaderConfig.infoDateFormat,
jdbcReaderConfig.identifierQuotingPolicy)
jdbcReaderConfig.identifierQuotingPolicy,
jdbcReaderConfig.sqlGeneratorClass,
ConfigUtils.getExtraConfig(conf, "sql"))
case None => throw new IllegalArgumentException(s"Unknown info date type specified (${jdbcReaderConfig.infoDateType}). " +
s"It should be one of: date, string, number")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package za.co.absa.pramen.core.reader.model

import com.typesafe.config.Config
import org.slf4j.LoggerFactory
import za.co.absa.pramen.api.sql.QuotingPolicy
import za.co.absa.pramen.core.utils.ConfigUtils

case class TableReaderJdbcConfig(
Expand All @@ -32,7 +33,8 @@ case class TableReaderJdbcConfig(
correctDecimalsFixPrecision: Boolean = false,
enableSchemaMetadata: Boolean = false,
useJdbcNative: Boolean = false,
identifierQuotingPolicy: QuotingPolicy = QuotingPolicy.Auto
identifierQuotingPolicy: QuotingPolicy = QuotingPolicy.Auto,
sqlGeneratorClass: Option[String] = None
)

object TableReaderJdbcConfig {
Expand All @@ -51,6 +53,7 @@ object TableReaderJdbcConfig {
val ENABLE_SCHEMA_METADATA_KEY = "enable.schema.metadata"
val USE_JDBC_NATIVE = "use.jdbc.native"
val IDENTIFIER_QUOTING_POLICY = "identifier.quoting.policy"
val SQL_GENERATOR_CLASS_KEY = "sql.generator.class"

def load(conf: Config, parent: String = ""): TableReaderJdbcConfig = {
ConfigUtils.validatePathsExistence(conf, parent, HAS_INFO_DATE :: Nil)
Expand Down Expand Up @@ -87,7 +90,8 @@ object TableReaderJdbcConfig {
correctDecimalsFixPrecision = ConfigUtils.getOptionBoolean(conf, CORRECT_DECIMALS_FIX_PRECISION).getOrElse(false),
enableSchemaMetadata = ConfigUtils.getOptionBoolean(conf, ENABLE_SCHEMA_METADATA_KEY).getOrElse(false),
useJdbcNative = ConfigUtils.getOptionBoolean(conf, USE_JDBC_NATIVE).getOrElse(false),
identifierQuotingPolicy = identifierQuotingPolicy
identifierQuotingPolicy = identifierQuotingPolicy,
sqlGeneratorClass = ConfigUtils.getOptionString(conf, SQL_GENERATOR_CLASS_KEY)
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

package za.co.absa.pramen.core.sql

import za.co.absa.pramen.api.sql.{SqlColumnType, SqlConfig, SqlGeneratorBase}

import java.time.LocalDate
import java.time.format.DateTimeFormatter

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

package za.co.absa.pramen.core.sql

import za.co.absa.pramen.api.sql.{SqlColumnType, SqlConfig, SqlGeneratorBase}

import java.time.LocalDate
import java.time.format.DateTimeFormatter

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

package za.co.absa.pramen.core.sql

import za.co.absa.pramen.api.sql.{SqlColumnType, SqlConfig, SqlGeneratorBase}

import java.time.LocalDate
import java.time.format.DateTimeFormatter

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package za.co.absa.pramen.core.sql

import org.apache.spark.sql.jdbc.JdbcDialects
import org.slf4j.LoggerFactory
import za.co.absa.pramen.api.sql.{SqlColumnType, SqlConfig, SqlGeneratorBase}
import za.co.absa.pramen.core.sql.impl.HiveDialect

import java.time.LocalDate
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

package za.co.absa.pramen.core.sql

import za.co.absa.pramen.api.sql.{SqlColumnType, SqlConfig, SqlGeneratorBase}

import java.time.LocalDate
import java.time.format.DateTimeFormatter

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,50 +16,31 @@

package za.co.absa.pramen.core.sql

import com.typesafe.config.Config
import org.slf4j.LoggerFactory
import za.co.absa.pramen.api.sql.{SqlConfig, SqlGenerator}

import java.sql.Connection
import java.time.LocalDate

trait SqlGenerator {
def getDtable(sql: String): String

def getCountQuery(tableName: String): String

def getCountQuery(tableName: String, infoDateBegin: LocalDate, infoDateEnd: LocalDate): String

def getDataQuery(tableName: String, columns: Seq[String], limit: Option[Int]): String

def getDataQuery(tableName: String, infoDateBegin: LocalDate, infoDateEnd: LocalDate, columns: Seq[String], limit: Option[Int]): String

def getWhere(dateBegin: LocalDate, dateEnd: LocalDate): String

/** Returns the date literal for the dialect of the SQL. */
def getDateLiteral(date: LocalDate): String

/**
* Quotes an identifier, if needed according to the generator configuration.
*/
def escape(identifier: String): String
object SqlGeneratorLoader {
private val log = LoggerFactory.getLogger(this.getClass)

/**
* Always quotes an identifier name with characters specific to SQL dialects.
* If the identifier is already wrapped, it won't be double wrapped.
* It supports partially quoted identifiers. E.g. '"my_catalog".my table' will be quoted as '"my_catalog"."my table"'.
* Loads an SQL generator, If SQL configuration contains a generator class name, it will be loaded.
* If not, the generator will be selected based on the driver name based on the internal mapping.
* @param driver The driver class.
* @param sqlConfig The SQL configuration.
* @return The SQL generator.
*/
def quote(identifier: String): String

def requiresConnection: Boolean = false

def setConnection(connection: Connection): Unit = {}
}
def getSqlGenerator(driver: String, sqlConfig: SqlConfig): SqlGenerator = {
val sqlGenerator = sqlConfig.sqlGeneratorClass match {
case Some(clazz) => fromClass(clazz, sqlConfig)
case None => fromDriverName(driver, sqlConfig)
}

object SqlGenerator {
private val log = LoggerFactory.getLogger(this.getClass)
log.info(s"Using SQL generator: ${sqlGenerator.getClass.getCanonicalName}")
sqlGenerator
}

def fromDriverName(driver: String, sqlConfig: SqlConfig, extraConfig: Config): SqlGenerator = {
val sqlGenerator = driver match {
def fromDriverName(driver: String, sqlConfig: SqlConfig): SqlGenerator = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you leave the methods fromDriverName and fromClass with public access on purpose?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, it indeed makes sense to make the private since the object has getSqlGenerator() method that can be used in both cases.

driver match {
case "org.postgresql.Driver" => new SqlGeneratorPostgreSQL(sqlConfig)
case "oracle.jdbc.OracleDriver" => new SqlGeneratorOracle(sqlConfig)
case "net.sourceforge.jtds.jdbc.Driver" => new SqlGeneratorMicrosoft(sqlConfig)
Expand All @@ -75,7 +56,12 @@ object SqlGenerator {
log.warn(s"Unsupported JDBC driver: '$d'. Trying to use a generic SQL generator.")
new SqlGeneratorGeneric(sqlConfig)
}
log.info(s"Using SQL generator: ${sqlGenerator.getClass.getCanonicalName}")
sqlGenerator
}

def fromClass(clazz: String, sqlConfig: SqlConfig): SqlGenerator = {
Class.forName(clazz)
.getConstructor(classOf[SqlConfig])
.newInstance(sqlConfig)
.asInstanceOf[SqlGenerator]
}
}
Loading
Loading