diff --git a/trl/core.py b/trl/core.py index fa18f8548e..a787301b0f 100644 --- a/trl/core.py +++ b/trl/core.py @@ -121,7 +121,7 @@ def entropy_from_logits(logits): def average_torch_dicts(list_of_dicts): - """Average values of a list of dicts wiht torch tensors.""" + """Average values of a list of dicts with torch tensors.""" average_dict = dict() for key in list_of_dicts[0].keys(): average_dict[key] = torch.mean(torch.stack([d[key] for d in list_of_dicts]), axis=0) diff --git a/trl/models/modeling_base.py b/trl/models/modeling_base.py index 01322d172d..7b18da4a1c 100644 --- a/trl/models/modeling_base.py +++ b/trl/models/modeling_base.py @@ -120,7 +120,7 @@ class and the arguments that are specific to trl models. is_shared = True if is_shared: - # dowload each file and add it to the state_dict + # download each file and add it to the state_dict state_dict = {} for shard_file in files_to_download: filename = hf_hub_download(pretrained_model_name_or_path, shard_file) diff --git a/trl/models/modeling_value_head.py b/trl/models/modeling_value_head.py index b4ad3d7795..db7e40cf21 100644 --- a/trl/models/modeling_value_head.py +++ b/trl/models/modeling_value_head.py @@ -189,7 +189,7 @@ def generate(self, *args, **kwargs): def state_dict(self, *args, **kwargs): r""" Returns the state dictionary of the model. We add the state dictionary of the value head - to the state dictionary of the wrapped model by preprending the key with `v_head.`. + to the state dictionary of the wrapped model by prepending the key with `v_head.`. """ pretrained_model_state_dict = self.pretrained_model.state_dict(*args, **kwargs) v_head_state_dict = self.v_head.state_dict(*args, **kwargs) @@ -205,7 +205,7 @@ def push_to_hub(self, *args, **kwargs): def post_init(self, state_dict): r""" We add the state dictionary of the value head to the state dictionary of the wrapped model - by preprending the key with `v_head.`. This function removes the `v_head.` prefix from the + by prepending the key with `v_head.`. This function removes the `v_head.` prefix from the keys of the value head state dictionary. """ for k in list(state_dict.keys()): @@ -260,7 +260,7 @@ def _has_lm_head(self): def post_init(self, state_dict): r""" We add the state dictionary of the value head to the state dictionary of the wrapped model - by preprending the key with `v_head.`. This function removes the `v_head.` prefix from the + by prepending the key with `v_head.`. This function removes the `v_head.` prefix from the keys of the value head state dictionary. """ for k in list(state_dict.keys()): @@ -272,7 +272,7 @@ def post_init(self, state_dict): def state_dict(self, *args, **kwargs): r""" Returns the state dictionary of the model. We add the state dictionary of the value head - to the state dictionary of the wrapped model by preprending the key with `v_head.`. + to the state dictionary of the wrapped model by prepending the key with `v_head.`. """ pretrained_model_state_dict = self.pretrained_model.state_dict(*args, **kwargs) v_head_state_dict = self.v_head.state_dict(*args, **kwargs) diff --git a/trl/trainer/ppo_trainer.py b/trl/trainer/ppo_trainer.py index de809de2ce..b25642528b 100644 --- a/trl/trainer/ppo_trainer.py +++ b/trl/trainer/ppo_trainer.py @@ -92,6 +92,8 @@ class PPOTrainer(BaseTrainer): """ The PPOTrainer uses Proximal Policy Optimization to optimise language models. + Note, this trainer is heavily inspired by the original OpenAI learning to summarize work here: + https://github.com/openai/summarize-from-feedback Attributes: **config** (`PPOConfig`) -- Configuration object for PPOTrainer. Check the documentation of `PPOConfig` for more @@ -346,7 +348,7 @@ def generate(self, query_tensor: torch.Tensor, **generation_kwargs): Args: query_tensor (`torch.LongTensor`): A tensor of shape (`batch_size`, `seq_len`) containing query tokens. - gen_kwargs (dict[str, Any]): + generation_kwargs (dict[str, Any]): Keyword arguments for generation. Returns: @@ -812,7 +814,7 @@ def log_stats( stats (dict[str, Any]): A dictionary of training stats. batch (dict[str, Any]): - A dictionary of batch data, this containes the queries and responses. + A dictionary of batch data, this contains the queries and responses. rewards (`List[torch.FloatTensor]`): A tensor of rewards. """