diff --git a/dialect/mysql/mysql_test.go b/dialect/mysql/mysql_test.go index f4850a59..57f64bcf 100644 --- a/dialect/mysql/mysql_test.go +++ b/dialect/mysql/mysql_test.go @@ -539,6 +539,27 @@ func (mt *mysqlTest) TestWindowFunction() { mt.Error(ds.WithDialect("mysql").ScanStructs(&entries), "goqu: adapter does not support window function clause") } +func (mt *mysqlTest) TestInsertFromSelect() { + ds := mt.db.From("entry") + + subquery := goqu.Select( + goqu.V(11), + goqu.V(11), + goqu.C("float"), + goqu.C("string"), + goqu.C("time"), + goqu.C("bool"), + goqu.C("bytes"), + ).From(goqu.T("entry")).Where(goqu.C("int").Eq(9)) + + query := ds.Insert().Cols().FromQuery(subquery) + _, _, err := query.ToSQL() + + mt.NoError(err) + _, err = query.Executor().Exec() + mt.NoError(err) +} + func TestMysqlSuite(t *testing.T) { suite.Run(t, new(mysqlTest)) } diff --git a/insert_dataset.go b/insert_dataset.go index e8b106c7..64523021 100644 --- a/insert_dataset.go +++ b/insert_dataset.go @@ -1,6 +1,8 @@ package goqu import ( + "fmt" + "github.com/doug-martin/goqu/v9/exec" "github.com/doug-martin/goqu/v9/exp" "github.com/doug-martin/goqu/v9/internal/errors" @@ -143,6 +145,17 @@ func (id *InsertDataset) ColsAppend(cols ...interface{}) *InsertDataset { // Adds a subquery to the insert. See examples. func (id *InsertDataset) FromQuery(from exp.AppendableExpression) *InsertDataset { + if sds, ok := from.(*SelectDataset); ok { + if sds.dialect != GetDialect("default") && id.Dialect() != sds.dialect { + panic( + fmt.Errorf( + "incompatible dialects for INSERT (%q) and SELECT (%q)", + id.dialect.Dialect(), sds.dialect.Dialect(), + ), + ) + } + sds.dialect = id.dialect + } return id.copy(id.clauses.SetFrom(from)) } diff --git a/insert_dataset_test.go b/insert_dataset_test.go index 2ef25d0d..bf1eebc5 100644 --- a/insert_dataset_test.go +++ b/insert_dataset_test.go @@ -207,6 +207,43 @@ func (ids *insertDatasetSuite) TestFromQuery() { ) } +func (ids *insertDatasetSuite) TestFromQueryDialectInheritance() { + md := new(mocks.SQLDialect) + md.On("Dialect").Return("dialect") + + ids.Run("ok, default dialect is replaced with insert dialect", func() { + bd := Insert("items").SetDialect(md).FromQuery(From("other_items")) + ids.Require().Equal(md, bd.clauses.From().(*SelectDataset).Dialect()) + }) + + ids.Run("ok, insert and select dialects coincide", func() { + bd := Insert("items").SetDialect(md).FromQuery(From("other_items").SetDialect(md)) + ids.Require().Equal(md, bd.clauses.From().(*SelectDataset).Dialect()) + }) + + ids.Run("ok, insert and select dialects are default", func() { + bd := Insert("items").FromQuery(From("other_items")) + ids.Require().Equal(GetDialect("default"), bd.clauses.From().(*SelectDataset).Dialect()) + }) + + ids.Run("panic, insert and select dialects are different", func() { + defer func() { + r := recover() + if r == nil { + ids.Fail("there should be a panic") + } + ids.Require().Equal( + "incompatible dialects for INSERT (\"dialect\") and SELECT (\"other_dialect\")", + r.(error).Error(), + ) + }() + + otherDialect := new(mocks.SQLDialect) + otherDialect.On("Dialect").Return("other_dialect") + Insert("items").SetDialect(md).FromQuery(From("otherItems").SetDialect(otherDialect)) + }) +} + func (ids *insertDatasetSuite) TestVals() { val1 := []interface{}{ "a", "b",