Skip to content

Commit

Permalink
For eventual multi-task support:
Browse files Browse the repository at this point in the history
1. Remove label extraction from `GraphTensorProcessorFn`,
1. Add label extraction helpers: e.g., root node and context,
1. Change `task.preprocess` to always return labels. (Current `tasks` now take a `label_fn` on `__init__`).

PiperOrigin-RevId: 511576072
  • Loading branch information
dzelle authored and tensorflower-gardener committed Feb 22, 2023
1 parent 9e79919 commit 3188367
Show file tree
Hide file tree
Showing 15 changed files with 346 additions and 248 deletions.
16 changes: 6 additions & 10 deletions tensorflow_gnn/docs/guide/runner.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,17 +39,16 @@ train_ds_provider = runner.TFRecordDatasetProvider(file_pattern="...")
# len(valid_ds_provider.get_dataset(...)) == 1634.
valid_ds_provider = runner.TFRecordDatasetProvider(file_pattern="...")

# Extract labels from the graph context, importantly: this lambda matches
# the `GraphTensorProcessorFn` protocol (see below).
extract_labels = lambda gt: gt, gt.context["label"]

# Use `embedding` feature as the only node feature.
initial_node_states = lambda node_set, node_set_name: node_set["embedding"]
# This `tf.keras.layers.Layer` matches the `GraphTensorProcessorFn` protocol.
map_features = tfgnn.keras.layers.MapFeatures(node_sets_fn=initial_node_states)

# Extract labels from the graph context.
extract_labels = lambda inputs: inputs.context["label"]

# Binary classification by the root node.
task = runner.RootNodeBinaryClassification(node_set_name="nodes")
task = runner.RootNodeBinaryClassification("nodes", label_fn=extract_labels)

trainer = runner.KerasTrainer(
strategy=tf.distribute.TPUStrategy(...),
Expand Down Expand Up @@ -258,7 +257,7 @@ example, a custom training loop with look ahead gradients.
```python
class GraphTensorProcessorFn(Protocol):

def __call__(self, gt: tfgnn.GraphTensor) -> Union[tfgnn.GraphTensor, Tuple[tfgnn.GraphTensor, tfgnn.Field]]:
def __call__(self, inputs: tfgnn.GraphTensor) -> tfgnn.GraphTensor:
raise NotImplementedError()
```

Expand All @@ -271,10 +270,7 @@ dataset. Importantly: all `GraphTensorProcessorFn` are applied in a
`tf.data.Dataset.map` call (and correspondingly executed on CPU). All
`GraphTensorProcessorFn` are collected in a `tf.keras.Model` specifically for
feature processing. The final model exported by [orchestration](#orchestration)
will contain both the feature processing model and the client GNN. Any
`GraphTensorProcessorFn` may return a processed `GraphTensor` or a processed
`GraphTensor` *and* `Field` (e.g., where the field is used as a supervision
target).
will contain both the feature processing model and the client GNN.

TIP: A `tf.keras.Model` or `tf.keras.layers.Layer`, whose inputs and outputs are
scalar `GraphTensor`, matches the `GraphTensorProcessorFn` protocol (and may be
Expand Down
61 changes: 18 additions & 43 deletions tensorflow_gnn/models/contrastive_losses/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import abc
from collections.abc import Callable, Sequence
from typing import Optional, Union
from typing import Optional, Tuple, Union

import tensorflow as tf
import tensorflow_gnn as tfgnn
Expand All @@ -26,6 +26,9 @@
from tensorflow_gnn.models.contrastive_losses import losses
from tensorflow_gnn.models.contrastive_losses.deep_graph_infomax import layers as dgi_layers

Field = tfgnn.Field
GraphTensor = tfgnn.GraphTensor


class _ConstrastiveLossTask(runner.Task, abc.ABC):
"""Base class for unsupervised contrastive representation learning tasks.
Expand All @@ -51,8 +54,7 @@ def __init__(
*,
feature_name: str = tfgnn.HIDDEN_STATE,
representations_layer_name: Optional[str] = None,
seed: Optional[int] = None,
):
seed: Optional[int] = None):
self._representations_layer_name = (
representations_layer_name or "clean_representations"
)
Expand Down Expand Up @@ -109,27 +111,7 @@ def make_contrastive_layer(self) -> tf.keras.layers.Layer:
"""Returns the layer contrasting clean outputs with the correupted ones."""
raise NotImplementedError()

@abc.abstractmethod
def preprocess(
self, gt: tfgnn.GraphTensor
) -> tuple[tfgnn.GraphTensor, tfgnn.Field]:
"""Returns the input GraphTensor."""
raise NotImplementedError()

@abc.abstractmethod
def losses(self) -> Sequence[Callable[[tf.Tensor, tf.Tensor], tf.Tensor]]:
"""Returns an empty losses tuple.
Loss signatures are according to `tf.keras.losses.Loss.`
"""
raise NotImplementedError()

def metrics(self) -> Sequence[Callable[[tf.Tensor, tf.Tensor], tf.Tensor]]:
"""Returns an empty metrics tuple.
Metric signatures are according to `tf.keras.metrics.Metric,` here: no
metrics are returned, because metrics are task-specific.
"""
return tuple()


Expand All @@ -140,11 +122,10 @@ def make_contrastive_layer(self) -> tf.keras.layers.Layer:
return dgi_layers.DeepGraphInfomaxLogits()

def preprocess(
self, gt: tfgnn.GraphTensor
) -> tuple[tfgnn.GraphTensor, tfgnn.Field]:
self,
inputs: GraphTensor) -> Tuple[Optional[GraphTensor], Field]:
"""Creates labels--i.e., (positive, negative)--for Deep Graph Infomax."""
y = tf.tile(tf.constant(((1, 0),), dtype=tf.int32), (gt.num_components, 1))
return gt, y
return None, tf.tile(((1, 0),), (inputs.num_components, 1))

def losses(self) -> Sequence[Callable[[tf.Tensor, tf.Tensor], tf.Tensor]]:
return (tf.keras.losses.BinaryCrossentropy(from_logits=True),)
Expand All @@ -167,26 +148,23 @@ def __init__(
representations_layer_name: Optional[str] = None,
seed: Optional[int] = None,
lambda_: Optional[Union[tf.Tensor, float]] = None,
normalize_batch: bool = True,
):
normalize_batch: bool = True):
super().__init__(
node_set_name,
feature_name=feature_name,
representations_layer_name=representations_layer_name,
seed=seed,
)
seed=seed)
self._lambda = lambda_
self._normalize_batch = normalize_batch

def make_contrastive_layer(self) -> tf.keras.layers.Layer:
return tf.keras.layers.Layer()

def preprocess(
self, gt: tfgnn.GraphTensor
) -> tuple[tfgnn.GraphTensor, tfgnn.Field]:
self,
inputs: GraphTensor) -> Tuple[Optional[GraphTensor], Field]:
"""Creates unused pseudo-labels."""
y = tf.zeros((gt.num_components, 0), dtype=tf.int32)
return gt, y
return None, tf.zeros((inputs.num_components, 0), dtype=tf.int32)

def losses(self) -> Sequence[Callable[[tf.Tensor, tf.Tensor], tf.Tensor]]:
def loss_fn(_, x):
Expand All @@ -211,14 +189,12 @@ def __init__(
seed: Optional[int] = None,
sim_weight: Union[tf.Tensor, float] = 25.,
var_weight: Union[tf.Tensor, float] = 25.,
cov_weight: Union[tf.Tensor, float] = 1.,
):
cov_weight: Union[tf.Tensor, float] = 1.):
super().__init__(
node_set_name,
feature_name=feature_name,
representations_layer_name=representations_layer_name,
seed=seed,
)
seed=seed)
self._sim_weight = sim_weight
self._var_weight = var_weight
self._cov_weight = cov_weight
Expand All @@ -227,11 +203,10 @@ def make_contrastive_layer(self) -> tf.keras.layers.Layer:
return tf.keras.layers.Layer()

def preprocess(
self, gt: tfgnn.GraphTensor
) -> tuple[tfgnn.GraphTensor, tfgnn.Field]:
self,
inputs: GraphTensor) -> Tuple[Optional[GraphTensor], Field]:
"""Creates unused pseudo-labels."""
y = tf.zeros((gt.num_components, 0), dtype=tf.int32)
return gt, y
return None, tf.zeros((inputs.num_components, 0), dtype=tf.int32)

def losses(self) -> Sequence[Callable[[tf.Tensor, tf.Tensor], tf.Tensor]]:
def loss_fn(_, x):
Expand Down
5 changes: 3 additions & 2 deletions tensorflow_gnn/models/contrastive_losses/tasks_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,8 @@ def test_fit(self, task: runner.Task):
"""Verifies an adapted model's fit."""
ds = tf.data.Dataset.from_tensors(random_graph_tensor()).repeat()
ds = ds.batch(2).map(tfgnn.GraphTensor.merge_batch_to_components)
ds = ds.map(task.preprocess).take(5)
# `preprocess` performs no manipulations on `x`.
ds = ds.map(lambda x: (x, task.preprocess(x)[1])).take(5)

def get_loss():
values = model.evaluate(ds)
Expand All @@ -200,7 +201,7 @@ def test_preprocess(self, task: runner.Task):
expected = random_graph_tensor()
actual, _ = task.preprocess(expected)

self.assertEqual(actual, expected)
self.assertIsNone(actual)

@parameterized.named_parameters(all_tasks_inputs())
def test_adapt(self, task: runner.Task):
Expand Down
4 changes: 2 additions & 2 deletions tensorflow_gnn/runner/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ pytype_strict_library(
"//tensorflow_gnn/runner/tasks:regression",
"//tensorflow_gnn/runner/trainers:keras_fit",
"//tensorflow_gnn/runner/utils:attribution",
"//tensorflow_gnn/runner/utils:label_fns",
"//tensorflow_gnn/runner/utils:model",
"//tensorflow_gnn/runner/utils:model_dir",
"//tensorflow_gnn/runner/utils:model_export",
Expand Down Expand Up @@ -85,7 +86,6 @@ py_strict_test(
"//tensorflow_gnn/models/vanilla_mpnn",
"//tensorflow_gnn/runner/tasks:classification",
"//tensorflow_gnn/runner/trainers:keras_fit",
"//tensorflow_gnn/runner/utils:model_templates",
"//tensorflow_gnn/runner/utils:padding",
"//tensorflow_gnn/runner/utils:label_fns",
],
)
5 changes: 5 additions & 0 deletions tensorflow_gnn/runner/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from tensorflow_gnn.runner.tasks import regression
from tensorflow_gnn.runner.trainers import keras_fit
from tensorflow_gnn.runner.utils import attribution
from tensorflow_gnn.runner.utils import label_fns
from tensorflow_gnn.runner.utils import model as model_utils
from tensorflow_gnn.runner.utils import model_dir
from tensorflow_gnn.runner.utils import model_export
Expand All @@ -43,6 +44,10 @@
Trainer = interfaces.Trainer
Task = interfaces.Task

# Label fns
ContextLabelFn = label_fns.ContextLabelFn
RootNodeLabelFn = label_fns.RootNodeLabelFn

# Model directory
incrementing_model_dir = model_dir.incrementing_model_dir

Expand Down
102 changes: 56 additions & 46 deletions tensorflow_gnn/runner/distribute_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@
from tensorflow_gnn.runner.utils import padding

_LABELS = tuple(range(32))

_SAMPLE_DICT = immutabledict({(tfgnn.CONTEXT, None, "label"): _LABELS})

_SCHEMA = """
context {
features {
Expand Down Expand Up @@ -63,8 +65,6 @@
}
"""

TaskAndProcessor = tuple[interfaces.Task, interfaces.GraphTensorProcessorFn]


def _all_eager_strategy_combinations():
strategies = [
Expand All @@ -90,55 +90,68 @@ def _all_eager_strategy_combinations():
return tftest.combinations.combine(distribution=strategies)


def _all_task_and_processors_combinations():
def _all_task_combinations():

def extract_binary_labels(gt):
return gt, gt.context["label"] % 2
def extract_binary_labels(inputs):
return inputs.context["label"] % 2

def extract_multiclass_labels(gt):
return gt, gt.context["label"]
def extract_multiclass_labels(inputs):
return inputs.context["label"]

def extract_regression_labels(gt):
return gt, tf.ones_like(gt.context["label"], dtype=tf.float32)
def extract_regression_labels(inputs):
return tf.ones_like(inputs.context["label"], dtype=tf.float32)

task_and_processor = {
tasks = [
# Root node classification
classification.RootNodeBinaryClassification(node_set_name="node"):
extract_binary_labels,
classification.RootNodeBinaryClassification(
"node",
label_fn=extract_binary_labels),
classification.RootNodeMulticlassClassification(
node_set_name="node",
num_classes=len(_LABELS)): extract_multiclass_labels,
"node",
num_classes=len(_LABELS),
label_fn=extract_multiclass_labels),
# Graph classification
classification.GraphBinaryClassification(node_set_name="node"):
extract_binary_labels,
classification.GraphBinaryClassification(
"node",
label_fn=extract_binary_labels),
classification.GraphMulticlassClassification(
node_set_name="node",
num_classes=len(_LABELS)): extract_multiclass_labels,
"node",
num_classes=len(_LABELS),
label_fn=extract_multiclass_labels),
# Root node regression
regression.RootNodeMeanAbsoluteError(node_set_name="node"):
extract_regression_labels,
regression.RootNodeMeanAbsolutePercentageError(node_set_name="node"):
extract_regression_labels,
regression.RootNodeMeanSquaredError(node_set_name="node"):
extract_regression_labels,
regression.RootNodeMeanSquaredLogarithmicError(node_set_name="node"):
extract_regression_labels,
regression.RootNodeMeanSquaredLogScaledError(node_set_name="node"):
extract_regression_labels,
regression.RootNodeMeanAbsoluteError(
"node",
label_fn=extract_regression_labels),
regression.RootNodeMeanAbsolutePercentageError(
"node",
label_fn=extract_regression_labels),
regression.RootNodeMeanSquaredError(
"node",
label_fn=extract_regression_labels),
regression.RootNodeMeanSquaredLogarithmicError(
"node",
label_fn=extract_regression_labels),
regression.RootNodeMeanSquaredLogScaledError(
"node",
label_fn=extract_regression_labels),
# Graph regression
regression.GraphMeanAbsoluteError(node_set_name="node"):
extract_regression_labels,
regression.GraphMeanAbsolutePercentageError(node_set_name="node"):
extract_regression_labels,
regression.GraphMeanSquaredError(node_set_name="node"):
extract_regression_labels,
regression.GraphMeanSquaredLogarithmicError(node_set_name="node"):
extract_regression_labels,
regression.GraphMeanSquaredLogScaledError(node_set_name="node"):
extract_regression_labels,
}
items = list(task_and_processor.items())
return tftest.combinations.combine(task_and_processor=items)
regression.GraphMeanAbsoluteError(
"node",
label_fn=extract_regression_labels),
regression.GraphMeanAbsolutePercentageError(
"node",
label_fn=extract_regression_labels),
regression.GraphMeanSquaredError(
"node",
label_fn=extract_regression_labels),
regression.GraphMeanSquaredLogarithmicError(
"node",
label_fn=extract_regression_labels),
regression.GraphMeanSquaredLogScaledError(
"node",
label_fn=extract_regression_labels),
]
return tftest.combinations.combine(task=tasks)


class DatasetProvider(interfaces.DatasetProvider):
Expand All @@ -155,13 +168,13 @@ class OrchestrationTests(tf.test.TestCase, parameterized.TestCase):
@tfdistribute.combinations.generate(
tftest.combinations.times(
_all_eager_strategy_combinations(),
_all_task_and_processors_combinations()
_all_task_combinations()
)
)
def test_run(
self,
distribution: tf.distribute.Strategy,
task_and_processor: TaskAndProcessor):
task: interfaces.Task):
schema = tfgnn.parse_schema(_SCHEMA)
gtspec = tfgnn.create_graph_spec_from_schema_pb(schema)
gt = tfgnn.write_example(tfgnn.random_graph_tensor(
Expand Down Expand Up @@ -199,8 +212,6 @@ def test_run(
train_padding = None
valid_padding = None

task, processor = task_and_processor

orchestration.run(
train_ds_provider=ds_provider,
train_padding=train_padding,
Expand All @@ -211,7 +222,6 @@ def test_run(
task=task,
gtspec=gtspec,
global_batch_size=2,
feature_processors=(processor,),
valid_ds_provider=ds_provider,
valid_padding=valid_padding)

Expand Down
Loading

0 comments on commit 3188367

Please sign in to comment.