diff --git a/src/accelerate/big_modeling.py b/src/accelerate/big_modeling.py index 065d6ad6eab..9101153077c 100644 --- a/src/accelerate/big_modeling.py +++ b/src/accelerate/big_modeling.py @@ -107,7 +107,7 @@ def init_on_device(device: torch.device, include_buffers: bool = None): from accelerate import init_on_device with init_on_device(device=torch.device("cuda")): - tst = nn.Liner(100, 100) # on `cuda` device + tst = nn.Linear(100, 100) # on `cuda` device ``` """ if include_buffers is None: