From 6312d86c54db2da8b9874163564a86637d5c869c Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 8 Oct 2021 17:51:27 +0800 Subject: [PATCH] Support specify select/omit columns with table --- statement.go | 7 +++++++ statement_test.go | 13 +++++++++++++ 2 files changed, 20 insertions(+) diff --git a/statement.go b/statement.go index 3b76f653a..bea4f7f07 100644 --- a/statement.go +++ b/statement.go @@ -6,6 +6,7 @@ import ( "database/sql/driver" "fmt" "reflect" + "regexp" "sort" "strconv" "strings" @@ -627,6 +628,8 @@ func (stmt *Statement) Changed(fields ...string) bool { return false } +var nameMatcher = regexp.MustCompile(`\.[\W]?(.+?)[\W]?$`) + // SelectAndOmitColumns get select and omit columns, select -> true, omit -> false func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) (map[string]bool, bool) { results := map[string]bool{} @@ -647,6 +650,8 @@ func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) ( } } else if field := stmt.Schema.LookUpField(column); field != nil && field.DBName != "" { results[field.DBName] = true + } else if matches := nameMatcher.FindStringSubmatch(column); len(matches) == 2 { + results[matches[1]] = true } else { results[column] = true } @@ -662,6 +667,8 @@ func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) ( } } else if field := stmt.Schema.LookUpField(omit); field != nil && field.DBName != "" { results[field.DBName] = false + } else if matches := nameMatcher.FindStringSubmatch(omit); len(matches) == 2 { + results[matches[1]] = false } else { results[omit] = false } diff --git a/statement_test.go b/statement_test.go index 03ad81dc6..3f099d611 100644 --- a/statement_test.go +++ b/statement_test.go @@ -34,3 +34,16 @@ func TestWhereCloneCorruption(t *testing.T) { }) } } + +func TestNameMatcher(t *testing.T) { + for k, v := range map[string]string{ + "table.name": "name", + "`table`.`name`": "name", + "'table'.'name'": "name", + "'table'.name": "name", + } { + if matches := nameMatcher.FindStringSubmatch(k); len(matches) < 2 || matches[1] != v { + t.Errorf("failed to match value: %v, got %v, expect: %v", k, matches, v) + } + } +}