Skip to content

Commit

Permalink
enhance: support save for Oracle (#3364)
Browse files Browse the repository at this point in the history
  • Loading branch information
oldme-git authored Mar 13, 2024
1 parent 11f7187 commit 409041b
Show file tree
Hide file tree
Showing 6 changed files with 204 additions and 11 deletions.
118 changes: 114 additions & 4 deletions contrib/drivers/oracle/oracle_do_insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import (
"context"
"database/sql"
"fmt"
"github.com/gogf/gf/v2/container/gset"
"github.com/gogf/gf/v2/text/gstr"
"strings"

"github.com/gogf/gf/v2/database/gdb"
Expand All @@ -24,10 +26,7 @@ func (d *Driver) DoInsert(
) (result sql.Result, err error) {
switch option.InsertOption {
case gdb.InsertOptionSave:
return nil, gerror.NewCode(
gcode.CodeNotSupported,
`Save operation is not supported by oracle driver`,
)
return d.doSave(ctx, link, table, list, option)

case gdb.InsertOptionReplace:
return nil, gerror.NewCode(
Expand Down Expand Up @@ -93,3 +92,114 @@ func (d *Driver) DoInsert(
}
return batchResult, nil
}

// doSave support upsert for Oracle
func (d *Driver) doSave(ctx context.Context,
link gdb.Link, table string, list gdb.List, option gdb.DoInsertOption,
) (result sql.Result, err error) {
if len(option.OnConflict) == 0 {
return nil, gerror.NewCode(
gcode.CodeMissingParameter, `Please specify conflict columns`,
)
}

if len(list) == 0 {
return nil, gerror.NewCode(
gcode.CodeInvalidRequest, `Save operation list is empty by oracle driver`,
)
}

var (
one = list[0]
charL, charR = d.GetChars()
valueCharL, valueCharR = "'", "'"

conflictKeys = option.OnConflict
conflictKeySet = gset.New(false)

// insertKeys: Handle valid keys that need to be inserted
// insertValues: Handle values that need to be inserted
// updateValues: Handle values that need to be updated
// queryValues: Handle data that need to be upsert
queryValues, insertKeys, insertValues, updateValues []string
)

// conflictKeys slice type conv to set type
for _, conflictKey := range conflictKeys {
conflictKeySet.Add(gstr.ToUpper(conflictKey))
}

for key, value := range one {
saveValue := gconv.String(value)
queryValues = append(
queryValues,
fmt.Sprintf(
valueCharL+"%s"+valueCharR+" AS "+charL+"%s"+charR,
saveValue, key,
),
)

insertKeys = append(insertKeys, charL+key+charR)
insertValues = append(insertValues, "T2."+charL+key+charR)

// filter conflict keys in updateValues
if !conflictKeySet.Contains(key) {
updateValues = append(
updateValues,
fmt.Sprintf(`T1.%s = T2.%s`, charL+key+charR, charL+key+charR),
)
}
}

batchResult := new(gdb.SqlResult)
sqlStr := parseSqlForUpsert(table, queryValues, insertKeys, insertValues, updateValues, conflictKeys)
r, err := d.DoExec(ctx, link, sqlStr)
if err != nil {
return r, err
}
if n, err := r.RowsAffected(); err != nil {
return r, err
} else {
batchResult.Result = r
batchResult.Affected += n
}
return batchResult, nil
}

// parseSqlForUpsert
// MERGE INTO {{table}} T1
// USING ( SELECT {{queryValues}} FROM DUAL T2
// ON (T1.{{duplicateKey}} = T2.{{duplicateKey}} AND ...)
// WHEN NOT MATCHED THEN
// INSERT {{insertKeys}} VALUES {{insertValues}}
// WHEN MATCHED THEN
// UPDATE SET {{updateValues}}
func parseSqlForUpsert(table string,
queryValues, insertKeys, insertValues, updateValues, duplicateKey []string,
) (sqlStr string) {
var (
queryValueStr = strings.Join(queryValues, ",")
insertKeyStr = strings.Join(insertKeys, ",")
insertValueStr = strings.Join(insertValues, ",")
updateValueStr = strings.Join(updateValues, ",")
duplicateKeyStr string
pattern = gstr.Trim(`MERGE INTO %s T1 USING (SELECT %s FROM DUAL) T2 ON (%s) WHEN NOT MATCHED THEN INSERT(%s) VALUES (%s) WHEN MATCHED THEN UPDATE SET %s`)
)

for index, keys := range duplicateKey {
if index != 0 {
duplicateKeyStr += " AND "
}
duplicateTmp := fmt.Sprintf("T1.%s = T2.%s", keys, keys)
duplicateKeyStr += duplicateTmp
}

return fmt.Sprintf(pattern,
table,
queryValueStr,
duplicateKeyStr,
insertKeyStr,
insertValueStr,
updateValueStr,
)
}
6 changes: 3 additions & 3 deletions contrib/drivers/oracle/oracle_z_unit_basic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import (
"github.com/gogf/gf/v2/test/gtest"
)

func TestTables(t *testing.T) {
func Test_Tables(t *testing.T) {
gtest.C(t, func(t *gtest.T) {
tables := []string{"t_user1", "pop", "haha"}

Expand Down Expand Up @@ -60,7 +60,7 @@ func TestTables(t *testing.T) {
})
}

func TestTableFields(t *testing.T) {
func Test_Table_Fields(t *testing.T) {
gtest.C(t, func(t *gtest.T) {
createTable("t_user")
defer dropTable("t_user")
Expand Down Expand Up @@ -107,7 +107,7 @@ func TestTableFields(t *testing.T) {
})
}

func TestDoInsert(t *testing.T) {
func Test_Do_Insert(t *testing.T) {
gtest.C(t, func(t *gtest.T) {
createTable("t_user")
defer dropTable("t_user")
Expand Down
80 changes: 78 additions & 2 deletions contrib/drivers/oracle/oracle_z_unit_model_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ func Test_Model_RightJoin(t *testing.T) {
})
}

func TestPage(t *testing.T) {
func Test_Page(t *testing.T) {
table := createInitTable()
defer dropTable(table)
result, err := db.Model(table).Page(1, 2).Order("ID").All()
Expand Down Expand Up @@ -162,7 +162,6 @@ func TestPage(t *testing.T) {
func Test_Model_Insert(t *testing.T) {
table := createTable()
defer dropTable(table)
// db.SetDebug(true)
gtest.C(t, func(t *gtest.T) {
user := db.Model(table)
result, err := user.Data(g.Map{
Expand Down Expand Up @@ -1101,6 +1100,83 @@ func Test_Model_WhereOrNotLike(t *testing.T) {
})
}

func Test_Model_Save(t *testing.T) {
table := createTable("test")
defer dropTable(table)
gtest.C(t, func(t *gtest.T) {
type User struct {
Id int
Passport string
Password string
NickName string
CreateTime *gtime.Time
}
var (
user User
count int
result sql.Result
createTime = gtime.Now().Format("Y-m-d")
err error
)

result, err = db.Model(table).Data(g.Map{
"id": 1,
"passport": "p1",
"password": "15d55ad283aa400af464c76d713c07ad",
"nickname": "n1",
"create_time": createTime,
}).OnConflict("id").Save()

t.AssertNil(err)
n, _ := result.RowsAffected()
t.Assert(n, 1)

err = db.Model(table).Scan(&user)
t.AssertNil(err)
t.Assert(user.Id, 1)
t.Assert(user.Passport, "p1")
t.Assert(user.Password, "15d55ad283aa400af464c76d713c07ad")
t.Assert(user.NickName, "n1")
t.Assert(user.CreateTime.Format("Y-m-d"), createTime)

_, err = db.Model(table).Data(g.Map{
"id": 1,
"passport": "p1",
"password": "25d55ad283aa400af464c76d713c07ad",
"nickname": "n2",
"create_time": createTime,
}).OnConflict("id").Save()
t.AssertNil(err)

err = db.Model(table).Scan(&user)
t.AssertNil(err)
t.Assert(user.Passport, "p1")
t.Assert(user.Password, "25d55ad283aa400af464c76d713c07ad")
t.Assert(user.NickName, "n2")
t.Assert(user.CreateTime.Format("Y-m-d"), createTime)

count, err = db.Model(table).Count()
t.AssertNil(err)
t.Assert(count, 1)
})
}

func Test_Model_Replace(t *testing.T) {
table := createTable()
defer dropTable(table)

gtest.C(t, func(t *gtest.T) {
_, err := db.Model(table).Data(g.Map{
"id": 1,
"passport": "t11",
"password": "25d55ad283aa400af464c76d713c07ad",
"nickname": "T11",
"create_time": "2018-10-24 10:00:00",
}).Replace()
t.Assert(err, "Replace operation is not supported by oracle driver")
})
}

/* not support the "AS"
func Test_Model_Raw(t *testing.T) {
table := createInitTable()
Expand Down
5 changes: 4 additions & 1 deletion contrib/drivers/pgsql/pgsql_format_upsert.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"fmt"

"github.com/gogf/gf/v2/database/gdb"
"github.com/gogf/gf/v2/errors/gcode"
"github.com/gogf/gf/v2/errors/gerror"
"github.com/gogf/gf/v2/text/gstr"
"github.com/gogf/gf/v2/util/gconv"
Expand All @@ -19,7 +20,9 @@ import (
// For example: ON CONFLICT (id) DO UPDATE SET ...
func (d *Driver) FormatUpsert(columns []string, list gdb.List, option gdb.DoInsertOption) (string, error) {
if len(option.OnConflict) == 0 {
return "", gerror.New("Please specify conflict columns")
return "", gerror.NewCode(
gcode.CodeMissingParameter, `Please specify conflict columns`,
)
}

var onDuplicateStr string
Expand Down
5 changes: 4 additions & 1 deletion contrib/drivers/sqlite/sqlite_format_upsert.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"fmt"

"github.com/gogf/gf/v2/database/gdb"
"github.com/gogf/gf/v2/errors/gcode"
"github.com/gogf/gf/v2/errors/gerror"
"github.com/gogf/gf/v2/text/gstr"
"github.com/gogf/gf/v2/util/gconv"
Expand All @@ -19,7 +20,9 @@ import (
// For example: ON CONFLICT (id) DO UPDATE SET ...
func (d *Driver) FormatUpsert(columns []string, list gdb.List, option gdb.DoInsertOption) (string, error) {
if len(option.OnConflict) == 0 {
return "", gerror.New("Please specify conflict columns")
return "", gerror.NewCode(
gcode.CodeMissingParameter, `Please specify conflict columns`,
)
}

var onDuplicateStr string
Expand Down
1 change: 1 addition & 0 deletions database/gdb/gdb_core_underlying.go
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,7 @@ func (c *Core) FormatUpsert(columns []string, list List, option DoInsertOption)
)
}
}

return InsertOnDuplicateKeyUpdate + " " + onDuplicateStr, nil
}

Expand Down

0 comments on commit 409041b

Please sign in to comment.