Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
.vscode
8 changes: 4 additions & 4 deletions oracle/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -448,19 +448,19 @@ func writeQuotedIdentifier(builder *strings.Builder, identifier string) {
// - plsqlBuilder: The builder to write the PL/SQL code into.
// - dbNames: The slice containing the column names.
// - table: The table name
func writeTableRecordCollectionDecl(plsqlBuilder *strings.Builder, dbNames []string, table string) {
func writeTableRecordCollectionDecl(db *gorm.DB, plsqlBuilder *strings.Builder, dbNames []string, table string) {
// Declare a record where each element has the same structure as a row from the given table
plsqlBuilder.WriteString(" TYPE t_record IS RECORD (\n")
for i, field := range dbNames {
if i > 0 {
plsqlBuilder.WriteString(",\n")
}
plsqlBuilder.WriteString(" ")
writeQuotedIdentifier(plsqlBuilder, field)
db.QuoteTo(plsqlBuilder, field)
plsqlBuilder.WriteString(" ")
writeQuotedIdentifier(plsqlBuilder, table)
db.QuoteTo(plsqlBuilder, table)
plsqlBuilder.WriteString(".")
writeQuotedIdentifier(plsqlBuilder, field)
db.QuoteTo(plsqlBuilder, field)
plsqlBuilder.WriteString("%TYPE")
}
plsqlBuilder.WriteString("\n")
Expand Down
44 changes: 22 additions & 22 deletions oracle/create.go
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ func buildBulkMergePLSQL(db *gorm.DB, createValues clause.Values, onConflictClau

// Start PL/SQL block
plsqlBuilder.WriteString("DECLARE\n")
writeTableRecordCollectionDecl(&plsqlBuilder, stmt.Schema.DBNames, stmt.Table)
writeTableRecordCollectionDecl(db, &plsqlBuilder, stmt.Schema.DBNames, stmt.Table)
plsqlBuilder.WriteString(" l_affected_records t_records;\n")

// Create array types and variables for each column
Expand Down Expand Up @@ -323,7 +323,7 @@ func buildBulkMergePLSQL(db *gorm.DB, createValues clause.Values, onConflictClau
// FORALL with MERGE and RETURNING BULK COLLECT INTO
plsqlBuilder.WriteString(fmt.Sprintf(" FORALL i IN 1..%d\n", len(createValues.Values)))
plsqlBuilder.WriteString(" MERGE INTO ")
writeQuotedIdentifier(&plsqlBuilder, stmt.Table)
db.QuoteTo(&plsqlBuilder, stmt.Table)
plsqlBuilder.WriteString(" t\n")
// Build USING clause
plsqlBuilder.WriteString(" USING (SELECT ")
Expand All @@ -332,7 +332,7 @@ func buildBulkMergePLSQL(db *gorm.DB, createValues clause.Values, onConflictClau
plsqlBuilder.WriteString(", ")
}
plsqlBuilder.WriteString(fmt.Sprintf("l_col_%d_array(i) AS ", idx))
writeQuotedIdentifier(&plsqlBuilder, column.Name)
db.QuoteTo(&plsqlBuilder, column.Name)
}
plsqlBuilder.WriteString(" FROM DUAL) s\n")

Expand All @@ -344,9 +344,9 @@ func buildBulkMergePLSQL(db *gorm.DB, createValues clause.Values, onConflictClau
plsqlBuilder.WriteString(" AND ")
}
plsqlBuilder.WriteString("t.")
writeQuotedIdentifier(&plsqlBuilder, conflictCol.Name)
db.QuoteTo(&plsqlBuilder, conflictCol.Name)
plsqlBuilder.WriteString(" = s.")
writeQuotedIdentifier(&plsqlBuilder, conflictCol.Name)
db.QuoteTo(&plsqlBuilder, conflictCol.Name)
}
plsqlBuilder.WriteString(")\n")

Expand All @@ -371,9 +371,9 @@ func buildBulkMergePLSQL(db *gorm.DB, createValues clause.Values, onConflictClau
plsqlBuilder.WriteString(", ")
}
plsqlBuilder.WriteString("t.")
writeQuotedIdentifier(&plsqlBuilder, column.Name)
db.QuoteTo(&plsqlBuilder, column.Name)
plsqlBuilder.WriteString(" = s.")
writeQuotedIdentifier(&plsqlBuilder, column.Name)
db.QuoteTo(&plsqlBuilder, column.Name)
updateCount++
}
}
Expand Down Expand Up @@ -405,9 +405,9 @@ func buildBulkMergePLSQL(db *gorm.DB, createValues clause.Values, onConflictClau
plsqlBuilder.WriteString(", ")
}
plsqlBuilder.WriteString("t.")
writeQuotedIdentifier(&plsqlBuilder, column.Name)
db.QuoteTo(&plsqlBuilder, column.Name)
plsqlBuilder.WriteString(" = s.")
writeQuotedIdentifier(&plsqlBuilder, column.Name)
db.QuoteTo(&plsqlBuilder, column.Name)
updateCount++
}
}
Expand All @@ -427,9 +427,9 @@ func buildBulkMergePLSQL(db *gorm.DB, createValues clause.Values, onConflictClau
}
}
plsqlBuilder.WriteString(" WHEN MATCHED THEN UPDATE SET t.")
writeQuotedIdentifier(&plsqlBuilder, noopCol)
db.QuoteTo(&plsqlBuilder, noopCol)
plsqlBuilder.WriteString(" = s.")
writeQuotedIdentifier(&plsqlBuilder, noopCol)
db.QuoteTo(&plsqlBuilder, noopCol)
plsqlBuilder.WriteString("\n")
}

Expand All @@ -444,7 +444,7 @@ func buildBulkMergePLSQL(db *gorm.DB, createValues clause.Values, onConflictClau
if insertCount > 0 {
plsqlBuilder.WriteString(", ")
}
writeQuotedIdentifier(&plsqlBuilder, column.Name)
db.QuoteTo(&plsqlBuilder, column.Name)
insertCount++
}
}
Expand All @@ -459,7 +459,7 @@ func buildBulkMergePLSQL(db *gorm.DB, createValues clause.Values, onConflictClau
plsqlBuilder.WriteString(", ")
}
plsqlBuilder.WriteString("s.")
writeQuotedIdentifier(&plsqlBuilder, column.Name)
db.QuoteTo(&plsqlBuilder, column.Name)
insertCount++
}
}
Expand All @@ -475,7 +475,7 @@ func buildBulkMergePLSQL(db *gorm.DB, createValues clause.Values, onConflictClau
if insertCount > 0 {
plsqlBuilder.WriteString(", ")
}
writeQuotedIdentifier(&plsqlBuilder, column.Name)
db.QuoteTo(&plsqlBuilder, column.Name)
insertCount++
}
}
Expand All @@ -489,7 +489,7 @@ func buildBulkMergePLSQL(db *gorm.DB, createValues clause.Values, onConflictClau
plsqlBuilder.WriteString(", ")
}
plsqlBuilder.WriteString("s.")
writeQuotedIdentifier(&plsqlBuilder, column.Name)
db.QuoteTo(&plsqlBuilder, column.Name)
insertCount++
}
}
Expand All @@ -503,7 +503,7 @@ func buildBulkMergePLSQL(db *gorm.DB, createValues clause.Values, onConflictClau
if i > 0 {
plsqlBuilder.WriteString(", ")
}
writeQuotedIdentifier(&plsqlBuilder, column)
db.QuoteTo(&plsqlBuilder, column)
}
plsqlBuilder.WriteString("\n BULK COLLECT INTO l_affected_records;\n")

Expand All @@ -514,7 +514,7 @@ func buildBulkMergePLSQL(db *gorm.DB, createValues clause.Values, onConflictClau
if field := findFieldByDBName(schema, column); field != nil {
stmt.Vars = append(stmt.Vars, sql.Out{Dest: createTypedDestination(field)})
plsqlBuilder.WriteString(fmt.Sprintf(" IF l_affected_records.COUNT > %d THEN :%d := l_affected_records(%d).", rowIdx, outParamIndex+1, rowIdx+1))
writeQuotedIdentifier(&plsqlBuilder, column)
db.QuoteTo(&plsqlBuilder, column)
plsqlBuilder.WriteString("; END IF;\n")
outParamIndex++
}
Expand Down Expand Up @@ -548,7 +548,7 @@ func buildBulkInsertOnlyPLSQL(db *gorm.DB, createValues clause.Values) {

// Start PL/SQL block
plsqlBuilder.WriteString("DECLARE\n")
writeTableRecordCollectionDecl(&plsqlBuilder, stmt.Schema.DBNames, stmt.Table)
writeTableRecordCollectionDecl(db, &plsqlBuilder, stmt.Schema.DBNames, stmt.Table)
plsqlBuilder.WriteString(" l_inserted_records t_records;\n")

// Create array types and variables for each column
Expand Down Expand Up @@ -582,14 +582,14 @@ func buildBulkInsertOnlyPLSQL(db *gorm.DB, createValues clause.Values) {
// FORALL with RETURNING BULK COLLECT INTO
plsqlBuilder.WriteString(fmt.Sprintf(" FORALL i IN 1..%d\n", len(createValues.Values)))
plsqlBuilder.WriteString(" INSERT INTO ")
writeQuotedIdentifier(&plsqlBuilder, stmt.Table)
db.QuoteTo(&plsqlBuilder, stmt.Table)
plsqlBuilder.WriteString(" (")
// Add column names
for i, column := range createValues.Columns {
if i > 0 {
plsqlBuilder.WriteString(", ")
}
writeQuotedIdentifier(&plsqlBuilder, column.Name)
db.QuoteTo(&plsqlBuilder, column.Name)
}
plsqlBuilder.WriteString(") VALUES (")

Expand All @@ -609,7 +609,7 @@ func buildBulkInsertOnlyPLSQL(db *gorm.DB, createValues clause.Values) {
if i > 0 {
plsqlBuilder.WriteString(", ")
}
writeQuotedIdentifier(&plsqlBuilder, column)
db.QuoteTo(&plsqlBuilder, column)
}
plsqlBuilder.WriteString("\n BULK COLLECT INTO l_inserted_records;\n")

Expand All @@ -618,7 +618,7 @@ func buildBulkInsertOnlyPLSQL(db *gorm.DB, createValues clause.Values) {
for rowIdx := 0; rowIdx < len(createValues.Values); rowIdx++ {
for _, column := range allColumns {
var columnBuilder strings.Builder
writeQuotedIdentifier(&columnBuilder, column)
db.QuoteTo(&columnBuilder, column)
quotedColumn := columnBuilder.String()

if field := findFieldByDBName(schema, column); field != nil {
Expand Down
16 changes: 8 additions & 8 deletions oracle/delete.go
Original file line number Diff line number Diff line change
Expand Up @@ -255,13 +255,13 @@ func buildBulkDeletePLSQL(db *gorm.DB) {

// Start PL/SQL block
plsqlBuilder.WriteString("DECLARE\n")
writeTableRecordCollectionDecl(&plsqlBuilder, stmt.Schema.DBNames, stmt.Table)
writeTableRecordCollectionDecl(db, &plsqlBuilder, stmt.Schema.DBNames, stmt.Table)
plsqlBuilder.WriteString(" l_deleted_records t_records;\n")
plsqlBuilder.WriteString("BEGIN\n")

// Build DELETE statement
plsqlBuilder.WriteString(" DELETE FROM ")
writeQuotedIdentifier(&plsqlBuilder, stmt.Table)
db.QuoteTo(&plsqlBuilder, stmt.Table)

// Add WHERE clause if it exists
if whereClause, hasWhere := stmt.Clauses["WHERE"]; hasWhere {
Expand All @@ -278,7 +278,7 @@ func buildBulkDeletePLSQL(db *gorm.DB) {
if i > 0 {
plsqlBuilder.WriteString(", ")
}
writeQuotedIdentifier(&plsqlBuilder, column)
db.QuoteTo(&plsqlBuilder, column)

}
plsqlBuilder.WriteString("\n BULK COLLECT INTO l_deleted_records;\n")
Expand All @@ -297,7 +297,7 @@ func buildBulkDeletePLSQL(db *gorm.DB) {

plsqlBuilder.WriteString(fmt.Sprintf(" IF l_deleted_records.COUNT > %d THEN\n", rowIdx))
plsqlBuilder.WriteString(fmt.Sprintf(" :%d := l_deleted_records(%d).", outParamIndex+1, rowIdx+1))
writeQuotedIdentifier(&plsqlBuilder, column)
db.QuoteTo(&plsqlBuilder, column)
plsqlBuilder.WriteString(";\n")
plsqlBuilder.WriteString(" END IF;\n")
outParamIndex++
Expand All @@ -324,9 +324,9 @@ func buildWhereClause(db *gorm.DB, plsqlBuilder *strings.Builder, expressions []
case clause.Eq:
// Write the column name
if columnName, ok := e.Column.(string); ok {
writeQuotedIdentifier(plsqlBuilder, columnName)
db.QuoteTo(plsqlBuilder, columnName)
} else if columnExpr, ok := e.Column.(clause.Column); ok {
writeQuotedIdentifier(plsqlBuilder, columnExpr.Name)
db.QuoteTo(plsqlBuilder, columnExpr.Name)
} else {
plsqlBuilder.WriteString(fmt.Sprintf("%v", e.Column))
}
Expand All @@ -342,9 +342,9 @@ func buildWhereClause(db *gorm.DB, plsqlBuilder *strings.Builder, expressions []

case clause.IN:
if columnName, ok := e.Column.(string); ok {
writeQuotedIdentifier(plsqlBuilder, columnName)
db.QuoteTo(plsqlBuilder, columnName)
} else if columnExpr, ok := e.Column.(clause.Column); ok {
writeQuotedIdentifier(plsqlBuilder, columnExpr.Name)
db.QuoteTo(plsqlBuilder, columnExpr.Name)
} else {
plsqlBuilder.WriteString(fmt.Sprintf("%v", e.Column))
}
Expand Down
31 changes: 24 additions & 7 deletions oracle/oracle.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,11 @@ import (
)

type Config struct {
DriverName string
DataSourceName string
Conn *sql.DB
DefaultStringSize uint
DriverName string
DataSourceName string
Conn *sql.DB
DefaultStringSize uint
SkipQuoteIdentifiers bool
}

type Dialector struct {
Expand Down Expand Up @@ -104,6 +105,18 @@ func (d Dialector) Initialize(db *gorm.DB) (err error) {
callback.Update().Replace("gorm:update", Update)
callback.Query().Before("gorm:query").Register("oracle:before_query", BeforeQuery)

if d.SkipQuoteIdentifiers {
// When identifiers are not quoted, columns are returned by Oracle in uppercase.
// Fields in the models may be lower case for compatibility with other databases.
// Match them up with the fields using the column mapping.
oracleCaseHandler := "oracle:case_handler"
if callback.Query().Get(oracleCaseHandler) == nil {
if err := callback.Query().Before("gorm:query").Register(oracleCaseHandler, MismatchedCaseHandler); err != nil {
return err
}
}
}

maps.Copy(db.ClauseBuilders, OracleClauseBuilders())

if d.Conn == nil {
Expand Down Expand Up @@ -237,9 +250,13 @@ func (d Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v inter

// Manages quoting of identifiers
func (d Dialector) QuoteTo(writer clause.Writer, str string) {
var builder strings.Builder
writeQuotedIdentifier(&builder, str)
writer.WriteString(builder.String())
out := str
if !d.SkipQuoteIdentifiers {
var builder strings.Builder
writeQuotedIdentifier(&builder, str)
out = builder.String()
}
_, _ = writer.WriteString(out)
}

var numericPlaceholder = regexp.MustCompile(`:(\d+)`)
Expand Down
19 changes: 18 additions & 1 deletion oracle/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,10 @@
package oracle

import (
"gorm.io/gorm"
"regexp"
"strings"

"gorm.io/gorm"
)

// Identifies the table name alias provided as
Expand All @@ -65,3 +66,19 @@ func BeforeQuery(db *gorm.DB) {
}
return
}

// MismatchedCaseHandler handles Oracle Case Insensitivity.
// When identifiers are not quoted, columns are returned by Oracle in uppercase.
// Fields in the models may be lower case for compatibility with other databases.
// Match them up with the fields using the column mapping.
func MismatchedCaseHandler(gormDB *gorm.DB) {
if gormDB.Statement == nil || gormDB.Statement.Schema == nil {
return
}
if len(gormDB.Statement.Schema.Fields) > 0 && gormDB.Statement.ColumnMapping == nil {
gormDB.Statement.ColumnMapping = map[string]string{}
}
for _, field := range gormDB.Statement.Schema.Fields {
gormDB.Statement.ColumnMapping[strings.ToUpper(field.DBName)] = field.Name
}
}
10 changes: 5 additions & 5 deletions oracle/update.go
Original file line number Diff line number Diff line change
Expand Up @@ -476,21 +476,21 @@ func buildUpdatePLSQL(db *gorm.DB) {

// Start PL/SQL block
plsqlBuilder.WriteString("DECLARE\n")
writeTableRecordCollectionDecl(&plsqlBuilder, stmt.Schema.DBNames, stmt.Table)
writeTableRecordCollectionDecl(db, &plsqlBuilder, stmt.Schema.DBNames, stmt.Table)
plsqlBuilder.WriteString(" l_updated_records t_records;\n")
plsqlBuilder.WriteString("BEGIN\n")

// Build UPDATE statement
plsqlBuilder.WriteString(" UPDATE ")
writeQuotedIdentifier(&plsqlBuilder, stmt.Table)
db.QuoteTo(&plsqlBuilder, stmt.Table)
plsqlBuilder.WriteString(" SET ")

// Add SET assignments - handle both regular values and expressions
for i, assignment := range set {
if i > 0 {
plsqlBuilder.WriteString(", ")
}
writeQuotedIdentifier(&plsqlBuilder, assignment.Column.Name)
db.QuoteTo(&plsqlBuilder, assignment.Column.Name)
plsqlBuilder.WriteString(" = ")

// Check if the value is a clause.Expr (like gorm.Expr)
Expand Down Expand Up @@ -528,7 +528,7 @@ func buildUpdatePLSQL(db *gorm.DB) {
if i > 0 {
plsqlBuilder.WriteString(", ")
}
writeQuotedIdentifier(&plsqlBuilder, column)
db.QuoteTo(&plsqlBuilder, column)
}
plsqlBuilder.WriteString("\n BULK COLLECT INTO l_updated_records;\n")

Expand Down Expand Up @@ -559,7 +559,7 @@ func buildUpdatePLSQL(db *gorm.DB) {
// Add the assignment to PL/SQL with correct parameter reference
plsqlBuilder.WriteString(fmt.Sprintf(" IF l_updated_records.COUNT > %d THEN\n", rowIdx))
plsqlBuilder.WriteString(fmt.Sprintf(" :%d := l_updated_records(%d).", paramIndex, rowIdx+1))
writeQuotedIdentifier(&plsqlBuilder, column)
db.QuoteTo(&plsqlBuilder, column)
plsqlBuilder.WriteString(";\n")
plsqlBuilder.WriteString(" END IF;\n")
}
Expand Down
2 changes: 1 addition & 1 deletion tests/.gitignore
Original file line number Diff line number Diff line change
@@ -1 +1 @@
go.sum
passed-tests.txt.new
Loading