Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#482 Fix reading of column description for tables that are specified in quotes #485

Merged
merged 2 commits into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,13 @@ trait SqlGenerator {
*/
def quote(identifier: String): String

/**
* Unquotes an identifier name with characters specific to SQL dialects.
* If the identifier is already not quoted, nothing will be done.
* It supports partially quoted identifiers. E.g. '"my_catalog".my table' will be quoted as 'my_catalog.my table'.
*/
def unquote(identifier: String): String

/**
* Returns true if the SQL generator can only work if it has an active connection.
* This can be for database engines that does not support "SELECT * FROM table" and require explicit list of columns.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,16 @@ abstract class SqlGeneratorBase(sqlConfig: SqlConfig) extends SqlGenerator {
}
}

def unquoteSingleIdentifier(identifier: String): String = {
val (escapeBegin, escapeEnd) = beginEndEscapeChars

if (identifier.startsWith(s"$escapeBegin") && identifier.endsWith(s"$escapeEnd") && identifier.length > 2) {
identifier.substring(1, identifier.length - 1)
} else {
identifier
}
}

override def getAliasExpression(expression: String, alias: String): String = {
s"$expression AS ${escape(alias)}"
}
Expand All @@ -53,6 +63,11 @@ abstract class SqlGeneratorBase(sqlConfig: SqlConfig) extends SqlGenerator {
splitComplexIdentifier(identifier).map(quoteSingleIdentifier).mkString(".")
}

override final def unquote(identifier: String): String = {
validateIdentifier(identifier)
splitComplexIdentifier(identifier).map(unquoteSingleIdentifier).mkString(".")
}

override final def escape(identifier: String): String = {
if (needsEscaping(sqlConfig.identifierQuotingPolicy, identifier)) {
quote(identifier)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,11 @@ class TableReaderJdbc(jdbcReaderConfig: TableReaderJdbcConfig,
JdbcSparkUtils.withJdbcMetadata(jdbcReaderConfig.jdbcConfig, sql) { (connection, jdbcMetadata) =>
val schemaWithMetadata = JdbcSparkUtils.addMetadataFromJdbc(df.schema, jdbcMetadata)
val schemaWithColumnDescriptions = tableOpt match {
case Some(table) => JdbcSparkUtils.addColumnDescriptionsFromJdbc(schemaWithMetadata, table, connection)
case None => schemaWithMetadata
case Some(table) =>
log.info(s"Reading JDBC metadata descriptions the table: $table")
JdbcSparkUtils.addColumnDescriptionsFromJdbc(schemaWithMetadata, sqlGen.unquote(table), connection)
case None =>
schemaWithMetadata
}
df = spark.createDataFrame(df.rdd, schemaWithColumnDescriptions)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,13 @@ class TableReaderJdbcNative(jdbcReaderConfig: TableReaderJdbcConfig,
}

if (jdbcReaderConfig.enableSchemaMetadata) {
JdbcSparkUtils.withJdbcMetadata(jdbcReaderConfig.jdbcConfig, sql) { (connection, jdbcMetadata) =>
JdbcSparkUtils.withJdbcMetadata(jdbcReaderConfig.jdbcConfig, sql) { (connection, _) =>
val schemaWithColumnDescriptions = tableOpt match {
case Some(table) =>
log.info(s"Reading JDBC metadata descriptions the query: $sql")
JdbcSparkUtils.addColumnDescriptionsFromJdbc(df.schema, table, connection)
case None => df.schema
log.info(s"Reading JDBC metadata descriptions the table: $table")
JdbcSparkUtils.addColumnDescriptionsFromJdbc(df.schema, sqlGen.unquote(table), connection)
case None =>
df.schema
}
df = spark.createDataFrame(df.rdd, schemaWithColumnDescriptions)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,11 @@ class SqlGeneratorMicrosoft(sqlConfig: SqlConfig) extends SqlGenerator {
splitComplexIdentifier(identifier).map(quoteSingleIdentifier).mkString(".")
}

override def unquote(identifier: String): String = {
validateIdentifier(identifier)
splitComplexIdentifier(identifier).map(unquoteSingleIdentifier).mkString(".")
}

override def escape(identifier: String): String = {
if (needsEscaping(sqlConfig.identifierQuotingPolicy, identifier)) {
quote(identifier)
Expand All @@ -138,6 +143,18 @@ class SqlGeneratorMicrosoft(sqlConfig: SqlConfig) extends SqlGenerator {
escape(sqlConfig.infoDateColumn)
}

private def unquoteSingleIdentifier(identifier: String): String = {
val (escapeBegin, escapeEnd) = beginEndEscapeChars

if (identifier.startsWith(s"$escapeBegin") && identifier.endsWith(s"$escapeEnd") && identifier.length > 2) {
identifier.substring(1, identifier.length - 1)
} else if (identifier.startsWith(s"$escapeChar2") && identifier.endsWith(s"$escapeChar2") && identifier.length > 2) {
identifier.substring(1, identifier.length - 1)
} else {
identifier
}
}

private def quoteSingleIdentifier(identifier: String): String = {
val (escapeBegin, escapeEnd) = beginEndEscapeChars

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ import za.co.absa.pramen.core.utils.impl.JdbcFieldMetadata

import java.sql.{Connection, DatabaseMetaData, ResultSet, ResultSetMetaData}
import scala.collection.mutable.ListBuffer
import scala.util.control.NonFatal

object JdbcSparkUtils {
private val log = LoggerFactory.getLogger(this.getClass)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,6 @@ class SqlGeneratorDummy(sqlConfig: SqlConfig) extends SqlGenerator {
override def escape(identifier: String): String = null

override def quote(identifier: String): String = null

override def unquote(identifier: String): String = null
}
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,14 @@ class SqlGeneratorLoaderSuite extends AnyWordSpec with RelationalDbFixture {
assert(actual == "\"System User\".\"Table Name\"")
}
}

"unquote" should {
"quote each subfields separately" in {
val actual = gen.unquote("System User.\"Table Name\"")

assert(actual == "System User.Table Name")
}
}
}

"Microsoft SQL generator" should {
Expand Down Expand Up @@ -341,6 +349,20 @@ class SqlGeneratorLoaderSuite extends AnyWordSpec with RelationalDbFixture {
}
}

"unquote" should {
"quote each subfields separately using quotes" in {
val actual = genDate.unquote("System User.\"Table Name\"")

assert(actual == "System User.Table Name")
}

"quote each subfields separately using brackets" in {
val actual = genDate.unquote("[System User].[Table Name]")

assert(actual == "System User.Table Name")
}
}

"splitComplexIdentifier" should {
"throw on an empty identifier" in {
assertThrows[IllegalArgumentException] {
Expand Down Expand Up @@ -576,6 +598,14 @@ class SqlGeneratorLoaderSuite extends AnyWordSpec with RelationalDbFixture {
}
}

"unquote" should {
"quote each subfields separately" in {
val actual = gen.unquote("System User.\"Table Name\"")

assert(actual == "System User.Table Name")
}
}

"splitComplexIdentifier" should {
"throw on an empty identifier" in {
assertThrows[IllegalArgumentException] {
Expand Down Expand Up @@ -1025,6 +1055,14 @@ class SqlGeneratorLoaderSuite extends AnyWordSpec with RelationalDbFixture {
assert(actual == "\"System User\".\"Table Name\"")
}
}

"unquote" should {
"quote each subfields separately" in {
val actual = genDate.unquote("System User.\"Table Name\"")

assert(actual == "System User.Table Name")
}
}
}

"MySQL SQL generator" should {
Expand Down Expand Up @@ -1147,6 +1185,14 @@ class SqlGeneratorLoaderSuite extends AnyWordSpec with RelationalDbFixture {
assert(actual == "`System User`.`Table Name`")
}
}

"unquote" should {
"quote each subfields separately" in {
val actual = genDate.unquote("System User.`Table Name`")

assert(actual == "System User.Table Name")
}
}
}

"DB2 SQL generator" should {
Expand Down Expand Up @@ -1269,6 +1315,14 @@ class SqlGeneratorLoaderSuite extends AnyWordSpec with RelationalDbFixture {
assert(actual == "\"System User\".\"Table Name\"")
}
}

"unquote" should {
"quote each subfields separately" in {
val actual = genDate.unquote("System User.\"Table Name\"")

assert(actual == "System User.Table Name")
}
}
}

"HSQL generator" should {
Expand Down Expand Up @@ -1560,6 +1614,14 @@ class SqlGeneratorLoaderSuite extends AnyWordSpec with RelationalDbFixture {
}
}
}

"unquote" should {
"quote each subfields separately" in {
val actual = genDate.unquote("System User.\"Table Name\"")

assert(actual == "System User.Table Name")
}
}
}

}
Loading