Skip to content

Commit

Permalink
Add adapter_to() method for moving & converting adapter weights (#699)
Browse files Browse the repository at this point in the history
This PR:
- introduces new methods `adapter_to()` & `adapter_fusion_to()` to move
only adapter weights to device or convert dtype
- avoids moving full model in `AdapterTrainer` when loading best model.

Fixes #694.
  • Loading branch information
calpt committed May 11, 2024
1 parent 0c0e034 commit ed9a537
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 4 deletions.
4 changes: 1 addition & 3 deletions notebooks/QLoRA_Llama_Finetuning.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -261,9 +261,7 @@
"metadata": {},
"outputs": [],
"source": [
"# for _, v in model.get_adapter(\"assistant_adapter\").items():\n",
"# for _, module in v.items():\n",
"# module.to(\"cuda\")"
"# model.adapter_to(\"assistant_adapter\", device=\"cuda\")"
]
},
{
Expand Down
7 changes: 7 additions & 0 deletions src/adapters/methods/bottleneck.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,13 @@ def get_adapter(self, adapter_name: str):
else:
return None

def get_adapter_fusion(self, adapter_names: Union[List, str]):
adapter_names = adapter_names if isinstance(adapter_names, str) else ",".join(adapter_names)
if adapter_names in self.adapter_fusion_layer:
return self.adapter_fusion_layer[adapter_names]
else:
return None

def pre_block(self, adapter_setup: Union[AdapterCompositionBlock, str], state: BottleneckState) -> BottleneckState:
if isinstance(adapter_setup, AdapterCompositionBlock):
adapter_name = adapter_setup.first()
Expand Down
36 changes: 36 additions & 0 deletions src/adapters/model_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1035,6 +1035,42 @@ def get_adapter(self, name) -> dict:

return dict(destination)

def adapter_to(
self, name: str, device: Optional[Union[torch.device, str]] = None, dtype: Optional[torch.dtype] = None
):
"""
Moves the adapter with the given name to the specified device and data type.
Args:
name (str): The name of the adapter to be moved.
device (torch.device or str, optional): The device on which the adapter should be moved.
dtype (torch.dtype, optional): The data type to which the adapter should be cast.
"""
for _, v in self.get_adapter(name).items():
for _, module in v.items():
module.to(device=device, dtype=dtype)

def adapter_fusion_to(
self,
adapter_names: Union[Fuse, list, str],
device: Optional[Union[torch.device, str]] = None,
dtype: Optional[torch.dtype] = None,
):
"""
Moves the adapter fusion layer with the given name to the specified device and data type.
Args:
adapter_names (Union[Fuse, list, str]): The name of the adapter fusion layer to be moved.
device (torch.device or str, optional): The device on which the adapter fusion layer should be moved.
dtype (torch.dtype, optional): The data type to which the adapter fusion layer should be cast.
"""
for _, layer in self.iter_layers():
for module in layer.modules():
if isinstance(module, BottleneckLayer):
fusion = module.get_adapter_fusion(adapter_names)
if fusion is not None:
fusion.to(device=device, dtype=dtype)

def adapter_summary(self, as_dict=False) -> Union[str, dict]:
"""
Returns a string summary of all adapters currently added to the model. Each entry in the summary table has the
Expand Down
3 changes: 2 additions & 1 deletion src/adapters/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ def _load_best_model(self):
adapter_dir = os.path.join(self.state.best_model_checkpoint, adapter)
if os.path.exists(adapter_dir):
model.load_adapter(adapter_dir)
model.adapter_to(adapter, device=self.args.device)
if self.train_adapter_fusion:
logger.info(
f"Loading best adapter fusion(s) from {self.state.best_model_checkpoint} (score:"
Expand All @@ -222,7 +223,7 @@ def _load_best_model(self):
fusion_dir = os.path.join(self.state.best_model_checkpoint, fusion)
if os.path.exists(fusion_dir):
model.load_adapter_fusion(fusion_dir)
model.to(self.args.device)
model.adapter_fusion_to(fusion, device=self.args.device)


class AdapterTrainerCallback(TrainerCallback):
Expand Down

0 comments on commit ed9a537

Please sign in to comment.