-
Notifications
You must be signed in to change notification settings - Fork 7
/
util.go
156 lines (149 loc) · 4.78 KB
/
util.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
package sqlite
import (
"context"
"database/sql"
"fmt"
"strings"
)
// DropAll deletes all the data from a database.
//
// The schemaName parameter follows the SQLite PRAMGA schema-name conventions:
// https://sqlite.org/pragma.html#syntax
func DropAll(ctx context.Context, conn *sql.Conn, schemaName string) (err error) {
defer func() {
if err != nil {
err = fmt.Errorf("sqlitedb.DropAll: %w", err)
}
}()
if schemaName == "" {
schemaName = "main"
}
var indexes, tables, triggers, views []string
// Filter on sql to avoid auto indexes.
// See https://www.sqlite.org/schematab.html for sqlite_schema docs.
rows, err := conn.QueryContext(ctx, fmt.Sprintf("SELECT name, type FROM %q.sqlite_schema WHERE sql != ''", schemaName))
if err != nil {
return err
}
defer rows.Close()
for rows.Next() {
var name, sqlType string
if err := rows.Scan(&name, &sqlType); err != nil {
return err
}
switch sqlType {
case "index":
indexes = append(indexes, name)
case "table":
tables = append(tables, name)
case "trigger":
triggers = append(triggers, name)
case "view":
views = append(views, name)
default:
return fmt.Errorf("unknown sqlite schema type %q for %q", sqlType, name)
}
}
rows.Close()
if err := rows.Err(); err != nil {
return err
}
for _, name := range indexes {
if _, err := conn.ExecContext(ctx, fmt.Sprintf("DROP INDEX %q.%q", schemaName, name)); err != nil {
return err
}
}
for _, name := range triggers {
if _, err := conn.ExecContext(ctx, fmt.Sprintf("DROP TRIGGER %q.%q", schemaName, name)); err != nil {
return err
}
}
for _, name := range views {
if _, err := conn.ExecContext(ctx, fmt.Sprintf("DROP VIEW %q.%q", schemaName, name)); err != nil {
return err
}
}
for _, name := range tables {
if _, err := conn.ExecContext(ctx, fmt.Sprintf("DROP TABLE %q.%q", schemaName, name)); err != nil {
return err
}
}
return nil
}
// CopyAll copies the contents of one database to another.
//
// Traditionally this is done in sqlite by closing the database and copying
// the file. However it can be useful to do it online: a single exclusive
// transaction can cross multiple databases, and if multiple processes are
// using a file, this lets one replace the database without first
// communicating with the other processes, asking them to close the DB first.
//
// The dstSchemaName and srcSchemaName parameters follow the SQLite PRAMGA
// schema-name conventions: https://sqlite.org/pragma.html#syntax
func CopyAll(ctx context.Context, conn *sql.Conn, dstSchemaName, srcSchemaName string) (err error) {
defer func() {
if err != nil {
err = fmt.Errorf("sqlitedb.CopyAll: %w", err)
}
}()
if dstSchemaName == "" {
dstSchemaName = "main"
}
if srcSchemaName == "" {
srcSchemaName = "main"
}
if dstSchemaName == srcSchemaName {
return fmt.Errorf("source matches destination: %q", srcSchemaName)
}
// Filter on sql to avoid auto indexes.
// See https://www.sqlite.org/schematab.html for sqlite_schema docs.
rows, err := conn.QueryContext(ctx, fmt.Sprintf("SELECT name, type, sql FROM %q.sqlite_schema WHERE sql != ''", srcSchemaName))
if err != nil {
return err
}
defer rows.Close()
for rows.Next() {
var name, sqlType, sqlText string
if err := rows.Scan(&name, &sqlType, &sqlText); err != nil {
return err
}
// Regardless of the case or whitespace used in the original
// create statement (or whether or not "if not exists" is used),
// the SQL text in the sqlite_schema table always reads:
// "CREATE (TABLE|VIEW|INDEX|TRIGGER) name".
// We take advantage of that here to rewrite the create
// statement for a different schema.
switch sqlType {
case "index":
sqlText = strings.TrimPrefix(sqlText, "CREATE INDEX ")
sqlText = fmt.Sprintf("CREATE INDEX %q.%s", dstSchemaName, sqlText)
if _, err := conn.ExecContext(ctx, sqlText); err != nil {
return err
}
case "table":
sqlText = strings.TrimPrefix(sqlText, "CREATE TABLE ")
sqlText = fmt.Sprintf("CREATE TABLE %q.%s", dstSchemaName, sqlText)
if _, err := conn.ExecContext(ctx, sqlText); err != nil {
return err
}
if _, err := conn.ExecContext(ctx, fmt.Sprintf("INSERT INTO %q.%q SELECT * FROM %q.%q;", dstSchemaName, name, srcSchemaName, name)); err != nil {
return err
}
case "trigger":
sqlText = strings.TrimPrefix(sqlText, "CREATE TRIGGER ")
sqlText = fmt.Sprintf("CREATE TRIGGER %q.%s", dstSchemaName, sqlText)
if _, err := conn.ExecContext(ctx, sqlText); err != nil {
return err
}
case "view":
sqlText = strings.TrimPrefix(sqlText, "CREATE VIEW ")
sqlText = fmt.Sprintf("CREATE VIEW %q.%s", dstSchemaName, sqlText)
if _, err := conn.ExecContext(ctx, sqlText); err != nil {
return err
}
default:
return fmt.Errorf("unknown sqlite schema type %q for %q", sqlType, name)
}
}
return rows.Err()
}