Skip to content

Commit 510ff8a

Browse files
authored
Merge pull request #2863 from Suor/locks
remote: protect all remote client/session creation code with locks
2 parents 593ab17 + 8af1aee commit 510ff8a

File tree

8 files changed

+54
-45
lines changed

8 files changed

+54
-45
lines changed

dvc/remote/azure.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
import logging
55
import os
66
import re
7-
from datetime import datetime
8-
from datetime import timedelta
7+
from datetime import datetime, timedelta
8+
import threading
99

10-
from funcy import cached_property
10+
from funcy import cached_property, wrap_prop
1111

1212
from dvc.config import Config
1313
from dvc.path_info import CloudURLInfo
@@ -64,6 +64,7 @@ def __init__(self, repo, config):
6464
else self.path_cls.from_parts(scheme=self.scheme, netloc=bucket)
6565
)
6666

67+
@wrap_prop(threading.Lock())
6768
@cached_property
6869
def blob_service(self):
6970
from azure.storage.blob import BlockBlobService

dvc/remote/gs.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55
from functools import wraps
66
import io
77
import os.path
8+
import threading
89

9-
from funcy import cached_property
10+
from funcy import cached_property, wrap_prop
1011

1112
from dvc.config import Config
1213
from dvc.exceptions import DvcException
@@ -91,6 +92,7 @@ def __init__(self, repo, config):
9192
self.projectname = config.get(Config.SECTION_GCP_PROJECTNAME, None)
9293
self.credentialpath = config.get(Config.SECTION_GCP_CREDENTIALPATH)
9394

95+
@wrap_prop(threading.Lock())
9496
@cached_property
9597
def gs(self):
9698
from google.cloud.storage import Client

dvc/remote/http.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from __future__ import unicode_literals
22

33
import logging
4+
import threading
45

5-
from funcy import cached_property
6+
from funcy import cached_property, wrap_prop
67

78
from dvc.config import Config
89
from dvc.config import ConfigError
@@ -81,6 +82,7 @@ def get_file_checksum(self, path_info):
8182

8283
return etag
8384

85+
@wrap_prop(threading.Lock())
8486
@cached_property
8587
def _session(self):
8688
import requests

dvc/remote/oss.py

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33

44
import logging
55
import os
6+
import threading
7+
8+
from funcy import cached_property, wrap_prop
69

710
from dvc.config import Config
811
from dvc.path_info import CloudURLInfo
@@ -61,30 +64,29 @@ def __init__(self, repo, config):
6164
or "defaultSecret"
6265
)
6366

64-
self._bucket = None
65-
66-
@property
67+
@wrap_prop(threading.Lock())
68+
@cached_property
6769
def oss_service(self):
6870
import oss2
6971

70-
if self._bucket is None:
71-
logger.debug("URL {}".format(self.path_info))
72-
logger.debug("key id {}".format(self.key_id))
73-
logger.debug("key secret {}".format(self.key_secret))
74-
auth = oss2.Auth(self.key_id, self.key_secret)
75-
self._bucket = oss2.Bucket(
76-
auth, self.endpoint, self.path_info.bucket
72+
logger.debug("URL {}".format(self.path_info))
73+
logger.debug("key id {}".format(self.key_id))
74+
logger.debug("key secret {}".format(self.key_secret))
75+
76+
auth = oss2.Auth(self.key_id, self.key_secret)
77+
bucket = oss2.Bucket(auth, self.endpoint, self.path_info.bucket)
78+
79+
# Ensure bucket exists
80+
try:
81+
bucket.get_bucket_info()
82+
except oss2.exceptions.NoSuchBucket:
83+
bucket.create_bucket(
84+
oss2.BUCKET_ACL_PUBLIC_READ,
85+
oss2.models.BucketCreateConfig(
86+
oss2.BUCKET_STORAGE_CLASS_STANDARD
87+
),
7788
)
78-
try: # verify that bucket exists
79-
self._bucket.get_bucket_info()
80-
except oss2.exceptions.NoSuchBucket:
81-
self._bucket.create_bucket(
82-
oss2.BUCKET_ACL_PUBLIC_READ,
83-
oss2.models.BucketCreateConfig(
84-
oss2.BUCKET_STORAGE_CLASS_STANDARD
85-
),
86-
)
87-
return self._bucket
89+
return bucket
8890

8991
def remove(self, path_info):
9092
if path_info.scheme != self.scheme:

dvc/remote/pool.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from collections import deque
22
from contextlib import contextmanager
3+
import threading
34

4-
from funcy import memoize
5+
from funcy import memoize, wrap_with
56

67

78
@contextmanager
@@ -17,6 +18,7 @@ def get_connection(conn_func, *args, **kwargs):
1718
pool.release(conn)
1819

1920

21+
@wrap_with(threading.Lock())
2022
@memoize
2123
def get_pool(conn_func, *args, **kwargs):
2224
return Pool(conn_func, *args, **kwargs)

dvc/remote/s3.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33

44
import logging
55
import os
6+
import threading
67

7-
from funcy import cached_property
8+
from funcy import cached_property, wrap_prop
89

910
from dvc.config import Config
1011
from dvc.exceptions import DvcException
@@ -56,6 +57,7 @@ def __init__(self, repo, config):
5657
if shared_creds:
5758
os.environ.setdefault("AWS_SHARED_CREDENTIALS_FILE", shared_creds)
5859

60+
@wrap_prop(threading.Lock())
5961
@cached_property
6062
def s3(self):
6163
import boto3

dvc/remote/ssh/__init__.py

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,9 @@
88
import os
99
import threading
1010
from concurrent.futures import ThreadPoolExecutor
11-
from contextlib import closing
12-
from contextlib import contextmanager
11+
from contextlib import closing, contextmanager
12+
13+
from funcy import memoize, wrap_with
1314

1415
import dvc.prompt as prompt
1516
from dvc.config import Config
@@ -24,8 +25,15 @@
2425
logger = logging.getLogger(__name__)
2526

2627

27-
saved_passwords = {}
28-
saved_passwords_lock = threading.Lock()
28+
@wrap_with(threading.Lock())
29+
@memoize
30+
def ask_password(host, user, port):
31+
return prompt.password(
32+
"Enter a private key passphrase or a password for "
33+
"host '{host}' port '{port}' user '{user}'".format(
34+
host=host, port=port, user=user
35+
)
36+
)
2937

3038

3139
class RemoteSSH(RemoteBASE):
@@ -120,21 +128,11 @@ def _try_get_ssh_config_keyfile(user_ssh_config):
120128
def ensure_credentials(self, path_info=None):
121129
if path_info is None:
122130
path_info = self.path_info
123-
host, user, port = path_info.host, path_info.user, path_info.port
131+
124132
# NOTE: we use the same password regardless of the server :(
125133
if self.ask_password and self.password is None:
126-
with saved_passwords_lock:
127-
server_key = (host, user, port)
128-
password = saved_passwords.get(server_key)
129-
130-
if password is None:
131-
saved_passwords[server_key] = password = prompt.password(
132-
"Enter a private key passphrase or a password for "
133-
"host '{host}' port '{port}' user '{user}'".format(
134-
host=host, port=port, user=user
135-
)
136-
)
137-
self.password = password
134+
host, user, port = path_info.host, path_info.user, path_info.port
135+
self.password = ask_password(host, user, port)
138136

139137
def ssh(self, path_info):
140138
self.ensure_credentials(path_info)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def run(self):
7474
"humanize>=0.5.1",
7575
"PyYAML>=5.1.2",
7676
"ruamel.yaml>=0.16.1",
77-
"funcy>=1.12",
77+
"funcy>=1.14",
7878
"pathspec>=0.6.0",
7979
"shortuuid>=0.5.0",
8080
"tqdm>=4.38.0,<5",

0 commit comments

Comments
 (0)