diff --git a/driver_test.go b/driver_test.go index 63c9ac9..6621564 100644 --- a/driver_test.go +++ b/driver_test.go @@ -3,6 +3,7 @@ package zetasqlite_test import ( "context" "database/sql" + "strings" "testing" "github.com/google/go-cmp/cmp" @@ -186,6 +187,36 @@ CREATE TABLE IF NOT EXISTS Singers ( }) } +func TestCreateTable(t *testing.T) { + t.Run("primary keys", func(t *testing.T) { + db, err := sql.Open("zetasqlite", ":memory:") + if err != nil { + t.Fatal(err) + } + if _, err := db.Exec(` +CREATE TABLE IF NOT EXISTS Singers ( + SingerId INT64 NOT NULL PRIMARY KEY, + FirstName STRING(1024), + LastName STRING(1024) +)`); err != nil { + t.Fatal(err) + } + stmt, err := db.Prepare("INSERT Singers (SingerId, FirstName, LastName) VALUES (@SingerID, @FirstName, @LastName)") + if err != nil { + t.Fatal(err) + } + _, err = stmt.Exec(int64(1), "Kylie", "Minogue") + if err != nil { + t.Fatal(err) + } + + _, err = stmt.Exec(int64(1), "Miss", "Kitten") + if !strings.HasSuffix(err.Error(), "UNIQUE constraint failed: Singers.SingerId") { + t.Fatalf("expected failed unique constraint err, got: %s", err) + } + }) +} + func TestPreparedStatements(t *testing.T) { t.Run("prepared select", func(t *testing.T) { db, err := sql.Open("zetasqlite", ":memory:") diff --git a/internal/spec.go b/internal/spec.go index fb1c831..6261c6f 100644 --- a/internal/spec.go +++ b/internal/spec.go @@ -139,11 +139,23 @@ func (s *TableSpec) SQLiteSchema() string { columns = append(columns, c.SQLiteSchema()) } if len(s.PrimaryKey) != 0 { + primaryKeys := make([]string, len(s.PrimaryKey)) + + for i, key := range s.PrimaryKey { + primaryKeys[i] = fmt.Sprintf("%s COLLATE zetasqlite_collate", key) + } + columns = append( columns, - fmt.Sprintf("PRIMARY KEY (%s)", strings.Join(s.PrimaryKey, ",")), + fmt.Sprintf("PRIMARY KEY (%s)", strings.Join(primaryKeys, ",")), ) } + var clustering string + if len(s.PrimaryKey) > 0 { + clustering = "WITHOUT ROWID" + } else { + clustering = "" + } var stmt string switch s.CreateMode { case ast.CreateDefaultMode: @@ -153,7 +165,7 @@ func (s *TableSpec) SQLiteSchema() string { case ast.CreateIfNotExistsMode: stmt = "CREATE TABLE IF NOT EXISTS" } - return fmt.Sprintf("%s `%s` (%s)", stmt, s.TableName(), strings.Join(columns, ",")) + return fmt.Sprintf("%s `%s` (%s) %s", stmt, s.TableName(), strings.Join(columns, ","), clustering) } func viewSQLiteSchema(s *TableSpec) string {