Skip to content

Commit

Permalink
Change unclear variables & fix typos (#134)
Browse files Browse the repository at this point in the history
* rename variables, fix typos

* style

* remove breaking changes

* revert line
  • Loading branch information
Nathan Lambert authored Feb 8, 2023
1 parent 8070ac0 commit 5caddf6
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 8 deletions.
2 changes: 1 addition & 1 deletion trl/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def entropy_from_logits(logits):


def average_torch_dicts(list_of_dicts):
"""Average values of a list of dicts wiht torch tensors."""
"""Average values of a list of dicts with torch tensors."""
average_dict = dict()
for key in list_of_dicts[0].keys():
average_dict[key] = torch.mean(torch.stack([d[key] for d in list_of_dicts]), axis=0)
Expand Down
2 changes: 1 addition & 1 deletion trl/models/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ class and the arguments that are specific to trl models.
is_shared = True

if is_shared:
# dowload each file and add it to the state_dict
# download each file and add it to the state_dict
state_dict = {}
for shard_file in files_to_download:
filename = hf_hub_download(pretrained_model_name_or_path, shard_file)
Expand Down
8 changes: 4 additions & 4 deletions trl/models/modeling_value_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def generate(self, *args, **kwargs):
def state_dict(self, *args, **kwargs):
r"""
Returns the state dictionary of the model. We add the state dictionary of the value head
to the state dictionary of the wrapped model by preprending the key with `v_head.`.
to the state dictionary of the wrapped model by prepending the key with `v_head.`.
"""
pretrained_model_state_dict = self.pretrained_model.state_dict(*args, **kwargs)
v_head_state_dict = self.v_head.state_dict(*args, **kwargs)
Expand All @@ -205,7 +205,7 @@ def push_to_hub(self, *args, **kwargs):
def post_init(self, state_dict):
r"""
We add the state dictionary of the value head to the state dictionary of the wrapped model
by preprending the key with `v_head.`. This function removes the `v_head.` prefix from the
by prepending the key with `v_head.`. This function removes the `v_head.` prefix from the
keys of the value head state dictionary.
"""
for k in list(state_dict.keys()):
Expand Down Expand Up @@ -260,7 +260,7 @@ def _has_lm_head(self):
def post_init(self, state_dict):
r"""
We add the state dictionary of the value head to the state dictionary of the wrapped model
by preprending the key with `v_head.`. This function removes the `v_head.` prefix from the
by prepending the key with `v_head.`. This function removes the `v_head.` prefix from the
keys of the value head state dictionary.
"""
for k in list(state_dict.keys()):
Expand All @@ -272,7 +272,7 @@ def post_init(self, state_dict):
def state_dict(self, *args, **kwargs):
r"""
Returns the state dictionary of the model. We add the state dictionary of the value head
to the state dictionary of the wrapped model by preprending the key with `v_head.`.
to the state dictionary of the wrapped model by prepending the key with `v_head.`.
"""
pretrained_model_state_dict = self.pretrained_model.state_dict(*args, **kwargs)
v_head_state_dict = self.v_head.state_dict(*args, **kwargs)
Expand Down
6 changes: 4 additions & 2 deletions trl/trainer/ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@
class PPOTrainer(BaseTrainer):
"""
The PPOTrainer uses Proximal Policy Optimization to optimise language models.
Note, this trainer is heavily inspired by the original OpenAI learning to summarize work here:
https://github.com/openai/summarize-from-feedback
Attributes:
**config** (`PPOConfig`) -- Configuration object for PPOTrainer. Check the documentation of `PPOConfig` for more
Expand Down Expand Up @@ -346,7 +348,7 @@ def generate(self, query_tensor: torch.Tensor, **generation_kwargs):
Args:
query_tensor (`torch.LongTensor`):
A tensor of shape (`batch_size`, `seq_len`) containing query tokens.
gen_kwargs (dict[str, Any]):
generation_kwargs (dict[str, Any]):
Keyword arguments for generation.
Returns:
Expand Down Expand Up @@ -812,7 +814,7 @@ def log_stats(
stats (dict[str, Any]):
A dictionary of training stats.
batch (dict[str, Any]):
A dictionary of batch data, this containes the queries and responses.
A dictionary of batch data, this contains the queries and responses.
rewards (`List[torch.FloatTensor]`):
A tensor of rewards.
"""
Expand Down

0 comments on commit 5caddf6

Please sign in to comment.