Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Types: Fix PEP 484 implicit-optional compliance #1297

Merged
merged 1 commit into from
Jan 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class ScriptArguments:
def get_stack_exchange_paired(
data_dir: str = "data/rl",
sanity_check: bool = False,
cache_dir: str = None,
cache_dir: Optional[str] = None,
num_proc=24,
) -> Dataset:
"""Load the stack-exchange-paired dataset from Hugging Face and convert it to the necessary format.
Expand Down
4 changes: 2 additions & 2 deletions examples/scripts/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
--lora_alpha=16
"""
from dataclasses import dataclass, field
from typing import Dict
from typing import Dict, Optional

import torch
from datasets import Dataset, load_dataset
Expand Down Expand Up @@ -87,7 +87,7 @@ def extract_anthropic_prompt(prompt_and_response):
return prompt_and_response[: search_term_idx + len(search_term)]


def get_hh(split: str, sanity_check: bool = False, silent: bool = False, cache_dir: str = None) -> Dataset:
def get_hh(split: str, sanity_check: bool = False, silent: bool = False, cache_dir: Optional[str] = None) -> Dataset:
"""Load the Anthropic Helpful-Harmless dataset from Hugging Face and convert it to the necessary format.

The dataset is converted to a dictionary with the following structure:
Expand Down
2 changes: 1 addition & 1 deletion trl/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def whiten(values: torch.Tensor, shift_mean: bool = True) -> torch.Tensor:
return whitened


def masked_mean(values: torch.Tensor, mask: torch.Tensor, axis: bool = None) -> torch.Tensor:
def masked_mean(values: torch.Tensor, mask: torch.Tensor, axis: Optional[bool] = None) -> torch.Tensor:
"""Compute mean of tensor with a masked values."""
if axis is not None:
return (values * mask).sum(axis=axis) / mask.sum(axis=axis)
Expand Down
3 changes: 2 additions & 1 deletion trl/environment/base_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import re
import warnings
from typing import Optional

import torch
from accelerate.utils import extract_model_from_parallel
Expand Down Expand Up @@ -416,7 +417,7 @@ def _generate_batched(
self,
query_tensors,
batch_size: int = 16,
pad_to_multiple_of: int = None,
pad_to_multiple_of: Optional[int] = None,
):
"""
Generate responses for a list of query tensors.
Expand Down
3 changes: 2 additions & 1 deletion trl/models/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import logging
import os
from copy import deepcopy
from typing import Optional

import torch
import torch.nn as nn
Expand Down Expand Up @@ -600,7 +601,7 @@ def compute_reward_score(self, input_ids, attention_mask=None, **kwargs):


def create_reference_model(
model: PreTrainedModelWrapper, num_shared_layers: int = None, pattern: str = None
model: PreTrainedModelWrapper, num_shared_layers: Optional[int] = None, pattern: Optional[str] = None
) -> PreTrainedModelWrapper:
"""
Creates a static reference copy of a model. Note that model will be in `.eval()` mode.
Expand Down
12 changes: 6 additions & 6 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,12 +139,12 @@ class DPOTrainer(Trainer):

def __init__(
self,
model: Union[PreTrainedModel, nn.Module, str] = None,
model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
beta: float = 0.1,
label_smoothing: float = 0,
loss_type: Literal["sigmoid", "hinge", "ipo", "kto_pair"] = "sigmoid",
args: TrainingArguments = None,
args: Optional[TrainingArguments] = None,
data_collator: Optional[DataCollator] = None,
label_pad_token_id: int = -100,
padding_value: Optional[int] = None,
Expand All @@ -165,11 +165,11 @@ def __init__(
generate_during_eval: bool = False,
compute_metrics: Optional[Callable[[EvalLoopOutput], Dict]] = None,
precompute_ref_log_probs: bool = False,
dataset_num_proc: int = None,
dataset_num_proc: Optional[int] = None,
model_init_kwargs: Optional[Dict] = None,
ref_model_init_kwargs: Optional[Dict] = None,
model_adapter_name: str = None,
ref_adapter_name: str = None,
model_adapter_name: Optional[str] = None,
ref_adapter_name: Optional[str] = None,
):
if model_init_kwargs is None:
model_init_kwargs = {}
Expand Down Expand Up @@ -585,7 +585,7 @@ def build_tokenized_answer(self, prompt, answer):
attention_mask=answer_attention_mask,
)

def tokenize_row(self, feature, model: Union[PreTrainedModel, nn.Module] = None) -> Dict:
def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, nn.Module]] = None) -> Dict:
"""Tokenize a single row from a DPO specific dataset.

At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation
Expand Down
6 changes: 3 additions & 3 deletions trl/trainer/iterative_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,9 @@ class IterativeSFTTrainer(Trainer):

def __init__(
self,
model: PreTrainedModel = None,
args: TrainingArguments = None,
tokenizer: PreTrainedTokenizerBase = None,
model: Optional[PreTrainedModel] = None,
args: Optional[TrainingArguments] = None,
tokenizer: Optional[PreTrainedTokenizerBase] = None,
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (
None,
None,
Expand Down
12 changes: 6 additions & 6 deletions trl/trainer/ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,10 +142,10 @@ class PPOTrainer(BaseTrainer):

def __init__(
self,
config: PPOConfig = None,
model: PreTrainedModelWrapper = None,
config: Optional[PPOConfig] = None,
model: Optional[PreTrainedModelWrapper] = None,
ref_model: Optional[PreTrainedModelWrapper] = None,
tokenizer: PreTrainedTokenizerBase = None,
tokenizer: Optional[PreTrainedTokenizerBase] = None,
dataset: Optional[Union[torch.utils.data.Dataset, Dataset]] = None,
optimizer: Optional[torch.optim.Optimizer] = None,
data_collator: Optional[typing.Callable] = None,
Expand Down Expand Up @@ -431,7 +431,7 @@ def _remove_unused_columns(self, dataset: "Dataset"):
def generate(
self,
query_tensor: Union[torch.Tensor, List[torch.Tensor]],
length_sampler: Callable = None,
length_sampler: Optional[Callable] = None,
batch_size: int = 4,
return_prompt: bool = True,
generate_ref_response: bool = False,
Expand Down Expand Up @@ -508,10 +508,10 @@ def _generate_batched(
self,
model: PreTrainedModelWrapper,
query_tensors: List[torch.Tensor],
length_sampler: Callable = None,
length_sampler: Optional[Callable] = None,
batch_size: int = 4,
return_prompt: bool = True,
pad_to_multiple_of: int = None,
pad_to_multiple_of: Optional[int] = None,
remove_padding: bool = True,
**generation_kwargs,
):
Expand Down
2 changes: 1 addition & 1 deletion trl/trainer/reward_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class RewardTrainer(Trainer):

def __init__(
self,
model: Union[PreTrainedModel, nn.Module] = None,
model: Optional[Union[PreTrainedModel, nn.Module]] = None,
args: Optional[RewardConfig] = None,
data_collator: Optional[DataCollator] = None,
train_dataset: Optional[Dataset] = None,
Expand Down
4 changes: 2 additions & 2 deletions trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,8 @@ class SFTTrainer(Trainer):

def __init__(
self,
model: Union[PreTrainedModel, nn.Module, str] = None,
args: TrainingArguments = None,
model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
args: Optional[TrainingArguments] = None,
data_collator: Optional[DataCollator] = None,
train_dataset: Optional[Dataset] = None,
eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
Expand Down
2 changes: 1 addition & 1 deletion trl/trainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ class DataCollatorForCompletionOnlyLM(DataCollatorForLanguageModeling):
def __init__(
self,
response_template: Union[str, List[int]],
instruction_template: Union[str, List[int]] = None,
instruction_template: Optional[Union[str, List[int]]] = None,
*args,
mlm: bool = False,
ignore_index: int = -100,
Expand Down
Loading