From c0f36e12e6289d21a7a9639d5352c69c6052e0d5 Mon Sep 17 00:00:00 2001 From: Juraj Bubniak Date: Fri, 24 Sep 2021 13:36:50 +0200 Subject: [PATCH] feat: add support for specifying tables to be locked in ForUpdate, ForNoKeyUpdate, ForKeyShare, ForShare --- README.md | 4 +-- dialect/sqlite3/sqlite3.go | 1 + dialect/sqlserver/sqlserver.go | 1 + docs/selecting.md | 26 ++++++++++++++++ exp/lock.go | 9 +++++- select_dataset.go | 20 ++++++------ select_dataset_example_test.go | 32 +++++++++++++++++++ select_dataset_test.go | 48 +++++++++++++++++++++++++++++ sqlgen/select_sql_generator.go | 17 +++++++++- sqlgen/select_sql_generator_test.go | 13 ++++++++ sqlgen/sql_dialect_options.go | 3 ++ 11 files changed, 160 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index ee99c476..1f470fdb 100644 --- a/README.md +++ b/README.md @@ -277,7 +277,7 @@ New features and/or enhancements are great and I encourage you to either submit 1. The use case 2. A short example -If you are issuing a PR also also include the following +If you are issuing a PR also include the following 1. Tests - otherwise the PR will not be merged 2. Documentation - otherwise the PR will not be merged @@ -297,7 +297,7 @@ go test -v -race ./... You can also run the tests in a container using [docker-compose](https://docs.docker.com/compose/). ```sh -GO_VERSION=latest docker-compose run goqu +MYSQL_VERSION=8 POSTGRES_VERSION=13.4 SQLSERVER_VERSION=2017-CU8-ubuntu GO_VERSION=latest docker-compose run goqu ``` ## License diff --git a/dialect/sqlite3/sqlite3.go b/dialect/sqlite3/sqlite3.go index e70222fb..f415bc4b 100644 --- a/dialect/sqlite3/sqlite3.go +++ b/dialect/sqlite3/sqlite3.go @@ -60,6 +60,7 @@ func DialectOptions() *goqu.SQLDialectOptions { opts.ConflictDoUpdateFragment = []byte(" DO UPDATE SET ") opts.ConflictDoNothingFragment = []byte(" DO NOTHING ") opts.ForUpdateFragment = []byte("") + opts.OfFragment = []byte("") opts.NowaitFragment = []byte("") return opts } diff --git a/dialect/sqlserver/sqlserver.go b/dialect/sqlserver/sqlserver.go index 58f9ad22..db34938c 100644 --- a/dialect/sqlserver/sqlserver.go +++ b/dialect/sqlserver/sqlserver.go @@ -80,6 +80,7 @@ func DialectOptions() *goqu.SQLDialectOptions { 0x1a: []byte("\\x1a"), } + opts.OfFragment = []byte("") opts.ConflictFragment = []byte("") opts.ConflictDoUpdateFragment = []byte("") opts.ConflictDoNothingFragment = []byte("") diff --git a/docs/selecting.md b/docs/selecting.md index ebab06c4..f35f2aa1 100644 --- a/docs/selecting.md +++ b/docs/selecting.md @@ -14,6 +14,7 @@ * [`Window`](#window) * [`With`](#with) * [`SetError`](#seterror) + * [`ForUpdate`](#forupdate) * Executing Queries * [`ScanStructs`](#scan-structs) - Scans rows into a slice of structs * [`ScanStruct`](#scan-struct) - Scans a row into a slice a struct, returns false if a row wasnt found @@ -875,6 +876,31 @@ name is empty name is empty ``` + +**[`ForUpdate`](https://godoc.org/github.com/doug-martin/goqu/#SelectDataset.ForUpdate)** + +```go +sql, _, _ := goqu.From("test").ForUpdate(exp.Wait).ToSQL() +fmt.Println(sql) +``` + +Output: +```sql +SELECT * FROM "test" FOR UPDATE +``` + +If your dialect supports FOR UPDATE OF you provide tables to be locked as variable arguments to the ForUpdate method. + +```go +sql, _, _ := goqu.From("test").ForUpdate(exp.Wait, goqu.T("test")).ToSQL() +fmt.Println(sql) +``` + +Output: +```sql +SELECT * FROM "test" FOR UPDATE OF "test" +``` + ## Executing Queries To execute your query use [`goqu.Database#From`](https://godoc.org/github.com/doug-martin/goqu/#Database.From) to create your dataset diff --git a/exp/lock.go b/exp/lock.go index e4548a22..9b8bf72e 100644 --- a/exp/lock.go +++ b/exp/lock.go @@ -6,10 +6,12 @@ type ( Lock interface { Strength() LockStrength WaitOption() WaitOption + Of() []IdentifierExpression } lock struct { strength LockStrength waitOption WaitOption + of []IdentifierExpression } ) @@ -25,10 +27,11 @@ const ( SkipLocked ) -func NewLock(strength LockStrength, option WaitOption) Lock { +func NewLock(strength LockStrength, option WaitOption, of ...IdentifierExpression) Lock { return lock{ strength: strength, waitOption: option, + of: of, } } @@ -39,3 +42,7 @@ func (l lock) Strength() LockStrength { func (l lock) WaitOption() WaitOption { return l.waitOption } + +func (l lock) Of() []IdentifierExpression { + return l.of +} diff --git a/select_dataset.go b/select_dataset.go index d027a372..775c387d 100644 --- a/select_dataset.go +++ b/select_dataset.go @@ -359,27 +359,27 @@ func (sd *SelectDataset) ClearWhere() *SelectDataset { } // Adds a FOR UPDATE clause. See examples. -func (sd *SelectDataset) ForUpdate(waitOption exp.WaitOption) *SelectDataset { - return sd.withLock(exp.ForUpdate, waitOption) +func (sd *SelectDataset) ForUpdate(waitOption exp.WaitOption, of ...exp.IdentifierExpression) *SelectDataset { + return sd.withLock(exp.ForUpdate, waitOption, of...) } // Adds a FOR NO KEY UPDATE clause. See examples. -func (sd *SelectDataset) ForNoKeyUpdate(waitOption exp.WaitOption) *SelectDataset { - return sd.withLock(exp.ForNoKeyUpdate, waitOption) +func (sd *SelectDataset) ForNoKeyUpdate(waitOption exp.WaitOption, of ...exp.IdentifierExpression) *SelectDataset { + return sd.withLock(exp.ForNoKeyUpdate, waitOption, of...) } // Adds a FOR KEY SHARE clause. See examples. -func (sd *SelectDataset) ForKeyShare(waitOption exp.WaitOption) *SelectDataset { - return sd.withLock(exp.ForKeyShare, waitOption) +func (sd *SelectDataset) ForKeyShare(waitOption exp.WaitOption, of ...exp.IdentifierExpression) *SelectDataset { + return sd.withLock(exp.ForKeyShare, waitOption, of...) } // Adds a FOR SHARE clause. See examples. -func (sd *SelectDataset) ForShare(waitOption exp.WaitOption) *SelectDataset { - return sd.withLock(exp.ForShare, waitOption) +func (sd *SelectDataset) ForShare(waitOption exp.WaitOption, of ...exp.IdentifierExpression) *SelectDataset { + return sd.withLock(exp.ForShare, waitOption, of...) } -func (sd *SelectDataset) withLock(strength exp.LockStrength, option exp.WaitOption) *SelectDataset { - return sd.copy(sd.clauses.SetLock(exp.NewLock(strength, option))) +func (sd *SelectDataset) withLock(strength exp.LockStrength, option exp.WaitOption, of ...exp.IdentifierExpression) *SelectDataset { + return sd.copy(sd.clauses.SetLock(exp.NewLock(strength, option, of...))) } // Adds a GROUP BY clause. See examples. diff --git a/select_dataset_example_test.go b/select_dataset_example_test.go index 35f62794..e8d7fc2d 100644 --- a/select_dataset_example_test.go +++ b/select_dataset_example_test.go @@ -9,6 +9,7 @@ import ( "time" "github.com/doug-martin/goqu/v9" + "github.com/doug-martin/goqu/v9/exp" "github.com/lib/pq" ) @@ -1651,3 +1652,34 @@ func ExampleSelectDataset_Executor_scannerScanVal() { // Sally // Vinita } + +func ExampleForUpdate() { + sql, args, _ := goqu.From("test").ForUpdate(exp.Wait).ToSQL() + fmt.Println(sql, args) + + // Output: + // SELECT * FROM "test" FOR UPDATE [] +} + +func ExampleForUpdate_of() { + sql, args, _ := goqu.From("test").ForUpdate(exp.Wait, goqu.T("test")).ToSQL() + fmt.Println(sql, args) + + // Output: + // SELECT * FROM "test" FOR UPDATE OF "test" [] +} + +func ExampleForUpdate_ofMultiple() { + sql, args, _ := goqu.From("table1").Join( + goqu.T("table2"), + goqu.On(goqu.I("table2.id").Eq(goqu.I("table1.id"))), + ).ForUpdate( + exp.Wait, + goqu.T("table1"), + goqu.T("table2"), + ).ToSQL() + fmt.Println(sql, args) + + // Output: + // SELECT * FROM "table1" INNER JOIN "table2" ON ("table2"."id" = "table1"."id") FOR UPDATE OF "table1", "table2" [] +} diff --git a/select_dataset_test.go b/select_dataset_test.go index 63fdd997..6579cfad 100644 --- a/select_dataset_test.go +++ b/select_dataset_test.go @@ -674,6 +674,18 @@ func (sds *selectDatasetSuite) TestForUpdate() { SetFrom(exp.NewColumnListExpression("test")). SetLock(exp.NewLock(exp.ForUpdate, goqu.NoWait)), }, + selectTestCase{ + ds: bd.ForUpdate(goqu.NoWait, goqu.T("table1")), + clauses: exp.NewSelectClauses(). + SetFrom(exp.NewColumnListExpression("test")). + SetLock(exp.NewLock(exp.ForUpdate, goqu.NoWait, goqu.T("table1"))), + }, + selectTestCase{ + ds: bd.ForUpdate(goqu.NoWait, goqu.T("table1"), goqu.T("table2")), + clauses: exp.NewSelectClauses(). + SetFrom(exp.NewColumnListExpression("test")). + SetLock(exp.NewLock(exp.ForUpdate, goqu.NoWait, goqu.T("table1"), goqu.T("table2"))), + }, selectTestCase{ ds: bd, clauses: exp.NewSelectClauses().SetFrom(exp.NewColumnListExpression("test")), @@ -690,6 +702,18 @@ func (sds *selectDatasetSuite) TestForNoKeyUpdate() { SetFrom(exp.NewColumnListExpression("test")). SetLock(exp.NewLock(exp.ForNoKeyUpdate, goqu.NoWait)), }, + selectTestCase{ + ds: bd.ForNoKeyUpdate(goqu.NoWait, goqu.T("table1")), + clauses: exp.NewSelectClauses(). + SetFrom(exp.NewColumnListExpression("test")). + SetLock(exp.NewLock(exp.ForNoKeyUpdate, goqu.NoWait, goqu.T("table1"))), + }, + selectTestCase{ + ds: bd.ForNoKeyUpdate(goqu.NoWait, goqu.T("table1"), goqu.T("table2")), + clauses: exp.NewSelectClauses(). + SetFrom(exp.NewColumnListExpression("test")). + SetLock(exp.NewLock(exp.ForNoKeyUpdate, goqu.NoWait, goqu.T("table1"), goqu.T("table2"))), + }, selectTestCase{ ds: bd, clauses: exp.NewSelectClauses().SetFrom(exp.NewColumnListExpression("test")), @@ -706,6 +730,18 @@ func (sds *selectDatasetSuite) TestForKeyShare() { SetFrom(exp.NewColumnListExpression("test")). SetLock(exp.NewLock(exp.ForKeyShare, goqu.NoWait)), }, + selectTestCase{ + ds: bd.ForKeyShare(goqu.NoWait, goqu.T("table1")), + clauses: exp.NewSelectClauses(). + SetFrom(exp.NewColumnListExpression("test")). + SetLock(exp.NewLock(exp.ForKeyShare, goqu.NoWait, goqu.T("table1"))), + }, + selectTestCase{ + ds: bd.ForKeyShare(goqu.NoWait, goqu.T("table1"), goqu.T("table2")), + clauses: exp.NewSelectClauses(). + SetFrom(exp.NewColumnListExpression("test")). + SetLock(exp.NewLock(exp.ForKeyShare, goqu.NoWait, goqu.T("table1"), goqu.T("table2"))), + }, selectTestCase{ ds: bd, clauses: exp.NewSelectClauses().SetFrom(exp.NewColumnListExpression("test")), @@ -722,6 +758,18 @@ func (sds *selectDatasetSuite) TestForShare() { SetFrom(exp.NewColumnListExpression("test")). SetLock(exp.NewLock(exp.ForShare, goqu.NoWait)), }, + selectTestCase{ + ds: bd.ForShare(goqu.NoWait, goqu.T("table1")), + clauses: exp.NewSelectClauses(). + SetFrom(exp.NewColumnListExpression("test")). + SetLock(exp.NewLock(exp.ForShare, goqu.NoWait, goqu.T("table1"))), + }, + selectTestCase{ + ds: bd.ForShare(goqu.NoWait, goqu.T("table1"), goqu.T("table2")), + clauses: exp.NewSelectClauses(). + SetFrom(exp.NewColumnListExpression("test")). + SetLock(exp.NewLock(exp.ForShare, goqu.NoWait, goqu.T("table1"), goqu.T("table2"))), + }, selectTestCase{ ds: bd, clauses: exp.NewSelectClauses().SetFrom(exp.NewColumnListExpression("test")), diff --git a/sqlgen/select_sql_generator.go b/sqlgen/select_sql_generator.go index 7bbb4142..de322910 100644 --- a/sqlgen/select_sql_generator.go +++ b/sqlgen/select_sql_generator.go @@ -196,8 +196,23 @@ func (ssg *selectSQLGenerator) ForSQL(b sb.SQLBuilder, lockingClause exp.Lock) { case exp.ForKeyShare: b.Write(ssg.DialectOptions().ForKeyShareFragment) } + + of := lockingClause.Of() + if ofLen := len(of); ofLen > 0 { + if ofFragment := ssg.DialectOptions().OfFragment; len(ofFragment) > 0 { + b.Write(ofFragment) + for i, table := range of { + ssg.ExpressionSQLGenerator().Generate(b, table) + if i < ofLen-1 { + b.WriteRunes(ssg.DialectOptions().CommaRune, ssg.DialectOptions().SpaceRune) + } + } + b.WriteRunes(ssg.DialectOptions().SpaceRune) + } + } + // the WAIT case is the default in Postgres, and is what you get if you don't specify NOWAIT or - // SKIP LOCKED. There's no special syntax for it in PG, so we don't do anything for it here + // SKIP LOCKED. There's no special syntax for it in PG, so we don't do anything for it here switch lockingClause.WaitOption() { case exp.Wait: return diff --git a/sqlgen/select_sql_generator_test.go b/sqlgen/select_sql_generator_test.go index ce048545..90394b18 100644 --- a/sqlgen/select_sql_generator_test.go +++ b/sqlgen/select_sql_generator_test.go @@ -3,6 +3,7 @@ package sqlgen_test import ( "testing" + "github.com/doug-martin/goqu/v9" "github.com/doug-martin/goqu/v9/exp" "github.com/doug-martin/goqu/v9/internal/errors" "github.com/doug-martin/goqu/v9/internal/sb" @@ -506,6 +507,7 @@ func (ssgs *selectSQLGeneratorSuite) TestToSelectSQL_withFor() { opts.ForNoKeyUpdateFragment = []byte(" for no key update ") opts.ForShareFragment = []byte(" for share ") opts.ForKeyShareFragment = []byte(" for key share ") + opts.OfFragment = []byte("of ") opts.NowaitFragment = []byte("nowait") opts.SkipLockedFragment = []byte("skip locked") @@ -513,10 +515,13 @@ func (ssgs *selectSQLGeneratorSuite) TestToSelectSQL_withFor() { scFnW := sc.SetLock(exp.NewLock(exp.ForNolock, exp.Wait)) scFnNw := sc.SetLock(exp.NewLock(exp.ForNolock, exp.NoWait)) scFnSl := sc.SetLock(exp.NewLock(exp.ForNolock, exp.SkipLocked)) + scFnSlOf := sc.SetLock(exp.NewLock(exp.ForNolock, exp.SkipLocked, goqu.T("my_table"))) scFsW := sc.SetLock(exp.NewLock(exp.ForShare, exp.Wait)) scFsNw := sc.SetLock(exp.NewLock(exp.ForShare, exp.NoWait)) scFsSl := sc.SetLock(exp.NewLock(exp.ForShare, exp.SkipLocked)) + scFsSlOf := sc.SetLock(exp.NewLock(exp.ForShare, exp.SkipLocked, goqu.T("my_table"))) + scFsSlOfMulti := sc.SetLock(exp.NewLock(exp.ForShare, exp.SkipLocked, goqu.T("my_table"), goqu.T("table2"))) scFksW := sc.SetLock(exp.NewLock(exp.ForKeyShare, exp.Wait)) scFksNw := sc.SetLock(exp.NewLock(exp.ForKeyShare, exp.NoWait)) @@ -539,6 +544,8 @@ func (ssgs *selectSQLGeneratorSuite) TestToSelectSQL_withFor() { selectTestCase{clause: scFnSl, sql: `SELECT * FROM "test"`}, selectTestCase{clause: scFnSl, sql: `SELECT * FROM "test"`, isPrepared: true}, + selectTestCase{clause: scFnSlOf, sql: `SELECT * FROM "test"`}, + selectTestCase{clause: scFnSlOf, sql: `SELECT * FROM "test"`, isPrepared: true, args: []interface{}{}}, selectTestCase{clause: scFsW, sql: `SELECT * FROM "test" for share `}, selectTestCase{clause: scFsW, sql: `SELECT * FROM "test" for share `, isPrepared: true}, @@ -549,6 +556,12 @@ func (ssgs *selectSQLGeneratorSuite) TestToSelectSQL_withFor() { selectTestCase{clause: scFsSl, sql: `SELECT * FROM "test" for share skip locked`}, selectTestCase{clause: scFsSl, sql: `SELECT * FROM "test" for share skip locked`, isPrepared: true}, + selectTestCase{clause: scFsSlOf, sql: `SELECT * FROM "test" for share of "my_table" skip locked`}, + selectTestCase{clause: scFsSlOf, sql: `SELECT * FROM "test" for share of "my_table" skip locked`, isPrepared: true}, + + selectTestCase{clause: scFsSlOfMulti, sql: `SELECT * FROM "test" for share of "my_table", "table2" skip locked`}, + selectTestCase{clause: scFsSlOfMulti, sql: `SELECT * FROM "test" for share of "my_table", "table2" skip locked`, isPrepared: true}, + selectTestCase{clause: scFksW, sql: `SELECT * FROM "test" for key share `}, selectTestCase{clause: scFksW, sql: `SELECT * FROM "test" for key share `, isPrepared: true}, diff --git a/sqlgen/sql_dialect_options.go b/sqlgen/sql_dialect_options.go index b8292e5e..8aab4b41 100644 --- a/sqlgen/sql_dialect_options.go +++ b/sqlgen/sql_dialect_options.go @@ -119,6 +119,8 @@ type ( ForNoKeyUpdateFragment []byte // The SQL FOR SHARE fragment(DEFAULT=[]byte(" FOR SHARE ")) ForShareFragment []byte + // The SQL OF fragment(DEFAULT=[]byte("OF ")) + OfFragment []byte // The SQL FOR KEY SHARE fragment(DEFAULT=[]byte(" FOR KEY SHARE ")) ForKeyShareFragment []byte // The SQL NOWAIT fragment(DEFAULT=[]byte("NOWAIT")) @@ -450,6 +452,7 @@ func DefaultDialectOptions() *SQLDialectOptions { ForNoKeyUpdateFragment: []byte(" FOR NO KEY UPDATE "), ForShareFragment: []byte(" FOR SHARE "), ForKeyShareFragment: []byte(" FOR KEY SHARE "), + OfFragment: []byte("OF "), NowaitFragment: []byte("NOWAIT"), SkipLockedFragment: []byte("SKIP LOCKED"), LateralFragment: []byte("LATERAL "),