Skip to content

Commit

Permalink
fix remaining pyink issues
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Jul 21, 2023
1 parent 3ab11fe commit 717be27
Show file tree
Hide file tree
Showing 112 changed files with 460 additions and 1,280 deletions.
13 changes: 3 additions & 10 deletions .github/analytics/get_repo_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,7 @@ def send_query(query, query_type, cursor=None):
# TODO: Expand this, either by parsing the query type from the query
# directly or manually adding more query_types to the set
if query_type not in {'issues', 'pullRequests'}:
raise ValueError(
"Only 'issues' and 'pullRequests' queries are currently supported"
)
raise ValueError("Only 'issues' and 'pullRequests' queries are currently supported")
# TODO: Generalize this
# WARNING: The cursor injection depends on the specific structure of the
# query, this is the main reason why query types are limited to issues/PRs
Expand Down Expand Up @@ -162,9 +160,7 @@ def __init__(self, query_fname, query_type, repo_owner, repo_name):
self.load_query()

def load_query(self):
self.query = load_query_from_file(
self.query_fname, self.repo_owner, self.repo_name
)
self.query = load_query_from_file(self.query_fname, self.repo_owner, self.repo_name)

def get(self):
self.raw_data = get_all_responses(self.query, self.query_type)
Expand Down Expand Up @@ -228,10 +224,7 @@ def _get_pr_features(prs):
):
time_labeled_or_assigned = _to_datetime(event['createdAt'])

if (
time_labeled_or_assigned is None
and event['__typename'] == 'AssignedEvent'
):
if time_labeled_or_assigned is None and event['__typename'] == 'AssignedEvent':
time_labeled_or_assigned = _to_datetime(event['createdAt'])

if event['__typename'] in {'ClosedEvent', 'MergedEvent'}:
Expand Down
6 changes: 2 additions & 4 deletions dev/update_requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
Alternatively, the list can also be provided from the local environment with:
python dev --versions="$(pip freeze | sed s/==/-/g) flax-0.3.6"
python dev --versions="$(pip freeze | sed s/==/-/g) flax-0.3.6"
"""

import pathlib
Expand Down Expand Up @@ -87,9 +87,7 @@ def main(argv):
del argv

versions = {
pkg_version[: pkg_version.rindex('-')]: pkg_version[
pkg_version.rindex('-') + 1 :
]
pkg_version[: pkg_version.rindex('-')]: pkg_version[pkg_version.rindex('-') + 1 :]
for pkg_version in FLAGS.versions.replace('\n', ' ').split(' ')
if '-' in pkg_version
}
Expand Down
8 changes: 2 additions & 6 deletions docs/_ext/codediff.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,7 @@ def parse(
test_code = lines[idx + 1 :]
code_right = self._code_block(test_code)

output = self._tabs(
(title_left, code_left), (title_right, code_right), sync=sync
)
output = self._tabs((title_left, code_left), (title_right, code_right), sync=sync)

return output, test_code

Expand Down Expand Up @@ -109,9 +107,7 @@ class CodeDiffDirective(SphinxDirective):
}

def run(self):
table_code, test_code = CodeDiffParser().parse(
list(self.content), **self.options
)
table_code, test_code = CodeDiffParser().parse(list(self.content), **self.options)

# Create a test node as a comment node so it won't show up in the docs.
# We add attribute "testnodetype" so it is be picked up by the doctest
Expand Down
4 changes: 1 addition & 3 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,7 @@
'repository_url': 'https://github.com/google/flax',
'use_repository_button': True, # add a 'link to repository' button
'use_issues_button': False, # add an 'Open an Issue' button
'path_to_docs': (
'docs'
), # used to compute the path to launch notebooks in colab
'path_to_docs': ('docs'), # used to compute the path to launch notebooks in colab
'launch_buttons': {
'colab_url': 'https://colab.research.google.com/',
},
Expand Down
16 changes: 4 additions & 12 deletions docs/conf_sphinx_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,7 @@ def generate_autosummary_content(

def skip_member(obj: Any, name: str, objtype: str) -> bool:
try:
return app.emit_firstresult(
'autodoc-skip-member', objtype, name, obj, False, {}
)
return app.emit_firstresult('autodoc-skip-member', objtype, name, obj, False, {})
except Exception as exc:
ag.logger.warning(
__(
Expand All @@ -61,9 +59,7 @@ def skip_member(obj: Any, name: str, objtype: str) -> bool:
return False

def get_class_members(obj: Any) -> Dict[str, Any]:
members = sphinx.ext.autodoc.get_class_members(
obj, [qualname], ag.safe_getattr
)
members = sphinx.ext.autodoc.get_class_members(obj, [qualname], ag.safe_getattr)
return {name: member.object for name, member in members.items()}

def get_module_members(obj: Any) -> Dict[str, Any]:
Expand Down Expand Up @@ -163,12 +159,8 @@ def get_modules(obj: Any) -> Tuple[List[str], List[str]]:
elif doc.objtype == 'class':
ns['members'] = dir(obj)
ns['inherited_members'] = set(dir(obj)) - set(obj.__dict__.keys())
ns['methods'], ns['all_methods'] = get_members(
obj, {'method'}, ['__init__']
)
ns['attributes'], ns['all_attributes'] = get_members(
obj, {'attribute', 'property'}
)
ns['methods'], ns['all_methods'] = get_members(obj, {'method'}, ['__init__'])
ns['attributes'], ns['all_attributes'] = get_members(obj, {'attribute', 'property'})
ns['annotations'] = list(getattr(obj, '__annotations__', {}).keys())

if modname is None or qualname is None:
Expand Down
22 changes: 8 additions & 14 deletions examples/cloud/launch_gce.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,7 @@
'down the VM. Set to 0 to disable.'
),
)
flags.DEFINE_integer(
'accelerator_count', 8, help='Number of accelerators to use.'
)
flags.DEFINE_integer('accelerator_count', 8, help='Number of accelerators to use.')

# GCS configuration.
flags.DEFINE_string(
Expand All @@ -97,15 +95,11 @@
)

# Repo configuration.
flags.DEFINE_string(
'repo', 'https://github.com/google/flax', help='Git repository'
)
flags.DEFINE_string('repo', 'https://github.com/google/flax', help='Git repository')
flags.DEFINE_string('branch', 'main', help='Git repository')

# Example configuration.
flags.DEFINE_string(
'example', None, help='Name of Flax example (e.g. "imagenet").'
)
flags.DEFINE_string('example', None, help='Name of Flax example (e.g. "imagenet").')
flags.DEFINE_string(
'args',
'',
Expand Down Expand Up @@ -202,7 +196,8 @@ def launch_gce(*, vm_name: str, startup_script: str):


def print_howto(login_args: Sequence[str]):
print(f"""
print(
f"""
###############################################################################
###############################################################################
Expand All @@ -226,7 +221,8 @@ def print_howto(login_args: Sequence[str]):
###############################################################################
###############################################################################
""")
"""
)


def main(_):
Expand Down Expand Up @@ -274,9 +270,7 @@ def main(_):
login_true_args = login_args[:-1] + ['true']
while True:
try:
result = subprocess.run(
login_true_args, timeout=10, stderr=subprocess.PIPE
)
result = subprocess.run(login_true_args, timeout=10, stderr=subprocess.PIPE)
if result.returncode == 0:
break
stderr = result.stderr.decode('utf8')
Expand Down
8 changes: 2 additions & 6 deletions examples/imagenet/imagenet_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,7 @@ class ImagenetBenchmark(Benchmark):
"""Benchmarks for the ImageNet Flax example."""

@flagsaver
def _test_8x_v100_half_precision(
self, num_epochs: int, min_accuracy, max_accuracy
):
def _test_8x_v100_half_precision(self, num_epochs: int, min_accuracy, max_accuracy):
"""Utility to benchmark ImageNet on 8xV100 GPUs. Use in your test func."""
# Prepare and set flags defined in main.py.
config = config_lib.get_config()
Expand Down Expand Up @@ -69,9 +67,7 @@ def _test_8x_v100_half_precision(

# Use the reporting API to report single or multiple metrics/extras.
self.report_wall_time(benchmark_time)
self.report_metrics(
{'sec_per_epoch': sec_per_epoch, 'accuracy': end_accuracy}
)
self.report_metrics({'sec_per_epoch': sec_per_epoch, 'accuracy': end_accuracy})

def test_8x_v100_half_precision_short(self):
"""Run ImageNet on 8x V100 GPUs in half precision for 2 epochs."""
Expand Down
3 changes: 1 addition & 2 deletions examples/imagenet/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,7 @@ def main(argv):
# Add a note so that we can tell which task is which JAX host.
# (Depending on the platform task 0 is not guaranteed to be host 0)
platform.work_unit().set_task_status(
f'process_index: {jax.process_index()}, '
f'process_count: {jax.process_count()}'
f'process_index: {jax.process_index()}, ' f'process_count: {jax.process_count()}'
)
platform.work_unit().create_artifact(
platform.ArtifactType.DIRECTORY, FLAGS.workdir, 'workdir'
Expand Down
28 changes: 10 additions & 18 deletions examples/imagenet/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@ def __call__(
y = self.norm(scale_init=nn.initializers.zeros_init())(y)

if residual.shape != y.shape:
residual = self.conv(
self.filters, (1, 1), self.strides, name='conv_proj'
)(residual)
residual = self.conv(self.filters, (1, 1), self.strides, name='conv_proj')(
residual
)
residual = self.norm(name='norm_proj')(residual)

return self.act(residual + y)
Expand Down Expand Up @@ -78,9 +78,9 @@ def __call__(self, x):
y = self.norm(scale_init=nn.initializers.zeros_init())(y)

if residual.shape != y.shape:
residual = self.conv(
self.filters * 4, (1, 1), self.strides, name='conv_proj'
)(residual)
residual = self.conv(self.filters * 4, (1, 1), self.strides, name='conv_proj')(
residual
)
residual = self.norm(name='norm_proj')(residual)

return self.act(residual + y)
Expand Down Expand Up @@ -137,18 +137,10 @@ def __call__(self, x, train: bool = True):

ResNet18 = partial(ResNet, stage_sizes=[2, 2, 2, 2], block_cls=ResNetBlock)
ResNet34 = partial(ResNet, stage_sizes=[3, 4, 6, 3], block_cls=ResNetBlock)
ResNet50 = partial(
ResNet, stage_sizes=[3, 4, 6, 3], block_cls=BottleneckResNetBlock
)
ResNet101 = partial(
ResNet, stage_sizes=[3, 4, 23, 3], block_cls=BottleneckResNetBlock
)
ResNet152 = partial(
ResNet, stage_sizes=[3, 8, 36, 3], block_cls=BottleneckResNetBlock
)
ResNet200 = partial(
ResNet, stage_sizes=[3, 24, 36, 3], block_cls=BottleneckResNetBlock
)
ResNet50 = partial(ResNet, stage_sizes=[3, 4, 6, 3], block_cls=BottleneckResNetBlock)
ResNet101 = partial(ResNet, stage_sizes=[3, 4, 23, 3], block_cls=BottleneckResNetBlock)
ResNet152 = partial(ResNet, stage_sizes=[3, 8, 36, 3], block_cls=BottleneckResNetBlock)
ResNet200 = partial(ResNet, stage_sizes=[3, 24, 36, 3], block_cls=BottleneckResNetBlock)


ResNet18Local = partial(
Expand Down
20 changes: 5 additions & 15 deletions examples/imagenet/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,7 @@ def loss_fn(params):
loss = cross_entropy_loss(logits, batch['label'])
weight_penalty_params = jax.tree_util.tree_leaves(params)
weight_decay = 0.0001
weight_l2 = sum(
jnp.sum(x**2) for x in weight_penalty_params if x.ndim > 1
)
weight_l2 = sum(jnp.sum(x**2) for x in weight_penalty_params if x.ndim > 1)
weight_penalty = weight_decay * 0.5 * weight_l2
loss = loss + weight_penalty
return loss, (new_model_state, logits)
Expand All @@ -134,9 +132,7 @@ def loss_fn(params):
lr = learning_rate_fn(step)

if dynamic_scale:
grad_fn = dynamic_scale.value_and_grad(
loss_fn, has_aux=True, axis_name='batch'
)
grad_fn = dynamic_scale.value_and_grad(loss_fn, has_aux=True, axis_name='batch')
dynamic_scale, is_fin, aux, grads = grad_fn(state.params)
# dynamic loss takes care of averaging gradients across replicas
else:
Expand Down Expand Up @@ -271,9 +267,7 @@ def create_train_state(
return state


def train_and_evaluate(
config: ml_collections.ConfigDict, workdir: str
) -> TrainState:
def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str) -> TrainState:
"""Execute model training and evaluation loop.
Args:
Expand Down Expand Up @@ -338,9 +332,7 @@ def train_and_evaluate(
num_steps = config.num_train_steps

if config.steps_per_eval == -1:
num_validation_examples = dataset_builder.info.splits[
'validation'
].num_examples
num_validation_examples = dataset_builder.info.splits['validation'].num_examples
steps_per_eval = num_validation_examples // config.batch_size
else:
steps_per_eval = config.steps_per_eval
Expand All @@ -350,9 +342,7 @@ def train_and_evaluate(
base_learning_rate = config.learning_rate * config.batch_size / 256.0

model_cls = getattr(models, config.model)
model = create_model(
model_cls=model_cls, half_precision=config.half_precision
)
model = create_model(model_cls=model_cls, half_precision=config.half_precision)

learning_rate_fn = create_learning_rate_fn(
config, base_learning_rate, steps_per_epoch
Expand Down
8 changes: 2 additions & 6 deletions examples/linen_design_test/attention_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,7 @@ class Dense(Module):
@compact
def __call__(self, inputs):
inputs = jnp.asarray(inputs, self.dtype)
kernel = self.param(
'kernel', self.kernel_init, (inputs.shape[-1], self.features)
)
kernel = self.param('kernel', self.kernel_init, (inputs.shape[-1], self.features))
kernel = jnp.asarray(kernel, self.dtype)
y = lax.dot_general(
inputs,
Expand Down Expand Up @@ -141,9 +139,7 @@ def __call__(self, inputs_q, inputs_kv, bias=None, dtype=jnp.float32):


def concise_vmap(module, in_axes, out_axes, axis_size=None, **var_specs):
variable_axes = {
k: v[0] for k, v in var_specs.items() if isinstance(v, Sequence)
}
variable_axes = {k: v[0] for k, v in var_specs.items() if isinstance(v, Sequence)}
splits = {k: v[1] for k, v in var_specs.items() if isinstance(v, Sequence)}
return vmap(
module,
Expand Down
4 changes: 1 addition & 3 deletions examples/linen_design_test/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,7 @@ class Dense(Module):

@compact
def __call__(self, inputs):
kernel = self.param(
'kernel', self.kernel_init, (inputs.shape[-1], self.features)
)
kernel = self.param('kernel', self.kernel_init, (inputs.shape[-1], self.features))
y = lax.dot_general(
inputs,
kernel,
Expand Down
11 changes: 3 additions & 8 deletions examples/lm1b/input_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,7 @@ def pack_dataset(
for k in keys:
if k not in shapes:
raise ValueError(
'Key %s not found in dataset. Available keys are %s'
% (k, shapes.keys())
'Key %s not found in dataset. Available keys are %s' % (k, shapes.keys())
)
if not shapes[k].is_compatible_with(tf.TensorShape([None])): # type: ignore[wrong-arg-types]
raise ValueError('Tensors to be packed must be one-dimensional.')
Expand All @@ -137,9 +136,7 @@ def pack_dataset(
# Setting batch_size=length ensures that the concatenated sequences (if they
# have length >=1) are sufficient to fill at least one packed example.
batch_size = max(key2length.values())
dataset = dataset.padded_batch(
batch_size, padded_shapes={k: [-1] for k in keys}
)
dataset = dataset.padded_batch(batch_size, padded_shapes={k: [-1] for k in keys})
dataset = _pack_with_tf_ops(dataset, keys, key2length)

# Set the Tensor shapes correctly since they get lost in the process.
Expand Down Expand Up @@ -223,9 +220,7 @@ def body_fn(i, partial, outputs):
for k in keys:
can_append = tf.logical_and(
can_append,
tf.less_equal(
tf.size(partial[k]) + tf.size(one_example[k]), key2length[k]
),
tf.less_equal(tf.size(partial[k]) + tf.size(one_example[k]), key2length[k]),
)

def false_fn():
Expand Down
3 changes: 1 addition & 2 deletions examples/lm1b/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,7 @@ def main(argv):
# Add a note so that we can tell which task is which JAX host.
# (Depending on the platform task 0 is not guaranteed to be host 0)
platform.work_unit().set_task_status(
f'process_index: {jax.process_index()}, '
f'process_count: {jax.process_count()}'
f'process_index: {jax.process_index()}, ' f'process_count: {jax.process_count()}'
)
platform.work_unit().create_artifact(
platform.ArtifactType.DIRECTORY, FLAGS.workdir, 'workdir'
Expand Down
Loading

0 comments on commit 717be27

Please sign in to comment.