-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
189 lines (151 loc) · 6.59 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
"""Utils."""
import collections
from typing import Any, Callable, List, Tuple
import jax
import jax.numpy as jnp
from absl import logging
from clu import metric_writers
PyTree = Any
def compute_flops(n_ctx: int,
n_vocab: int,
n_layers: int,
n_heads: int,
d_model: int,
ff_ratio: float = 4,
fuse_multiply_add: bool = False) -> float:
"""Computes forward pass FLOPs per sequence based on Deepmind's Chinchilla paper.
Ref: https://arxiv.org/pdf/2203.15556
"""
fma_factor = 1 if fuse_multiply_add else 2
flops = 0
flops += fma_factor * (3 * n_ctx * d_model**2) # QKV projection.
flops += fma_factor * (n_ctx**2 * d_model) # Q @ K.T
flops += 3 * n_ctx**2 * n_heads # Softmax
flops += fma_factor * (n_ctx**2 * d_model) # Attention reduction.
flops += fma_factor * (n_ctx * d_model**2) # Out projection
flops += fma_factor * (n_ctx * d_model) * (2 * ff_ratio * d_model) # MLPBlock.
flops *= n_layers
flops += fma_factor * (n_ctx * d_model) # Embeddings forward.
flops += fma_factor * (n_ctx * n_vocab * d_model) # Project to n_vocab logits.
return flops
def compute_flops_v2(apply_fn: Callable,
dummy_inputs: list,
fuse_multiply_add: bool = True) -> float:
"""Compute the number of FLOPs of a Flax model.
Do not use.
This interferes with backends (e.g. if cuDNN SDPA API is enabled) and gives inconsistent counts.
Possibly replace `compute_flops` with this method if JAX inconsistency of FLOP analysis is resolved.
"""
analysis = jax.jit(apply_fn, backend="cpu").lower(*dummy_inputs).cost_analysis()
# Not all JAX backends return analysis.
# See: https://jax.readthedocs.io/en/latest/aot.html#debug-information-and-analyses-when-available
# Ideally we should be able to get flops analysis with `CPU` backend to save GPU memory.
flops = 0 if analysis["flops"] == -1.0 else analysis["flops"]
if fuse_multiply_add:
flops = flops / 2
return flops
def recover_tree(keys, values, sep: str = "."):
"""Unflatten key-value pairs to a nested dictionary where each key is `sep` path separated."""
tree = {}
sub_trees = collections.defaultdict(list)
for k, v in zip(keys, values):
if sep not in k:
tree[k] = v
else:
left, right = k.split(sep, 1)
sub_trees[left].append((right, v))
for k, kv_pairs in sub_trees.items():
tree[k] = recover_tree(*zip(*kv_pairs))
return tree
def unreplicate_and_get(tree: PyTree) -> PyTree:
"""Fetches to CPU the first local copy of a `pmap` replicated tree."""
return jax.device_get(jax.tree_util.tree_map(lambda x: x[0], tree))
def tf_to_numpy(batch: PyTree) -> PyTree:
"""Zero-copy numpy conversion."""
return jax.tree.map(lambda x: x._numpy(), batch)
def shard_batches(batch: PyTree, num_devices: int = None) -> PyTree:
"""Shard batch to `num_devices` or as inferred from local device count."""
num_devices = num_devices or jax.local_device_count()
return jax.tree.map(lambda x: x.reshape((num_devices, -1) + x.shape[1:]),
batch)
def get_cosine_lr_schedule(max_lr: float, min_lr: float, max_steps: int,
warmup_steps: int) -> Callable[[int], float]:
"""Cosine learning rate schedule.
Args:
max_lr: Peak learning rate.
min_lr: Minimum constant learning rate after cosine decay.
max_steps: Number of steps to decay over the entire training (including warmup).
warmup_steps: Number of steps to linearly increase learning rate from 0 to `max_lr`.
Returns:
A function that returns lr for requested step.
"""
def sched_fn(step: int) -> float:
decay_ratio = (step - warmup_steps) / (max_steps - warmup_steps)
decay_ratio = jnp.clip(decay_ratio, 0.0, 1.0)
lr = min_lr + (max_lr - min_lr) * 0.5 * (1 + jnp.cos(decay_ratio * jnp.pi))
lr = jnp.minimum(lr, max_lr * step / warmup_steps)
return lr
return sched_fn
def log_summary(step: int,
metrics: List[dict],
extra_logs: dict = None,
writer: metric_writers.MetricWriter = None,
prefix: str = "train"):
"""Logs train summary and optionally writes summaries.
Args:
step: Integer, current step.
metrics: A list of metric dictionaries collected over steps.
extra_logs: A dict of addl. logs (e.g. learning rates).
writer: Optional metric writer to write summaries to a file.
prefix: Prefix to be applied on metric keys.
Returns:
Nothing.
"""
# Transpose: list of dicts to dict of lists.
metrics = jax.tree.map(lambda *vals: jnp.stack(vals), *metrics)
# Log only on main host.
if jax.process_index() == 0:
summaries = extra_logs or {}
summaries.update({
"/".join((prefix, key)): val.mean() for key, val in metrics.items()
})
# Log to stdout
for name, value in summaries.items():
logging.info(f"\u001b[35m[{step}]\u001b[0m {name}={float(value):.5f}")
if writer is not None:
writer.write_scalars(step, summaries)
def accumulate_gradient(value_and_grad_fn,
params: PyTree,
batch: PyTree,
accum_steps: int = 1) -> Tuple[jnp.ndarray, PyTree]:
"""Accumulates gradients over given steps.
Args:
value_and_grad_fn: Gradient function that does not return aux values.
params: Parameters, passed as first argument to `value_and_grad_fn`.
batch: Batch, passed as second argument to `value_and_grad_fn`.
accum_steps: Number of micro batches to accumulate over. Defaults to 1,
which means no gradients are accumulated.
Returns:
Tuple (loss, grads).
"""
if accum_steps > 1:
bs = next(iter(jax.tree.leaves(batch))).shape[0]
assert bs % accum_steps == 0, (
f"Invalid accum_steps {accum_steps} for batch size `{bs}")
microbatch_size = bs // accum_steps
logging.info("Accumulating with microbatch_size %d over %d steps.",
microbatch_size, accum_steps)
def get_microbatch(batch, i):
return jax.tree.map(
lambda t: jnp.reshape(t, (accum_steps, -1) + (t.shape[1:]))[i], batch)
# Initialize accumulator.
l, g = value_and_grad_fn(params, get_microbatch(batch, 0))
def accumulate(i, l_and_g):
l, g = l_and_g
l_i, g_i = value_and_grad_fn(params, get_microbatch(batch, i))
return (l + l_i, jax.tree.map(jnp.add, g, g_i))
# Average over accum_steps.
loss, grads = jax.lax.fori_loop(1, accum_steps, accumulate, (l, g))
return jax.tree.map(lambda x: x / accum_steps, (loss, grads))
else:
return value_and_grad_fn(params, batch)