Skip to content

Commit

Permalink
Fix Typing (part 1) (#240)
Browse files Browse the repository at this point in the history
* remove some type ignore

* fix typing

* fix formatting

* pr cmts

---------

Co-authored-by: Vitaliy Chiley <vitaliy@mosaicml.com>
  • Loading branch information
hanlint and vchiley authored Jun 27, 2023
1 parent 73940be commit dea59aa
Show file tree
Hide file tree
Showing 12 changed files with 70 additions and 53 deletions.
6 changes: 3 additions & 3 deletions llmfoundry/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

try:
import torch
import torch

try:
from llmfoundry import optim, utils
from llmfoundry.data import (ConcatTokensDataset,
MixtureOfDenoisersCollator, NoConcatDataset,
Expand All @@ -24,7 +24,7 @@

except ImportError as e:
try:
is_cuda_available = torch.cuda.is_available() # type: ignore
is_cuda_available = torch.cuda.is_available()
except:
is_cuda_available = False

Expand Down
8 changes: 5 additions & 3 deletions llmfoundry/callbacks/fdiff_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,12 @@ def batch_end(self, state: State, logger: Logger):
def eval_end(self, state: State, logger: Logger):
if self.diff_eval_metrics:
evaluator = state.dataloader_label
metrics = list(state.eval_metrics[evaluator].keys()) # type: ignore
assert evaluator is not None, 'dataloader should have been set'

metrics = list(state.eval_metrics[evaluator].keys())

for k in metrics:
mkey = '/'.join(['metrics', evaluator, k]) # type: ignore
mkey = '/'.join(['metrics', evaluator, k])
if mkey in self.eval_prev_metric.keys():
logger.log_metrics({
f'{mkey}_fdiff':
Expand All @@ -59,5 +61,5 @@ def eval_end(self, state: State, logger: Logger):
})

for k in metrics:
mkey = '/'.join(['metrics', evaluator, k]) # type: ignore
mkey = '/'.join(['metrics', evaluator, k])
self.eval_prev_metric[mkey] = state.eval_metric_values[k]
12 changes: 7 additions & 5 deletions llmfoundry/callbacks/generate_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,9 @@ def generate(self, state: State, logger: Logger):
dummy_input = device.tensor_to_device(dummy_input)
with get_precision_context(state.precision):
with torch.no_grad():
_ = model.model(input_ids=dummy_input) # type: ignore
_ = model.model(input_ids=dummy_input)

output_token_ids = model.model.generate( # type: ignore
output_token_ids = model.model.generate(
input_ids=tokenized_input['input_ids'],
attention_mask=tokenized_input['attention_mask'],
synced_gpus=True,
Expand All @@ -85,9 +85,11 @@ def generate(self, state: State, logger: Logger):

if dist.get_global_rank() == 0:
if self.wandb_logger is not None:
artifact = wandb.Artifact(
'generate_samples_' + str(wandb.run.id), # type: ignore
type='predictions')
assert wandb.run is not None, 'wandb should have started run'

artifact = wandb.Artifact('generate_samples_' +
str(wandb.run.id),
type='predictions')

rows = []
for i in range(len(self.prompts)):
Expand Down
10 changes: 5 additions & 5 deletions llmfoundry/data/denoising.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ def build_text_denoising_dataloader(
cfg: DictConfig,
tokenizer: Tokenizer,
device_batch_size: int,
) -> DataLoader:
) -> DataLoader[Dict]:
"""Constructor function for a Mixture of Denoisers dataloader.
This function constructs a dataloader that can be used to train an
Expand Down Expand Up @@ -480,7 +480,7 @@ def build_text_denoising_dataloader(
batch_size=device_batch_size,
)

if dataset.tokenizer.pad_token is None: # type: ignore
if dataset.tokenizer.pad_token is None:
dataset.tokenizer.pad_token = dataset.tokenizer.eos_token

if cfg.dataset.get('packing_ratio'):
Expand Down Expand Up @@ -564,7 +564,7 @@ def noise_token_sequence(
else:
u = np.random.uniform(low=(mask_ratio * 2) - 1, high=1.0)
mean_span_length = float(np.round(1 + u * (length - 1)))
mask_ratio = mean_span_length / length # type: ignore
mask_ratio = mean_span_length / length
use_sentinels = False
else:
use_sentinels = True
Expand Down Expand Up @@ -871,9 +871,9 @@ def _format_tokens_for_decoder_only(
tokenizer = build_tokenizer(tokenizer_cfg)

loader = build_text_denoising_dataloader(cfg, tokenizer, device_batch_size)
assert isinstance(loader.dataset, StreamingTextDataset)

print(
f'\n\nTRUNCATING TO: {loader.dataset.max_seq_len}\n\n') # type: ignore
print(f'\n\nTRUNCATING TO: {loader.dataset.max_seq_len}\n\n')

packing = cfg.dataset.get('packing_ratio') is not None
if packing:
Expand Down
2 changes: 0 additions & 2 deletions llmfoundry/models/hf/hf_prefix_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,6 @@ def __init__(self, om_model_config: DictConfig, tokenizer: Tokenizer):
if om_model_config.get('adapt_vocab_for_denoising', False):
adapt_tokenizer_for_denoising(tokenizer)

vocab_size = len(tokenizer)

init_device = om_model_config.get('init_device', 'cpu')

# Get the device we want to initialize, and use the
Expand Down
2 changes: 0 additions & 2 deletions llmfoundry/models/hf/hf_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,6 @@ def __init__(self, om_model_config: DictConfig, tokenizer: Tokenizer):
if om_model_config.get('adapt_vocab_for_denoising', False):
adapt_tokenizer_for_denoising(tokenizer)

vocab_size = len(tokenizer)

init_device = om_model_config.get('init_device', 'cpu')

# Get the device we want to initialize, and use the
Expand Down
4 changes: 3 additions & 1 deletion llmfoundry/models/layers/norm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from typing import Dict, Type

import torch


Expand Down Expand Up @@ -107,7 +109,7 @@ def forward(self, x):
self.eps).to(dtype=x.dtype)


NORM_CLASS_REGISTRY = {
NORM_CLASS_REGISTRY: Dict[str, Type[torch.nn.Module]] = {
'layernorm': torch.nn.LayerNorm,
'low_precision_layernorm': LPLayerNorm,
'rmsnorm': RMSNorm,
Expand Down
40 changes: 25 additions & 15 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,21 +33,31 @@
from llmfoundry.models.layers.custom_embedding import SharedEmbedding
from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY
from llmfoundry.models.mpt.configuration_mpt import MPTConfig
# NOTE: We import all the utils directly just so that HuggingFace will detect
# all the files that it needs to copy into its modules folder. Otherwise it misses
# the ones imported in the submodule

# NOTE: All utils are imported directly even if unused so that
# HuggingFace can detect all the needed files to copy into its modules folder.
# Otherwise, certain modules are missing.
# isort: off
from llmfoundry.models.utils.adapt_tokenizer import (
AutoTokenizerForMOD, adapt_tokenizer_for_denoising)
AutoTokenizerForMOD, # type: ignore (see note),
adapt_tokenizer_for_denoising, # type: ignore (see note)
)
from llmfoundry.models.utils.hf_prefixlm_converter import (
add_bidirectional_mask_if_missing, convert_hf_causal_lm_to_prefix_lm)
from llmfoundry.models.utils.meta_init_context import init_empty_weights
from llmfoundry.models.utils.param_init_fns import ( # type: ignore
MODEL_INIT_REGISTRY, generic_param_init_fn_)
add_bidirectional_mask_if_missing, # type: ignore (see note)
convert_hf_causal_lm_to_prefix_lm, # type: ignore (see note)
)
from llmfoundry.models.utils.meta_init_context import \
init_empty_weights # type: ignore (see note)
from llmfoundry.models.utils.param_init_fns import (
generic_param_init_fn_, # type: ignore (see note)
MODEL_INIT_REGISTRY,
)

try:
from llmfoundry.models.layers.flash_attn_triton import flash_attn_func
except:
pass
# isort: on

Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]

Expand Down Expand Up @@ -145,7 +155,7 @@ def __init__(self, config: MPTConfig):
def get_input_embeddings(self):
return self.wte

def set_input_embeddings(self, value):
def set_input_embeddings(self, value: nn.Embedding):
self.wte = value

@torch.no_grad()
Expand Down Expand Up @@ -294,6 +304,7 @@ def forward(

if attention_mask is not None:
attention_mask = attention_mask.bool()

if prefix_mask is not None:
prefix_mask = prefix_mask.bool()

Expand All @@ -320,8 +331,7 @@ def forward(
'prefix_mask is a required argument when MPT is configured with prefix_lm=True.'
)


# Raise a not implemented error if input_embeds is not None (this is an arg in huggingface transformers and we need to support it for PEFT)
# Raise a not implemented error if input_embeds is not None (this is an arg in huggingface transformers and we need to support it for PEFT)
if inputs_embeds is not None:
raise NotImplementedError(
'inputs_embeds is not implemented for MPT.')
Expand Down Expand Up @@ -473,7 +483,7 @@ def __init__(self, config: MPTConfig):

print(f'Instantiating an MPTForCausalLM model from {__file__}')

self.transformer = MPTModel(config)
self.transformer: MPTModel = MPTModel(config)

for child in self.transformer.children():
if isinstance(child, torch.nn.ModuleList):
Expand Down Expand Up @@ -565,11 +575,11 @@ def forward(

loss = None
if labels is not None:
labels = torch.roll(labels, shifts=-1)
labels[:, -1] = -100
_labels = torch.roll(labels, shifts=-1)
_labels[:, -1] = -100
loss = F.cross_entropy(
logits.view(-1, logits.size(-1)),
labels.to(logits.device).view(-1),
_labels.to(logits.device).view(-1),
)

return CausalLMOutputWithPast(
Expand Down
2 changes: 2 additions & 0 deletions llmfoundry/utils/huggingface_hub_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ def convert_to_relative_import(


def find_module_file(module_name: str) -> str:
if not module_name:
raise ValueError(f'Invalid input: {module_name=}')
module = importlib.import_module(module_name)
module_file = module.__file__
return module_file
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ include = [

# Pyright
[tool.pyright]
exclude = ['env-**']
exclude = ['env-**', 'venv*']
ignore = ['llmfoundry/models/layers/flash_attn_triton.py']
stubPath = "" # suppress useless 'stubPath is not a valid directory' errors

reportUnnecessaryIsInstance = "warning"
Expand Down
2 changes: 1 addition & 1 deletion tests/test_hf_mpt_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import pytest
from composer.core.precision import get_precision_context
from composer.utils import dist, get_device, reproducibility
from composer.utils import get_device, reproducibility
from omegaconf import OmegaConf as om

from llmfoundry import COMPOSER_MODEL_REGISTRY
Expand Down
32 changes: 17 additions & 15 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
ComposerHFPrefixLM)
from llmfoundry.models.hf.model_wrapper import HuggingFaceModelWithZLoss
from llmfoundry.models.layers import NORM_CLASS_REGISTRY, build_alibi_bias
from llmfoundry.models.layers.blocks import MPTBlock
from llmfoundry.models.mpt import MPTConfig, MPTForCausalLM
from llmfoundry.utils import build_tokenizer

Expand Down Expand Up @@ -359,13 +360,15 @@ def test_loss_fn():
pytest.skip('Fused cross entropy was not installed')

# run numerical test in pure fp32
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
torch.backends.cuda.matmul.allow_tf32 = False # type: ignore (third-party)
torch.backends.cudnn.allow_tf32 = False # type: ignore (third-party)

conf_path = 'scripts/train/yamls/pretrain/testing.yaml'
with open(conf_path) as f:
test_cfg = om.load(f)

assert isinstance(test_cfg, DictConfig)

test_cfg.device = 'cuda:0'
test_cfg.model.init_device = 'cuda:0'
test_cfg.model.init_config = {
Expand Down Expand Up @@ -471,25 +474,24 @@ def test_mpt_creation(norm_type, no_bias):
assert mpt.config.expansion_ratio == 2
assert mpt.config.max_seq_len == 2048

assert mpt.transformer.wte.weight.shape == torch.Size( # type: ignore
assert mpt.transformer.wte.weight.shape == torch.Size(
[hf_config.vocab_size, hf_config.d_model])
assert mpt.transformer.wpe.weight.shape == torch.Size( # type: ignore
assert mpt.transformer.wpe.weight.shape == torch.Size(
[hf_config.max_seq_len, hf_config.d_model])
assert mpt.transformer.emb_drop.p == 0.1 # type: ignore
assert len(mpt.transformer.blocks) == 2 # type: ignore
assert mpt.transformer.emb_drop.p == 0.1
assert len(mpt.transformer.blocks) == 2

d_model = hf_config.d_model
for block in mpt.transformer.blocks: # type: ignore
assert block.norm_1.weight.shape == torch.Size([d_model
]) # type: ignore
assert block.norm_2.weight.shape == torch.Size([d_model
]) # type: ignore
assert block.ffn.up_proj.weight.shape == torch.Size( # type: ignore
for block in mpt.transformer.blocks:
assert isinstance(block, MPTBlock)
assert block.norm_1.weight.shape == torch.Size([d_model])
assert block.norm_2.weight.shape == torch.Size([d_model])
assert block.ffn.up_proj.weight.shape == torch.Size(
[hf_config.d_model * hf_config.expansion_ratio, hf_config.d_model])
assert block.ffn.down_proj.weight.shape == torch.Size( # type: ignore
assert block.ffn.down_proj.weight.shape == torch.Size(
[hf_config.d_model, hf_config.d_model * hf_config.expansion_ratio])
assert block.resid_attn_dropout.p == 0.2 # type: ignore
assert block.resid_ffn_dropout.p == 0.2 # type: ignore
assert block.resid_attn_dropout.p == 0.2
assert block.resid_ffn_dropout.p == 0.2


@pytest.mark.parametrize('attention_impl,device', [('torch', 'cpu'),
Expand Down

0 comments on commit dea59aa

Please sign in to comment.