diff --git a/docs/source/sft_trainer.mdx b/docs/source/sft_trainer.mdx index a22811f21c..2fc8c07e5f 100644 --- a/docs/source/sft_trainer.mdx +++ b/docs/source/sft_trainer.mdx @@ -356,7 +356,7 @@ pip install -U optimum Once you have loaded your model, wrap the `trainer.train()` call under the `with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):` context manager: ```diff -# ... +... + with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False): trainer.train()