Graph multiclass classification from pooled node states.
Inherits From: Task
runner.GraphMulticlassClassification(
node_set_name: str,
*,
num_classes: Optional[int] = None,
class_names: Optional[Sequence[str]] = None,
per_class_statistics: bool = False,
state_name: str = tfgnn.HIDDEN_STATE,
reduce_type: str = 'mean',
name: str = 'classification_logits',
label_fn: Optional[LabelFn] = None,
label_feature_name: Optional[str] = None
)
gather_activations(
inputs: GraphTensor
) -> Field
losses() -> interfaces.Losses
Sparse categorical crossentropy loss.
metrics() -> interfaces.Metrics
Sparse categorical metrics.
predict(
inputs: tfgnn.GraphTensor
) -> interfaces.Predictions
Apply a linear head for classification.
Args | |
---|---|
inputs
|
A tfgnn.GraphTensor for classification.
|
Returns | |
---|---|
The classification logits. |
preprocess(
inputs: GraphTensor
) -> tuple[GraphTensor, Field]
Preprocesses a scalar (after merge_batch_to_components
) GraphTensor
.
This function uses the Keras functional API to define non-trainable
transformations of the symbolic input GraphTensor
, which get executed during
dataset preprocessing in a tf.data.Dataset.map(...)
operation. It has two
responsibilities:
- Splitting the training label out of the input for training. It must be returned as a separate tensor or mapping of tensors.
- Optionally, transforming input features. Some advanced modeling techniques
require running the same base GNN on multiple different transformations, so
this function may return a single
GraphTensor
or a non-empty sequence ofGraphTensors
. The corresponding base GNN output for eachGraphTensor
is provided to thepredict(...)
method.
Args | |
---|---|
inputs
|
A symbolic Keras GraphTensor for processing.
|
Returns | |
---|---|
A tuple of processed GraphTensor (s) and a (one or mapping of) Field to
be used as labels.
|