|
3 | 3 | import unittest
|
4 | 4 | import sqlite3 as sqlite
|
5 | 5 |
|
| 6 | + |
6 | 7 | class DumpTests(unittest.TestCase):
|
7 | 8 | def setUp(self):
|
8 | 9 | self.cx = sqlite.connect(":memory:")
|
@@ -49,6 +50,51 @@ def test_table_dump(self):
|
49 | 50 | [self.assertEqual(expected_sqls[i], actual_sqls[i])
|
50 | 51 | for i in range(len(expected_sqls))]
|
51 | 52 |
|
| 53 | + def test_dump_autoincrement(self): |
| 54 | + expected = [ |
| 55 | + 'CREATE TABLE "t1" (id integer primary key autoincrement);', |
| 56 | + 'INSERT INTO "t1" VALUES(NULL);', |
| 57 | + 'CREATE TABLE "t2" (id integer primary key autoincrement);', |
| 58 | + ] |
| 59 | + self.cu.executescript("".join(expected)) |
| 60 | + |
| 61 | + # the NULL value should now be automatically be set to 1 |
| 62 | + expected[1] = expected[1].replace("NULL", "1") |
| 63 | + expected.insert(0, "BEGIN TRANSACTION;") |
| 64 | + expected.extend([ |
| 65 | + 'DELETE FROM "sqlite_sequence";', |
| 66 | + 'INSERT INTO "sqlite_sequence" VALUES(\'t1\',1);', |
| 67 | + 'COMMIT;', |
| 68 | + ]) |
| 69 | + |
| 70 | + actual = [stmt for stmt in self.cx.iterdump()] |
| 71 | + self.assertEqual(expected, actual) |
| 72 | + |
| 73 | + def test_dump_autoincrement_create_new_db(self): |
| 74 | + self.cu.execute("BEGIN TRANSACTION") |
| 75 | + self.cu.execute("CREATE TABLE t1 (id integer primary key autoincrement)") |
| 76 | + self.cu.execute("CREATE TABLE t2 (id integer primary key autoincrement)") |
| 77 | + self.cu.executemany("INSERT INTO t1 VALUES(?)", ((None,) for _ in range(9))) |
| 78 | + self.cu.executemany("INSERT INTO t2 VALUES(?)", ((None,) for _ in range(4))) |
| 79 | + self.cx.commit() |
| 80 | + |
| 81 | + cx2 = sqlite.connect(":memory:") |
| 82 | + query = "".join(self.cx.iterdump()) |
| 83 | + cx2.executescript(query) |
| 84 | + cu2 = cx2.cursor() |
| 85 | + |
| 86 | + dataset = ( |
| 87 | + ("t1", 9), |
| 88 | + ("t2", 4), |
| 89 | + ) |
| 90 | + for table, seq in dataset: |
| 91 | + with self.subTest(table=table, seq=seq): |
| 92 | + res = cu2.execute(""" |
| 93 | + SELECT "seq" FROM "sqlite_sequence" WHERE "name" == ? |
| 94 | + """, (table,)) |
| 95 | + rows = res.fetchall() |
| 96 | + self.assertEqual(rows[0][0], seq) |
| 97 | + |
52 | 98 | def test_unorderable_row(self):
|
53 | 99 | # iterdump() should be able to cope with unorderable row types (issue #15545)
|
54 | 100 | class UnorderableRow:
|
|
0 commit comments