Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change unclear variables & fix typos #134

Merged
merged 4 commits into from
Feb 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion trl/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def entropy_from_logits(logits):


def average_torch_dicts(list_of_dicts):
"""Average values of a list of dicts wiht torch tensors."""
"""Average values of a list of dicts with torch tensors."""
average_dict = dict()
for key in list_of_dicts[0].keys():
average_dict[key] = torch.mean(torch.stack([d[key] for d in list_of_dicts]), axis=0)
Expand Down
2 changes: 1 addition & 1 deletion trl/models/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ class and the arguments that are specific to trl models.
is_shared = True

if is_shared:
# dowload each file and add it to the state_dict
# download each file and add it to the state_dict
state_dict = {}
for shard_file in files_to_download:
filename = hf_hub_download(pretrained_model_name_or_path, shard_file)
Expand Down
8 changes: 4 additions & 4 deletions trl/models/modeling_value_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def generate(self, *args, **kwargs):
def state_dict(self, *args, **kwargs):
r"""
Returns the state dictionary of the model. We add the state dictionary of the value head
to the state dictionary of the wrapped model by preprending the key with `v_head.`.
to the state dictionary of the wrapped model by prepending the key with `v_head.`.
"""
pretrained_model_state_dict = self.pretrained_model.state_dict(*args, **kwargs)
v_head_state_dict = self.v_head.state_dict(*args, **kwargs)
Expand All @@ -205,7 +205,7 @@ def push_to_hub(self, *args, **kwargs):
def post_init(self, state_dict):
r"""
We add the state dictionary of the value head to the state dictionary of the wrapped model
by preprending the key with `v_head.`. This function removes the `v_head.` prefix from the
by prepending the key with `v_head.`. This function removes the `v_head.` prefix from the
keys of the value head state dictionary.
"""
for k in list(state_dict.keys()):
Expand Down Expand Up @@ -260,7 +260,7 @@ def _has_lm_head(self):
def post_init(self, state_dict):
r"""
We add the state dictionary of the value head to the state dictionary of the wrapped model
by preprending the key with `v_head.`. This function removes the `v_head.` prefix from the
by prepending the key with `v_head.`. This function removes the `v_head.` prefix from the
keys of the value head state dictionary.
"""
for k in list(state_dict.keys()):
Expand All @@ -272,7 +272,7 @@ def post_init(self, state_dict):
def state_dict(self, *args, **kwargs):
r"""
Returns the state dictionary of the model. We add the state dictionary of the value head
to the state dictionary of the wrapped model by preprending the key with `v_head.`.
to the state dictionary of the wrapped model by prepending the key with `v_head.`.
"""
pretrained_model_state_dict = self.pretrained_model.state_dict(*args, **kwargs)
v_head_state_dict = self.v_head.state_dict(*args, **kwargs)
Expand Down
6 changes: 4 additions & 2 deletions trl/trainer/ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@
class PPOTrainer(BaseTrainer):
"""
The PPOTrainer uses Proximal Policy Optimization to optimise language models.
Note, this trainer is heavily inspired by the original OpenAI learning to summarize work here:
https://github.com/openai/summarize-from-feedback

Attributes:
**config** (`PPOConfig`) -- Configuration object for PPOTrainer. Check the documentation of `PPOConfig` for more
Expand Down Expand Up @@ -346,7 +348,7 @@ def generate(self, query_tensor: torch.Tensor, **generation_kwargs):
Args:
query_tensor (`torch.LongTensor`):
A tensor of shape (`batch_size`, `seq_len`) containing query tokens.
gen_kwargs (dict[str, Any]):
generation_kwargs (dict[str, Any]):
Keyword arguments for generation.

Returns:
Expand Down Expand Up @@ -812,7 +814,7 @@ def log_stats(
stats (dict[str, Any]):
A dictionary of training stats.
batch (dict[str, Any]):
A dictionary of batch data, this containes the queries and responses.
A dictionary of batch data, this contains the queries and responses.
rewards (`List[torch.FloatTensor]`):
A tensor of rewards.
"""
Expand Down