Currently we have train_micro_batch_size and sample_max_num_sequences to limit the batch size for sampling and training in terms of the number of sequences. It would be great to allow limiting the maximum number of tokens, since then we can guarantee that no OOMs happen irregardless of the sequence length. Probably we would need to implement {train, sample}_max_num_tokens since the memory requirements for sample and forward_backward are quite different.
Thanks @kouroshHakha for the suggestion!