Skip to content

Latest commit

 

History

History
179 lines (137 loc) · 6.25 KB

File metadata and controls

179 lines (137 loc) · 6.25 KB

runner.Task

View source on GitHub

Defines a learning objective for a GNN.

A Task represents a learning objective for a GNN model and defines all the non-GNN pieces around the base GNN. Specifically:

  1. preprocess is expected to return a GraphTensor (or GraphTensors) and a Field where (a) the base GNN's output for each GraphTensor is passed to predict and (b) the Field is used as the training label (for supervised tasks); 2) predict is expected to (a) take the base GNN's output for each GraphTensor returned by preprocess and (b) return a tensor with the model's prediction for this task; 3) losses is expected to return callables (tf.Tensor, tf.Tensor) -> tf.Tensor that accept (y_true, y_pred) where y_true is produced by some dataset and y_pred is the model's prediction from (2); 4) metrics is expected to return callables (tf.Tensor, tf.Tensor) -> tf.Tensor that accept (y_true, y_pred) where y_true is produced by some dataset and y_pred is the model's prediction from (2).

Task can emit multiple outputs in predict: in that case we require that (a) it is a mapping, (b) outputs of losses and metrics are also mappings with matching keys, and (c) there is exactly one loss per key (there may be a sequence of metrics per key). This is done to prevent accidental dropping of losses (see b/291874188).

No constraints are made on the predict method; e.g.: it may append a head with learnable weights or it may perform tensor computations only. (The entire Task coordinates what that means with respect to dataset—via preprocess—, modeling—via predict— and optimization—via losses.)

Tasks are applied in the scope of a training invocation: they are subject to the executing context of the Trainer and should, when needed, override it (e.g., a global policy, like tf.keras.mixed_precision.global_policy() and its implications over logit and activation layers).

Methods

losses

View source

@abc.abstractmethod
losses() -> Losses

Returns arbitrary task specific losses.

metrics

View source

@abc.abstractmethod
metrics() -> Metrics

Returns arbitrary task specific metrics.

predict

View source

@abc.abstractmethod
predict(
    *args
) -> Predictions

Produces prediction outputs for the learning objective.

Overall model composition* makes use of the Keras Functional API (https://www.tensorflow.org/guide/keras/functional) to map symbolic Keras GraphTensor inputs to symbolic Keras Field outputs. Outputs must match the structure (one or mapping) of labels from preprocess.

*) outputs = predict(GNN(inputs)) where inputs are those GraphTensor returned by preprocess(...), GNN is the base GNN, predict is this method and outputs are the prediction outputs for the learning objective.

Args
*args The symbolic Keras GraphTensor inputs(s). These inputs correspond (in sequence) to the base GNN output of each GraphTensor returned by preprocess(...).
Returns
The model's prediction output for this task.

preprocess

View source

@abc.abstractmethod
preprocess(
    inputs: GraphTensor
) -> tuple[OneOrSequenceOf[GraphTensor], OneOrMappingOf[Field]]

Preprocesses a scalar (after merge_batch_to_components) GraphTensor.

This function uses the Keras functional API to define non-trainable transformations of the symbolic input GraphTensor, which get executed during dataset preprocessing in a tf.data.Dataset.map(...) operation. It has two responsibilities:

  1. Splitting the training label out of the input for training. It must be returned as a separate tensor or mapping of tensors.
  2. Optionally, transforming input features. Some advanced modeling techniques require running the same base GNN on multiple different transformations, so this function may return a single GraphTensor or a non-empty sequence of GraphTensors. The corresponding base GNN output for each GraphTensor is provided to the predict(...) method.
Args
inputs A symbolic Keras GraphTensor for processing.
Returns
A tuple of processed GraphTensor(s) and a (one or mapping of) Field to be used as labels.