Skip to content

Commit

Permalink
Update usage of the deprecated is_foo functions from tff.Types.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 556870394
  • Loading branch information
michaelreneer authored and tensorflow-copybara committed Aug 14, 2023
1 parent 6998836 commit 135c996
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 23 deletions.
42 changes: 25 additions & 17 deletions tensorflow_federated/python/core/backends/mapreduce/form_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,18 +115,18 @@ def get_state_initialization_computation(
Raises:
TypeError: If the arguments are of the wrong types.
"""
initialize_tree = initialize_computation.to_building_block()
init_type = initialize_tree.type_signature
init_type = initialize_computation.type_signature
_check_type_is_no_arg_fn(init_type, '`initialize`', TypeError)
if (
not init_type.result.is_federated() # pytype: disable=attribute-error
or init_type.result.placement != placements.SERVER # pytype: disable=attribute-error
not isinstance(init_type.result, computation_types.FederatedType)
or init_type.result.placement is not placements.SERVER
):
raise TypeError(
'Expected `initialize` to return a single federated value '
'placed at server (type `T@SERVER`), found return type:\n'
f'{init_type.result}' # pytype: disable=attribute-error
)
initialize_tree = initialize_computation.to_building_block()
initialize_tree, _ = tree_transformations.replace_intrinsics_with_bodies(
initialize_tree
)
Expand Down Expand Up @@ -270,25 +270,29 @@ def _check_function_signature_compatible_with_broadcast_form(
f'{function_type.parameter}'
)
server_data_type, client_data_type = function_type.parameter # pytype: disable=attribute-error
if not (
server_data_type.is_federated() and server_data_type.placement.is_server()
if (
not isinstance(server_data_type, computation_types.FederatedType)
or server_data_type.placement is not placements.SERVER
):
raise TypeError(
'`BroadcastForm` expects a computation whose first parameter is server '
'data (a federated type placed at server) but found first parameter of '
f'type:\n{server_data_type}'
)
if not (
client_data_type.is_federated()
and client_data_type.placement.is_clients()
if (
not isinstance(client_data_type, computation_types.FederatedType)
or client_data_type.placement is not placements.CLIENTS
):
raise TypeError(
'`BroadcastForm` expects a computation whose first parameter is client '
'data (a federated type placed at clients) but found first parameter '
f'of type:\n{client_data_type}'
)
result_type = function_type.result
if not (result_type.is_federated() and result_type.placement.is_clients()): # pytype: disable=attribute-error
if (
not isinstance(result_type, computation_types.FederatedType)
or result_type.placement is not placements.CLIENTS
):
raise TypeError(
'`BroadcastForm` expects a computation whose result is client data '
'(a federated type placed at clients) but found result type:\n'
Expand Down Expand Up @@ -528,7 +532,7 @@ def _as_function_of_some_federated_subparameters(
raise tree_transformations.ParameterSelectionError(path, bb)
int_path.append(structure.name_to_index_map(selected_type)[index])
selected_type = selected_type[index]
if not selected_type.is_federated():
if not isinstance(selected_type, computation_types.FederatedType):
raise _NonFederatedSelectionError(
'Attempted to rebind references to parameter selection path '
f'{path} from type {bb.parameter_type}, but the value at that path '
Expand Down Expand Up @@ -1146,12 +1150,16 @@ def _find_non_client_placed_args(inner_comp):
# If the arg is non-placed or server-placed, prepare to create a
# federated broadcast that depends on it by normalizing it to a
# server-placed value.
if not aggregation_arg.type_signature.is_federated():
has_placement_predicate = lambda x: x.type_signature.is_federated()
if (
tree_analysis.count(aggregation_arg, has_placement_predicate)
> 0
):
if not isinstance(
aggregation_arg.type_signature, computation_types.FederatedType
):

def _has_placement(type_spec):
return isinstance(
type_spec.type_signature, computation_types.FederatedType
)

if tree_analysis.count(aggregation_arg, _has_placement) > 0:
raise TypeError(
'DistributeAggregateForm cannot handle an aggregation '
f'intrinsic arg with type {aggregation_arg.type_signature}'
Expand Down
1 change: 1 addition & 0 deletions tensorflow_federated/python/core/templates/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ py_library(
":measured_process",
"//tensorflow_federated/python/common_libs:structure",
"//tensorflow_federated/python/core/impl/computation:computation_base",
"//tensorflow_federated/python/core/impl/types:computation_types",
"//tensorflow_federated/python/core/impl/types:placements",
],
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from tensorflow_federated.python.common_libs import structure
from tensorflow_federated.python.core.impl.computation import computation_base
from tensorflow_federated.python.core.impl.types import computation_types
from tensorflow_federated.python.core.impl.types import placements
from tensorflow_federated.python.core.templates import errors
from tensorflow_federated.python.core.templates import measured_process
Expand Down Expand Up @@ -110,7 +111,9 @@ def __init__(
# validation here easier as that must be true.
super().__init__(initialize_fn, next_fn, next_is_multi_arg=True)

if not initialize_fn.type_signature.result.is_federated():
if not isinstance(
initialize_fn.type_signature.result, computation_types.FederatedType
):
raise AggregationNotFederatedError(
'Provided `initialize_fn` must return a federated type, but found '
f'return type:\n{initialize_fn.type_signature.result}\nTip: If you '
Expand All @@ -120,7 +123,11 @@ def __init__(
next_types = structure.flatten(
next_fn.type_signature.parameter
) + structure.flatten(next_fn.type_signature.result)
non_federated_types = [t for t in next_types if not t.is_federated()]
non_federated_types = [
t
for t in next_types
if not isinstance(t, computation_types.FederatedType)
]
if non_federated_types:
offending_types_str = '\n- '.join(str(t) for t in non_federated_types)
raise AggregationNotFederatedError(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,8 +227,8 @@ def is_stateful(process: IterativeProcess) -> bool:
contains types other than `tff.types.StructType`, `False` otherwise.
"""
state_type = process.state_type
if state_type.is_federated():
state_type = state_type.member # pytype: disable=attribute-error
if isinstance(state_type, computation_types.FederatedType):
state_type = state_type.member
return not type_analysis.contains_only(
state_type, lambda t: isinstance(t, computation_types.StructType)
)
Original file line number Diff line number Diff line change
Expand Up @@ -147,14 +147,18 @@ def create_test_process(
) -> iterative_process.IterativeProcess:
@tensorflow_computation.tf_computation
def create_value():
if isinstance(type_spec, computation_types.FederatedType):
converted_type = type_spec.member
else:
converted_type = type_spec
return type_conversions.structure_from_tensor_type_tree(
lambda t: tf.zeros(dtype=t.dtype, shape=t.shape),
type_spec.member if type_spec.is_federated() else type_spec,
converted_type,
)

@federated_computation.federated_computation
def init_fn():
if type_spec.is_federated():
if isinstance(type_spec, computation_types.FederatedType):
return intrinsics.federated_eval(create_value, type_spec.placement)
else:
return create_value()
Expand Down

0 comments on commit 135c996

Please sign in to comment.