diff --git a/sqlite_utils/db.py b/sqlite_utils/db.py index 371eed9c..a60df432 100644 --- a/sqlite_utils/db.py +++ b/sqlite_utils/db.py @@ -601,6 +601,10 @@ def quote_default_value(self, value: str) -> str: return self.quote(value) + def database_names(self) -> List[str]: + "List of string database names available in this connection." + return [r[1] for r in self.execute("PRAGMA database_list").fetchall()] + def table_names(self, fts4: bool = False, fts5: bool = False) -> List[str]: """ List of string table names in this database. @@ -614,7 +618,20 @@ def table_names(self, fts4: bool = False, fts5: bool = False) -> List[str]: if fts5: where.append("sql like '%USING FTS5%'") sql = "select name from sqlite_master where {}".format(" AND ".join(where)) - return [r[0] for r in self.execute(sql).fetchall()] + + def _exec_in_db(db_name: str, sql: str) -> List[str]: + if db_name == "main": + db_name = "" + if db_name: + sql = sql.replace("sqlite_master", f"{db_name}.sqlite_master") + table_names = [r[0] for r in self.execute(sql).fetchall()] + if db_name: + return [f"{db_name}.{tbl_name}" for tbl_name in table_names] + return table_names + + return list( + itertools.chain(*[_exec_in_db(db_name, sql) for db_name in self.database_names()]) + ) def view_names(self) -> List[str]: "List of string view names in this database." @@ -1271,12 +1288,34 @@ def init_spatialite(self, path: Optional[str] = None) -> bool: return result and bool(result[0]) +def _split_names(fullname: str) -> Tuple[str, str]: + if '.' not in fullname: + return '', fullname + return fullname.split('.') + + +def dbname(fullname: str) -> str: + return _split_names(fullname)[0] + + +def tablename(fullname: str) -> str: + return _split_names(fullname)[1] + + +def escaped_name(fullname: str) -> str: + """This is how SQLite expects a database name joined to a table name to use the square-bracket escapes.""" + db, tbl = _split_names(fullname) + if not db: + return f'[{tbl}]' + return f'{db}.[{tbl}]' + + class Queryable: def exists(self) -> bool: "Does this table or view exist yet?" return False - def __init__(self, db, name): + def __init__(self, db, name: str): self.db = db self.name = name @@ -1292,7 +1331,7 @@ def count_where( :param where_args: Parameters to use with that fragment - an iterable for ``id > ?`` parameters, or a dictionary for ``id > :id`` """ - sql = "select count(*) from [{}]".format(self.name) + sql = "select count(*) from {}".format(escaped_name(self.name)) if where is not None: sql += " where " + where return self.db.execute(sql, where_args or []).fetchone()[0] @@ -1335,7 +1374,7 @@ def rows_where( """ if not self.exists(): return - sql = "select {} from [{}]".format(select, self.name) + sql = "select {} from {}".format(select, escaped_name(self.name)) if where is not None: sql += " where " + where if order_by is not None: @@ -1387,12 +1426,23 @@ def pks_and_rows_where( row_pk = row_pk[0] yield row_pk, row + @property + def is_attached(self) -> bool: + return dbname(self.name) not in {'', 'main'} + + @property + def _pragma_name(self) -> Tuple[str, str]: + if "." in self.name: + db, name = self.name.split(".") + return db + ".", name + return "", self.name + @property def columns(self) -> List["Column"]: "List of :ref:`Columns ` representing the columns in this table or view." if not self.exists(): return [] - rows = self.db.execute("PRAGMA table_info([{}])".format(self.name)).fetchall() + rows = self.db.execute("PRAGMA {}table_info([{}])".format(*self._pragma_name)).fetchall() return [Column(*row) for row in rows] @property @@ -1403,9 +1453,10 @@ def columns_dict(self) -> Dict[str, Any]: @property def schema(self) -> str: "SQL schema for this table or view." - return self.db.execute( - "select sql from sqlite_master where name = ?", (self.name,) - ).fetchone()[0] + db, name = self._pragma_name + return self.db.execute(f"select sql from {db}sqlite_master where name = ?", (name,)).fetchone()[ + 0 + ] class Table(Queryable): @@ -1544,7 +1595,7 @@ def foreign_keys(self) -> List["ForeignKey"]: "List of foreign keys defined on this table." fks = [] for row in self.db.execute( - "PRAGMA foreign_key_list([{}])".format(self.name) + "PRAGMA {}foreign_key_list([{}])".format(*self._pragma_name) ).fetchall(): if row is not None: id, seq, table_name, from_, to_, on_update, on_delete, match = row @@ -1569,7 +1620,8 @@ def virtual_table_using(self) -> Optional[str]: @property def indexes(self) -> List[Index]: "List of indexes defined on this table." - sql = 'PRAGMA index_list("{}")'.format(self.name) + db, table_name = self._pragma_name + sql = 'PRAGMA {}index_list("{}")'.format(db, table_name) indexes = [] for row in self.db.execute_returning_dicts(sql): index_name = row["name"] @@ -1578,7 +1630,7 @@ def indexes(self) -> List[Index]: if not index_name.startswith('"') else index_name ) - column_sql = "PRAGMA index_info({})".format(index_name_quoted) + column_sql = "PRAGMA {}index_info({})".format(db, index_name_quoted) columns = [] for seqno, cid, name in self.db.execute(column_sql).fetchall(): columns.append(name) @@ -1593,7 +1645,8 @@ def indexes(self) -> List[Index]: @property def xindexes(self) -> List[XIndex]: "List of indexes defined on this table using the more detailed ``XIndex`` format." - sql = 'PRAGMA index_list("{}")'.format(self.name) + db, table_name = self._pragma_name + sql = 'PRAGMA {}index_list("{}")'.format(db, table_name) indexes = [] for row in self.db.execute_returning_dicts(sql): index_name = row["name"] @@ -1602,7 +1655,7 @@ def xindexes(self) -> List[XIndex]: if not index_name.startswith('"') else index_name ) - column_sql = "PRAGMA index_xinfo({})".format(index_name_quoted) + column_sql = "PRAGMA {}index_xinfo({})".format(db, index_name_quoted) index_columns = [] for info in self.db.execute(column_sql).fetchall(): index_columns.append(XIndexColumn(*info)) @@ -1612,12 +1665,13 @@ def xindexes(self) -> List[XIndex]: @property def triggers(self) -> List[Trigger]: "List of triggers defined on this table." + db, table_name = self._pragma_name return [ Trigger(*r) for r in self.db.execute( - "select name, tbl_name, sql from sqlite_master where type = 'trigger'" + f"select name, tbl_name, sql from {db}sqlite_master where type = 'trigger'" " and tbl_name = ?", - (self.name,), + (table_name,), ).fetchall() ] @@ -1709,9 +1763,9 @@ def duplicate(self, new_name: str) -> "Table": if not self.exists(): raise NoTable(f"Table {self.name} does not exist") with self.db.conn: - sql = "CREATE TABLE [{new_table}] AS SELECT * FROM [{table}];".format( - new_table=new_name, - table=self.name, + sql = "CREATE TABLE {new_table} AS SELECT * FROM {table};".format( + new_table=escaped_name(new_name), + table=escaped_name(self.name), ) self.db.execute(sql) return self.db[new_name] @@ -1765,21 +1819,22 @@ def transform( column_order=column_order, keep_table=keep_table, ) - pragma_foreign_keys_was_on = self.db.execute("PRAGMA foreign_keys").fetchone()[ + db, _ = self._pragma_name + pragma_foreign_keys_was_on = self.db.execute(f"PRAGMA {db}foreign_keys").fetchone()[ 0 ] try: if pragma_foreign_keys_was_on: - self.db.execute("PRAGMA foreign_keys=0;") + self.db.execute(f"PRAGMA {db}foreign_keys=0;") with self.db.conn: for sql in sqls: self.db.execute(sql) # Run the foreign_key_check before we commit if pragma_foreign_keys_was_on: - self.db.execute("PRAGMA foreign_key_check;") + self.db.execute(f"PRAGMA {db}foreign_key_check;") finally: if pragma_foreign_keys_was_on: - self.db.execute("PRAGMA foreign_keys=1;") + self.db.execute(f"PRAGMA {db}foreign_keys=1;") return self def transform_sql( @@ -1944,9 +1999,9 @@ def transform_sql( if "rowid" not in new_cols: new_cols.insert(0, "rowid") old_cols.insert(0, "rowid") - copy_sql = "INSERT INTO [{new_table}] ({new_cols})\n SELECT {old_cols} FROM [{old_table}];".format( - new_table=new_table_name, - old_table=self.name, + copy_sql = "INSERT INTO {new_table} ({new_cols})\n SELECT {old_cols} FROM {old_table};".format( + new_table=escaped_name(new_table_name), + old_table=escaped_name(self.name), old_cols=", ".join("[{}]".format(col) for col in old_cols), new_cols=", ".join("[{}]".format(col) for col in new_cols), ) @@ -1954,13 +2009,13 @@ def transform_sql( # Drop (or keep) the old table if keep_table: sqls.append( - "ALTER TABLE [{}] RENAME TO [{}];".format(self.name, keep_table) + "ALTER TABLE {} RENAME TO {};".format(escaped_name(self.name), escaped_name(keep_table)) ) else: - sqls.append("DROP TABLE [{}];".format(self.name)) + sqls.append("DROP TABLE {};".format(escaped_name(self.name))) # Rename the new one sqls.append( - "ALTER TABLE [{}] RENAME TO [{}];".format(new_table_name, self.name) + "ALTER TABLE {} RENAME TO {};".format(escaped_name(new_table_name), escaped_name(self.name)) ) return sqls @@ -2023,11 +2078,11 @@ def extract( lookup_columns = [(rename.get(col) or col) for col in columns] lookup_table.create_index(lookup_columns, unique=True, if_not_exists=True) self.db.execute( - "INSERT OR IGNORE INTO [{lookup_table}] ({lookup_columns}) SELECT DISTINCT {table_cols} FROM [{table}]".format( - lookup_table=table, + "INSERT OR IGNORE INTO {lookup_table} ({lookup_columns}) SELECT DISTINCT {table_cols} FROM {table}".format( + lookup_table=escaped_name(table), lookup_columns=", ".join("[{}]".format(c) for c in lookup_columns), table_cols=", ".join("[{}]".format(c) for c in columns), - table=self.name, + table=escaped_name(self.name), ) ) @@ -2036,14 +2091,14 @@ def extract( # And populate it self.db.execute( - "UPDATE [{table}] SET [{magic_lookup_column}] = (SELECT id FROM [{lookup_table}] WHERE {where})".format( - table=self.name, + "UPDATE {table} SET [{magic_lookup_column}] = (SELECT id FROM {lookup_table} WHERE {where})".format( + table=escaped_name(self.name), magic_lookup_column=magic_lookup_column, - lookup_table=table, + lookup_table=escaped_name(table), where=" AND ".join( - "[{table}].[{column}] IS [{lookup_table}].[{lookup_column}]".format( - table=self.name, - lookup_table=table, + "{table}.[{column}] IS {lookup_table}.[{lookup_column}]".format( + table=escaped_name(self.name), + lookup_table=escaped_name(table), column=column, lookup_column=rename.get(column) or column, ) @@ -2117,13 +2172,13 @@ def create_index( textwrap.dedent( """ CREATE {unique}INDEX {if_not_exists}[{index_name}] - ON [{table_name}] ({columns}); + ON {table_name} ({columns}); """ ) .strip() .format( index_name=created_index_name, - table_name=self.name, + table_name=escaped_name(self.name), columns=", ".join(columns_sql), unique="UNIQUE " if unique else "", if_not_exists="IF NOT EXISTS " if if_not_exists else "", @@ -2193,8 +2248,8 @@ def add_column( not_null_sql = "NOT NULL DEFAULT {}".format( self.db.quote_default_value(not_null_default) ) - sql = "ALTER TABLE [{table}] ADD COLUMN [{col_name}] {col_type}{not_null_default};".format( - table=self.name, + sql = "ALTER TABLE {table} ADD COLUMN [{col_name}] {col_type}{not_null_default};".format( + table=escaped_name(self.name), col_name=col_name, col_type=fk_col_type or COLUMN_TYPE_MAPPING[col_type], not_null_default=(" " + not_null_sql) if not_null_sql else "", @@ -2211,7 +2266,7 @@ def drop(self, ignore: bool = False): :param ignore: Set to ``True`` to ignore the error if the table does not exist """ try: - self.db.execute("DROP TABLE [{}]".format(self.name)) + self.db.execute("DROP TABLE {}".format(escaped_name(self.name))) except sqlite3.OperationalError: if not ignore: raise @@ -2378,6 +2433,9 @@ def enable_fts( """ Enable SQLite full-text search against the specified columns. + Creates the FTS virtual table(s) in the `main` database, even if the + source table is in an attached database. + See :ref:`python_api_fts` for more details. :param columns: List of column names to include in the search index. @@ -2386,6 +2444,7 @@ def enable_fts( :param tokenize: Custom SQLite tokenizer to use, for example ``"porter"`` to enable Porter stemming. :param replace: Should any existing FTS index for this table be replaced by the new one? """ + table_name = tablename(self.name) create_fts_sql = ( textwrap.dedent( """ @@ -2397,19 +2456,19 @@ def enable_fts( ) .strip() .format( - table=self.name, + table=table_name, columns=", ".join("[{}]".format(c) for c in columns), fts_version=fts_version, tokenize="\n tokenize='{}',".format(tokenize) if tokenize else "", ) ) should_recreate = False - if replace and self.db["{}_fts".format(self.name)].exists(): + if replace and self.db["{}_fts".format(table_name)].exists(): # Does the table need to be recreated? - fts_schema = self.db["{}_fts".format(self.name)].schema + fts_schema = self.db["{}_fts".format(table_name)].schema if fts_schema != create_fts_sql: should_recreate = True - expected_triggers = {self.name + suffix for suffix in ("_ai", "_ad", "_au")} + expected_triggers = {table_name + suffix for suffix in ("_ai", "_ad", "_au")} existing_triggers = {t.name for t in self.triggers} has_triggers = existing_triggers.issuperset(expected_triggers) if has_triggers != create_triggers: @@ -2444,7 +2503,7 @@ def enable_fts( ) .strip() .format( - table=self.name, + table=table_name, columns=", ".join("[{}]".format(c) for c in columns), old_cols=old_cols, new_cols=new_cols, @@ -2469,7 +2528,7 @@ def populate_fts(self, columns: Iterable[str]) -> "Table": ) .strip() .format( - table=self.name, columns=", ".join("[{}]".format(c) for c in columns) + table=tablename(self.name), columns=", ".join("[{}]".format(c) for c in columns) ) ) self.db.executescript(sql) @@ -2505,9 +2564,9 @@ def rebuild_fts(self): fts_table = self.detect_fts() if fts_table is None: # Assume this is itself an FTS table - fts_table = self.name + fts_table = escaped_name(self.name) self.db.execute( - "INSERT INTO [{table}]([{table}]) VALUES('rebuild');".format( + "INSERT INTO {table}({table}) VALUES('rebuild');".format( table=fts_table ) ) @@ -2529,10 +2588,11 @@ def detect_fts(self) -> Optional[str]: ) """ ).strip() + table_name = tablename(self.name) args = { - "like": "%VIRTUAL TABLE%USING FTS%content=[{}]%".format(self.name), - "like2": '%VIRTUAL TABLE%USING FTS%content="{}"%'.format(self.name), - "table": self.name, + "like": "%VIRTUAL TABLE%USING FTS%content=[{}]%".format(table_name), + "like2": '%VIRTUAL TABLE%USING FTS%content="{}"%'.format(table_name), + "table": table_name, } rows = self.db.execute(sql, args).fetchall() if len(rows) == 0: @@ -2592,7 +2652,7 @@ def search_sql( select rowid, {columns} - from [{dbtable}]{where_clause} + from {dbtable}{where_clause} ) select {columns_with_prefix} @@ -2621,7 +2681,7 @@ def search_sql( if offset is not None: limit_offset += " offset {}".format(offset) return sql.format( - dbtable=self.name, + dbtable=escaped_name(self.name), where_clause="\n where {}".format(where) if where else "", original=original, columns=columns_sql, @@ -2692,8 +2752,8 @@ def delete(self, pk_values: Union[list, tuple, str, int, float]) -> "Table": pk_values = [pk_values] self.get(pk_values) wheres = ["[{}] = ?".format(pk_name) for pk_name in self.pks] - sql = "delete from [{table}] where {wheres}".format( - table=self.name, wheres=" and ".join(wheres) + sql = "delete from {table} where {wheres}".format( + table=escaped_name(self.name), wheres=" and ".join(wheres) ) with self.db.conn: self.db.execute(sql, pk_values) @@ -2717,7 +2777,7 @@ def delete_where( """ if not self.exists(): return self - sql = "delete from [{}]".format(self.name) + sql = f"delete from {escaped_name(self.name)}" if where is not None: sql += " where " + where self.db.execute(sql, where_args or []) @@ -2762,8 +2822,8 @@ def update( args.append(jsonify_if_needed(value)) wheres = ["[{}] = ?".format(pk_name) for pk_name in pks] args.extend(pk_values) - sql = "update [{table}] set {sets} where {wheres}".format( - table=self.name, sets=", ".join(sets), wheres=" and ".join(wheres) + sql = "update {table} set {sets} where {wheres}".format( + table=escaped_name(self.name), sets=", ".join(sets), wheres=" and ".join(wheres) ) with self.db.conn: try: @@ -2843,8 +2903,8 @@ def convert_value(v): if fn_name == "": fn_name = f"lambda_{abs(hash(fn))}" self.db.register_function(convert_value, name=fn_name) - sql = "update [{table}] set {sets}{where};".format( - table=self.name, + sql = "update {table} set {sets}{where};".format( + table=escaped_name(self.name), sets=", ".join( [ "[{output_column}] = {fn_name}([{column}])".format( @@ -2965,8 +3025,8 @@ def build_insert_queries_and_params( # them since it ignores the resulting integrity errors if not_null: placeholders.extend(not_null) - sql = "INSERT OR IGNORE INTO [{table}]({cols}) VALUES({placeholders});".format( - table=self.name, + sql = "INSERT OR IGNORE INTO {table}({cols}) VALUES({placeholders});".format( + table=escaped_name(self.name), cols=", ".join(["[{}]".format(p) for p in placeholders]), placeholders=", ".join(["?" for p in placeholders]), ) @@ -2976,8 +3036,8 @@ def build_insert_queries_and_params( # UPDATE [book] SET [name] = 'Programming' WHERE [id] = 1001; set_cols = [col for col in all_columns if col not in pks] if set_cols: - sql2 = "UPDATE [{table}] SET {pairs} WHERE {wheres}".format( - table=self.name, + sql2 = "UPDATE {table} SET {pairs} WHERE {wheres}".format( + table=escaped_name(self.name), pairs=", ".join( "[{}] = {}".format(col, conversions.get(col, "?")) for col in set_cols @@ -3004,10 +3064,10 @@ def build_insert_queries_and_params( elif ignore: or_what = "OR IGNORE " sql = """ - INSERT {or_what}INTO [{table}] ({columns}) VALUES {rows}; + INSERT {or_what}INTO {table} ({columns}) VALUES {rows}; """.strip().format( or_what=or_what, - table=self.name, + table=escaped_name(self.name), columns=", ".join("[{}]".format(c) for c in all_columns), rows=", ".join( "({placeholders})".format( @@ -3265,7 +3325,7 @@ def insert_all( self.last_rowid = None self.last_pk = None if truncate and self.exists(): - self.db.execute("DELETE FROM [{}];".format(self.name)) + self.db.execute("DELETE FROM {};".format(escaped_name(self.name))) for chunk in chunks(itertools.chain([first_record], records), batch_size): chunk = list(chunk) num_records_processed += len(chunk) @@ -3776,7 +3836,7 @@ def drop(self, ignore=False): """ try: - self.db.execute("DROP VIEW [{}]".format(self.name)) + self.db.execute("DROP VIEW {}".format(escaped_name(self.name))) except sqlite3.OperationalError: if not ignore: raise