Defines a learning objective for a GNN.
A Task
represents a learning objective for a GNN model and defines all the
non-GNN pieces around the base GNN. Specifically:
preprocess
is expected to return aGraphTensor
(orGraphTensor
s) and aField
where (a) the base GNN's output for eachGraphTensor
is passed topredict
and (b) theField
is used as the training label (for supervised tasks); 2)predict
is expected to (a) take the base GNN's output for eachGraphTensor
returned bypreprocess
and (b) return a tensor with the model's prediction for this task; 3)losses
is expected to return callables (tf.Tensor
,tf.Tensor
) ->tf.Tensor
that accept (y_true
,y_pred
) wherey_true
is produced by some dataset andy_pred
is the model's prediction from (2); 4)metrics
is expected to return callables (tf.Tensor
,tf.Tensor
) ->tf.Tensor
that accept (y_true
,y_pred
) wherey_true
is produced by some dataset andy_pred
is the model's prediction from (2).
Task
can emit multiple outputs in predict
: in that case we require that (a)
it is a mapping, (b) outputs of losses
and metrics
are also mappings with
matching keys, and (c) there is exactly one loss per key (there may be a
sequence of metrics per key). This is done to prevent accidental dropping of
losses (see b/291874188).
No constraints are made on the predict
method; e.g.: it may append a head with
learnable weights or it may perform tensor computations only. (The entire Task
coordinates what that means with respect to dataset—via preprocess
—,
modeling—via predict
— and optimization—via losses
.)
Task
s are applied in the scope of a training invocation: they are subject to
the executing context of the Trainer
and should, when needed, override it
(e.g., a global policy, like tf.keras.mixed_precision.global_policy()
and its
implications over logit and activation layers).
@abc.abstractmethod
losses() -> Losses
Returns arbitrary task specific losses.
@abc.abstractmethod
metrics() -> Metrics
Returns arbitrary task specific metrics.
@abc.abstractmethod
predict( *args ) -> Predictions
Produces prediction outputs for the learning objective.
Overall model composition* makes use of the Keras Functional API
(https://www.tensorflow.org/guide/keras/functional) to map symbolic Keras
GraphTensor
inputs to symbolic Keras Field
outputs. Outputs must match the
structure (one or mapping) of labels from preprocess
.
*) outputs = predict(GNN(inputs))
where inputs
are those GraphTensor
returned by preprocess(...)
, GNN
is the base GNN, predict
is this method
and outputs
are the prediction outputs for the learning objective.
Args | |
---|---|
*args
|
The symbolic Keras GraphTensor inputs(s). These inputs correspond
(in sequence) to the base GNN output of each GraphTensor returned by
preprocess(...) .
|
Returns | |
---|---|
The model's prediction output for this task. |
@abc.abstractmethod
preprocess( inputs: GraphTensor ) -> tuple[OneOrSequenceOf[GraphTensor], OneOrMappingOf[Field]]
Preprocesses a scalar (after merge_batch_to_components
) GraphTensor
.
This function uses the Keras functional API to define non-trainable
transformations of the symbolic input GraphTensor
, which get executed during
dataset preprocessing in a tf.data.Dataset.map(...)
operation. It has two
responsibilities:
- Splitting the training label out of the input for training. It must be returned as a separate tensor or mapping of tensors.
- Optionally, transforming input features. Some advanced modeling techniques
require running the same base GNN on multiple different transformations, so
this function may return a single
GraphTensor
or a non-empty sequence ofGraphTensors
. The corresponding base GNN output for eachGraphTensor
is provided to thepredict(...)
method.
Args | |
---|---|
inputs
|
A symbolic Keras GraphTensor for processing.
|
Returns | |
---|---|
A tuple of processed GraphTensor (s) and a (one or mapping of) Field to
be used as labels.
|