Skip to content

Commit

Permalink
line length 80
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Jul 21, 2023
1 parent 3ab11fe commit 0f4e662
Show file tree
Hide file tree
Showing 44 changed files with 359 additions and 142 deletions.
8 changes: 6 additions & 2 deletions .github/analytics/get_repo_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,9 @@ def main(_):
)
issues.get()

df_issues = df_issues0 = pd.DataFrame(list(_get_issues_features(issues.raw_data)))
df_issues = df_issues0 = pd.DataFrame(
list(_get_issues_features(issues.raw_data))
)
df_issues['issue_response_time'] = (
df_issues['time_labeled_or_converted'] - df_issues['created_at']
)
Expand All @@ -350,7 +352,9 @@ def main(_):
prs.get()

df_prs = df_prs0 = pd.DataFrame(list(_get_pr_features(prs.raw_data)))
time_response = df_prs[['time_labeled_or_assigned', 'time_review']].min(axis=1)
time_response = df_prs[['time_labeled_or_assigned', 'time_review']].min(
axis=1
)
df_prs['pr_response_time'] = time_response - df_prs['ready_for_review_at']
df_prs['pr_resolution_time'] = (
df_prs['time_merged_or_closed'] - df_prs['ready_for_review_at']
Expand Down
2 changes: 1 addition & 1 deletion dev/update_requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
Alternatively, the list can also be provided from the local environment with:
python dev --versions="$(pip freeze | sed s/==/-/g) flax-0.3.6"
python dev --versions="$(pip freeze | sed s/==/-/g) flax-0.3.6"
"""

import pathlib
Expand Down
6 changes: 4 additions & 2 deletions examples/cloud/launch_gce.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,8 @@ def launch_gce(*, vm_name: str, startup_script: str):


def print_howto(login_args: Sequence[str]):
print(f"""
print(
f"""
###############################################################################
###############################################################################
Expand All @@ -226,7 +227,8 @@ def print_howto(login_args: Sequence[str]):
###############################################################################
###############################################################################
""")
"""
)


def main(_):
Expand Down
4 changes: 3 additions & 1 deletion examples/imagenet/train_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ def test_create_model_local(self):
Uses smaller inputs than `test_create_model` to due to higher compute.
"""
model = train.create_model(model_cls=models._ResNet1Local, half_precision=False) # pylint: disable=protected-access
model = train.create_model(
model_cls=models._ResNet1Local, half_precision=False
) # pylint: disable=protected-access
params, batch_stats = train.initialized(random.PRNGKey(0), 64, model)
variables = {'params': params, 'batch_stats': batch_stats}
x = random.normal(random.PRNGKey(1), (1, 64, 64, 3))
Expand Down
4 changes: 3 additions & 1 deletion examples/lm1b/temperature_sampler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ def tokens_to_logits(tokens, cache):
logits = logits.squeeze(axis=1)
return logits, cache

new_tokens = temperature_sample(tokens, cache, tokens_to_logits, key, topk=5)
new_tokens = temperature_sample(
tokens, cache, tokens_to_logits, key, topk=5
)

np.testing.assert_array_equal(new_tokens, [[5, 6, 7, 8]])

Expand Down
12 changes: 9 additions & 3 deletions examples/lm1b/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,9 @@ def per_host_sum_pmap(in_tree):
host_psum = jax.pmap(lambda x: jax.lax.psum(x, "i"), "i", devices=devices)

def pre_pmap(xs):
return jax.tree_util.tree_map(lambda x: jnp.broadcast_to(x, (1,) + x.shape), xs)
return jax.tree_util.tree_map(
lambda x: jnp.broadcast_to(x, (1,) + x.shape), xs
)

def post_pmap(xs):
return jax.tree_util.tree_map(lambda x: x[0], xs)
Expand Down Expand Up @@ -525,7 +527,9 @@ def encode_strings(strs, max_len):

# Shard data to devices and do a training step.
with jax.profiler.StepTraceAnnotation("train", step_num=step):
batch = common_utils.shard(jax.tree_util.tree_map(np.asarray, next(train_iter)))
batch = common_utils.shard(
jax.tree_util.tree_map(np.asarray, next(train_iter))
)
state, metrics = p_train_step(state, batch, dropout_rng=dropout_rngs)
train_metrics.append(metrics)

Expand All @@ -542,7 +546,9 @@ def encode_strings(strs, max_len):
lr = train_metrics.pop("learning_rate").mean()
metrics_sums = jax.tree_util.tree_map(jnp.sum, train_metrics)
denominator = metrics_sums.pop("denominator")
summary = jax.tree_util.tree_map(lambda x: x / denominator, metrics_sums) # pylint: disable=cell-var-from-loop
summary = jax.tree_util.tree_map(
lambda x: x / denominator, metrics_sums
) # pylint: disable=cell-var-from-loop
summary["learning_rate"] = lr
summary["perplexity"] = jnp.clip(
jnp.exp(summary["loss"]), a_max=1.0e4
Expand Down
4 changes: 3 additions & 1 deletion examples/nlp_seq/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,9 @@ def eval_step(params, batch):
tick = time.time()
best_dev_score = 0
for step, batch in zip(range(num_train_steps), train_iter):
batch = common_utils.shard(jax.tree_util.tree_map(lambda x: x._numpy(), batch)) # pylint: disable=protected-access
batch = common_utils.shard(
jax.tree_util.tree_map(lambda x: x._numpy(), batch)
) # pylint: disable=protected-access

state, metrics = p_train_step(state, batch, dropout_rng=dropout_rngs)
metrics_all.append(metrics)
Expand Down
8 changes: 6 additions & 2 deletions examples/ppo/ppo_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,9 @@ def get_experience(
sim_state = sim.conn.recv()
sim_states.append(sim_state)
sim_states = np.concatenate(sim_states, axis=0)
log_probs, values = agent.policy_action(state.apply_fn, state.params, sim_states)
log_probs, values = agent.policy_action(
state.apply_fn, state.params, sim_states
)
log_probs, values = jax.device_get((log_probs, values))
probs = np.exp(np.array(log_probs))
for i, sim in enumerate(simulators):
Expand Down Expand Up @@ -343,7 +345,9 @@ def train(
score = test_episodes.policy_test(1, state.apply_fn, state.params, game)
frames = step * config.num_agents * config.actor_steps
summary_writer.scalar('game_score', score, frames)
logging.info('Step %s:\nframes seen %s\nscore %s\n\n', step, frames, score)
logging.info(
'Step %s:\nframes seen %s\nscore %s\n\n', step, frames, score
)

# Core training code.
alpha = (
Expand Down
4 changes: 3 additions & 1 deletion examples/seq2seq/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,9 @@ def __call__(
encoding format).
"""
# Encode inputs.
encoder = nn.RNN(nn.LSTMCell(self.hidden_size), return_carry=True, name='encoder')
encoder = nn.RNN(
nn.LSTMCell(self.hidden_size), return_carry=True, name='encoder'
)
decoder = nn.RNN(
DecoderLSTMCell(
decoder_inputs.shape[-1], self.teacher_force, self.vocab_size
Expand Down
8 changes: 6 additions & 2 deletions examples/sst2/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,12 +197,16 @@ def setup(self):

def __call__(self, embedded_inputs, lengths):
# Forward LSTM.
initial_state = self.forward_lstm.initialize_carry(embedded_inputs[:, 0].shape)
initial_state = self.forward_lstm.initialize_carry(
embedded_inputs[:, 0].shape
)
_, forward_outputs = self.forward_lstm(initial_state, embedded_inputs)

# Backward LSTM.
reversed_inputs = flip_sequences(embedded_inputs, lengths)
initial_state = self.backward_lstm.initialize_carry(reversed_inputs[:, 0].shape)
initial_state = self.backward_lstm.initialize_carry(
reversed_inputs[:, 0].shape
)
_, backward_outputs = self.backward_lstm(initial_state, reversed_inputs)
backward_outputs = flip_sequences(backward_outputs, lengths)

Expand Down
20 changes: 12 additions & 8 deletions examples/wmt/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,9 @@ def beam_init(batch_size, beam_size, max_decode_len, cache):
finished_seqs0 = jnp.zeros((batch_size, beam_size, max_decode_len), jnp.int32)
finished_flags0 = jnp.zeros((batch_size, beam_size), jnp.bool_)
# add beam dimension to attention cache pytree elements
beam_cache0 = jax.tree_util.tree_map(lambda x: add_beam_dim(x, beam_size), cache)
beam_cache0 = jax.tree_util.tree_map(
lambda x: add_beam_dim(x, beam_size), cache
)
return BeamState(
cur_index=cur_index0,
live_logprobs=live_logprobs0,
Expand Down Expand Up @@ -350,13 +352,15 @@ def beam_search_loop_body_fn(state):
[state.finished_flags, newly_finished], axis=1
)
# --> [batch, beams, length], [batch, beams], [batch, beams]
top_finished_seq, top_finished_scores, top_finished_flags = (
gather_topk_beams(
[finished_seqs, finished_scores, finished_flags],
finished_scores,
batch_size,
beam_size,
)
(
top_finished_seq,
top_finished_scores,
top_finished_flags,
) = gather_topk_beams(
[finished_seqs, finished_scores, finished_flags],
finished_scores,
batch_size,
beam_size,
)

return BeamState(
Expand Down
12 changes: 9 additions & 3 deletions examples/wmt/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,9 @@ def per_host_sum_pmap(in_tree):
host_psum = jax.pmap(lambda x: jax.lax.psum(x, "i"), "i", devices=devices)

def pre_pmap(xs):
return jax.tree_util.tree_map(lambda x: jnp.broadcast_to(x, (1,) + x.shape), xs)
return jax.tree_util.tree_map(
lambda x: jnp.broadcast_to(x, (1,) + x.shape), xs
)

def post_pmap(xs):
return jax.tree_util.tree_map(lambda x: x[0], xs)
Expand Down Expand Up @@ -626,7 +628,9 @@ def decode_tokens(toks):

# Shard data to devices and do a training step.
with jax.profiler.StepTraceAnnotation("train", step_num=step):
batch = common_utils.shard(jax.tree_util.tree_map(np.asarray, next(train_iter)))
batch = common_utils.shard(
jax.tree_util.tree_map(np.asarray, next(train_iter))
)
state, metrics = p_train_step(state, batch, dropout_rng=dropout_rngs)
train_metrics.append(metrics)

Expand All @@ -643,7 +647,9 @@ def decode_tokens(toks):
lr = train_metrics.pop("learning_rate").mean()
metrics_sums = jax.tree_util.tree_map(jnp.sum, train_metrics)
denominator = metrics_sums.pop("denominator")
summary = jax.tree_util.tree_map(lambda x: x / denominator, metrics_sums) # pylint: disable=cell-var-from-loop
summary = jax.tree_util.tree_map(
lambda x: x / denominator, metrics_sums
) # pylint: disable=cell-var-from-loop
summary["learning_rate"] = lr
summary = {"train_" + k: v for k, v in summary.items()}
writer.write_scalars(step, summary)
Expand Down
4 changes: 3 additions & 1 deletion flax/core/axes_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,9 @@ def body_fn(c, xs, init_mode=False):
'broadcasted variable has a data dependency on the scan body.'
)
out_flat.append(const)
broadcast_in, constants_out = jax.tree_util.tree_unflatten(out_tree(), out_flat)
broadcast_in, constants_out = jax.tree_util.tree_unflatten(
out_tree(), out_flat
)

c, ys = lax.scan(
body_fn, init, xs, length=length, reverse=reverse, unroll=unroll
Expand Down
4 changes: 3 additions & 1 deletion flax/core/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,9 @@ def body(mdl, c):

value: Any
names: LogicalNames = struct.field(pytree_node=False)
mesh: Optional[jax.sharding.Mesh] = struct.field(default=None, pytree_node=False)
mesh: Optional[jax.sharding.Mesh] = struct.field(
default=None, pytree_node=False
)

def unbox(self, apply_constraint=True) -> Any:
"""Returns the wrapped value with the partitioning applied as a sharding constraint."""
Expand Down
3 changes: 1 addition & 2 deletions flax/core/nn/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,7 @@ def dense_general(
if set(batch_dims) != set(range(max_dim + 1)):
raise ValueError(
'batch_dims %s must be consecutive leading '
'dimensions starting from 0.'
% str(batch_dims)
'dimensions starting from 0.' % str(batch_dims)
)

ndim = inputs.ndim
Expand Down
4 changes: 3 additions & 1 deletion flax/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ def __init__(self):
class FlaxError(Exception):

def __init__(self, message):
error_page = 'https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html'
error_page = (
'https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html'
)
module_name = self.__class__.__module__
class_name = self.__class__.__name__
error_msg = f'{message} ({error_page}#{module_name}.{class_name})'
Expand Down
4 changes: 3 additions & 1 deletion flax/jax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,9 @@ def transpose_out(x):

def body_wrapper(c, xs):
if keepdims:
xs = jax.tree_util.tree_map(lambda x: x.reshape((1,) * len(axis) + x.shape), xs)
xs = jax.tree_util.tree_map(
lambda x: x.reshape((1,) * len(axis) + x.shape), xs
)
xs = jax.tree_util.tree_map(transpose_out, xs)
c, ys = body_fn(c, xs)
if keepdims:
Expand Down
13 changes: 9 additions & 4 deletions flax/linen/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,9 @@ class MultiHeadDotProductAttention(Module):
deterministic: Optional[bool] = None
precision: PrecisionLike = None
kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init
bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.zeros_init()
bias_init: Callable[
[PRNGKey, Shape, Dtype], Array
] = initializers.zeros_init()
use_bias: bool = True
attention_fn: Callable[..., Array] = dot_product_attention
decode: bool = False
Expand Down Expand Up @@ -316,9 +318,12 @@ def __call__(
'cache', 'cache_index', lambda: jnp.array(0, dtype=jnp.int32)
)
if is_initialized:
*batch_dims, max_length, num_heads, depth_per_head = (
cached_key.value.shape
)
(
*batch_dims,
max_length,
num_heads,
depth_per_head,
) = cached_key.value.shape
# shape check of cached keys against query input
expected_shape = tuple(batch_dims) + (1, num_heads, depth_per_head)
if expected_shape != query.shape:
Expand Down
16 changes: 12 additions & 4 deletions flax/linen/experimental/layers_with_named_axes.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@


default_kernel_init = initializers.lecun_normal()
default_embed_init = initializers.variance_scaling(1.0, 'fan_in', 'normal', out_axis=0)
default_embed_init = initializers.variance_scaling(
1.0, 'fan_in', 'normal', out_axis=0
)


class Dense(nn.Module):
Expand All @@ -66,7 +68,9 @@ class Dense(nn.Module):
param_dtype: DType = jnp.float32
precision: PrecisionLike = None
kernel_init: Callable[[PRNGKey, Shape, DType], Array] = default_kernel_init
bias_init: Callable[[PRNGKey, Shape, DType], Array] = initializers.zeros_init()
bias_init: Callable[
[PRNGKey, Shape, DType], Array
] = initializers.zeros_init()
kernel_axes: Tuple[str, ...] = ()
dot_general: DotGeneralT = lax.dot_general

Expand Down Expand Up @@ -301,8 +305,12 @@ class LayerNorm(nn.Module):
param_dtype: DType = jnp.float32
use_bias: bool = True
use_scale: bool = True
bias_init: Callable[[PRNGKey, Shape, DType], Array] = initializers.zeros_init()
scale_init: Callable[[PRNGKey, Shape, DType], Array] = initializers.ones_init()
bias_init: Callable[
[PRNGKey, Shape, DType], Array
] = initializers.zeros_init()
scale_init: Callable[
[PRNGKey, Shape, DType], Array
] = initializers.ones_init()

@nn.compact
def __call__(self, x):
Expand Down
Loading

0 comments on commit 0f4e662

Please sign in to comment.