Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Typing (part 1) #240

Merged
merged 8 commits into from
Jun 27, 2023
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions llmfoundry/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

try:
import torch
import torch

try:
from llmfoundry import optim, utils
from llmfoundry.data import (ConcatTokensDataset,
MixtureOfDenoisersCollator, NoConcatDataset,
Expand All @@ -24,7 +24,7 @@

except ImportError as e:
try:
is_cuda_available = torch.cuda.is_available() # type: ignore
is_cuda_available = torch.cuda.is_available()
except:
is_cuda_available = False

Expand Down
8 changes: 5 additions & 3 deletions llmfoundry/callbacks/fdiff_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,12 @@ def batch_end(self, state: State, logger: Logger):
def eval_end(self, state: State, logger: Logger):
if self.diff_eval_metrics:
evaluator = state.dataloader_label
metrics = list(state.eval_metrics[evaluator].keys()) # type: ignore
assert evaluator is not None, 'dataloader should have been set'
vchiley marked this conversation as resolved.
Show resolved Hide resolved

metrics = list(state.eval_metrics[evaluator].keys())

for k in metrics:
mkey = '/'.join(['metrics', evaluator, k]) # type: ignore
mkey = '/'.join(['metrics', evaluator, k])
if mkey in self.eval_prev_metric.keys():
logger.log_metrics({
f'{mkey}_fdiff':
Expand All @@ -59,5 +61,5 @@ def eval_end(self, state: State, logger: Logger):
})

for k in metrics:
mkey = '/'.join(['metrics', evaluator, k]) # type: ignore
mkey = '/'.join(['metrics', evaluator, k])
self.eval_prev_metric[mkey] = state.eval_metric_values[k]
12 changes: 7 additions & 5 deletions llmfoundry/callbacks/generate_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,9 @@ def generate(self, state: State, logger: Logger):
dummy_input = device.tensor_to_device(dummy_input)
with get_precision_context(state.precision):
with torch.no_grad():
_ = model.model(input_ids=dummy_input) # type: ignore
_ = model.model(input_ids=dummy_input)

output_token_ids = model.model.generate( # type: ignore
output_token_ids = model.model.generate(
input_ids=tokenized_input['input_ids'],
attention_mask=tokenized_input['attention_mask'],
synced_gpus=True,
Expand All @@ -85,9 +85,11 @@ def generate(self, state: State, logger: Logger):

if dist.get_global_rank() == 0:
if self.wandb_logger is not None:
artifact = wandb.Artifact(
'generate_samples_' + str(wandb.run.id), # type: ignore
type='predictions')
assert wandb.run is not None, 'wandb should have started run'
vchiley marked this conversation as resolved.
Show resolved Hide resolved

artifact = wandb.Artifact('generate_samples_' +
str(wandb.run.id),
type='predictions')

rows = []
for i in range(len(self.prompts)):
Expand Down
10 changes: 5 additions & 5 deletions llmfoundry/data/denoising.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ def build_text_denoising_dataloader(
cfg: DictConfig,
tokenizer: Tokenizer,
device_batch_size: int,
) -> DataLoader:
) -> DataLoader[Dict]:
"""Constructor function for a Mixture of Denoisers dataloader.

This function constructs a dataloader that can be used to train an
Expand Down Expand Up @@ -480,7 +480,7 @@ def build_text_denoising_dataloader(
batch_size=device_batch_size,
)

if dataset.tokenizer.pad_token is None: # type: ignore
if dataset.tokenizer.pad_token is None:
dataset.tokenizer.pad_token = dataset.tokenizer.eos_token

if cfg.dataset.get('packing_ratio'):
Expand Down Expand Up @@ -564,7 +564,7 @@ def noise_token_sequence(
else:
u = np.random.uniform(low=(mask_ratio * 2) - 1, high=1.0)
mean_span_length = float(np.round(1 + u * (length - 1)))
mask_ratio = mean_span_length / length # type: ignore
mask_ratio = mean_span_length / length
use_sentinels = False
else:
use_sentinels = True
Expand Down Expand Up @@ -871,9 +871,9 @@ def _format_tokens_for_decoder_only(
tokenizer = build_tokenizer(tokenizer_cfg)

loader = build_text_denoising_dataloader(cfg, tokenizer, device_batch_size)
assert isinstance(loader.dataset, StreamingTextDataset)

print(
f'\n\nTRUNCATING TO: {loader.dataset.max_seq_len}\n\n') # type: ignore
print(f'\n\nTRUNCATING TO: {loader.dataset.max_seq_len}\n\n')

packing = cfg.dataset.get('packing_ratio') is not None
if packing:
Expand Down
2 changes: 0 additions & 2 deletions llmfoundry/models/hf/hf_prefix_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,6 @@ def __init__(self, om_model_config: DictConfig, tokenizer: Tokenizer):
if om_model_config.get('adapt_vocab_for_denoising', False):
adapt_tokenizer_for_denoising(tokenizer)

vocab_size = len(tokenizer)

init_device = om_model_config.get('init_device', 'cpu')

# Get the device we want to initialize, and use the
Expand Down
2 changes: 0 additions & 2 deletions llmfoundry/models/hf/hf_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,6 @@ def __init__(self, om_model_config: DictConfig, tokenizer: Tokenizer):
if om_model_config.get('adapt_vocab_for_denoising', False):
adapt_tokenizer_for_denoising(tokenizer)

vocab_size = len(tokenizer)

init_device = om_model_config.get('init_device', 'cpu')

# Get the device we want to initialize, and use the
Expand Down
4 changes: 3 additions & 1 deletion llmfoundry/models/layers/norm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from typing import Dict, Type

import torch


Expand Down Expand Up @@ -107,7 +109,7 @@ def forward(self, x):
self.eps).to(dtype=x.dtype)


NORM_CLASS_REGISTRY = {
NORM_CLASS_REGISTRY: Dict[str, Type[torch.nn.Module]] = {
'layernorm': torch.nn.LayerNorm,
'low_precision_layernorm': LPLayerNorm,
'rmsnorm': RMSNorm,
Expand Down
37 changes: 24 additions & 13 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,21 +33,31 @@
from llmfoundry.models.layers.custom_embedding import SharedEmbedding
from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY
from llmfoundry.models.mpt.configuration_mpt import MPTConfig
# NOTE: We import all the utils directly just so that HuggingFace will detect
# all the files that it needs to copy into its modules folder. Otherwise it misses
# the ones imported in the submodule

# NOTE: All utils are imported directly even if unused so that
# HuggingFace can detect all the needed files to copy into its modules folder.
# Otherwise, certain modules are missing.
# isort: off
from llmfoundry.models.utils.adapt_tokenizer import (
AutoTokenizerForMOD, adapt_tokenizer_for_denoising)
AutoTokenizerForMOD, # type: ignore (see note),
adapt_tokenizer_for_denoising, # type: ignore (see note)
)
from llmfoundry.models.utils.hf_prefixlm_converter import (
add_bidirectional_mask_if_missing, convert_hf_causal_lm_to_prefix_lm)
from llmfoundry.models.utils.meta_init_context import init_empty_weights
from llmfoundry.models.utils.param_init_fns import ( # type: ignore
MODEL_INIT_REGISTRY, generic_param_init_fn_)
add_bidirectional_mask_if_missing, # type: ignore (see note)
convert_hf_causal_lm_to_prefix_lm, # type: ignore (see note)
)
from llmfoundry.models.utils.meta_init_context import \
init_empty_weights # type: ignore (see note)
from llmfoundry.models.utils.param_init_fns import (
generic_param_init_fn_, # type: ignore (see note)
MODEL_INIT_REGISTRY,
)

try:
from llmfoundry.models.layers.flash_attn_triton import flash_attn_func
except:
pass
# isort: on

Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]

Expand Down Expand Up @@ -145,7 +155,7 @@ def __init__(self, config: MPTConfig):
def get_input_embeddings(self):
return self.wte

def set_input_embeddings(self, value):
def set_input_embeddings(self, value: nn.Embedding):
self.wte = value

@torch.no_grad()
Expand Down Expand Up @@ -291,6 +301,7 @@ def forward(

if attention_mask is not None:
attention_mask = attention_mask.bool()

if prefix_mask is not None:
prefix_mask = prefix_mask.bool()

Expand Down Expand Up @@ -456,7 +467,7 @@ def __init__(self, config: MPTConfig):
raise ValueError(
'MPTForCausalLM only supports tied word embeddings')

self.transformer = MPTModel(config)
self.transformer: MPTModel = MPTModel(config)
hanlint marked this conversation as resolved.
Show resolved Hide resolved

for child in self.transformer.children():
if isinstance(child, torch.nn.ModuleList):
Expand Down Expand Up @@ -540,10 +551,10 @@ def forward(

loss = None
if labels is not None:
labels = torch.roll(labels, shifts=-1)
labels[:, -1] = -100
_labels = torch.roll(labels, shifts=-1)
_labels[:, -1] = -100
loss = F.cross_entropy(logits.view(-1, logits.size(-1)),
labels.to(logits.device).view(-1))
_labels.to(logits.device).view(-1))

return CausalLMOutputWithPast(
loss=loss,
Expand Down
2 changes: 2 additions & 0 deletions llmfoundry/utils/huggingface_hub_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ def convert_to_relative_import(


def find_module_file(module_name: str) -> str:
if not module_name:
raise ValueError(f'Invalid input: {module_name=}')
module = importlib.import_module(module_name)
module_file = module.__file__
return module_file
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ include = [

# Pyright
[tool.pyright]
exclude = ['env-**']
exclude = ['env-**', 'venv*']
ignore = ['llmfoundry/models/layers/flash_attn_triton.py']
stubPath = "" # suppress useless 'stubPath is not a valid directory' errors

reportUnnecessaryIsInstance = "warning"
Expand Down
2 changes: 1 addition & 1 deletion tests/test_hf_mpt_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import pytest
from composer.core.precision import get_precision_context
from composer.utils import dist, get_device, reproducibility
from composer.utils import get_device, reproducibility
from omegaconf import OmegaConf as om

from llmfoundry import COMPOSER_MODEL_REGISTRY
Expand Down
32 changes: 17 additions & 15 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
ComposerHFPrefixLM)
from llmfoundry.models.hf.model_wrapper import HuggingFaceModelWithZLoss
from llmfoundry.models.layers import NORM_CLASS_REGISTRY, build_alibi_bias
from llmfoundry.models.layers.blocks import MPTBlock
from llmfoundry.models.mpt import MPTConfig, MPTForCausalLM
from llmfoundry.utils import build_tokenizer

Expand Down Expand Up @@ -359,13 +360,15 @@ def test_loss_fn():
pytest.skip('Fused cross entropy was not installed')

# run numerical test in pure fp32
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
torch.backends.cuda.matmul.allow_tf32 = False # type: ignore (third-party)
torch.backends.cudnn.allow_tf32 = False # type: ignore (third-party)

conf_path = 'scripts/train/yamls/pretrain/testing.yaml'
with open(conf_path) as f:
test_cfg = om.load(f)

assert isinstance(test_cfg, DictConfig)

hanlint marked this conversation as resolved.
Show resolved Hide resolved
test_cfg.device = 'cuda:0'
test_cfg.model.init_device = 'cuda:0'
test_cfg.model.init_config = {
Expand Down Expand Up @@ -471,25 +474,24 @@ def test_mpt_creation(norm_type, no_bias):
assert mpt.config.expansion_ratio == 2
assert mpt.config.max_seq_len == 2048

assert mpt.transformer.wte.weight.shape == torch.Size( # type: ignore
assert mpt.transformer.wte.weight.shape == torch.Size(
[hf_config.vocab_size, hf_config.d_model])
assert mpt.transformer.wpe.weight.shape == torch.Size( # type: ignore
assert mpt.transformer.wpe.weight.shape == torch.Size(
[hf_config.max_seq_len, hf_config.d_model])
assert mpt.transformer.emb_drop.p == 0.1 # type: ignore
assert len(mpt.transformer.blocks) == 2 # type: ignore
assert mpt.transformer.emb_drop.p == 0.1
assert len(mpt.transformer.blocks) == 2

d_model = hf_config.d_model
for block in mpt.transformer.blocks: # type: ignore
assert block.norm_1.weight.shape == torch.Size([d_model
]) # type: ignore
assert block.norm_2.weight.shape == torch.Size([d_model
]) # type: ignore
assert block.ffn.up_proj.weight.shape == torch.Size( # type: ignore
for block in mpt.transformer.blocks:
assert isinstance(block, MPTBlock)
assert block.norm_1.weight.shape == torch.Size([d_model])
assert block.norm_2.weight.shape == torch.Size([d_model])
assert block.ffn.up_proj.weight.shape == torch.Size(
[hf_config.d_model * hf_config.expansion_ratio, hf_config.d_model])
assert block.ffn.down_proj.weight.shape == torch.Size( # type: ignore
assert block.ffn.down_proj.weight.shape == torch.Size(
[hf_config.d_model, hf_config.d_model * hf_config.expansion_ratio])
assert block.resid_attn_dropout.p == 0.2 # type: ignore
assert block.resid_ffn_dropout.p == 0.2 # type: ignore
assert block.resid_attn_dropout.p == 0.2
assert block.resid_ffn_dropout.p == 0.2


@pytest.mark.parametrize('attention_impl,device', [('torch', 'cpu'),
Expand Down