# Gematria inference API This document describes the APIs for inference with trained Gematria models. ## Command-line inference API The module [`gematria.model.python.main_function`](http://gematria/model/python/main_function.py) provides an inference mode where the binary reads a [`.tfrecord` file](representation.md) where each record contains a single `BasicBlockWithThroughputProto` in the serialized proto format. The output is written in the same format and preserving the order of the samples to another file. Model binaries using this module support inference automatically. The required flags to run inference are: - `--gematria_action=predict`: required to run the model in batch inference mode. - `--gematria_input_file={filename}`: The path to the input `.tfrecord` file. - `--gematria_output_file={filename}`: The path to the output `.tfrecord` file. - `--gematria_checkpoint_file={checkpoint}`: The path to a TensorFlow checkpoint that contains the trained model used for inference. In addition to these flags, you must also provide the parameters of the model in model-specific flags with the same values as those used to train the model. Example command-line: ```shell $ bazel run -c opt \ //gematria/granite/python:run_granite_model \ -- \ --gematria_action=predict \ --gematria_input_file=/tmp/input.tfrecord \ --gematria_output_file=/tmp/output.tfrecord \ --gematria_tokens_file=/tmp/tokens.txt \ --gematria_checkpoint_file=/tmp/granite_model/model.ckpt-10000 ``` ## Python inference API Python code can interact directly with the Gematria model class, without going through a `.tfrecord` file or. Gematria models based on the `gematria.model.python.main_function.ModelBase` class all provide a `Predict` method that takes a list of `BasicBlockWithThroughputProto` and returns a list of the same protos with the predictions added to them. Example code using the Python API: ```python import tensorflow.compat.v1 as tf from gematria.basic_block.python import tokens from gematria.granite.python import token_graph_builder_model from gematria.model.python import options _INPUT_BLOCKS = [] # Replace with a list of BasicBlockWithThroughputProtos. _CHECKPOINT_FILE = '' # Replace with a path to the TensorFlow checkpoint. _MODEL_TOKENS = [] # Replace with a list of tokens used for training the model. model = token_graph_builder_model.TokenGraphBuilderModel( tokens=_MODEL_TOKENS, dtype=tf.dtypes.float32, immediate_token=tokens.IMMEDIATE, fp_immediate_token=tokens.IMMEDIATE, address_token=tokens.ADDRESS, memory_token=tokens.MEMORY, node_embedding_size=256, edge_embedding_size=256, global_embedding_size=256, node_update_layers=(256, 256), edge_update_layers=(256, 256), global_update_layers=(256, 256), readout_layers=(256, 256), task_readout_layers=(256, 256), num_message_passing_iterations=8, loss_type=options.LossType.MEAN_SQUARED_ERROR, loss_normalization=options.ErrorNormalization.PERCENTAGE_ERROR ) model.Initialize() with tf.Session() as sess: saver = tf.train.Saver() saver.restore(sess, _CHECKPOINT_FILE) output_blocks = model.Predict(sess, _INPUT_BLOCKS) ```