Skip to content

Latest commit

 

History

History
239 lines (197 loc) · 6.37 KB

GraphMulticlassClassification.md

File metadata and controls

239 lines (197 loc) · 6.37 KB

runner.GraphMulticlassClassification

View source on GitHub

Graph multiclass classification from pooled node states.

Inherits From: Task

runner.GraphMulticlassClassification(
    node_set_name: str,
    *,
    num_classes: Optional[int] = None,
    class_names: Optional[Sequence[str]] = None,
    per_class_statistics: bool = False,
    state_name: str = tfgnn.HIDDEN_STATE,
    reduce_type: str = 'mean',
    name: str = 'classification_logits',
    label_fn: Optional[LabelFn] = None,
    label_feature_name: Optional[str] = None
)

Args

node_set_name The node set to pool.
num_classes The number of classes. Exactly one of num_classes or class_names must be specified
class_names The class names. Exactly one of num_classes or class_names must be specified
per_class_statistics Whether to compute statistics per class.
state_name The feature name for activations (e.g.: tfgnn.HIDDEN_STATE).
reduce_type The context pooling reduction type.
name The classification head's layer name. To control the naming of saved model outputs see the runner model exporters (e.g., KerasModelExporter).
label_fn A label extraction function. This function mutates the input GraphTensor. Mutually exclusive with label_feature_name.
label_feature_name A label feature name for readout from the auxiliary '_readout' node set. Readout does not mutate the input GraphTensor. Mutually exclusive with label_fn.

Methods

gather_activations

View source

gather_activations(
    inputs: GraphTensor
) -> Field

losses

View source

losses() -> interfaces.Losses

Sparse categorical crossentropy loss.

metrics

View source

metrics() -> interfaces.Metrics

Sparse categorical metrics.

predict

View source

predict(
    inputs: tfgnn.GraphTensor
) -> interfaces.Predictions

Apply a linear head for classification.

Args
inputs A tfgnn.GraphTensor for classification.
Returns
The classification logits.

preprocess

View source

preprocess(
    inputs: GraphTensor
) -> tuple[GraphTensor, Field]

Preprocesses a scalar (after merge_batch_to_components) GraphTensor.

This function uses the Keras functional API to define non-trainable transformations of the symbolic input GraphTensor, which get executed during dataset preprocessing in a tf.data.Dataset.map(...) operation. It has two responsibilities:

  1. Splitting the training label out of the input for training. It must be returned as a separate tensor or mapping of tensors.
  2. Optionally, transforming input features. Some advanced modeling techniques require running the same base GNN on multiple different transformations, so this function may return a single GraphTensor or a non-empty sequence of GraphTensors. The corresponding base GNN output for each GraphTensor is provided to the predict(...) method.
Args
inputs A symbolic Keras GraphTensor for processing.
Returns
A tuple of processed GraphTensor(s) and a (one or mapping of) Field to be used as labels.