-
Notifications
You must be signed in to change notification settings - Fork 44
/
cache.py
295 lines (227 loc) · 10 KB
/
cache.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
import importlib
import json
import os
import uuid
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Dict, List, Optional
import base64
import hashlib
def get_home_dir():
return os.getenv("TRITON_HOME", Path.home())
def default_cache_dir():
return os.path.join(get_home_dir(), ".triton", "cache")
def default_override_dir():
return os.path.join(get_home_dir(), ".triton", "override")
def default_dump_dir():
return os.path.join(get_home_dir(), ".triton", "dump")
class CacheManager(ABC):
def __init__(self, key):
pass
@abstractmethod
def get_file(self, filename) -> Optional[str]:
pass
@abstractmethod
def put(self, data, filename, binary=True) -> str:
pass
@abstractmethod
def get_group(self, filename: str) -> Optional[Dict[str, str]]:
pass
@abstractmethod
def put_group(self, filename: str, group: Dict[str, str]):
pass
class FileCacheManager(CacheManager):
def __init__(self, key, override=False, dump=False):
self.key = key
self.lock_path = None
if dump:
self.cache_dir = os.getenv("TRITON_DUMP_DIR", "").strip() or default_dump_dir()
self.cache_dir = os.path.join(self.cache_dir, self.key)
self.lock_path = os.path.join(self.cache_dir, "lock")
os.makedirs(self.cache_dir, exist_ok=True)
elif override:
self.cache_dir = os.getenv("TRITON_OVERRIDE_DIR", "").strip() or default_override_dir()
self.cache_dir = os.path.join(self.cache_dir, self.key)
else:
# create cache directory if it doesn't exist
self.cache_dir = os.getenv("TRITON_CACHE_DIR", "").strip() or default_cache_dir()
if self.cache_dir:
self.cache_dir = os.path.join(self.cache_dir, self.key)
self.lock_path = os.path.join(self.cache_dir, "lock")
os.makedirs(self.cache_dir, exist_ok=True)
else:
raise RuntimeError("Could not create or locate cache dir")
def _make_path(self, filename) -> str:
return os.path.join(self.cache_dir, filename)
def has_file(self, filename) -> bool:
if not self.cache_dir:
raise RuntimeError("Could not create or locate cache dir")
return os.path.exists(self._make_path(filename))
def get_file(self, filename) -> Optional[str]:
if self.has_file(filename):
return self._make_path(filename)
else:
return None
def get_group(self, filename: str) -> Optional[Dict[str, str]]:
grp_filename = f"__grp__{filename}"
if not self.has_file(grp_filename):
return None
grp_filepath = self._make_path(grp_filename)
with open(grp_filepath) as f:
grp_data = json.load(f)
child_paths = grp_data.get("child_paths", None)
# Invalid group data.
if child_paths is None:
return None
result = {}
for c, p in child_paths.items():
if os.path.exists(p):
result[c] = p
return result
# Note a group of pushed files as being part of a group
def put_group(self, filename: str, group: Dict[str, str]) -> str:
if not self.cache_dir:
raise RuntimeError("Could not create or locate cache dir")
grp_contents = json.dumps({"child_paths": group})
grp_filename = f"__grp__{filename}"
return self.put(grp_contents, grp_filename, binary=False)
def put(self, data, filename, binary=True) -> str:
if not self.cache_dir:
raise RuntimeError("Could not create or locate cache dir")
binary = isinstance(data, bytes)
if not binary:
data = str(data)
assert self.lock_path is not None
filepath = self._make_path(filename)
# Random ID to avoid any collisions
rnd_id = str(uuid.uuid4())
# we use the PID in case a bunch of these around so we can see what PID made it
pid = os.getpid()
# use temp dir to be robust against program interruptions
temp_dir = os.path.join(self.cache_dir, f"tmp.pid_{pid}_{rnd_id}")
os.makedirs(temp_dir, exist_ok=True)
temp_path = os.path.join(temp_dir, filename)
mode = "wb" if binary else "w"
with open(temp_path, mode) as f:
f.write(data)
# Replace is guaranteed to be atomic on POSIX systems if it succeeds
# so filepath cannot see a partial write
os.replace(temp_path, filepath)
os.removedirs(temp_dir)
return filepath
class RemoteCacheBackend:
"""
A backend implementation for accessing a remote/distributed cache.
"""
def __init__(self, key: str):
pass
@abstractmethod
def get(self, filenames: List[str]) -> Dict[str, bytes]:
pass
@abstractmethod
def put(self, filename: str, data: bytes):
pass
class RedisRemoteCacheBackend(RemoteCacheBackend):
def __init__(self, key):
import redis
self._key = key
self._key_fmt = os.environ.get("TRITON_REDIS_KEY_FORMAT", "triton:{key}:{filename}")
self._redis = redis.Redis(
host=os.environ.get("TRITON_REDIS_HOST", "localhost"),
port=int(os.environ.get("TRITON_REDIS_PORT", 6379)),
)
def _get_key(self, filename: str) -> str:
return self._key_fmt.format(key=self._key, filename=filename)
def get(self, filenames: List[str]) -> Dict[str, str]:
results = self._redis.mget([self._get_key(f) for f in filenames])
return {filename: result for filename, result in zip(filenames, results) if result is not None}
def put(self, filename: str, data: bytes) -> Dict[str, bytes]:
self._redis.set(self._get_key(filename), data)
class RemoteCacheManager(CacheManager):
def __init__(self, key, override=False, dump=False):
# Setup backend pointed too by `TRITON_REMOTE_CACHE_BACKEND`.
remote_cache_manager = os.environ["TRITON_REMOTE_CACHE_BACKEND"]
module_path, clz_nme = remote_cache_manager.split(":")
module = importlib.import_module(module_path)
remote_cache_cls = getattr(module, clz_nme)
self._backend = remote_cache_cls(key)
self._override = override
self._dump = dump
# Use a `FileCacheManager` to materialize remote cache paths locally.
self._file_cache_manager = FileCacheManager(key, override=override, dump=dump)
def _materialize(self, filename: str, data: bytes):
# We use a backing `FileCacheManager` to provide the materialized data.
return self._file_cache_manager.put(data, filename, binary=True)
def get_file(self, filename: str) -> Optional[str]:
# We don't handle the dump/override cases.
if self._dump or self._override:
return self._file_cache_manager.get_file(filename)
# We always check the remote cache backend -- even if our internal file-
# based cache has the item -- to make sure LRU accounting works as
# expected.
results = self._backend.get([filename])
if len(results) == 0:
return None
(_, data), = results.items()
return self._materialize(filename, data)
def put(self, data, filename: str, binary=True) -> str:
# We don't handle the dump/override cases.
if self._dump or self._override:
return self._file_cache_manager.put(data, filename, binary=binary)
if not isinstance(data, bytes):
data = str(data).encode("utf-8")
self._backend.put(filename, data)
return self._materialize(filename, data)
def get_group(self, filename: str) -> Optional[Dict[str, str]]:
# We don't handle the dump/override cases.
if self._dump or self._override:
return self._file_cache_manager.get_group(filename)
grp_filename = f"__grp__{filename}"
grp_filepath = self.get_file(grp_filename)
if grp_filepath is None:
return None
with open(grp_filepath) as f:
grp_data = json.load(f)
child_paths = grp_data.get("child_paths", None)
result = None
# Found group data.
if child_paths is not None:
result = {}
for child_path, data in self._backend.get(child_paths).items():
result[child_path] = self._materialize(child_path, data)
return result
def put_group(self, filename: str, group: Dict[str, str]):
# We don't handle the dump/override cases.
if self._dump or self._override:
return self._file_cache_manager.put_group(filename, group)
grp_contents = json.dumps({"child_paths": sorted(list(group.keys()))})
grp_filename = f"__grp__{filename}"
return self.put(grp_contents, grp_filename)
__cache_cls = FileCacheManager
__cache_cls_nme = "DEFAULT"
def _base32(key):
# Assume key is a hex string.
return base64.b32encode(bytes.fromhex(key)).decode("utf-8").rstrip("=")
def get_cache_manager(key) -> CacheManager:
import os
user_cache_manager = os.environ.get("TRITON_CACHE_MANAGER", None)
global __cache_cls
global __cache_cls_nme
if user_cache_manager is not None and user_cache_manager != __cache_cls_nme:
module_path, clz_nme = user_cache_manager.split(":")
module = importlib.import_module(module_path)
__cache_cls = getattr(module, clz_nme)
__cache_cls_nme = user_cache_manager
return __cache_cls(_base32(key))
def get_override_manager(key) -> CacheManager:
return __cache_cls(_base32(key), override=True)
def get_dump_manager(key) -> CacheManager:
return __cache_cls(_base32(key), dump=True)
def make_so_cache_key(version_hash, signature, constants, ids, **kwargs):
# Get unique key for the compiled code
signature = {k: 'ptr' if v[0] == '*' else v for k, v in signature.items()}
key = f"{version_hash}-{''.join(signature.values())}-{constants}-{ids}"
for kw in kwargs:
key = f"{key}-{kwargs.get(kw)}"
key = hashlib.sha256(key.encode("utf-8")).hexdigest()
return _base32(key)