Skip to content

Commit

Permalink
Expose a warn argument in vectorized_map to mute fallback to whil…
Browse files Browse the repository at this point in the history
…e conversions.

Reference issues: keras-team/keras-cv#264, keras-team/keras-cv#291

PiperOrigin-RevId: 443809170
  • Loading branch information
LukeWood authored and tensorflower-gardener committed Apr 23, 2022
1 parent e6ba479 commit 88a263e
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 16 deletions.
6 changes: 6 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@
difference range from 8 to 100 times depending on the size of k.
When running on CPU and GPU, a non-optimized XLA kernel is used.

* `tf.vectorized_map`:

* Added an optional parameter: `warn`. This parameter controls whether or
not warnings will be printed when operations in the provided `fn` fall
back to a while loop.

# Bug Fixes and Other Changes

* <SIMILAR TO ABOVE SECTION, BUT FOR OTHER IMPORTANT CHANGES / BUG FIXES>
Expand Down
41 changes: 29 additions & 12 deletions tensorflow/python/ops/parallel_for/control_flow_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,11 @@ def _is_under_xla_context():
return False


def pfor(loop_fn, iters, fallback_to_while_loop=True, parallel_iterations=None):
def pfor(loop_fn,
iters,
fallback_to_while_loop=True,
parallel_iterations=None,
warn=False):
"""Equivalent to running `loop_fn` `iters` times and stacking the outputs.
`pfor` has functionality similar to `for_loop`, i.e. running `loop_fn` `iters`
Expand Down Expand Up @@ -176,6 +180,7 @@ def pfor(loop_fn, iters, fallback_to_while_loop=True, parallel_iterations=None):
vectorizing all the iterations. If `parallel_iterations` is smaller than
`iters`, then chunks of at most that many iterations are dispatched in
sequence. This knob can be used to control the total memory usage.
warn: Whether or not to warn when falling back to while loops.
Returns:
Returns a nested structure of stacked tensor objects with the same nested
Expand All @@ -184,10 +189,12 @@ def pfor(loop_fn, iters, fallback_to_while_loop=True, parallel_iterations=None):
ValueError: If parallel_iterations is not None and not an integer > 1.
"""
def f():
return _pfor_impl(loop_fn,
iters,
fallback_to_while_loop=fallback_to_while_loop,
parallel_iterations=parallel_iterations)
return _pfor_impl(
loop_fn,
iters,
fallback_to_while_loop=fallback_to_while_loop,
parallel_iterations=parallel_iterations,
warn=warn)
# Note that we wrap into a tf.function if in eager execution mode or under
# XLA compilation. The latter is so that we don't compile operations like
# tf.placeholder that are created by the loop body.
Expand Down Expand Up @@ -266,7 +273,8 @@ def _pfor_impl(loop_fn,
iters,
fallback_to_while_loop,
parallel_iterations=None,
pfor_config=None):
pfor_config=None,
warn=False):
"""Implementation of pfor."""
assert not context.executing_eagerly()
loop_fn_has_config = _loop_fn_has_config(loop_fn)
Expand Down Expand Up @@ -319,9 +327,13 @@ def _pfor_impl(loop_fn,
parallel_iterations = None
if parallel_iterations is None:
with ops.name_scope("pfor"):
converter = PFor(loop_var, iters, new_ops,
fallback_to_while_loop=fallback_to_while_loop,
pfor_config=pfor_config)
converter = PFor(
loop_var,
iters,
new_ops,
fallback_to_while_loop=fallback_to_while_loop,
pfor_config=pfor_config,
warn=warn)
flattened_output_tensors = []
for loop_fn_output in nest.flatten(loop_fn_output_tensors):
output = converter.convert(loop_fn_output)
Expand Down Expand Up @@ -424,7 +436,7 @@ def _gather_from_tensor_or_composite(x, i):


@tf_export("vectorized_map")
def vectorized_map(fn, elems, fallback_to_while_loop=True):
def vectorized_map(fn, elems, fallback_to_while_loop=True, warn=True):
"""Parallel map on the list of tensors unpacked from `elems` on dimension 0.
This method works similar to `tf.map_fn` but is optimized to run much faster,
Expand Down Expand Up @@ -504,6 +516,8 @@ def model_fn(arg):
unsupported op, a ValueError is thrown. Note that the fallbacks can result
in slowdowns since vectorization often yields speedup of one to two orders
of magnitude.
warn: If set to `false`, this will supress any warnings due to operation
conversions in the provided `fn` falling back to while loops.
Returns:
A tensor or (possibly nested) sequence of tensors. Each tensor packs the
Expand Down Expand Up @@ -546,5 +560,8 @@ def _get_shape(x):
else:
batch_size = max(static_first_dims)

return pfor(loop_fn, batch_size,
fallback_to_while_loop=fallback_to_while_loop)
return pfor(
loop_fn,
batch_size,
fallback_to_while_loop=fallback_to_while_loop,
warn=warn)
8 changes: 6 additions & 2 deletions tensorflow/python/ops/parallel_for/pfor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1276,7 +1276,8 @@ def __init__(self,
fallback_to_while_loop,
all_indices=None,
all_indices_partitioned=False,
pfor_config=None):
pfor_config=None,
warn=False):
"""Creates an object to rewrite a parallel-for loop.
Args:
Expand All @@ -1296,6 +1297,7 @@ def __init__(self,
control flow construct where not all the pfor iterations are guaranteed
to be active.
pfor_config: PForConfig object used while constructing the loop body.
warn: Whether or not to warn on while loop conversions.
"""
assert isinstance(loop_var, ops.Tensor)
assert loop_var.op.type == "PlaceholderWithDefault"
Expand All @@ -1315,6 +1317,7 @@ def __init__(self,
self._pfor_ops = set(pfor_ops)
self._pfor_op_ids = set(x._id for x in pfor_ops)
self._fallback_to_while_loop = fallback_to_while_loop
self._warn = warn
self._pfor_config = pfor_config

def op_is_inside_loop(self, op):
Expand Down Expand Up @@ -1601,7 +1604,8 @@ def _add_to_stack(x):
y_op.inputs)
if (self._fallback_to_while_loop and not has_variant_outputs
and not has_vectorized_variant_inputs):
converter = partial(_fallback_converter, root_cause=root_cause)
converter = partial(
_fallback_converter, root_cause=root_cause, warn=self._warn)
else:
message = (f"No pfor vectorization defined for {y_op.type}\n"
f"{y_op}\n inputs: {converted_inputs}.")
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/tools/api/golden/v1/tensorflow.pbtxt
Original file line number Diff line number Diff line change
Expand Up @@ -2514,7 +2514,7 @@ tf_module {
}
member_method {
name: "vectorized_map"
argspec: "args=[\'fn\', \'elems\', \'fallback_to_while_loop\'], varargs=None, keywords=None, defaults=[\'True\'], "
argspec: "args=[\'fn\', \'elems\', \'fallback_to_while_loop\', \'warn\'], varargs=None, keywords=None, defaults=[\'True\', \'True\'], "
}
member_method {
name: "verify_tensor_all_finite"
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/tools/api/golden/v2/tensorflow.pbtxt
Original file line number Diff line number Diff line change
Expand Up @@ -1158,7 +1158,7 @@ tf_module {
}
member_method {
name: "vectorized_map"
argspec: "args=[\'fn\', \'elems\', \'fallback_to_while_loop\'], varargs=None, keywords=None, defaults=[\'True\'], "
argspec: "args=[\'fn\', \'elems\', \'fallback_to_while_loop\', \'warn\'], varargs=None, keywords=None, defaults=[\'True\', \'True\'], "
}
member_method {
name: "where"
Expand Down

0 comments on commit 88a263e

Please sign in to comment.