import time import torch import pandas as pd from torch.cuda.amp import autocast def bench(dtype, dim, shape, auto=True): linear = torch.nn.Linear(shape[-1], dim, device='cuda') inp = torch.randn(shape, device='cuda') ctx_manager = None if not auto: inp = inp.to(dtype) linear = linear.to(dtype) else: ctx_manager = autocast(dtype=dtype) def run(inp, layer, ctx_manager): if ctx_manager is not None: with ctx_manager: layer(inp) else: layer(inp) run(inp, linear, ctx_manager) torch.cuda.synchronize() t1 = time.time() for i in range(1000): run(inp, linear, ctx_manager) torch.cuda.synchronize() t2 = time.time() return t2 - t1 if __name__ == '__main__': hidden_dim = 1024 for auto in (True, False): print(f"\n## Autocast: {auto}\n") results, speedup = [], [] for batch_size in (8, 16, 32, 64): shape = [batch_size, 56, hidden_dim] results_row, speedup_row = dict(bs=str(batch_size)), dict(bs=str(batch_size)) for dtype in (torch.float32, torch.float16, torch.bfloat16): results_row[dtype] = bench(dtype, hidden_dim, shape, auto) speedup_row[dtype] = results_row[torch.float32] / results_row[dtype] results.append(results_row) speedup.append(speedup_row) print("\nResults:") print(pd.DataFrame(results).to_markdown(index=False, floatfmt=".3f")) print("\nSpeedup:") print(pd.DataFrame(speedup).to_markdown(index=False, floatfmt=".3f"))