Skip to content

PYTHON-5306: [v4.12] - Fix use of public MongoClient attributes before connection (#2285) #2311

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 30 additions & 9 deletions pymongo/asynchronous/mongo_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@
)
from pymongo.read_preferences import ReadPreference, _ServerMode
from pymongo.results import ClientBulkWriteResult
from pymongo.server_description import ServerDescription
from pymongo.server_selectors import writable_server_selector
from pymongo.server_type import SERVER_TYPE
from pymongo.topology_description import TOPOLOGY_TYPE, TopologyDescription
Expand Down Expand Up @@ -779,7 +780,7 @@ def __init__(
keyword_opts["document_class"] = doc_class
self._resolve_srv_info: dict[str, Any] = {"keyword_opts": keyword_opts}

seeds = set()
self._seeds = set()
is_srv = False
username = None
password = None
Expand All @@ -804,18 +805,18 @@ def __init__(
srv_max_hosts=srv_max_hosts,
)
is_srv = entity.startswith(SRV_SCHEME)
seeds.update(res["nodelist"])
self._seeds.update(res["nodelist"])
username = res["username"] or username
password = res["password"] or password
dbase = res["database"] or dbase
opts = res["options"]
fqdn = res["fqdn"]
else:
seeds.update(split_hosts(entity, self._port))
if not seeds:
self._seeds.update(split_hosts(entity, self._port))
if not self._seeds:
raise ConfigurationError("need to specify at least one host")

for hostname in [node[0] for node in seeds]:
for hostname in [node[0] for node in self._seeds]:
if _detect_external_db(hostname):
break

Expand All @@ -838,7 +839,7 @@ def __init__(
srv_service_name = opts.get("srvServiceName", common.SRV_SERVICE_NAME)

srv_max_hosts = srv_max_hosts or opts.get("srvmaxhosts")
opts = self._normalize_and_validate_options(opts, seeds)
opts = self._normalize_and_validate_options(opts, self._seeds)

# Username and password passed as kwargs override user info in URI.
username = opts.get("username", username)
Expand All @@ -857,7 +858,7 @@ def __init__(
"username": username,
"password": password,
"dbase": dbase,
"seeds": seeds,
"seeds": self._seeds,
"fqdn": fqdn,
"srv_service_name": srv_service_name,
"pool_class": pool_class,
Expand All @@ -873,8 +874,7 @@ def __init__(
self._options.read_concern,
)

if not is_srv:
self._init_based_on_options(seeds, srv_max_hosts, srv_service_name)
self._init_based_on_options(self._seeds, srv_max_hosts, srv_service_name)

self._opened = False
self._closed = False
Expand Down Expand Up @@ -975,6 +975,7 @@ def _init_based_on_options(
srv_service_name=srv_service_name,
srv_max_hosts=srv_max_hosts,
server_monitoring_mode=self._options.server_monitoring_mode,
topology_id=self._topology_settings._topology_id if self._topology_settings else None,
)
if self._options.auto_encryption_opts:
from pymongo.asynchronous.encryption import _Encrypter
Expand Down Expand Up @@ -1205,6 +1206,16 @@ def topology_description(self) -> TopologyDescription:

.. versionadded:: 4.0
"""
if self._topology is None:
servers = {(host, port): ServerDescription((host, port)) for host, port in self._seeds}
return TopologyDescription(
TOPOLOGY_TYPE.Unknown,
servers,
None,
None,
None,
self._topology_settings,
)
return self._topology.description

@property
Expand All @@ -1218,6 +1229,8 @@ def nodes(self) -> FrozenSet[_Address]:
to any servers, or a network partition causes it to lose connection
to all servers.
"""
if self._topology is None:
return frozenset()
description = self._topology.description
return frozenset(s.address for s in description.known_servers)

Expand Down Expand Up @@ -1576,6 +1589,8 @@ async def address(self) -> Optional[tuple[str, int]]:

.. versionadded:: 3.0
"""
if self._topology is None:
await self._get_topology()
topology_type = self._topology._description.topology_type
if (
topology_type == TOPOLOGY_TYPE.Sharded
Expand All @@ -1598,6 +1613,8 @@ async def primary(self) -> Optional[tuple[str, int]]:
.. versionadded:: 3.0
AsyncMongoClient gained this property in version 3.0.
"""
if self._topology is None:
await self._get_topology()
return await self._topology.get_primary() # type: ignore[return-value]

@property
Expand All @@ -1611,6 +1628,8 @@ async def secondaries(self) -> set[_Address]:
.. versionadded:: 3.0
AsyncMongoClient gained this property in version 3.0.
"""
if self._topology is None:
await self._get_topology()
return await self._topology.get_secondaries()

@property
Expand All @@ -1621,6 +1640,8 @@ async def arbiters(self) -> set[_Address]:
connected to a replica set, there are no arbiters, or this client was
created without the `replicaSet` option.
"""
if self._topology is None:
await self._get_topology()
return await self._topology.get_arbiters()

@property
Expand Down
7 changes: 5 additions & 2 deletions pymongo/asynchronous/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(
srv_service_name: str = common.SRV_SERVICE_NAME,
srv_max_hosts: int = 0,
server_monitoring_mode: str = common.SERVER_MONITORING_MODE,
topology_id: Optional[ObjectId] = None,
):
"""Represent MongoClient's configuration.

Expand Down Expand Up @@ -78,8 +79,10 @@ def __init__(
self._srv_service_name = srv_service_name
self._srv_max_hosts = srv_max_hosts or 0
self._server_monitoring_mode = server_monitoring_mode

self._topology_id = ObjectId()
if topology_id is not None:
self._topology_id = topology_id
else:
self._topology_id = ObjectId()
# Store the allocation traceback to catch unclosed clients in the
# test suite.
self._stack = "".join(traceback.format_stack()[:-2])
Expand Down
39 changes: 30 additions & 9 deletions pymongo/synchronous/mongo_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@
)
from pymongo.read_preferences import ReadPreference, _ServerMode
from pymongo.results import ClientBulkWriteResult
from pymongo.server_description import ServerDescription
from pymongo.server_selectors import writable_server_selector
from pymongo.server_type import SERVER_TYPE
from pymongo.synchronous import client_session, database, uri_parser
Expand Down Expand Up @@ -777,7 +778,7 @@ def __init__(
keyword_opts["document_class"] = doc_class
self._resolve_srv_info: dict[str, Any] = {"keyword_opts": keyword_opts}

seeds = set()
self._seeds = set()
is_srv = False
username = None
password = None
Expand All @@ -802,18 +803,18 @@ def __init__(
srv_max_hosts=srv_max_hosts,
)
is_srv = entity.startswith(SRV_SCHEME)
seeds.update(res["nodelist"])
self._seeds.update(res["nodelist"])
username = res["username"] or username
password = res["password"] or password
dbase = res["database"] or dbase
opts = res["options"]
fqdn = res["fqdn"]
else:
seeds.update(split_hosts(entity, self._port))
if not seeds:
self._seeds.update(split_hosts(entity, self._port))
if not self._seeds:
raise ConfigurationError("need to specify at least one host")

for hostname in [node[0] for node in seeds]:
for hostname in [node[0] for node in self._seeds]:
if _detect_external_db(hostname):
break

Expand All @@ -836,7 +837,7 @@ def __init__(
srv_service_name = opts.get("srvServiceName", common.SRV_SERVICE_NAME)

srv_max_hosts = srv_max_hosts or opts.get("srvmaxhosts")
opts = self._normalize_and_validate_options(opts, seeds)
opts = self._normalize_and_validate_options(opts, self._seeds)

# Username and password passed as kwargs override user info in URI.
username = opts.get("username", username)
Expand All @@ -855,7 +856,7 @@ def __init__(
"username": username,
"password": password,
"dbase": dbase,
"seeds": seeds,
"seeds": self._seeds,
"fqdn": fqdn,
"srv_service_name": srv_service_name,
"pool_class": pool_class,
Expand All @@ -871,8 +872,7 @@ def __init__(
self._options.read_concern,
)

if not is_srv:
self._init_based_on_options(seeds, srv_max_hosts, srv_service_name)
self._init_based_on_options(self._seeds, srv_max_hosts, srv_service_name)

self._opened = False
self._closed = False
Expand Down Expand Up @@ -973,6 +973,7 @@ def _init_based_on_options(
srv_service_name=srv_service_name,
srv_max_hosts=srv_max_hosts,
server_monitoring_mode=self._options.server_monitoring_mode,
topology_id=self._topology_settings._topology_id if self._topology_settings else None,
)
if self._options.auto_encryption_opts:
from pymongo.synchronous.encryption import _Encrypter
Expand Down Expand Up @@ -1203,6 +1204,16 @@ def topology_description(self) -> TopologyDescription:

.. versionadded:: 4.0
"""
if self._topology is None:
servers = {(host, port): ServerDescription((host, port)) for host, port in self._seeds}
return TopologyDescription(
TOPOLOGY_TYPE.Unknown,
servers,
None,
None,
None,
self._topology_settings,
)
return self._topology.description

@property
Expand All @@ -1216,6 +1227,8 @@ def nodes(self) -> FrozenSet[_Address]:
to any servers, or a network partition causes it to lose connection
to all servers.
"""
if self._topology is None:
return frozenset()
description = self._topology.description
return frozenset(s.address for s in description.known_servers)

Expand Down Expand Up @@ -1570,6 +1583,8 @@ def address(self) -> Optional[tuple[str, int]]:

.. versionadded:: 3.0
"""
if self._topology is None:
self._get_topology()
topology_type = self._topology._description.topology_type
if (
topology_type == TOPOLOGY_TYPE.Sharded
Expand All @@ -1592,6 +1607,8 @@ def primary(self) -> Optional[tuple[str, int]]:
.. versionadded:: 3.0
MongoClient gained this property in version 3.0.
"""
if self._topology is None:
self._get_topology()
return self._topology.get_primary() # type: ignore[return-value]

@property
Expand All @@ -1605,6 +1622,8 @@ def secondaries(self) -> set[_Address]:
.. versionadded:: 3.0
MongoClient gained this property in version 3.0.
"""
if self._topology is None:
self._get_topology()
return self._topology.get_secondaries()

@property
Expand All @@ -1615,6 +1634,8 @@ def arbiters(self) -> set[_Address]:
connected to a replica set, there are no arbiters, or this client was
created without the `replicaSet` option.
"""
if self._topology is None:
self._get_topology()
return self._topology.get_arbiters()

@property
Expand Down
7 changes: 5 additions & 2 deletions pymongo/synchronous/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(
srv_service_name: str = common.SRV_SERVICE_NAME,
srv_max_hosts: int = 0,
server_monitoring_mode: str = common.SERVER_MONITORING_MODE,
topology_id: Optional[ObjectId] = None,
):
"""Represent MongoClient's configuration.

Expand Down Expand Up @@ -78,8 +79,10 @@ def __init__(
self._srv_service_name = srv_service_name
self._srv_max_hosts = srv_max_hosts or 0
self._server_monitoring_mode = server_monitoring_mode

self._topology_id = ObjectId()
if topology_id is not None:
self._topology_id = topology_id
else:
self._topology_id = ObjectId()
# Store the allocation traceback to catch unclosed clients in the
# test suite.
self._stack = "".join(traceback.format_stack()[:-2])
Expand Down
Loading