Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move interfaces from orchestration.py to interfaces.py. #248

Merged
merged 1 commit into from
Dec 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion tensorflow_gnn/runner/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ pytype_strict_library(
srcs = ["__init__.py"],
visibility = ["//visibility:public"],
deps = [
":interfaces",
":orchestration",
"//tensorflow_gnn/runner/input:datasets",
"//tensorflow_gnn/runner/tasks:classification",
Expand All @@ -25,12 +26,23 @@ pytype_strict_library(
],
)

pytype_strict_library(
name = "interfaces",
srcs = ["interfaces.py"],
srcs_version = "PY3",
visibility = [":__subpackages__"],
deps = [
"//:expect_tensorflow_installed",
"//tensorflow_gnn",
],
)

pytype_strict_library(
name = "orchestration",
srcs = ["orchestration.py"],
srcs_version = "PY3",
visibility = [":__subpackages__"],
deps = [
":interfaces",
"//:expect_tensorflow_installed",
"//tensorflow_gnn",
"//tensorflow_gnn/runner/utils:model",
Expand Down
24 changes: 12 additions & 12 deletions tensorflow_gnn/runner/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
# ==============================================================================
"""A general purpose runner for TF-GNN."""
from tensorflow_gnn.runner import interfaces
from tensorflow_gnn.runner import orchestration
from tensorflow_gnn.runner.input import datasets
from tensorflow_gnn.runner.tasks import classification
Expand All @@ -35,6 +36,14 @@
SampleTFRecordDatasetsProvider = datasets.SampleTFRecordDatasetsProvider
TFRecordDatasetProvider = datasets.TFRecordDatasetProvider

# Interfaces
DatasetProvider = interfaces.DatasetProvider
GraphTensorPadding = interfaces.GraphTensorPadding
GraphTensorProcessorFn = interfaces.GraphTensorProcessorFn
ModelExporter = interfaces.ModelExporter
Trainer = interfaces.Trainer
Task = interfaces.Task

# Model directory
incrementing_model_dir = model_dir.incrementing_model_dir

Expand All @@ -47,35 +56,26 @@
chain_first_output = model_utils.chain_first_output

# Orchestration
DatasetProvider = orchestration.DatasetProvider
ModelExporter = orchestration.ModelExporter
run = orchestration.run
Trainer = orchestration.Trainer
Task = orchestration.Task
TFDataServiceConfig = orchestration.TFDataServiceConfig

# Padding
GraphTensorPadding = orchestration.GraphTensorPadding
GraphTensorProcessorFn = orchestration.GraphTensorProcessorFn
one_node_per_component = padding_utils.one_node_per_component
FitOrSkipPadding = padding_utils.FitOrSkipPadding
GraphTensorPadding = orchestration.GraphTensorPadding
TightPadding = padding_utils.TightPadding

# Strategies
ParameterServerStrategy = strategies.ParameterServerStrategy
TPUStrategy = strategies.TPUStrategy

# Tasks
#
# Unsupervised
# Tasks (Unsupervised)
DeepGraphInfomax = dgi.DeepGraphInfomax
# Classification
# Tasks (Classification)
RootNodeBinaryClassification = classification.RootNodeBinaryClassification
RootNodeMulticlassClassification = classification.RootNodeMulticlassClassification
GraphBinaryClassification = classification.GraphMulticlassClassification
GraphMulticlassClassification = classification.GraphMulticlassClassification
# Regression
# Tasks (Regression)
GraphMeanAbsoluteError = regression.GraphMeanAbsoluteError
GraphMeanAbsolutePercentageError = regression.GraphMeanAbsolutePercentageError
GraphMeanSquaredError = regression.GraphMeanSquaredError
Expand Down
2 changes: 1 addition & 1 deletion tensorflow_gnn/runner/input/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ pytype_strict_library(
visibility = ["//tensorflow_gnn/runner:__pkg__"],
deps = [
"//:expect_tensorflow_installed",
"//tensorflow_gnn/runner:orchestration",
"//tensorflow_gnn/runner:interfaces",
],
)

Expand Down
10 changes: 5 additions & 5 deletions tensorflow_gnn/runner/input/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from typing import Callable, List, Optional, Sequence

import tensorflow as tf
from tensorflow_gnn.runner import orchestration
from tensorflow_gnn.runner import interfaces


def _process_dataset(
Expand Down Expand Up @@ -56,7 +56,7 @@ def _process_dataset(
return dataset.prefetch(tf.data.AUTOTUNE)


class PassthruDatasetProvider(orchestration.DatasetProvider):
class PassthruDatasetProvider(interfaces.DatasetProvider):
"""Builds a `tf.data.Dataset` from a pass thru dataset.

Passes any `dataset` thru: omitting any sharding. For detailed documentation,
Expand All @@ -83,7 +83,7 @@ def get_dataset(self, _: tf.distribute.InputContext) -> tf.data.Dataset:
examples_shuffle_size=self._examples_shuffle_size)


class SimpleDatasetProvider(orchestration.DatasetProvider):
class SimpleDatasetProvider(interfaces.DatasetProvider):
"""Builds a `tf.data.Dataset` from a file pattern.

This `SimpleDatasetProvider` builds a `tf.data.Dataset` as follows:
Expand Down Expand Up @@ -208,7 +208,7 @@ def dataset_fn(dataset):
return sampled_dataset.prefetch(tf.data.AUTOTUNE)


class PassthruSampleDatasetsProvider(orchestration.DatasetProvider):
class PassthruSampleDatasetsProvider(interfaces.DatasetProvider):
"""Builds a sampled `tf.data.Dataset` from multiple pass thru datasets.

Passes any `principal_dataset` and `extra_datasets` thru: omitting any
Expand Down Expand Up @@ -255,7 +255,7 @@ def get_dataset(self, _: tf.distribute.InputContext) -> tf.data.Dataset:
examples_shuffle_size=self._examples_shuffle_size)


class SimpleSampleDatasetsProvider(orchestration.DatasetProvider):
class SimpleSampleDatasetsProvider(interfaces.DatasetProvider):
"""Builds a sampling `tf.data.Dataset` from a multiple file patterns.

For complete explanations regarding sampling see `_process_sampled_dataset()`.
Expand Down
193 changes: 193 additions & 0 deletions tensorflow_gnn/runner/interfaces.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
# Copyright 2021 The TensorFlow GNN Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Interfaces for the runner entry point."""
import abc
import sys
from typing import Callable, Optional, Sequence, Tuple, Union

import tensorflow as tf
import tensorflow_gnn as tfgnn

# pylint:disable=g-import-not-at-top
if sys.version_info >= (3, 8):
from typing import Protocol
from typing import runtime_checkable
else:
from typing_extensions import Protocol
from typing_extensions import runtime_checkable
# pylint:enable=g-import-not-at-top

GraphTensor = tfgnn.GraphTensor
GraphTensorAndField = Tuple[GraphTensor, tfgnn.Field]
SizeConstraints = tfgnn.SizeConstraints


class DatasetProvider(abc.ABC):

@abc.abstractmethod
def get_dataset(self, context: tf.distribute.InputContext) -> tf.data.Dataset:
"""Get a `tf.data.Dataset` by `context` per replica."""
raise NotImplementedError()


class GraphTensorPadding(abc.ABC):
"""Collects `GraphtTensor` padding helpers."""

@abc.abstractmethod
def get_filter_fn(
self,
size_constraints: SizeConstraints) -> Callable[..., bool]:
raise NotImplementedError()

@abc.abstractmethod
def get_size_constraints(self, target_batch_size: int) -> SizeConstraints:
raise NotImplementedError()


@runtime_checkable
class GraphTensorProcessorFn(Protocol):
"""A class for `GraphTensor` processing."""

def __call__(
self,
gt: GraphTensor) -> Union[GraphTensor, GraphTensorAndField]:
"""Processes a `GraphTensor` with optional `Field` extraction.

Args:
gt: A `GraphTensor` for processing.

Returns:
A processed `GraphTensor` or a processed `GraphTensor` and `Field.`
"""
raise NotImplementedError()


class ModelExporter(abc.ABC):
"""Saves a Keras model."""

@abc.abstractmethod
def save(
self,
preprocess_model: Optional[tf.keras.Model],
model: tf.keras.Model,
export_dir: str):
"""Saves a Keras model.

All persistence decisions are left to the implementation: e.g., a Keras
model with full API or a simple `tf.train.Checkpoint` may be saved.

Args:
preprocess_model: An optional `tf.keras.Model` for preprocessing.
model: A `tf.keras.Model` to save.
export_dir: A destination directory for the model.
"""
raise NotImplementedError()


class Task(abc.ABC):
"""Collects the ancillary, supporting pieces to train a Keras model.

`Task`s are applied and used to compile a `tf.keras.Model` in the scope
of a training invocation: they are subject to the executing context
of the `Trainer` and should, when needed, override it (e.g., a global
policy, like `tf.keras.mixed_precision.global_policy()` and its implications
over logit and activation layers).

A `Task` is expected to coordinate all of its methods and their return values
to define a graph learning objective. Precisely:

1) `preprocess` is expected to return a `GraphTensor` or
(`GraphTensor`, `Field`) where the `GraphTensor` matches the input of the
model returned by `adapt` and the `Field` is a training label
2) `adapt` is expected to return a `tf.keras.Model` that accepts a
`GraphTensor` matching the output of `preprocess`
3) `losses` is expected to return callables (`tf.Tensor`, `tf.Tensor`) ->
`tf.Tensor` that accept (`y_true`, `y_pred`) where `y_true` is produced
by some dataset and `y_pred` is output of the adapted model (see (2))
4) `metrics` is expected to return callables (`tf.Tensor`, `tf.Tensor`) ->
`tf.Tensor` that accept (`y_true`, `y_pred`) where `y_true` is produced
by some dataset and `y_pred` is output of the adapted model (see (2)).

No constraints are made on the `adapt` method; e.g.: it may adapt its input by
appending a head, it may add losses to its input, it may add metrics to its
input or it may do any combination of the aforementioned modifications. The
`adapt` method is expected to adapt an arbitrary `tf.keras.Model` to the graph
learning objective. (The entire `Tasks` coordinates what that means with
respect to input—via `preprocess`—, modeling—via `adapt`— and optimization—via
`losses.`)
"""

@abc.abstractmethod
def adapt(self, model: tf.keras.Model) -> tf.keras.Model:
"""Adapt a model to a task by appending arbitrary head(s)."""
raise NotImplementedError()

@abc.abstractmethod
def preprocess(
self,
gt: GraphTensor) -> Union[GraphTensor, GraphTensorAndField]:
"""Preprocess a scalar (after `merge_batch_to_components`) `GraphTensor`."""
raise NotImplementedError()

@abc.abstractmethod
def losses(self) -> Sequence[Callable[[tf.Tensor, tf.Tensor], tf.Tensor]]:
"""Arbitrary losses matching any head(s)."""
raise NotImplementedError()

@abc.abstractmethod
def metrics(self) -> Sequence[Callable[[tf.Tensor, tf.Tensor], tf.Tensor]]:
"""Arbitrary task specific metrics."""
raise NotImplementedError()


class Trainer(abc.ABC):
"""A class for training and validation."""

@property
@abc.abstractmethod
def model_dir(self) -> str:
raise NotImplementedError()

@property
@abc.abstractmethod
def strategy(self) -> tf.distribute.Strategy:
raise NotImplementedError()

@abc.abstractmethod
def train(
self,
model_fn: Callable[[], tf.keras.Model],
train_ds_provider: DatasetProvider,
*,
epochs: int = 1,
valid_ds_provider: Optional[DatasetProvider] = None) -> tf.keras.Model:
"""Trains a `tf.keras.Model` with optional validation.

Args:
model_fn: Returns a `tf.keras.Model` for use in training and validation.
train_ds_provider: A `DatasetProvider` for training. The items of the
`tf.data.Dataset` are pairs `(graph_tensor, label)` that represent one
batch of per-replica training inputs after
`GraphTensor.merge_batch_to_components()` has been applied.
epochs: The epochs to train.
valid_ds_provider: A `DatasetProvider` for validation. The items of the
`tf.data.Dataset` are pairs `(graph_tensor, label)` that represent one
batch of per-replica training inputs after
`GraphTensor.merge_batch_to_components()` has been applied.

Returns:
A trained `tf.keras.Model.`
"""
raise NotImplementedError()
Loading