Skip to content

Commit

Permalink
bugfix: every autoimport gets its own connection
Browse files Browse the repository at this point in the history
  • Loading branch information
tkrabel committed Nov 1, 2023
1 parent bf8d3a3 commit db187bd
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 5 deletions.
10 changes: 5 additions & 5 deletions rope/contrib/autoimport/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ def filter_package(package: Package) -> bool:


_deprecated_default: bool = object() # type: ignore
thread_local = local()


class AutoImport:
Expand Down Expand Up @@ -127,6 +126,7 @@ def __init__(
"`AutoImport(memory=True)` explicitly.",
DeprecationWarning,
)
self.thread_local = local()
self.connection = self.create_database_connection(
project=project,
memory=memory,
Expand Down Expand Up @@ -169,16 +169,16 @@ def connection(self):
This makes sure AutoImport can be shared across threads.
"""
if not hasattr(thread_local, "connection"):
thread_local.connection = self.create_database_connection(
if not hasattr(self.thread_local, "connection"):
self.thread_local.connection = self.create_database_connection(
project=self.project,
memory=self.memory,
)
return thread_local.connection
return self.thread_local.connection

@connection.setter
def connection(self, value: sqlite3.Connection):
thread_local.connection = value
self.thread_local.connection = value

def _setup_db(self):
models.Metadata.create_table(self.connection)
Expand Down
7 changes: 7 additions & 0 deletions ropetest/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,13 @@ def project_path(project):
yield pathlib.Path(project.address)


@pytest.fixture
def project2():
project = testutils.sample_project("sample_project2")
yield project
testutils.remove_project(project)


"""
Standard project structure for pytest fixtures
/mod1.py -- mod1
Expand Down
9 changes: 9 additions & 0 deletions ropetest/contrib/autoimport/autoimporttest.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,15 @@ def foo():
assert [("from pkg1 import foo", "foo")] == results


def test_connection(project: Project, project2: Project):
ai1 = AutoImport(project)
ai2 = AutoImport(project)
ai3 = AutoImport(project2)

assert ai1.connection is not ai2.connection
assert ai1.connection is not ai3.connection


@contextmanager
def assert_database_is_reset(conn):
conn.execute("ALTER TABLE names ADD COLUMN deprecated_column")
Expand Down

0 comments on commit db187bd

Please sign in to comment.