Skip to content

Training TF 2.0 Models on TPUs

Adrish Dey edited this page Aug 24, 2019 · 5 revisions

Porting Training Codes on TPU

Tensor Processing units are the fastest Processors one can get right now to train their machine learning models. However writing codes that runs on TPUs straight out of the box is not something that happens in reality.

TensorFlow 2.0

With release of TensorFlow 2.0, TensorFlow became more and more famous in the ML community. With nicer and cleaner API, TF 2.0 got a lot of fame among the newbies. The new features of TF 2.0, includes making keras the default API for building ML models, which caught a lot of attention not just among the newbies but also in the research community, as prototyping ML models became much easier giving an extremely familar pythonic interface. Eager Execution is enabled by default, making it more pythonic, and easier for debugging the code. One of the biggest change which came around is making distributed training more accessible. However, a few of the distributed "Strategy" APIs are still experimental, and not well documented. Among these includes the Strategy for distributing training on Tensor Processing Units (TPUs).

Tensor Processing Units (TPUs)

With the growth of the number of parameters of ML models, it became increasingly important to find a way to fit these models in the system memory and train them efficiently and faster than running the code for weeks. To address this problem, Google came up with an ASIC (Application Specific Integrated Circuit), which can reach few TeraFlops while running inference on trained models. with time, these ASICs were modified to support training the TensorFlow models, faster than any other accelarator available in market. These high speed ASICs or Tensor Processing Units (TPUs), became a huge hit in the ML community.

The biggest Turn-off of TPU training on TF 2.0

Although with all the power, housed in the TPUs, it became a second option to some of the ML researchers, mainly because the API to train models on these devices is not that easy, and the trainer code, is a bit different than regular CPU or GPU training. One of the main reason is TPUs are not connected on-board of the cloud VMs, and it runs on a gRPC network, where model graphs are sent as Protocol Buffers, and this graph is replicated among the cores of the TPUs, where the training code is run. Another bad thing about TPUs is they don't have native support to Python code. So anything that is evaluated using a python function cannot run on the TPUs. The Team is working hard to make it more accessible to the research community in TF2.0, but still these issues, are the main reason why the community prefers a multi GPU environment than going for a TPUs. Another thing that needs to be addressed is the lack of a proper official documentation where running custom training loop on TPUs is stated.

My Journey into TPUs

This summer, I spent building Enhanced Super Resolution Generative Adversarial Network as a part of Google Summer of Code '19. The training was successfully done on one single Tesla T4 GPU. The model was performing so well, me and my mentors (@srjoglekar246, @vbardiovskyg) were spellbound, and after discussing and nerding over the output, we planned something supercool, but super difficult. This includes Compressing the ESRGAN we have now, and reducing the number of parameters, to make inference much much faster. So, after setting up the test bed, for "Knowledge Distillation" of the GAN, we ran into a very bad problem. That is the trainer code was not fitting in the memory of a single T4 GPU (16 Gigs). So, I decided to ask for a hike in the GPU quota. Hell broke loose, when my request was rejected due to some payment method issues on Google Cloud Platform. So after trying and failing all the hacks available in the bag, we shifted our focus to the training on TPUs. I was fully aware of the TPU situation and was trying to avoid TPUs at any cost, but there was no way out, and choosing to hunt samples around github for 16 hours a day for 1 full week, was more of an obvious choice than giving up on the ground breaking product we were planning.

Pro Tip for training on TPUs

Option 1

Unless you want to go nuts in a jungle of meaningless, pointless error messages, Avoid TPUs till the last moment!! AAARRRGGGGGHHHH!!!!

Option 2

Can you give up on the project? if you can, GO FOR IT! YOU HAVE MY BEST WISHES!

Option 3 (Avoid it at any cost!)

I'm sorry to see you stuck in this position, :( However since you decided to do this, so let's dive in! Don't blame me that I didn't warn!

Tried and tested method for running on TPUs

  • The first task of any ML project comes down to the data. If you are using a custom dataset, you have to convert the dataset to TF Record File(s). It is highly advisable to shard the dataset (distribute the dataset in multiple TFRecords Files), and upload it to GCS Bucket.

  • The first task is to resolve the TPU cluster. In simple terms, you need to find the TPU to actually run anything on it!

import tensorflow as tf
# NOTE:
# Name of the TPU is the name of the TPU instance as per Cloud Console
# Complete gRPC Address refers to the address of the TPU Clustet along with the port
# for example
# 1. "grpc://xxx.xxx.xxx.xxx:xxxx"
cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver("name of the TPU or the complete gRPC address of the TPU")
# If you are in eager mode, don't forget to connect to the TPU
tf.config.experimental_connect_to_host(cluster_resolver.get_master())
# Now let's initialize th TPU
tf.tpu.experimental.initialize_tpu_system(cluster_resolver)

After connecting to the TPU system, the next step is to get start a Distributed Context, where the every declared variable will be shared among the various processing units. In TF 2.0, this whole process of distribution works by using Strategies. For TPUs, the name of the distribution strategy used is, "TPUStrategy", which when opened using a context manager, gives a scope to to work in the "PerReplica" context.

strategy = tf.distribute.TPUStrategy(cluster_resolver)

Everything inside with strategy.scope() runs in distributed context.

Now coming to things which is not written anywhere in the docs, but is very very important run training on TPUs.

Opening strategy scope, alone will never help! Everything (including the dataset) has to be placed on TPU. To do this, we need to open a tf.device("/job:worker") scope along with startegy scope. and if you forget you're bound to get a gRPC not found / invalid argument error.

So the structure should more or less look like.

with tf.device("/job:worker"), strategy.scope():
  # TFRecord Dataset Loading and Distribution
  # Model Initialization
  # Training
  # Anything else that I'm misssing

Now, as you already have guessed, each and everyone of these steps are bound to fail and make your life miserable if not dealt properly. So, let's tackle one by one.

Dataset Loading

Just as stated Earlier the dataset must be in sharded TFRecords. You can obviously go for TFDS (TensorFlow Datasets) But last time I checked it was breaking badly.

import os
TFRECORDPATH="gs://path/to/folder/in/GCS/bucket/containing/the/sharded/dataset"
BATCH_SIZE=1024 # Use a size which is a power of 2. Keep it bigger than 8

# TPUs can handle huge batch size. So, 1024 is a good place to start
# and reduce it by powers of 2, as long as it is not fitting in the memory

datasets = tf.io.gfile.glob(os.path.join(TFRECORDPATH, "*.tfrecord"))
#I prefer using .tfrecord extension while saving
# Feel Free to use anything else, just update while trying to match the nake pattern
dataset = tf.data.TFRecordDataset(datasets).repeat().batch(BATCH_SIZE)
dataset = iter(strategy.experimental_distribute_dataset(dataset)) # Distributed Iterator to iterate on the dataset
# This will distribute the loaded dataset among the workers Cores of TPU

Model Initialization

Building model using default keras API will work perfectly fine. However if you are planning to include an input_signature to export the saved model later on, well, that's not gonna work! (No idea, why :3) But there's a work around for that! (Huge Shoutout to @vbardiovskyg to point this out!)

class Dummy(tf.keras.models.Model):
  def __init__(self):
    super(Dummy, self).__init__()
    # Layer initialization
    # Declare all the variables and layers else it will raise an error
    # saying, the model is trying to initialize variables in non-first call
  @tf.function(input_signature=[tf.TensorSpec(....), ...])
  def call(self, inputs):
    return self.unsigned_call(inputs)
  def unsigned_call(self, inputs):
    # Layer Ops here
    return

Declaring the class like this allows to use unsigned_call(...) while training the model, with no graph traced. and when exporting the model,one can just trace the graph by passing in random numbers of same shape as the signature model(tf.random.normal(shape=(.,.,..))) and then exporting the model, using tf.saved_model.save(model, export_dir) should work like charm.

Training

Setting up training is not as trivial as normal training loop in eager mode. The main reason being the training step needs to be written in graph mode, and that graph will be distributed among the cores. Well, how to write graph code in eager mode? Autograph comes to save the day! Autograph is a functionality of TensorFlow which converts python code to tensorflow graph code, and when called it fetches the result from graph and provides the EagerTensor out of it! Pretty dope right? To use Autograph, wrapping a function with tf.function decorator is enough! So the complete code structure to pull out such a trick should look something like this.

def _step_function(data_samples):
  with tf.GradientTape() as tape:
    # Computing loss
    # Computing Derivative
    # Applying Gradient
    # compute the mean loss here for writing to summary
  return tf.cast(optimizer.iterations, tf.float32) # To get the current step number
  # Don't forget to cast it to float32, else taking mean
  # will raise an error
@tf.function
def train_step(data_samples):
  distributed_metric = strategy.experimental_run_v2(_step_function, args=(data_samples,))
  mean_metric = strategy.reduce(tf.ReduceOp.MEAN, distribute_metric, axis=None)
  return mean_metric # This will return mean of the distributed metric (in this case, number of steps)

total_num_steps = ... # You need to specify total number of steps to run for

while True:
  data_samples = next(dataset)
  num_steps = train_step(data_samples)
  if num_steps >= total_num_steps:
    break
  # DON'T COMPUTE mean of distributed metrics like loss, steps, outside `per_replica_context`, in this case, outside `_step_function`.

Hope that gives a clear idea of the whole process. Happy Training!

UPDATE

Well, after a lot of debugging, and reaching out the TPU team, I figured out how to run TPU Strategy using TensorFlow Datasets! Here's a small sample I wrote, to help demonstrate this. https://github.com/captain-pool/GSOC/tree/master/E1_TPU_Sample