Skip to content

[tx] Support max_num_tokens and in jax backend #821

@pcmoritz

Description

@pcmoritz

Currently we have train_micro_batch_size and sample_max_num_sequences to limit the batch size for sampling and training in terms of the number of sequences. It would be great to allow limiting the maximum number of tokens, since then we can guarantee that no OOMs happen irregardless of the sequence length. Probably we would need to implement {train, sample}_max_num_tokens since the memory requirements for sample and forward_backward are quite different.

Thanks @kouroshHakha for the suggestion!

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions