-
Notifications
You must be signed in to change notification settings - Fork 2
/
utility.py
148 lines (117 loc) · 4.37 KB
/
utility.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
"""General standalone utilities."""
import contextlib
import datetime
import multiprocessing
from pathlib import Path
from typing import (
Any,
Callable,
ContextManager,
Dict,
Generator,
Iterable,
Sequence,
Tuple,
TypeVar,
Union,
)
import numpy as np
import tensorflow as tf
from tensorflow import keras
def split_seed(seed: int, n: int) -> Tuple[int, ...]:
"""Split a random seed into n seeds.
Note that the original seed should not be used after calling this.
"""
return tuple(
int(seq.generate_state(1)[0]) for seq in np.random.SeedSequence(seed).spawn(n)
)
T = TypeVar("T")
def remove_keys(dict_: Dict[str, T], *keys: str) -> Dict[str, T]:
"""Return a new dictionary with specific keys removed."""
return {k: v for k, v in dict_.items() if k not in keys}
def to_jsonable(obj: Any) -> Any:
"""A decent default=? function for json.dump."""
if isinstance(obj, Path):
return str(obj)
if isinstance(obj, (np.ndarray, np.number)):
return obj.tolist()
if isinstance(obj, datetime.date): # datetime.datetime is a subclass
return obj.isoformat()
raise TypeError(f"Type '{type(obj).__name__}' is not JSON-serialisable")
Logger = Callable[..., None]
@contextlib.contextmanager
def logging(
*loggers: Union[ContextManager[Logger], Logger]
) -> Generator[Logger, None, None]:
"""A context manager that delegates logging calls to multiple "loggers".
Arguments are either:
- Callable actions
- Context managers that return callable actions
For example:
@contextlib.contextmanager
def log_to_file(path: Path) -> None:
with path.open("w") as f:
yield lambda item: print(item, file=f)
with logging(print, log_to_file(Path("log.txt"))) as log:
log("item one")
log("item two")
"""
with contextlib.ExitStack() as stack:
functions = [
stack.enter_context(logger) # type:ignore[arg-type]
if hasattr(logger, "__enter__")
else logger
for logger in loggers
]
def apply(*args: Any, **kwargs: Any) -> None:
for fn in functions:
fn(*args, **kwargs)
yield apply
def named_layers(
layer: Union[keras.layers.Layer, Sequence[keras.layers.Layer]],
prefix: Tuple[str, ...] = (),
) -> Iterable[Tuple[str, keras.layers.Layer]]:
"""Walk a layer, recursively trying to find sublayers."""
if isinstance(layer, (list, tuple)):
for n, child in enumerate(layer):
yield from named_layers(child, prefix + (str(n),))
if isinstance(layer, keras.layers.Layer):
yield (".".join(prefix), layer)
for attr, child in vars(layer).items():
if attr.startswith("_"):
continue
if isinstance(child, (list, tuple, keras.layers.Layer)):
yield from named_layers(child, prefix + (attr,))
def named_weights(
layer: keras.layers.Layer, recursive: bool = True
) -> Iterable[Tuple[str, tf.Variable]]:
"""Walk a layer to find weight variables with full path names.
recursive -- bool -- if `False`, only look at direct weights owned by this layer
"""
sublayers = named_layers(layer) if recursive else [("", layer)]
for name, sublayer in sublayers:
for attr, child in vars(sublayer).items():
if not attr.startswith("_") and isinstance(child, tf.Variable):
yield (f"{name}.{attr}" if name else attr, child)
def _runner(
queue: multiprocessing.Queue, # type:ignore[type-arg]
command: Callable[..., T],
args: Dict[str, Any],
) -> None:
queue.put_nowait(command(**args))
def run_in_subprocess(command: Callable[..., T], **args: Any) -> T:
"""Run a command synchronously in a (non-daemon) subprocess."""
# We'd prefer to use a simple multiprocessing.Pool here, but I can't find a
# way to make the workers non-daemonic
queue = multiprocessing.Manager().Queue()
process = multiprocessing.get_context("spawn").Process(
target=_runner, args=(queue, command, args)
)
process.start()
process.join()
if process.exitcode:
raise multiprocessing.ProcessError(
f"Process exited with code {process.exitcode}"
)
return queue.get() # type:ignore[no-any-return]