diff --git a/.github/workflows/ci_pipeline.yaml b/.github/workflows/ci_pipeline.yaml index 1c52c64df..2daa959a3 100644 --- a/.github/workflows/ci_pipeline.yaml +++ b/.github/workflows/ci_pipeline.yaml @@ -129,8 +129,7 @@ jobs: OS: ubuntu-latest PYTHON: 3.11 run: | - python .github/install_mindspore.py - pip install -r download.txt + pip install mindspore - name: Test with pytest run: | pytest -vs tests/transformers/models/${{ matrix.alpha }}*/test_modeling* diff --git a/.github/workflows/make_wheel_releases.yml b/.github/workflows/make_wheel_releases.yml index 554ce5017..76dd5354a 100644 --- a/.github/workflows/make_wheel_releases.yml +++ b/.github/workflows/make_wheel_releases.yml @@ -27,7 +27,7 @@ jobs: run: python -m build --wheel - name: Upload file - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: mindnlp-whl path: dist/* diff --git a/llm/finetune/Align/eval_model.py b/llm/finetune/Align/eval_model.py new file mode 100644 index 000000000..b45179aea --- /dev/null +++ b/llm/finetune/Align/eval_model.py @@ -0,0 +1,160 @@ +import collections +import collections.abc + +collections.Iterable = collections.abc.Iterable + +import mindspore as ms +from mindnlp.transformers import AlignModel, AlignProcessor +from mindspore import Tensor, nn, ops, Parameter +from pycocotools.coco import COCO +import os +from tqdm import tqdm +import pickle +from concurrent.futures import ThreadPoolExecutor +import numpy as np +import gc + +HYPERPARAMS = { + "model_name": "E:/Code/align_ft_torch/cache/model/kakaobrain/align-base", + "batch_size": 4, + "val_samples": 50, + "max_length": 128, + "num_workers": 8, + "data_dir": "MSCOCO", + "data_type": "val2017", + "val_cache_file": "mscoco_preprocessed_val_50.pkl", + "save_dir": "cache/model", + "model_save_path": "cache/model/finetuned_align_model_epoch_{epoch}.ckpt", + "processor_save_path": "cache/model/finetuned_align_processor" +} + +ms.set_context(mode=ms.PYNATIVE_MODE, device_target="Ascend") +ms.context.reset_auto_parallel_context() + + +def setup_coco(): + dataDir = HYPERPARAMS["data_dir"] + dataType = HYPERPARAMS["data_type"] + os.makedirs(dataDir, exist_ok=True) + os.makedirs(f"{dataDir}/annotations", exist_ok=True) + os.makedirs(f"{dataDir}/{dataType}", exist_ok=True) + ann_file = f"{dataDir}/annotations/captions_{dataType}.json" + if not os.path.exists(ann_file): + ann_zip = f"{dataDir}/annotations_trainval2017.zip" + if not os.path.exists(ann_zip): + raise FileNotFoundError(f"{ann_zip} not found. Please download it manually.") + print("Extracting annotations...") + os.system(f"unzip -o {ann_zip} -d {dataDir}") + return dataDir, dataType + + +dataDir, dataType = setup_coco() +annFile = f'{dataDir}/annotations/captions_{dataType}.json' +coco = COCO(annFile) + + +def get_image_and_caption(coco, img_id, cache_dir=f"{HYPERPARAMS['data_dir']}/{HYPERPARAMS['data_type']}"): + ann_ids = coco.getAnnIds(imgIds=img_id) + anns = coco.loadAnns(ann_ids) + caption = anns[0]['caption'] + img_info = coco.loadImgs(img_id)[0] + img_path = f"{cache_dir}/{img_info['file_name']}" + image = Image.open(img_path) + if image.mode != "RGB": + image = image.convert("RGB") + return image, caption + + +def process_sample(img_id, coco): + image, caption = get_image_and_caption(coco, img_id) + processor = AlignProcessor.from_pretrained(HYPERPARAMS["processor_save_path"]) + inputs = processor( + text=caption, + images=image, + return_tensors="ms", + padding="max_length", + max_length=HYPERPARAMS["max_length"] + ) + return (inputs["input_ids"][0], inputs["attention_mask"][0], inputs["pixel_values"][0]) + + +def preprocess_and_save(coco, num_samples, cache_file): + if os.path.exists(cache_file): + print(f"Loading preprocessed data from {cache_file}") + with open(cache_file, "rb") as f: + dataset = pickle.load(f) + print(f"Loaded dataset size: {len(dataset)} samples") + return dataset + img_ids = coco.getImgIds()[:num_samples] + dataset = [] + with ThreadPoolExecutor(max_workers=HYPERPARAMS["num_workers"]) as executor: + dataset = list(tqdm(executor.map(lambda x: process_sample(x, coco), img_ids), + total=num_samples, desc=f"Preprocessing dataset ({num_samples} samples)")) + with open(cache_file, "wb") as f: + pickle.dump(dataset, f) + return dataset + + +def create_val_dataloader(coco, batch_size=HYPERPARAMS["batch_size"]): + val_dataset = preprocess_and_save(coco, HYPERPARAMS["val_samples"], HYPERPARAMS["val_cache_file"]) + val_dataloader = ms.dataset.GeneratorDataset( + val_dataset, + column_names=["input_ids", "attention_mask", "pixel_values"] + ).batch(batch_size) + return val_dataloader + + +class TrainingNet(nn.Cell): + def __init__(self, model): + super().__init__() + self.model = model + self.global_pool = nn.AdaptiveAvgPool2d(1) + self.text_projection = nn.Dense(768, 640) + self.logit_scale = Parameter(Tensor(np.log(1 / 0.07), dtype=ms.float32), requires_grad=True) + self.image_embeds = None + self.text_embeds = None + + def construct(self, input_ids, attention_mask, pixel_values): + embedding_output = self.model.vision_model.embeddings(pixel_values) + encoder_outputs = self.model.vision_model.encoder(embedding_output) + last_hidden_state = encoder_outputs[0] + pooled_output = self.global_pool(last_hidden_state) + self.image_embeds = pooled_output.reshape(pooled_output.shape[:2]) + text_outputs = self.model.text_model(input_ids=input_ids, attention_mask=attention_mask) + text_embeds = text_outputs[0][:, 0, :] + self.text_embeds = self.text_projection(text_embeds) + logits = ops.matmul(self.image_embeds, self.text_embeds.T) * ops.exp(self.logit_scale) + labels = ops.arange(len(logits), dtype=ms.int32) + loss_i2t = nn.CrossEntropyLoss()(logits, labels) + loss_t2i = nn.CrossEntropyLoss()(logits.T, labels) + return (loss_i2t + loss_t2i) / 2 + + +def evaluate_model(coco, epoch_to_eval): + processor = AlignProcessor.from_pretrained(HYPERPARAMS["processor_save_path"]) + model = AlignModel.from_pretrained(HYPERPARAMS["model_name"], local_files_only=True) + net = TrainingNet(model) # 使用 TrainingNet 包装 AlignModel + param_dict = ms.load_checkpoint(HYPERPARAMS["model_save_path"].format(epoch=epoch_to_eval)) + ms.load_param_into_net(net, param_dict) # 加载到 TrainingNet + net.set_train(False) + + val_dataloader = create_val_dataloader(coco) + print(f"Val dataloader created with batch_size={HYPERPARAMS['batch_size']}, samples={HYPERPARAMS['val_samples']}") + + total_val_loss = 0 + val_steps = 0 + for batch in tqdm(val_dataloader.create_dict_iterator(), desc=f"Evaluating Epoch {epoch_to_eval}"): + loss = net(batch["input_ids"], batch["attention_mask"], batch["pixel_values"]) + total_val_loss += loss.asnumpy() + val_steps += 1 + avg_val_loss = total_val_loss / val_steps + print(f"Epoch {epoch_to_eval}, Eval Loss: {avg_val_loss:.4f}") + + gc.collect() + return avg_val_loss + + +if __name__ == "__main__": + print("Starting model evaluation...") + for epoch in range(1, 11): + evaluate_model(coco, epoch) \ No newline at end of file diff --git a/llm/finetune/Align/finetune.py b/llm/finetune/Align/finetune.py new file mode 100644 index 000000000..e3a4ffeb9 --- /dev/null +++ b/llm/finetune/Align/finetune.py @@ -0,0 +1,203 @@ +import collections +import collections.abc + +collections.Iterable = collections.abc.Iterable + +import mindspore as ms +from mindnlp.transformers import AlignModel, AlignProcessor +from mindspore import Tensor, nn, ops, Parameter +from PIL import Image +from pycocotools.coco import COCO +import os +from tqdm import tqdm +import pickle +from concurrent.futures import ThreadPoolExecutor +import numpy as np + +HYPERPARAMS = { + "model_name": "E:/Code/align_ft_torch/cache/model/kakaobrain/align-base", + "epochs": 10, + "batch_size": 4, + "learning_rate": 1e-4, + "train_samples": 200, + "max_length": 128, + "num_workers": 8, + "data_dir": "MSCOCO", + "data_type": "val2017", + "train_cache_file": "mscoco_preprocessed_train_200.pkl", + "save_dir": "cache/model", + "model_save_path": "cache/model/finetuned_align_model_epoch_{epoch}.ckpt", + "processor_save_path": "cache/model/finetuned_align_processor" +} + +ms.set_context(mode=ms.PYNATIVE_MODE, device_target="Ascend") +ms.context.reset_auto_parallel_context() + +processor = AlignProcessor.from_pretrained(HYPERPARAMS["model_name"], local_files_only=True) +model = AlignModel.from_pretrained(HYPERPARAMS["model_name"], local_files_only=True) +model.set_train(True) + +print("Model config:", model.config) +params = model.trainable_params() +print("Number of trainable params:", len(params)) + + +def setup_coco(): + dataDir = HYPERPARAMS["data_dir"] + dataType = HYPERPARAMS["data_type"] + os.makedirs(dataDir, exist_ok=True) + os.makedirs(f"{dataDir}/annotations", exist_ok=True) + os.makedirs(f"{dataDir}/{dataType}", exist_ok=True) + ann_file = f"{dataDir}/annotations/captions_{dataType}.json" + if not os.path.exists(ann_file): + ann_zip = f"{dataDir}/annotations_trainval2017.zip" + if not os.path.exists(ann_zip): + raise FileNotFoundError(f"{ann_zip} not found. Please download it manually.") + print("Extracting annotations...") + os.system(f"unzip -o {ann_zip} -d {dataDir}") + return dataDir, dataType + + +dataDir, dataType = setup_coco() +annFile = f'{dataDir}/annotations/captions_{dataType}.json' +coco = COCO(annFile) + + +def get_image_and_caption(coco, img_id, cache_dir=f"{HYPERPARAMS['data_dir']}/{HYPERPARAMS['data_type']}"): + ann_ids = coco.getAnnIds(imgIds=img_id) + anns = coco.loadAnns(ann_ids) + caption = anns[0]['caption'] + img_info = coco.loadImgs(img_id)[0] + img_path = f"{cache_dir}/{img_info['file_name']}" + image = Image.open(img_path) + if image.mode != "RGB": + image = image.convert("RGB") + return image, caption + + +def process_sample(img_id, coco): + image, caption = get_image_and_caption(coco, img_id) + inputs = processor( + text=caption, + images=image, + return_tensors="ms", + padding="max_length", + max_length=HYPERPARAMS["max_length"] + ) + return (inputs["input_ids"][0], inputs["attention_mask"][0], inputs["pixel_values"][0]) + + +def preprocess_and_save(coco, num_samples, cache_file): + if os.path.exists(cache_file): + print(f"Loading preprocessed data from {cache_file}") + with open(cache_file, "rb") as f: + dataset = pickle.load(f) + print(f"Loaded dataset size: {len(dataset)} samples") + return dataset + img_ids = coco.getImgIds()[:num_samples] + dataset = [] + with ThreadPoolExecutor(max_workers=HYPERPARAMS["num_workers"]) as executor: + dataset = list(tqdm(executor.map(lambda x: process_sample(x, coco), img_ids), + total=num_samples, desc=f"Preprocessing dataset ({num_samples} samples)")) + with open(cache_file, "wb") as f: + pickle.dump(dataset, f) + return dataset + + +def create_train_dataloader(coco, batch_size=HYPERPARAMS["batch_size"]): + train_dataset = preprocess_and_save(coco, HYPERPARAMS["train_samples"], HYPERPARAMS["train_cache_file"]) + train_dataloader = ms.dataset.GeneratorDataset( + train_dataset, + column_names=["input_ids", "attention_mask", "pixel_values"] + ).batch(batch_size) + return train_dataloader + + +class TrainingNet(nn.Cell): + def __init__(self, model): + super().__init__() + self.model = model + self.global_pool = nn.AdaptiveAvgPool2d(1) + self.text_projection = nn.Dense(768, 640) + self.logit_scale = Parameter(Tensor(np.log(1 / 0.07), dtype=ms.float32), requires_grad=True) + self.image_embeds = None + self.text_embeds = None + + def construct(self, input_ids, attention_mask, pixel_values): + embedding_output = self.model.vision_model.embeddings(pixel_values) + encoder_outputs = self.model.vision_model.encoder(embedding_output) + last_hidden_state = encoder_outputs[0] + pooled_output = self.global_pool(last_hidden_state) + self.image_embeds = pooled_output.reshape(pooled_output.shape[:2]) + text_outputs = self.model.text_model(input_ids=input_ids, attention_mask=attention_mask) + text_embeds = text_outputs[0][:, 0, :] + self.text_embeds = self.text_projection(text_embeds) + logits = ops.matmul(self.image_embeds, self.text_embeds.T) * ops.exp(self.logit_scale) + labels = ops.arange(len(logits), dtype=ms.int32) + loss_i2t = nn.CrossEntropyLoss()(logits, labels) + loss_t2i = nn.CrossEntropyLoss()(logits.T, labels) + return (loss_i2t + loss_t2i) / 2 + + +def convert_to_parameter(params): + converted = [] + for i, param in enumerate(params): + if not isinstance(param, Parameter): + name = getattr(param, 'name', f"param_{i}") if hasattr(param, 'name') else f"param_{i}" + converted.append(Parameter(param.data, name=name, requires_grad=True)) + else: + converted.append(param) + return converted + + +def finetune_model(coco, model, processor, + epochs=HYPERPARAMS["epochs"], + batch_size=HYPERPARAMS["batch_size"], + learning_rate=HYPERPARAMS["learning_rate"]): + train_dataloader = create_train_dataloader(coco, batch_size) + print(f"Train dataloader created with batch_size={batch_size}, samples={HYPERPARAMS['train_samples']}") + + params = model.trainable_params() + if not params: + print("No trainable params found, enabling all parameters.") + for param in model.parameters_and_names(): + param[1].requires_grad = True + params = model.trainable_params() + + params = convert_to_parameter(params) + print(f"Optimizer initialized with {len(params)} parameters") + net = TrainingNet(model) + optimizer = nn.Adam(params + [net.text_projection.weight, net.text_projection.bias, net.logit_scale], + learning_rate=learning_rate) + train_net = nn.TrainOneStepCell(net, optimizer) + + for epoch in range(epochs): + iterator = train_dataloader.create_dict_iterator() + total_train_loss = 0 + steps = 0 + for batch in tqdm(iterator, desc=f"Epoch {epoch + 1}/{epochs} (Train)"): + loss = train_net(batch["input_ids"], batch["attention_mask"], batch["pixel_values"]) + total_train_loss += loss.asnumpy() + steps += 1 + if steps == 1: + print(f"Epoch {epoch + 1}, Step 1 - Train Loss: {loss.asnumpy():.4f}") + logits = ops.matmul(net.image_embeds, net.text_embeds.T) * ops.exp(net.logit_scale) + print(f"Logits sample: {logits[:2, :2]}") + avg_train_loss = total_train_loss / steps + print(f"Epoch {epoch + 1}/{epochs}, Average Train Loss: {avg_train_loss:.4f}") + + param_after = net.text_projection.weight.asnumpy() + if epoch == 0: + param_before = param_after.copy() + print("Params updated:", not np.array_equal(param_before, param_after)) + + save_dir = HYPERPARAMS["save_dir"] + os.makedirs(save_dir, exist_ok=True) + ms.save_checkpoint(net, HYPERPARAMS["model_save_path"].format(epoch=epoch + 1)) + + processor.save_pretrained(HYPERPARAMS["processor_save_path"]) + return model + + +print("Starting model finetuning...") +finetuned_model = finetune_model(coco, model, processor) \ No newline at end of file diff --git a/llm/finetune/albert/Albert_mind.py b/llm/finetune/albert/Albert_mind.py new file mode 100644 index 000000000..68c41bc9a --- /dev/null +++ b/llm/finetune/albert/Albert_mind.py @@ -0,0 +1,130 @@ +import random +import mindspore as ms +from mindspore import nn, ops, Tensor +from mindspore.dataset import GeneratorDataset +from mindnlp.transformers import AlbertTokenizer, AlbertForSequenceClassification +from mindnlp.engine import Trainer, TrainingArguments +from datasets import load_dataset +import numpy as np +import os +import evaluate + +# 1. 加载预训练模型和分词器 +model_name = "albert-base-v1" +tokenizer = AlbertTokenizer.from_pretrained(model_name) +model = AlbertForSequenceClassification.from_pretrained( + model_name, num_labels=2) + +# 2. 加载IMDb数据集 +dataset = load_dataset("stanfordnlp/imdb", trust_remote_code=True) +print("dataset:", dataset) +# 3. 数据预处理函数 + + +def tokenize_function(examples): + tokenized = tokenizer( + examples["text"], + padding="max_length", + truncation=True, + max_length=512 + ) + # 添加标签到返回字典 + tokenized["labels"] = examples["label"] + return tokenized + + +# 应用预处理 +tokenized_datasets = dataset.map(tokenize_function, batched=True) + +# 检查标签分布(修正后的代码) +print("\n==== 数据分布验证 ====") + +# 检查训练集 +train_labels = np.array(tokenized_datasets["train"]["labels"]) +print("训练集标签统计:") +print("- 唯一值:", np.unique(train_labels)) +print("- 分布:", np.bincount(train_labels)) + +# 检查测试集 +test_labels = np.array(tokenized_datasets["test"]["labels"]) +print("\n测试集标签统计:") +print("- 唯一值:", np.unique(test_labels)) +print("- 分布:", np.bincount(test_labels)) +# 4. 转换数据集格式 + +def create_dataset(data, batch_size=8): + # 将数据转换为列表以便打乱 + data_list = list(data) + random.shuffle(data_list) # 打乱数据顺序 + + def generator(): + for item in data_list: # 遍历打乱后的数据 + yield item["input_ids"], item["attention_mask"], Tensor(item["labels"], dtype=ms.int32) + + return GeneratorDataset(generator(), ["input_ids", "attention_mask", "labels"]).batch(batch_size) + + +train_dataset = create_dataset(tokenized_datasets["train"]) +eval_dataset = create_dataset(tokenized_datasets["test"]) + +# 5. 加载评估指标 +accuracy = evaluate.load("accuracy") +f1 = evaluate.load("f1") +precision = evaluate.load("precision") +recall = evaluate.load("recall") + +sample = next(iter(train_dataset)) +print("Input IDs:", sample[0]) +print("Attention Mask:", sample[1]) +print("Labels:", sample[2]) + +# 自定义指标计算函数 +def compute_metrics(eval_pred): + logits, labels = eval_pred # 直接解包为logits和labels + predictions = np.argmax(logits, axis=-1) + + return { + "accuracy": accuracy.compute(predictions=predictions, references=labels)["accuracy"], + "f1": f1.compute(predictions=predictions, references=labels, average="binary")["f1"], + "precision": precision.compute(predictions=predictions, references=labels, average="binary")["precision"], + "recall": recall.compute(predictions=predictions, references=labels, average="binary")["recall"] + } + + +# 6. 配置训练参数 +training_args = TrainingArguments( + num_train_epochs=3, + per_device_train_batch_size=8, + per_device_eval_batch_size=8, + learning_rate=1e-5, + weight_decay=0.01, + output_dir="./results", + logging_dir="./logs", + logging_steps=10, + evaluation_strategy="epoch", + save_strategy="epoch", + load_best_model_at_end=True, + metric_for_best_model="accuracy", # 根据准确率选择最佳模型 + greater_is_better=True, # 准确率越高越好 +) + +# 7. 初始化并运行训练 +trainer = Trainer( + model=model, + args=training_args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + compute_metrics=compute_metrics, # 添加指标计算函数 +) + +trainer.train() + +# 8. 评估模型 +eval_results = trainer.evaluate(eval_dataset) +print(f"Evaluation results: {eval_results}") +print("\nFinal evaluation results:") +print(f"Accuracy: {eval_results['eval_accuracy']:.4f}") +print(f"F1 Score: {eval_results['eval_f1']:.4f}") +print(f"Precision: {eval_results['eval_precision']:.4f}") +print(f"Recall: {eval_results['eval_recall']:.4f}") + diff --git a/llm/finetune/albert/albert_StanfordIMDB_mindnlp.md b/llm/finetune/albert/albert_StanfordIMDB_mindnlp.md new file mode 100644 index 000000000..b7012f6ba --- /dev/null +++ b/llm/finetune/albert/albert_StanfordIMDB_mindnlp.md @@ -0,0 +1,58 @@ +# Albert mindnlp StanfordIMDB reviewer Finetune + +- Albert模型微调任务链接:[【开源实习】albert模型微调 · Issue #IAUONP · MindSpore/community - Gitee.com](https://gitee.com/mindspore/community/issues/IAUONP) +- 实现了Albert-base-v1 基准权重 在 [Sentiment analysis of IMDb reviews - Stanford University] 数据集上的微调 + +- base model: [albert/albert-base-v1 · Hugging Face](https://huggingface.co/albert/albert-base-v1) +- dataset: [stanfordnlp/imdb · Datasets at Hugging Face](https://huggingface.co/datasets/stanfordnlp/imdb) + +# Requirments +## Pytorch + +- GPU: RTX 4070ti 12G +- cuda: 11.8 +- Python version: 3.10 +- torch version: 2.5.0 +- transformers version : 4.47.0 + +## Mindspore 启智社区 Ascend910B算力资源 +- Ascend: 910B +- python: 3.11 +- mindspore: 2.5.0 +- mindnlp: 0.4.1 + +# Result for finetune + +training for 3 epochs + +## torch + +| Epoch | eval_loss | +| ------------------ | --------- | +| 1 | 0.3868 | +| 2 | 0.2978 | +| 3 | 0.3293 | +| Evaluation results | 0.2978 | + +**评估结果** + +| Accuracy | Precision | Recall | F1_score | +| -------- | --------- | ------ | -------- | +| 0.9212 | 0.9218 | 0.9284 | 0.9218 | + + + +## mindspore + +| Epoch | eval_loss | +| ------------------ | --------- | +| 1 | 0.2677 | +| 2 | 0.2314 | +| 3 | 0.2332 | +| Evaluation results | 0.2314 | + +**评估结果** + +| Accuracy | Precision | Recall | F1_score | +| -------- | --------- | ------ | -------- | +| 0.9219 | 0.9238 | 0.9218 | 0.9228 | diff --git a/llm/finetune/bit/README.md b/llm/finetune/bit/README.md new file mode 100644 index 000000000..9638e7b34 --- /dev/null +++ b/llm/finetune/bit/README.md @@ -0,0 +1,61 @@ +# bit微调 + +实现了"HorcruxNo13/bit-50"模型在"dpdl-benchmark/oxford_flowers102"数据集上的微调实验。 +任务链接在https://gitee.com/mindspore/community/issues/IAUPCI +transformers+pytorch+3090的benchmark是自己编写的,仓库位于https://github.com/outbreak-sen/Bit_flowers102_Finetune +更改代码位于llm/finetune/bit,只包含mindnlp+mindspore的 +实验结果如下 + +## 硬件 + +资源规格:NPU: 1*Ascend-D910B(显存: 64GB), CPU: 24, 内存: 192GB + +智算中心:武汉智算中心 + +镜像:mindspore_2_5_py311_cann8 + +torch训练硬件资源规格:Nvidia 3090 + +## 模型与数据集 + +模型:"HorcruxNo13/bit-50" + +数据集:"dpdl-benchmark/oxford_flowers102" + +## Eval Loss Values 表格 + +| Epoch | mindNLP | torch | +|-------|---------------|---------------| +| 1 | 3.5184175968 | 4.6460494995 | +| 2 | 1.7758612633 | 4.2146801949 | +| 3 | 0.9314232469 | 3.8055384159 | +| 4 | 0.6095938683 | 3.4315345287 | +| 5 | 0.4878421128 | 3.1143600941 | +| 6 | 0.4401741028 | 2.8422958851 | +| 7 | 0.4239776731 | 2.6192340851 | +| 8 | 0.4162144363 | 2.4506986141 | +| 9 | 0.4113974869 | 2.3450050354 | +| 10 | 0.4095760584 | 2.2997686863 | + +## Test Accuracy 表格 + +| Epoch | mindNLP | torch | +|-------|---------------|---------------| +| 1 | 0.9219 | 0.6225 | + +## 图片分类测试 + +问题来自评估数据集的第一个问题,微调后结果准确 + +* 问题输入: + dataset['test'][0]['image'] +* 真实标签: + 26 +* mindnlp未微调前的回答: + 25 +* mindnlp微调后的回答: + 26 +* torch微调前的回答: + 41 +* torch微调后的回答: + 26 \ No newline at end of file diff --git a/llm/finetune/bit/mindNLP_Bit_flowers.py b/llm/finetune/bit/mindNLP_Bit_flowers.py new file mode 100644 index 000000000..8e1d450ab --- /dev/null +++ b/llm/finetune/bit/mindNLP_Bit_flowers.py @@ -0,0 +1,142 @@ +import mindspore as ms +import mindspore.dataset as ds +from datasets import load_dataset +from mindnlp.transformers import ( + BitForImageClassification, + AutoImageProcessor +) +from mindnlp.engine import Trainer, TrainingArguments +import os +import numpy as np +ms.set_context(device_target="Ascend") +model_name = "HorcruxNo13/bit-50" +processor = AutoImageProcessor.from_pretrained(model_name) +model = BitForImageClassification.from_pretrained( + model_name, + num_labels=102, + ignore_mismatched_sizes=True +) +dataset = load_dataset("dpdl-benchmark/oxford_flowers102", split="train") +# 将训练集按8:2的比例拆分为训练集和测试集 +dataset = dataset.train_test_split(test_size=0.2, seed=42) +dataset.save_to_disk("./flowers102") + +print(dataset) +# 选择一个测试集样本进行测试 +test_image = dataset['test'][0]['image'] +test_label = dataset['test'][0]['label'] + +print("\n=== 训练参数 ===") +training_args = TrainingArguments( + output_dir="./mindNLP_bit_flowers102", + evaluation_strategy="epoch", + save_strategy="epoch", + learning_rate=5e-5, + per_device_train_batch_size=64, + per_device_eval_batch_size=128, + num_train_epochs=10, + gradient_accumulation_steps=1, + logging_steps=50, + load_best_model_at_end=True, + warmup_steps=0, + weight_decay=0.01, + remove_unused_columns=False, + max_grad_norm=0.0 # 禁用梯度裁剪 +) +print("\n=== 先生成np数据 ===") +train_data = [] +train_labels = [] +for item in dataset['train']: + img = item['image'].convert('RGB') + inputs = processor(images=img, return_tensors="np", size={"height": 384, "width": 384}) + train_data.append(inputs['pixel_values'][0]) + train_labels.append(item['label']) +test_data = [] +test_labels = [] +for item in dataset['test']: + img = item['image'].convert('RGB') + inputs = processor(images=img, return_tensors="np", size={"height": 384, "width": 384}) + test_data.append(inputs['pixel_values'][0]) + test_labels.append(item['label']) +train_data = np.array(train_data, dtype=np.float32) +train_labels = np.array(train_labels, dtype=np.int32) +test_data = np.array(test_data, dtype=np.float32) +test_labels = np.array(test_labels, dtype=np.int32) +print("\n=== 将预处理后的数据集转换为MindSpore格式 ===") +def create_mindspore_dataset(data, labels, batch_size, shuffle=True): + dataset = ds.NumpySlicesDataset( + { + "pixel_values": data, + "labels": labels + }, + shuffle=shuffle + ) + dataset = dataset.batch(batch_size, drop_remainder=True) + return dataset + +# 创建训练和评估数据集 +train_dataset = create_mindspore_dataset( + train_data, + train_labels, + batch_size=training_args.per_device_train_batch_size, + shuffle=True +) + +eval_dataset = create_mindspore_dataset( + test_data, + test_labels, + batch_size=training_args.per_device_eval_batch_size, + shuffle=False +) + +# 单图测试函数 +def test_single_image(model, processor, image): + inputs = processor( + images=image.convert('RGB'), + return_tensors="ms", + size={"height": 384, "width": 384} + ) + model.set_train(False) + outputs = model(**inputs) + predictions = outputs.logits.argmax(-1) + return predictions.asnumpy().item() + +print("\n=== 训练前测试 ===") +pred_before = test_single_image(model, processor, test_image) +print(f"真实标签: {test_label}") +print(f"预测标签: {pred_before}") + +import evaluate +import numpy as np +from mindnlp.engine.utils import EvalPrediction + +metric = evaluate.load("accuracy") +# 添加调试信息 +def compute_metrics(eval_pred: EvalPrediction): + logits, labels = eval_pred + predictions = np.argmax(logits, axis=-1) + result = metric.compute(predictions=predictions, references=labels) + return result +print("\n=== 创建Trainer实例 ===") +trainer = Trainer( + model=model, + args=training_args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + compute_metrics=compute_metrics, +) +# trainer = Trainer( +# model=model, +# args=training_args, +# train_dataset=train_dataset, +# eval_dataset=eval_dataset +# ) +print("\n=== 训练 ===") +trainer.train() +test_results = trainer.evaluate() +print(f"Test Accuracy: {test_results['eval_accuracy']:.4f}") + +print("\n=== 训练后测试 ===") +pred_after = test_single_image(model, processor, test_image) +print(f"真实标签: {test_label}") +print(f"预测标签: {pred_after}") diff --git a/llm/finetune/bit/mindnlplog.txt b/llm/finetune/bit/mindnlplog.txt new file mode 100644 index 000000000..a3e9fa3b6 --- /dev/null +++ b/llm/finetune/bit/mindnlplog.txt @@ -0,0 +1,69 @@ +(MindSpore) [ma-user work]$python mindNLP_Bit_flowers.py +Building prefix dict from the default dictionary ... +Loading model from cache /tmp/jieba.cache +Loading model cost 1.241 seconds. +Prefix dict has been built successfully. +Some weights of BitForImageClassification were not initialized from the model checkpoint at HorcruxNo13/bit-50 and are newly initialized because the shapes did not match: +- classifier.1.weight: found shape (1000, 2048) in the checkpoint and (102, 2048) in the model instantiated +- classifier.1.bias: found shape (1000,) in the checkpoint and (102,) in the model instantiated +You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference. +test-00000-of-00006.parquet: 100%|█████████████████████████████████████████████████████████████████| 420M/420M [02:14<00:00, 3.12MB/s] +test-00001-of-00006.parquet: 100%|█████████████████████████████████████████████████████████████████| 416M/416M [02:11<00:00, 3.17MB/s] +test-00002-of-00006.parquet: 0%| | 0.00/429M [00:00= 2 + first_axis_dim, other_shape = input.shape[0], input.shape[1:] + input_flat = input.reshape(first_axis_dim, -1) + indices_expanded = ops.expand_dims(indices, -1) + indices_expanded = ops.broadcast_to(indices_expanded, (-1, input_flat.shape[1])) + output_flat = ops.gather(input_flat, 0, indices_expanded) + output = output_flat.reshape(-1, *other_shape) + return output + + def bprop(self, input, indices, out, dout): + assert dout.ndim >= 2 + other_shape = dout.shape[1:] + grad_output = dout + + grad_flat = grad_output.reshape(grad_output.shape[0], -1) + grad_shape = (input.shape[0], grad_flat.shape[1]) + grad_input = ops.zeros(grad_shape, grad_flat.dtype) + + indices_expanded = ops.expand_dims(indices, -1) + indices_expanded = ops.broadcast_to(indices_expanded, (-1, grad_flat.shape[1])) + grad_input.scatter_(0, indices_expanded, grad_flat) + + return grad_input.reshape(input.shape[0], *other_shape), None + + +index_first_axis = IndexFirstAxis() + + +class IndexPutFirstAxis(nn.Cell): + def __init__(self): + super(IndexPutFirstAxis, self).__init__() + + def construct(self, values: mindspore.Tensor, indices: mindspore.Tensor, first_axis_dim: int): + assert indices.ndim == 1 + assert values.ndim >= 2 + output = ops.zeros( + (first_axis_dim, *values.shape[1:]), + values.dtype + ) + output[indices] = values + return output + + def bprop(self, values, indices, first_axis_dim, out, dout): + grad_values = dout[indices] + return grad_values, None, None + + +index_put_first_axis = IndexPutFirstAxis() + + +def pad_input( + hidden_states: mindspore.Tensor, + indices: mindspore.Tensor, + batch: int, + seqlen: int +): + """ + Arguments: + hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. + indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence. + batch: int, batch size for the padded sequence. + seqlen: int, maximum sequence length for the padded sequence. + Return: + hidden_states: (batch, seqlen, ...) + """ + output = index_put_first_axis(hidden_states, indices, batch * seqlen) + return output.reshape(batch, seqlen, *hidden_states.shape[1:]) + + +def unpad_input( + hidden_states: mindspore.Tensor, + attention_mask: mindspore.Tensor, + unused_mask: Optional[mindspore.Tensor] = None, +): + """ + Arguments: + hidden_states: (batch, seqlen, ...) + attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid. + unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused. + Return: + hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask. + indices: (total_nnz), the indices of masked tokens from the flattened input sequence. + cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states. + max_seqlen_in_batch: int + seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask. + """ + all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask + seqlens_in_batch = all_masks.sum(dim=-1, dtype=mindspore.int32) + used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=mindspore.int32) + indices = ops.nonzero(all_masks.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = ops.pad(ops.cumsum(seqlens_in_batch, dim=0, dtype=mindspore.int32), (1, 0)) + + hidden_states_flat = hidden_states.reshape(-1, *hidden_states.shape[2:]) + hidden_states = index_first_axis(hidden_states_flat, indices) + return ( + hidden_states, + indices, + cu_seqlens, + max_seqlen_in_batch, + used_seqlens_in_batch, + ) + + +def create_attn_mask(causal: bool, sparse_mode: int) -> Tuple[int, mindspore.Tensor]: + """ + Create a causal mask for the attention scores. + + Args: + causal (`bool`): + If `True`, the mask will be causal. + sparse_mode (`bool`): + If `True`, the mask will be top-left + aligned, otherwise it will be bottom-right aligned. + Returns: + `Tuple[bool, mindspore.Tensor]`: + A tuple containing sparse_mode and the mask tensor. + """ + if not causal: + sparse_mode = 0 + attn_mask = None + else: + if sparse_mode == TOP_LEFT_ALIGNED_CAUSAL_MASK_MODE: + attn_mask = ops.tril(ops.ones((2048, 2048)), diagonal=-1).bool() + else: + attn_mask = ops.triu(ops.ones((2048, 2048)), diagonal=1).bool() + return sparse_mode, attn_mask + + +def npu_flash_attn_func( + q: mindspore.Tensor, + k: mindspore.Tensor, + v: mindspore.Tensor, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + **kwargs, +): + head_num = q.shape[2] + sparse_mode, attn_mask = create_attn_mask(causal, SPARSE_MODE) + if softmax_scale is None: + softmax_scale = 1.0 / math.sqrt(q.shape[-1]) + output = flash_attention_score( + q, + k, + v, + head_num, + keep_prob=1.0 - dropout_p, + scalar_value=softmax_scale, + attn_mask=attn_mask, + input_layout="BSND", + sparse_mode=sparse_mode, + prefix=None, + ) + + return output + + +def npu_flash_attn_varlen_func( + q: mindspore.Tensor, + k: mindspore.Tensor, + v: mindspore.Tensor, + cu_seqlens_q: Optional[mindspore.Tensor] = None, + cu_seqlens_k: Optional[mindspore.Tensor] = None, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + **kwargs, +): + head_num = q.shape[1] + sparse_mode, attn_mask = create_attn_mask(causal, SPARSE_MODE) + if softmax_scale is None: + softmax_scale = 1.0 / math.sqrt(q.shape[-1]) + + output = flash_attention_score( + q, + k, + v, + head_num, + keep_prob=1.0 - dropout_p, + scalar_value=softmax_scale, + attn_mask=attn_mask, + input_layout="TND", + actual_seq_qlen=cu_seqlens_q[1:].asnumpy().tolist(), + actual_seq_kvlen=cu_seqlens_k[1:].asnumpy().tolist(), + sparse_mode=sparse_mode, + prefix=None, + ) + + return output diff --git a/mindnlp/transformers/modeling_flash_attention_utils.py b/mindnlp/transformers/modeling_flash_attention_utils.py new file mode 100644 index 000000000..6ed1c417f --- /dev/null +++ b/mindnlp/transformers/modeling_flash_attention_utils.py @@ -0,0 +1,372 @@ +# Copyright 2024 the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This module provides utilities for flash attention in Transformers models.""" + + +import os +import inspect +from typing import Optional, Tuple +import mindspore +from mindnlp.core import ops +from ..utils import logging +from .integrations.npu_flash_attention import index_first_axis, pad_input, unpad_input +from .integrations.npu_flash_attention import npu_flash_attn_func as flash_attn_func +from .integrations.npu_flash_attention import npu_flash_attn_varlen_func as flash_attn_varlen_func + + +logger = logging.get_logger(__name__) + + +if flash_attn_func is not None: + _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters) + + +def flash_attn_supports_top_left_mask(): + # down-right mask is used on Ascend NPU by default, set env `NPU_FA2_SPARSE_MODE=2` to activate top-left mask. + from .integrations.npu_flash_attention import is_npu_fa2_top_left_aligned_causal_mask + return is_npu_fa2_top_left_aligned_causal_mask() + + +def _get_unpad_data(attention_mask: mindspore.Tensor) -> Tuple[mindspore.Tensor, mindspore.Tensor, int]: + """ + Retrieves indexing data required to repad unpadded (ragged) tensors. + + Arguments: + attention_mask (`mindspore.Tensor`): + Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. + + Return: + indices (`mindspore.Tensor`): + The indices of non-masked tokens from the flattened input sequence. + cu_seqlens (`mindspore.Tensor`): + The cumulative sequence lengths, used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). + max_seqlen_in_batch (`int`): + Maximum sequence length in batch. + """ + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=mindspore.int32) + indices = ops.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = ops.pad(ops.cumsum(seqlens_in_batch, dim=0, dtype=mindspore.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +def _upad_input( + query_layer: mindspore.Tensor, + key_layer: mindspore.Tensor, + value_layer: mindspore.Tensor, + attention_mask: mindspore.Tensor, + query_length: int, +): + """ + Unpads query, key, and values tensors, using a single dimension for all tokens even though they belong to different batches. + + This function is used instead of `unpad_input` in order to avoid the recomputation of the same intermediary + tensors for query, key, value tensors. + + Arguments: + query_layer (`mindspore.Tensor`): + Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim). + key_layer (`mindspore.Tensor`): + Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). + value_layer (`mindspore.Tensor`): + Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). + attention_mask (`mindspore.Tensor`): + Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. + query_length (`int`): + Target length. + + Return: + query_layer (`mindspore.Tensor`): + Query state without padding. Shape: (total_target_length, num_heads, head_dim). + key_layer (`mindspore.Tensor`): + Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). + value_layer (`mindspore.Tensor`): + Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). + indices_q (`mindspore.Tensor`): + The indices of non-masked tokens from the flattened input target sequence. + (cu_seqlens_q, cu_seqlens_k) (`Tuple[int]`): + The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). + (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`Tuple[int]`): + Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value). + """ + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + + if query_length == kv_seq_len: + query_layer = index_first_axis(query_layer.reshape(batch_size * kv_seq_len, -1, head_dim), indices_k) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = ops.arange(batch_size + 1, dtype=mindspore.int32) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q, *_ = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +def prepare_fa2_from_position_ids( + query: mindspore.Tensor, + key: mindspore.Tensor, + value: mindspore.Tensor, + position_ids: mindspore.Tensor, +): + """ + This function returns necessary arguments to call `flash_attn_varlen_func`. + All three query, key, value states will be flattened. + Cumulative lengths of each examples in the batch will be extracted from position_ids. + + NOTE: ideally cumulative lengths should be prepared at the data collator stage + + Arguments: + query (`mindspore.Tensor`): + Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim). + key (`mindspore.Tensor`): + Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). + value (`mindspore.Tensor`): + Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). + position_ids (`mindspore.Tensor`): + Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. + + Return: + query (`mindspore.Tensor`): + Query state without padding. Shape: (total_target_length, num_heads, head_dim). + key (`mindspore.Tensor`): + Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). + value (`mindspore.Tensor`): + Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). + indices_q (`mindspore.Tensor`): + The indices of non-masked tokens from the flattened input target sequence. + (cu_seqlens_q, cu_seqlens_k) (`Tuple[int]`): + The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). + (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`Tuple[int]`): + Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value). + """ + query = query.view(-1, query.shape[-2], query.shape[-1]) + key = key.contiguous().view(-1, key.shape[-2], key.shape[-1]) + value = value.contiguous().view(-1, value.shape[-2], value.shape[-1]) + position_ids = position_ids.flatten() + indices_q = ops.arange(position_ids.shape[0], dtype=mindspore.int32) + + cu_seq_lens = ops.cat( + (indices_q[position_ids == 0], + mindspore.tensor(position_ids.shape, dtype=mindspore.int32) + ) + ) + max_length = position_ids.max() + 1 + + return (query, key, value, indices_q, (cu_seq_lens, cu_seq_lens), (max_length, max_length)) + + +def fa_peft_integration_check( + query: mindspore.Tensor, + key: mindspore.Tensor, + value: mindspore.Tensor, + target_dtype: Optional[mindspore.dtype.TensorType] = None, +): + """ + PEFT usually casts the layer norms in float32 for training stability reasons + therefore the input hidden states gets silently casted in float32. Hence, we need + cast them back in float16 / bfloat16 just to be sure everything works as expected. + This might slowdown training & inference so it is recommended to not cast the LayerNorms! + + Args: + query (`mindspore.Tensor`): + Input query states to be passed to Flash Attention API + key (`mindspore.Tensor`): + Input key states to be passed to Flash Attention API + value (`mindspore.Tensor`): + Input value states to be passed to Flash Attention API + target_dtype (`mindspore.dtype`, *optional*): + The dtype to convert the attention tensors to. Conversion can be ignored by + not providing the target dtype. + """ + if target_dtype is None: + return query, key, value + + input_dtype = query.dtype + if input_dtype == mindspore.float32: + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query = query.to(target_dtype) + key = key.to(target_dtype) + value = value.to(target_dtype) + + return query, key, value + + +flash_241 = False +deterministic_g = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1" + + +def _flash_attention_forward( + query_states: mindspore.Tensor, + key_states: mindspore.Tensor, + value_states: mindspore.Tensor, + attention_mask: mindspore.Tensor, + query_length: int, + is_causal: bool, + dropout: float = 0.0, + position_ids: Optional[mindspore.Tensor] = None, + softmax_scale: Optional[float] = None, + sliding_window: Optional[int] = None, + use_top_left_mask: bool = False, + softcap: Optional[float] = None, + deterministic: Optional[bool] = None, + cu_seq_lens_q: Optional[mindspore.Tensor] = None, + cu_seq_lens_k: Optional[mindspore.Tensor] = None, + max_length_q: Optional[int] = None, + max_length_k: Optional[int] = None, + target_dtype: Optional[mindspore.dtype.TensorType] = None, + **kwargs, +): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`mindspore.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`mindspore.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`mindspore.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`mindspore.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`float`): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + use_top_left_mask (`bool`, defaults to `False`): + flash_attn<2.1 generates top-left aligned causal mask, + while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. + This attribute is used to handle this difference. + """ + if not use_top_left_mask: + causal = is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. + causal = is_causal and query_length != 1 + + # Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length). + use_sliding_windows = ( + _flash_supports_window_size and sliding_window is not None and key_states.shape[1] > sliding_window + ) + flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {} + + if flash_241: + if deterministic is None: + deterministic = deterministic_g + flash_kwargs["deterministic"] = deterministic + + if softcap is not None: + flash_kwargs["softcap"] = softcap + + # PEFT possibly silently casts tensors to fp32, this potentially reconverts to correct dtype or is a no op + query_states, key_states, value_states = fa_peft_integration_check( + query_states, key_states, value_states, target_dtype + ) + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = _upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + **flash_kwargs, + ) + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + + # If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing + # then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage. + # Use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach + elif position_ids is not None and ( + max_length_q is not None or (query_length != 1 and not (position_ids.diff(dim=-1) >= 0).all()) + ): + batch_size = query_states.size(0) + if cu_seq_lens_q is None or cu_seq_lens_k is None: + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = ( + prepare_fa2_from_position_ids(query_states, key_states, value_states, position_ids) + ) + + cu_seq_lens_q, cu_seq_lens_k = cu_seq_lens + max_length_q, max_length_k = max_seq_lens + + else: + query_states = query_states.reshape(-1, query_states.size(-2), query_states.size(-1)) + key_states = key_states.reshape(-1, key_states.size(-2), key_states.size(-1)) + value_states = value_states.reshape(-1, value_states.size(-2), value_states.size(-1)) + + attn_output = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seq_lens_q, + cu_seqlens_k=cu_seq_lens_k, + max_seqlen_q=max_length_q, + max_seqlen_k=max_length_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + **flash_kwargs, + ) + + attn_output = attn_output.view(batch_size, -1, attn_output.size(-2), attn_output.size(-1)) + + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal, **flash_kwargs + ) + + return attn_output diff --git a/mindnlp/transformers/modeling_utils.py b/mindnlp/transformers/modeling_utils.py index 414404aaf..7138cf0af 100644 --- a/mindnlp/transformers/modeling_utils.py +++ b/mindnlp/transformers/modeling_utils.py @@ -1349,7 +1349,20 @@ def _autoset_attn_implementation( # If a config is passed with a preset attn_implementation, we skip the automatic dispatch and use the user-provided config, with hard checks that the requested attention implementation is available. requested_attn_implementation = config._attn_implementation_internal - config._attn_implementation = "eager" + if use_flash_attention_2: + logger.warning_once( + 'The model was loaded with use_flash_attention_2=True, which is deprecated and may be removed in a future release. Please use `attn_implementation="flash_attention_2"` instead.' + ) + config._attn_implementation = "flash_attention_2" + + if config._attn_implementation == "flash_attention_2": + config._attn_implementation = "flash_attention_2" + elif isinstance(requested_attn_implementation, dict): + config._attn_implementation = None + else: + config._attn_implementation = "eager" + + config._attn_implementation_autoset = True return config diff --git a/mindnlp/transformers/models/__init__.py b/mindnlp/transformers/models/__init__.py index 722aa0f7d..ff8a93c76 100644 --- a/mindnlp/transformers/models/__init__.py +++ b/mindnlp/transformers/models/__init__.py @@ -135,6 +135,7 @@ luke, lxmert, mamba, + mamba2, marian, markuplm, m2m_100, @@ -381,6 +382,7 @@ from .lxmert import * from .m2m_100 import * from .mamba import * +from .mamba2 import * from .marian import * from .markuplm import * from .maskformer import * @@ -626,6 +628,7 @@ __all__.extend(lxmert.__all__) __all__.extend(m2m_100.__all__) __all__.extend(mamba.__all__) +__all__.extend(mamba2.__all__) __all__.extend(marian.__all__) __all__.extend(markuplm.__all__) __all__.extend(maskformer.__all__) diff --git a/mindnlp/transformers/models/auto/configuration_auto.py b/mindnlp/transformers/models/auto/configuration_auto.py index 73d5851f2..96ae6008e 100644 --- a/mindnlp/transformers/models/auto/configuration_auto.py +++ b/mindnlp/transformers/models/auto/configuration_auto.py @@ -135,6 +135,7 @@ ("lxmert", "LxmertConfig"), ("m2m_100", "M2M100Config"), ("mamba", "MambaConfig"), + ("mamba2", "Mamba2Config"), ("marian", "MarianConfig"), ('markuplm', "MarkupLMConfig"), ("mask2former", "Mask2FormerConfig"), @@ -353,6 +354,7 @@ ("lxmert", "LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("m2m_100", "M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("mamba", "MAMBA_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("mamba2", "MAMBA2_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("marian", "MARIAN_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("markuplm", "MARKUPLM_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("mask2former", "MASK2FORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"), @@ -608,6 +610,7 @@ ("lxmert", "LXMERT"), ("m2m_100", "M2M100"), ("mamba", "Mamba"), + ("mamba2", "Mamba2"), ("marian", "Marian"), ("markuplm", "MarkupLM"), ("mask2former", "Mask2Former"), diff --git a/mindnlp/transformers/models/auto/modeling_auto.py b/mindnlp/transformers/models/auto/modeling_auto.py index 026ea2a43..3a8d5cb33 100644 --- a/mindnlp/transformers/models/auto/modeling_auto.py +++ b/mindnlp/transformers/models/auto/modeling_auto.py @@ -151,6 +151,7 @@ ("lxmert", "LxmertModel"), ("m2m_100", "M2M100Model"), ("mamba", "MambaModel"), + ("mamba2", "Mamba2Model"), ("marian", "MarianModel"), ("markuplm", "MarkupLMModel"), ("mask2former", "Mask2FormerModel"), @@ -318,6 +319,7 @@ ("luke", "LukeForMaskedLM"), ("lxmert", "LxmertForPreTraining"), ("mamba", "MambaForCausalLM"), + ("mamba2", "Mamba2ForCausalLM"), ("mega", "MegaForMaskedLM"), ("megatron-bert", "MegatronBertForPreTraining"), ('minicpm', 'MiniCPMForCausalLM'), @@ -405,6 +407,7 @@ ("luke", "LukeForMaskedLM"), ("m2m_100", "M2M100ForConditionalGeneration"), ("mamba", "MambaForCausalLM"), + ("mamba2", "Mamba2ForCausalLM"), ("marian", "MarianMTModel"), ("mega", "MegaForMaskedLM"), ("megatron-bert", "MegatronBertForCausalLM"), @@ -491,6 +494,7 @@ ("jetmoe", "JetMoeForCausalLM"), ("llama", "LlamaForCausalLM"), ("mamba", "MambaForCausalLM"), + ("mamba2", "Mamba2ForCausalLM"), ("marian", "MarianForCausalLM"), ("mbart", "MBartForCausalLM"), ("mega", "MegaForCausalLM"), diff --git a/mindnlp/transformers/models/auto/tokenization_auto.py b/mindnlp/transformers/models/auto/tokenization_auto.py index 1ad0adbe4..054141fad 100644 --- a/mindnlp/transformers/models/auto/tokenization_auto.py +++ b/mindnlp/transformers/models/auto/tokenization_auto.py @@ -269,6 +269,7 @@ ("lxmert", ("LxmertTokenizer", "LxmertTokenizerFast" if is_tokenizers_available() else None)), ("m2m_100", ("M2M100Tokenizer" if is_sentencepiece_available() else None, None)), ("mamba", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)), + ("mamba2", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)), ("marian", ("MarianTokenizer" if is_sentencepiece_available() else None, None)), ( "mbart", diff --git a/mindnlp/transformers/models/mamba2/__init__.py b/mindnlp/transformers/models/mamba2/__init__.py new file mode 100644 index 000000000..74e92a14b --- /dev/null +++ b/mindnlp/transformers/models/mamba2/__init__.py @@ -0,0 +1,23 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Mamba2 Model. +""" +from . import modeling_mamba2, configuration_mamba2 +from .modeling_mamba2 import * +from .configuration_mamba2 import * + +__all__ = [] +__all__.extend(modeling_mamba2.__all__) +__all__.extend(configuration_mamba2.__all__) diff --git a/mindnlp/transformers/models/mamba2/configuration_mamba2.py b/mindnlp/transformers/models/mamba2/configuration_mamba2.py new file mode 100644 index 000000000..c884be60e --- /dev/null +++ b/mindnlp/transformers/models/mamba2/configuration_mamba2.py @@ -0,0 +1,181 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""MAMBA2 configuration""" + +import math + +from mindnlp.utils import logging +from ...configuration_utils import PretrainedConfig + +logger = logging.get_logger(__name__) + +class Mamba2Config(PretrainedConfig): + """ + This is the configuration class to store the configuration of a [`Mamba2Model`]. It is used to instantiate a MAMBA2 + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the MAMBA2 + [state-spaces/mamba2-2.8b](https://huggingface.co/state-spaces/mamba2-2.8b) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + num_heads (`int`, *optional*, defaults to 128): + Number of heads for the evolution matrices of mamba 2. + head_dim (`int`, *optional*, defaults to 64): + Dimension of each head. + vocab_size (`int`, *optional*, defaults to 32768): + Vocabulary size of the MAMBA2 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Mamba2Model`]. + hidden_size (`int`, *optional*, defaults to 4096): + Dimensionality of the embeddings and hidden states. + state_size (`int`, *optional*, defaults to 128): shape of the state space latents. + num_hidden_layers (`int`, *optional*, defaults to 64): + Number of hidden layers in the model. + layer_norm_epsilon (`float`, *optional*, defaults to 1e-05): + The epsilon to use in the layer normalization layers. + pad_token_id (`int`, *optional*, defaults to 1): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 0): + The id of the beginning of sentence token in the vocabulary. + eos_token_id (`int`, *optional*, defaults to 2): + The id of the end of sentence token in the vocabulary. + expand (`int`, *optional*, defaults to 2): Expanding factor used to determine the intermediate size. + conv_kernel (`int`, *optional*, defaults to 4): Size of the convolution kernel. + n_groups (`int`, *optional*, defaults to 8): + Number of groups for the evolution matrices of mamba 2. + use_bias (`bool`, *optional*, defaults to `False`): + Whether or not to use bias in ["in_proj", "out_proj"] of the mixer block + use_conv_bias (`bool`, *optional*, defaults to `True`): + Whether or not to use bias in the convolution layer of the mixer block. + hidden_act (`str`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + initializer_range (`float`, *optional*, defaults to 0.1): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + residual_in_fp32 (`bool`, *optional*, defaults to `True`): + Whether or not residuals should be in `float32`. If set to `False` residuals will keep the same `dtype` as the rest of the model + time_step_rank (`Union[int,str]`, *optional*, defaults to `"auto"`): + Rank of the discretization projection matrix. `"auto"` means that it will default to `math.ceil(self.hidden_size / 16)` + time_step_min (`float`, *optional*, defaults to 0.001): + Minimum `time_step` used to bound `dt_proj.bias`. + time_step_max (`float`, *optional*, defaults to 0.1): + Maximum `time_step` used to bound `dt_proj.bias`. + time_step_floor (`float`, *optional*, defaults to 0.0001): + Minimum clamping value of the `dt_proj.bias` layer initialization. + time_step_limit (`tuple`, *optional*, defaults to `(0.0, inf)`): + Accepted range of time step values. + rescale_prenorm_residual (`bool`, *optional*, defaults to `False`): + Whether or not to rescale `out_proj` weights when initializing. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the cache should be used. + rms_norm (`bool`, *optional*, defaults to `True`): + Whether to use RMS norm or not. + chunk_size (`int`, *optional*, defaults to 256): + Size of the chunks that will comprise the sequence. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie word embeddings or not. + + + Example: + + ```python + >>> from transformers import Mamba2Config, Mamba2Model + + >>> # Initializing a Mamba2 configuration + >>> configuration = Mamba2Config() + + >>> # Initializing a model (with random weights) from the configuration + >>> model = Mamba2Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "mamba2" + + def __init__( + self, + num_heads=128, + head_dim=64, + vocab_size=32768, + hidden_size=4096, + state_size=128, + num_hidden_layers=64, + layer_norm_epsilon=1e-5, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + expand=2, + conv_kernel=4, + n_groups=8, + use_bias=False, + use_conv_bias=True, + hidden_act="silu", + initializer_range=0.1, + residual_in_fp32=True, + time_step_rank="auto", + time_step_min=0.001, + time_step_max=0.1, + time_step_floor=1e-4, + time_step_limit=(0.0, float("inf")), + rescale_prenorm_residual=False, + use_cache=True, + rms_norm=True, + chunk_size=256, + tie_word_embeddings=False, + **kwargs, + ): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.state_size = state_size + self.num_hidden_layers = num_hidden_layers + self.layer_norm_epsilon = layer_norm_epsilon + self.conv_kernel = conv_kernel + self.expand = expand + + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + self.pad_token_id = pad_token_id + self.use_bias = use_bias + self.use_conv_bias = use_conv_bias + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.time_step_rank = math.ceil(self.hidden_size / 16) if time_step_rank == "auto" else time_step_rank + self.time_step_min = time_step_min + self.time_step_max = time_step_max + self.time_step_floor = time_step_floor + self.rescale_prenorm_residual = rescale_prenorm_residual + self.residual_in_fp32 = residual_in_fp32 + self.use_cache = use_cache + self.n_groups = n_groups + self.num_heads = num_heads + self.head_dim = head_dim + self.rms_norm = rms_norm + self.state_size = state_size + self.chunk_size = chunk_size + self.time_step_limit = time_step_limit + self.tie_word_embeddings = tie_word_embeddings + + super().__init__( + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + pad_token_id=pad_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +__all__ = ["Mamba2Config"] diff --git a/mindnlp/transformers/models/mamba2/modeling_mamba2.py b/mindnlp/transformers/models/mamba2/modeling_mamba2.py new file mode 100644 index 000000000..d6044b53c --- /dev/null +++ b/mindnlp/transformers/models/mamba2/modeling_mamba2.py @@ -0,0 +1,917 @@ +# coding=utf-8 +# Copyright 2024 state-spaces/mamba2 org and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""MindSpore MAMBA2 model.""" + +import math +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import mindspore +from mindnlp.core import nn, ops, no_grad +from mindnlp.core.nn import CrossEntropyLoss + +from ....common.activations import ACT2FN +from ...generation import GenerationMixin +from ...modeling_utils import PreTrainedModel + +from ....utils import ( + ModelOutput, + logging, +) + +from .configuration_mamba2 import Mamba2Config + + +logger = logging.get_logger(__name__) + + + +_CHECKPOINT_FOR_DOC = "mistralai/mamba-codestral-7B-v0.1" +_CONFIG_FOR_DOC = "Mamba2Config" + + +# Helper methods for segment sum computation + + +def pad_tensor_by_size(input_tensor: mindspore.Tensor, pad_size: int): + """ + Padding x tensor with `pad_size` on the seq_len axis (axis=1) + + Assumes that we only have tensors of either size 4 or 3 + """ + pad_shape = (0, 0, 0, 0, 0, pad_size, 0, 0) if len(input_tensor.shape) == 4 else (0, 0, 0, pad_size, 0, 0) + + return nn.functional.pad(input_tensor, pad_shape, mode="constant", value=0) + + +def reshape_into_chunks(input_tensor, pad_size, chunk_size): + """ + Padding input_tensor with `pad_size` on the seq_len axis (axis=1) and + simultaneously splitting it into chunk sequences. + + Assumes that we only have tensors of either size 4 or 3 + """ + # [bsz, seq_len, ...] -> [bsz, seq_len multiple of chunk_size, ...] + input_tensor = pad_tensor_by_size(input_tensor, pad_size) + + if len(input_tensor.shape) == 3: + # [bsz, seq_len multiple of chunk_size, num_heads] -> [bsz, -1, chunk_size, num_heads] + return input_tensor.reshape(input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2]) + else: + # [bsz, seq_len multiple of chunk_size, num_heads, head_dim or state_size] -> [bsz, -1, chunk_size, num_heads, head_dim or state_size] + return input_tensor.reshape( + input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2], input_tensor.shape[3] + ) + + +def segment_sum(input_tensor): + """ + More stable segment sum calculation. Uses cumulative sums and masking instead of direct subtractions. + """ + chunk_size = input_tensor.shape[-1] + # 1. expand input tensor to have an additional dimension and repeat along that dimension + # [..., chunk_size] -> [..., chunk_size, chunk_size] + input_tensor = input_tensor.unsqueeze(-1) + target_shape = tuple(input_tensor.shape[:-1] + (chunk_size,)) + input_tensor = input_tensor.broadcast_to(target_shape) + # 2. create a lower triangular mask with the diagonal set to 0 to 0 out elements above diag + mask = ops.tril(ops.ones(chunk_size, chunk_size, dtype=mindspore.bool_), diagonal=-1) + input_tensor = input_tensor.masked_fill(~mask, mindspore.Tensor(0, dtype=input_tensor.dtype)) + # 3. compute actual cumsum + tensor_segsum = ops.cumsum(input_tensor, dim=-2) + + # 4. apply mask to keep only the lower triangular part of the cumulative sum result (incl diagonal this time) + mask = ops.tril(ops.ones(chunk_size, chunk_size, dtype=mindspore.bool_), diagonal=0) + tensor_segsum = tensor_segsum.masked_fill(~mask, mindspore.Tensor(float('-inf'), dtype=tensor_segsum.dtype)) + return tensor_segsum + + +def apply_mask_to_padding_states(hidden_states, attention_mask): + """ + Tunes out the hidden states for padding tokens, see https://github.com/state-spaces/mamba/issues/66 + """ + if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: + dtype = hidden_states.dtype + hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) + + return hidden_states + +# Simple roll function for CPU and NPU +if mindspore.context.get_context("device_target") == "GPU": + from mindspore.ops import roll +else: + def roll(x: mindspore.Tensor, shifts, dims=None): + """ + + Args: + x (mindspore.Tensor): Input tensor + shifts (Union[list(int), tuple(int), int]): Specifies the number of places by which elements are shifted positively (towards larger indices) along the specified dimension. Negative shifts will roll the elements in the opposite direction. + dims (Union[list(int), tuple(int), int], optional): Specifies the dimension indexes of shape to be rolled. Default: None. If dims is None, the Tensor will be flattened before rolling and then restored to the original shape. + Returns: + Tensor, has the same shape and type as input. + """ + # If dims is None, first flatten the tensor + if dims is None: + x = x.reshape(-1) + dims = 0 + + # Convert shifts and dims to lists if they are not already + if isinstance(shifts, int): + shifts = [shifts] + if isinstance(dims, int): + dims = [dims] + + # Ensure shifts and dims have the same length + if len(shifts) != len(dims): + raise ValueError("shifts and dims must have the same length") + + # Move each dimension + for shift, dim in zip(shifts, dims): + # Handle negative shifts + if shift < 0: + shift = x.shape[dim] + shift + + # Normalize shift, ensuring it is within valid range + shift = shift % x.shape[dim] + + if shift == 0: + continue + + # Split at the specified dimension + indices = list(range(x.ndim)) + indices[0], indices[dim] = indices[dim], indices[0] + x = x.swapaxes(0, dim) # Move the target dimension to the first dimension + + shape = x.shape + x = x.reshape(shape[0], -1) # Flatten the other dimensions + + # Perform roll operation + x = ops.concat([x[shape[0]-shift:], x[:shape[0]-shift]], dim=0) + + # Restore original shape + x = x.reshape(shape) + x = x.swapaxes(0, dim) # Restore dimensions + + return x + +class Mamba2Cache: + """ + Arguments: + config: Mamba2Config + batch_size: int + dtype: mindspore.dtype + + Attributes: + dtype: (`mindspore.dtype`): + The default `dtype` used to initializing the cache. + conv_kernel_size: (`int`): + Model's convolution kernel size taken from config. + n_groups: (`int`): + Model's number of groups taken from the config - similar to tensor parallel in Transformer. + state_size: (`int`): + Model's SSM state size taken from config. + num_heads: (`int`): + The number of heads used in the linear attention / SSM. + head_dim: (`int`): + The respective dimension of the heads used in the linear attention / SSM. + intermediate_size: (`int`): + Model's intermediate_size based on (expand * hidden_dim) from config. + conv_states: (`mindspore.Tensor`): + A tensor of shape `[num_layers, batch_size, conv_kernel_size, intermediate_size + 2 * n_groups * state_size]` that holds convolutional states. + ssm_states: (`mindspore.Tensor`): + A tensor of shape `[num_layers, batch_size, num_heads, head_dim, state_size]` that holds ssm states. + """ + + def __init__( + self, config: Mamba2Config, batch_size: int, dtype: mindspore.dtype = mindspore.float16): + self.dtype = dtype + self.conv_kernel_size = config.conv_kernel + self.n_groups = config.n_groups + self.state_size = config.state_size + self.num_heads = config.num_heads + self.head_dim = config.head_dim + self.intermediate_size = int(config.expand * config.hidden_size) + + self.conv_states = ops.zeros( + (config.num_hidden_layers, + batch_size, + self.intermediate_size + 2 * self.n_groups * self.state_size, + self.conv_kernel_size), + dtype=dtype, + ) + self.ssm_states = ops.zeros( + (config.num_hidden_layers, + batch_size, + self.num_heads, + self.head_dim, + self.state_size), + dtype=dtype, + ) + + def update_conv_state( + self, layer_idx: int, new_conv_state: mindspore.Tensor, cache_init: bool = False + ) -> mindspore.Tensor: + if cache_init: + self.conv_states[layer_idx] = new_conv_state + else: + self.conv_states[layer_idx] = roll(self.conv_states[layer_idx], shifts=-1, dims=-1) + self.conv_states[layer_idx][:, :, -1] = new_conv_state[:, 0, :] + return self.conv_states[layer_idx] + + def update_ssm_state(self, layer_idx: int, new_ssm_state: mindspore.Tensor): + self.ssm_states[layer_idx] = new_ssm_state + return self.ssm_states[layer_idx] + + def reset(self): + self.conv_states.zero_() + self.ssm_states.zero_() + + +class MambaRMSNormGated(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + super().__init__() + self.weight = nn.Parameter(ops.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states, gate=None): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(dtype=mindspore.float32) + + if gate is not None: + hidden_states = hidden_states * nn.functional.silu(gate.to(dtype=mindspore.float32)) + variance = hidden_states.pow(2).mean(-1, keep_dims=True) + hidden_states = hidden_states * ops.rsqrt(variance + self.variance_epsilon) + + return self.weight * hidden_states.to(input_dtype) + + +class Mamba2Mixer(nn.Module): + """ + Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`. + A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective) + ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4, + and is why Mamba is called **selective** state spaces) + """ + + def __init__(self, config: Mamba2Config, layer_idx: int): + super().__init__() + self.num_heads = config.num_heads + self.hidden_size = config.hidden_size + self.ssm_state_size = config.state_size + self.conv_kernel_size = config.conv_kernel + self.intermediate_size = int(config.expand * self.hidden_size) + self.time_step_rank = int(config.time_step_rank) + self.layer_idx = layer_idx + self.use_conv_bias = config.use_conv_bias + self.activation = config.hidden_act + self.act = ACT2FN[config.hidden_act] + + self.layer_norm_epsilon = config.layer_norm_epsilon + self.rms_norm = config.rms_norm + + self.n_groups = config.n_groups + self.head_dim = config.head_dim + self.chunk_size = config.chunk_size + + self.time_step_limit = config.time_step_limit + self.time_step_min = config.time_step_min + self.time_step_max = config.time_step_max + + self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.ssm_state_size + self.conv1d = nn.Conv1d( + in_channels=self.conv_dim, + out_channels=self.conv_dim, + bias=config.use_conv_bias, + kernel_size=config.conv_kernel, + groups=self.conv_dim, + padding=config.conv_kernel - 1, + ) + + # projection of the input hidden states + projection_size = self.intermediate_size + self.conv_dim + self.num_heads + self.in_proj = nn.Linear( + self.hidden_size, + projection_size, + bias=config.use_bias, + ) + # selective projection used to make dt, B and C input dependant + + # time step projection (discretization) + # instantiate once and copy inv_dt in init_weights of PretrainedModel + self.dt_bias = nn.Parameter(ops.ones(self.num_heads)) + + # S4D real initialization. These are not discretized! + # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded + A = ops.arange(1, self.num_heads + 1).astype(mindspore.float32) + self.A_log = nn.Parameter(ops.log(A)) + self.A_log._no_weight_decay = True + self.norm = MambaRMSNormGated(self.intermediate_size, eps=self.layer_norm_epsilon) + self.D = nn.Parameter(ops.ones(self.num_heads)) + self.D._no_weight_decay = True + + # use_bias (`bool`, *optional*, defaults to `False`) + self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias) + self.use_bias = config.use_bias + + # fmt: off + def mindspore_forward(self, input_states, cache_params: Optional[Mamba2Cache]=None, cache_position:Optional[mindspore.Tensor]=None, attention_mask: Optional[mindspore.Tensor]=None): + batch_size, seq_len, _ = input_states.shape + dtype = input_states.dtype + + # 1. Gated MLP's linear projection + input_states = apply_mask_to_padding_states(input_states, attention_mask) + projected_states = self.in_proj(input_states) + d_mlp = (projected_states.shape[-1] - 2 * self.intermediate_size - 2 * self.n_groups * self.ssm_state_size-self.num_heads) // 2 + _, _, gate, hidden_states_B_C, dt = ops.split( + projected_states, [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 + ) + + # 2. Convolution sequence transformation + if cache_params is not None and cache_position is not None and cache_position[0] > 0: + cache_params.update_conv_state(layer_idx=self.layer_idx, new_conv_state=hidden_states_B_C, cache_init=False) + + conv_states = cache_params.conv_states[self.layer_idx] + + hidden_states_B_C = ops.sum( + conv_states * self.conv1d.weight.squeeze(1), dim=-1 + ) + if self.use_conv_bias: + hidden_states_B_C = hidden_states_B_C + self.conv1d.bias + hidden_states_B_C = self.act(hidden_states_B_C) + else: + # Init cache + if cache_params is not None: + hidden_states_B_C_transposed = hidden_states_B_C.swapaxes(1, 2) + conv_states = nn.functional.pad( + hidden_states_B_C_transposed, (cache_params.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0) + ) + cache_params.update_conv_state(layer_idx=self.layer_idx, new_conv_state=conv_states, cache_init=True) + + hidden_states_B_C = self.act(self.conv1d(hidden_states_B_C.swapaxes(1, 2))[..., :seq_len].swapaxes(1, 2)) + + hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask) + hidden_states, B, C = ops.split( + hidden_states_B_C, + [self.intermediate_size, self.n_groups * self.ssm_state_size, self.n_groups * self.ssm_state_size], + dim=-1 + ) + + # 3. SSM transformation + A = -ops.exp(self.A_log.float()) # [num_heads] + if cache_params is not None and cache_position is not None and cache_position[0] > 0: + # Delete 'device' in mindspore + cache_device = cache_params.ssm_states + + # Note: there is no need to pad parameter matrices here, as there is just one new token + # for batched generation + dt = dt[:, 0, :][:, None, ...] + dt = dt.swapaxes(1, 2).broadcast_to((batch_size, dt.shape[-1], self.head_dim)) + # [num_heads] -> [num_heads, head_dim] + dt_bias = self.dt_bias[..., None].broadcast_to((self.dt_bias.shape[0], self.head_dim)) + + dt = nn.functional.softplus(dt + dt_bias.to(dt.dtype)) + dt = ops.clamp(dt, self.time_step_limit[0], self.time_step_limit[1]) + A = A[..., None, None].broadcast_to((self.num_heads, self.head_dim, self.ssm_state_size)).to(dtype=mindspore.float32) + # [bsz, num_heads, head_dim, state_size] + dA = (ops.exp(dt[..., None] * A)) + + # Discretize B + # [bsz, n_groups * state_size] -> [bsz, n_groups, 1, state_size] -> + # -> [bsz, n_groups, group to head repetition factor, state_size] -> [bsz, num_heads, state_size] + B = B.reshape(batch_size, self.n_groups, -1)[..., None, :] + B = B.broadcast_to((batch_size, self.n_groups, self.num_heads // self.n_groups, B.shape[-1])).contiguous() + B = B.reshape(batch_size, -1, B.shape[-1]) + # [bsz, num_heads, head_dim, state_size] + dB = dt[..., None] * B[..., None, :] + + # Discretize x into dB + # [bsz, intermediate_size] -> [bsz, num_heads, head_dim] + hidden_states = hidden_states.reshape(batch_size, -1, self.head_dim) + dBx = (dB * hidden_states[..., None]) + + # State calculation + cache_params.update_ssm_state( + layer_idx=self.layer_idx, + new_ssm_state=cache_params.ssm_states[self.layer_idx] * dA + dBx + ) + + # Subsequent output + # [bsz, n_groups * state_size] -> [bsz, num_heads, state_size] + C = C.reshape(batch_size, self.n_groups, -1)[..., None, :] + C = C.broadcast_to((batch_size, self.n_groups, self.num_heads // self.n_groups, C.shape[-1])) + C = C.reshape(batch_size, -1, C.shape[-1]) + # [bsz, num_heads, head_dim] + + ssm_states = cache_params.ssm_states[self.layer_idx].to(dtype=C.dtype) # Shape: [b, h, d, n] + + # Reshape ssm_states to merge the first two dimensions + ssm_states_reshaped = ssm_states.view(batch_size * self.num_heads, self.head_dim, self.ssm_state_size) # Shape: [b*h, d, n] + C_reshaped = C.view(batch_size * self.num_heads, self.ssm_state_size, 1) # Shape: [b*h, n, 1] + y = ops.bmm(ssm_states_reshaped, C_reshaped) + y = y.view(batch_size, self.num_heads, self.head_dim) + + # D skip connection + # [num_heads] -> [num_heads, head_dim] + D = self.D[..., None].broadcast_to((self.D.shape[0], self.head_dim)) + y = (y + hidden_states * D).to(y.dtype) + + # [bsz, num_heads, head_dim] -> [bsz, 1, intermediate_size] + y = y.reshape(batch_size, -1)[:, None, ...] + else: + # begin ssd naive implementation without einsums + dt = nn.functional.softplus(dt + self.dt_bias) + dt = ops.clamp(dt, self.time_step_limit[0], self.time_step_limit[1]) + hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).float() + B = B.reshape(batch_size, seq_len, -1, self.ssm_state_size).float() + C = C.reshape(batch_size, seq_len, -1, self.ssm_state_size).float() + B = B.tile((1, 1, self.num_heads // self.n_groups, 1)) + C = C.tile((1, 1, self.num_heads // self.n_groups, 1)) + pad_size = (self.chunk_size - seq_len % self.chunk_size) % self.chunk_size + + D_residual = self.D[..., None] * pad_tensor_by_size(hidden_states, pad_size) + + # Discretize x and A + hidden_states = hidden_states * dt[..., None] + A = A.to(hidden_states.dtype) * dt + + # Rearrange into blocks/chunks + hidden_states, A, B, C = [reshape_into_chunks(t, pad_size, self.chunk_size) for t in (hidden_states, A, B, C)] + + # [bsz, -1, chunk_size, num_heads] -> [bsz, num_heads, -1, chunk_size] + A = A.permute(0, 3, 1, 2) + A_cumsum = ops.cumsum(A, dim=-1) + + # 1. Compute the output for each intra-chunk (diagonal blocks) + # This is the analog of a causal mask + L = ops.exp(segment_sum(A)) + + # Contraction of C and B to get G (attention-weights like) + G_intermediate = C[:, :, :, None, :, :] * B[:, :, None, :, :, :] # shape: (b, c, l, s, h, n) + G = G_intermediate.sum(axis=-1) # shape: (b, c, l, s, h) + + # Compute M, equivalent to applying attention mask to weights + M_intermediate = G[..., None] * L.permute(0, 2, 3, 4, 1)[..., None] + M = M_intermediate.sum(axis=-1) + + # Compute Y_diag (apply to values) + Y_diag = (M[..., None] * hidden_states[:, :, None]).sum(axis=3) + + # 2. Compute the state for each intra-chunk + # (right term of low-rank factorization of off-diagonal blocks; B terms) + decay_states = ops.exp((A_cumsum[:, :, :, -1:] - A_cumsum)) + B_decay = B * decay_states.permute(0, -2, -1, 1)[..., None] + states = (B_decay[..., None, :] * hidden_states[..., None]).sum(axis=2) + + # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries + # (middle term of factorization of off-diag blocks; A terms) + if cache_params is not None and cache_position is not None and cache_position[0] > 0: + previous_states = cache_params.ssm_states[self.layer_idx][:, None, ...] + else: + previous_states = ops.zeros_like(states[:, :1]) + states = ops.cat([previous_states, states], dim=1) + decay_chunk = ops.exp(segment_sum(nn.functional.pad(A_cumsum[:, :, :, -1], (1, 0)))) + decay_chunk = decay_chunk.swapaxes(1, 3) + new_states = (decay_chunk[..., None, None] * states[:, :, None, ...]).sum(axis=1) + states, ssm_state = new_states[:, :-1], new_states[:, -1] + + # 4. Compute state -> output conversion per chunk + # (left term of low-rank factorization of off-diagonal blocks; C terms) + state_decay_out = ops.exp(A_cumsum) + C_times_states = (C[..., None, :] * states[:, :, None, ...]) + state_decay_out_permuted = state_decay_out.permute(0, 2, 3, 1) + Y_off = (C_times_states.sum(axis=-1) * state_decay_out_permuted[..., None]) + + # Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks) + y = Y_diag + Y_off + # [bsz, -1, self.chunk_size, num_heads, head_dim] -> [bsz, (padded) seq_len, num_heads, head_dim] + y = y.reshape(batch_size, -1, self.num_heads, self.head_dim) + + y = y + D_residual + # Cutting off padded chunks + if pad_size > 0: + y = y[:, :seq_len, :, :] + y = y.reshape(batch_size, seq_len, -1) + + # Init cache + if ssm_state is not None and cache_params is not None: + cache_params.update_ssm_state(layer_idx=self.layer_idx, new_ssm_state=ssm_state) + + scan_output = self.norm(y, gate) + + # end ssd naive + + # 4. Final linear projection + contextualized_states = self.out_proj(scan_output.to(dtype)) # [batch, seq_len, hidden_size] + return contextualized_states + + # fmt: on + def forward( + self, + hidden_states, + cache_params: Optional[Mamba2Cache] = None, + cache_position: Optional[mindspore.Tensor] = None, + attention_mask: Optional[mindspore.Tensor] = None, + ): + dtype = hidden_states.dtype + if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: + # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 + hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) + + return self.mindspore_forward(hidden_states, cache_params, cache_position, attention_mask) + + +class Mamba2RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Mamba2RMSNorm is equivalent to T5LayerNorm and LlamaRMSNorm + """ + super().__init__() + self.weight = nn.Parameter(ops.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(mindspore.float32) + variance = ops.mean(hidden_states.pow(2), -1, keepdim=True) + hidden_states = hidden_states * ops.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +class Mamba2Block(nn.Module): + def __init__(self, config, layer_idx): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.residual_in_fp32 = config.residual_in_fp32 + self.norm = Mamba2RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) + self.mixer = Mamba2Mixer(config, layer_idx=layer_idx) + + def forward( + self, + hidden_states, + cache_params: Optional[Mamba2Cache] = None, + cache_position: Optional[mindspore.Tensor] = None, + attention_mask: Optional[mindspore.Tensor] = None, + ): + residual = hidden_states + hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype)) + if self.residual_in_fp32: + residual = residual.to(dtype=mindspore.float32) + + hidden_states = self.mixer( + hidden_states, cache_params=cache_params, cache_position=cache_position, attention_mask=attention_mask + ) + hidden_states = residual + hidden_states + return hidden_states + + +class Mamba2PreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = Mamba2Config + base_model_prefix = "backbone" + _no_split_modules = ["Mamba2Block"] + supports_gradient_checkpointing = True + _is_stateful = True + + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, Mamba2Mixer): + module.A_log._no_weight_decay = True + module.D._no_weight_decay = True + + dt = ops.exp( + ops.rand(self.config.num_heads) + * (math.log(self.config.time_step_max) - math.log(self.config.time_step_min)) + + math.log(self.config.time_step_min) + ).clamp(min=self.config.time_step_floor) + + # # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 + inv_dt = dt + ops.log(-ops.expm1(-dt)) + with no_grad(): + module.dt_bias.assign_value(inv_dt) + module.dt_bias._no_reinit = True + + if isinstance(module, nn.Linear): + if module.bias is not None: + if not getattr(module.bias, "_no_reinit", False): + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, std=self.config.initializer_range) + + if self.config.rescale_prenorm_residual: + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + for name, p in module.named_parameters(): + if name in ["out_proj.weight"]: + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) + # We need to reinit p since this code could be called multiple times + # Having just p *= scale would repeatedly scale it down + nn.init.kaiming_uniform_(p, a=math.sqrt(5)) + with no_grad(): + p /= math.sqrt(self.config.num_hidden_layers) + + +@dataclass +# Copied from transformers.models.mamba.modeling_mamba.MambaOutput with MAMBA->MAMBA2,Mamba->Mamba2 +class Mamba2Output(ModelOutput): + """ + Class for the MAMBA2 model outputs. + + Args: + last_hidden_state (`mindspore.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + cache_params (`Mamba2Cache`): + The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to + avoid providing the old `input_ids`. + + Includes both the State space model state matrices after the selective scan, and the Convolutional states + hidden_states (`tuple(mindspore.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `mindspore.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + """ + + last_hidden_state: Optional[mindspore.Tensor] = None + cache_params: Optional[Mamba2Cache] = None + hidden_states: Optional[Tuple[mindspore.Tensor]] = None + + +@dataclass +# Copied from transformers.models.mamba.modeling_mamba.MambaCausalLMOutput with Mamba->Mamba2 +class Mamba2CausalLMOutput(ModelOutput): + """ + Base class for causal language model (or autoregressive) outputs. + + Args: + loss (`mindspore.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`mindspore.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + cache_params (`Mamba2Cache`): + The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to + avoid providing the old `input_ids`. + + Includes both the State space model state matrices after the selective scan, and the Convolutional states + hidden_states (`tuple(mindspore.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `mindspore.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + """ + + loss: Optional[mindspore.Tensor] = None + logits: Optional[mindspore.Tensor] = None + cache_params: Optional[Mamba2Cache] = None + hidden_states: Optional[Tuple[mindspore.Tensor]] = None + +class Mamba2Model(Mamba2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size) + self.layers = nn.ModuleList([Mamba2Block(config, layer_idx=idx) for idx in range(config.num_hidden_layers)]) + + self.gradient_checkpointing = False + self.norm_f = Mamba2RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) + # Initialize weights and apply final processing + self._register_load_state_dict_pre_hook(self.load_hook) + self.post_init() + + def load_hook(self, state_dict, prefix, *args): + for k in state_dict: + if "embedding." in k: + state_dict[k.replace("embedding.", "embeddings.")] = state_dict.pop(k) + break + + def get_input_embeddings(self): + return self.embeddings + + def set_input_embeddings(self, new_embeddings): + self.embeddings = new_embeddings + + def forward( + self, + input_ids: Optional[mindspore.Tensor] = None, + inputs_embeds: Optional[mindspore.Tensor] = None, + cache_params: Optional[Mamba2Cache] = None, + use_cache: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[mindspore.Tensor] = None, + attention_mask: Optional[mindspore.Tensor] = None, + **kwargs, + ) -> Union[Tuple, Mamba2Output]: + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): # ^ is python for xor + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + if inputs_embeds is None: + inputs_embeds = self.embeddings(input_ids) + + if self.gradient_checkpointing and self.training and use_cache: + use_cache = False + + if use_cache: + if cache_params is None: + cache_params = Mamba2Cache( + self.config, inputs_embeds.shape[0], dtype=inputs_embeds.dtype + ) + cache_position = ops.arange(0, self.config.conv_kernel, dtype=mindspore.int64) + elif cache_position is None: + # cases when we do manual forward instead of using `model.generate` which will initiate + # `cache_position` and makes sure it is not None, throw error here instead of doing some + # hack to conjecture the current cache position + raise ValueError( + "You have to specify the `cache_position` manually when `use_cache=True` and `cache_params` is passed, " + "you don't have to pass a `cache_params` if you are in prefilling stage because in that case it will " + "be initialized for you automatically" + ) + else: + cache_params = None + + hidden_states = inputs_embeds + all_hidden_states = () if output_hidden_states else None + for mixer_block in self.layers: + if self.gradient_checkpointing and self.training: + hidden_states = self._gradient_checkpointing_func( + mixer_block.__call__, hidden_states, cache_params, cache_position, attention_mask + ) + else: + hidden_states = mixer_block( + hidden_states, + cache_params=cache_params, + cache_position=cache_position, + attention_mask=attention_mask, + ) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + hidden_states = self.norm_f(hidden_states) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, cache_params, all_hidden_states] if v is not None) + + return Mamba2Output( + last_hidden_state=hidden_states, + cache_params=cache_params if use_cache else None, + hidden_states=all_hidden_states, + ) + +class Mamba2ForCausalLM(Mamba2PreTrainedModel, GenerationMixin): + _tied_weights_keys = [] + + def __init__(self, config): + super().__init__(config) + self.backbone = Mamba2Model(config) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_input_embeddings(self): + return self.backbone.get_input_embeddings() + + def set_input_embeddings(self, new_embeddings): + return self.backbone.set_input_embeddings(new_embeddings) + + def prepare_inputs_for_generation( + self, + input_ids=None, + inputs_embeds=None, + use_cache=None, + cache_params: Optional[Mamba2Cache] = None, + cache_position: Optional[mindspore.Tensor] = None, + attention_mask: Optional[mindspore.Tensor] = None, + **kwargs, + ): + # Overwitten -- uses `cache_params` as opposed to `past_key_values` + + if use_cache: + # `cache_position` should have been initialized in `generate` + if cache_position is None: + raise ValueError( + "`cache_position` should not be None as it should have been initialized in " + "`model.generate`, you are responsible for passing in a valid `cache_position` if " + "you are calling `prepare_inputs_for_generation` directly with `use_cache=True`" + ) + if cache_position[0] > 0: + input_ids = input_ids[:, -1][..., None] + + if attention_mask is not None: + attention_mask = None + else: + # we initialize the `cache_position` to full size of `conv_states` at prefill stage + # considering padding will be applied when input length is shorter, and truncation + # will be applied when it is longer, so it will be equivalent to always have it match + # the length of `cache_params.conv_states`, which is `config.conv_kernel` + cache_position = ops.arange(0, self.config.conv_kernel, dtype=mindspore.int64) + + if inputs_embeds is not None and cache_params is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "attention_mask": attention_mask, + "cache_params": cache_params, + "use_cache": use_cache, + "cache_position": cache_position, + } + ) + return model_inputs + + def forward( + self, + input_ids: Optional[mindspore.Tensor] = None, + inputs_embeds: Optional[mindspore.Tensor] = None, + cache_params: Optional[Mamba2Cache] = None, + labels: Optional[mindspore.Tensor] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[mindspore.Tensor] = None, + attention_mask: Optional[mindspore.Tensor] = None, + **kwargs, # for now we need this for generation + ) -> Union[Tuple, Mamba2CausalLMOutput]: + r""" + labels (`mindspore.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + mamba2_outputs = self.backbone( + input_ids, + cache_params=cache_params, + inputs_embeds=inputs_embeds, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + use_cache=use_cache, + cache_position=cache_position, + attention_mask=attention_mask, + ) + hidden_states = mamba2_outputs[0] + + logits = self.lm_head(hidden_states.to(self.lm_head.weight.dtype)).float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :] + shift_labels = labels[..., 1:] + + # Flatten the tokens + loss_fct = CrossEntropyLoss() + loss = loss_fct(shift_logits.view(-1, shift_logits.shape[-1]), shift_labels.view(-1)) + + if not return_dict: + output = (logits,) + mamba2_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return Mamba2CausalLMOutput( + loss=loss, + logits=logits, + cache_params=mamba2_outputs.cache_params, + hidden_states=mamba2_outputs.hidden_states, + ) + + +__all__ = ["Mamba2ForCausalLM", "Mamba2Model", "Mamba2PreTrainedModel"] diff --git a/mindnlp/transformers/models/whisper/modeling_whisper.py b/mindnlp/transformers/models/whisper/modeling_whisper.py index 1900818cd..c8f1698ec 100644 --- a/mindnlp/transformers/models/whisper/modeling_whisper.py +++ b/mindnlp/transformers/models/whisper/modeling_whisper.py @@ -41,6 +41,8 @@ from .configuration_whisper import WhisperConfig from .generation_whisper import WhisperGenerationMixin +from ...modeling_flash_attention_utils import _flash_attention_forward + logger = logging.get_logger(__name__) @@ -388,8 +390,118 @@ def forward( return attn_output, attn_weights, past_key_value +class WhisperFlashAttention2(WhisperAttention): + """ + Whisper flash attention module. This module inherits from `WhisperAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self._flash_attn_uses_top_left_mask = False + + def forward( + self, + hidden_states: mindspore.Tensor, + key_value_states: Optional[mindspore.Tensor] = None, + past_key_value: Optional[EncoderDecoderCache] = None, + attention_mask: Optional[mindspore.Tensor] = None, + layer_head_mask: Optional[mindspore.Tensor] = None, + output_attentions: bool = False, + cache_position: Optional[mindspore.Tensor] = None, + ) -> Tuple[mindspore.Tensor, Optional[mindspore.Tensor], Optional[Tuple[mindspore.Tensor]]]: + logger.warning_once( + "The `flash_attention_2` implementation is in beta and is subject to change. Please use with caution." + ) + if isinstance(past_key_value, StaticCache): + raise ValueError( + "The `static` cache implementation is not compatible with `attn_implementation='flash_attention_2'`. " + ) + + if output_attentions: + raise ValueError("WhisperFlashAttention2 attention does not support output_attentions") + + # if key_value_states are provided this layer is used as a cross-attention layer for the decoder + is_cross_attention = key_value_states is not None + bsz, tgt_len, _ = hidden_states.shape + + # get query proj + query_states = ops.reshape(self.q_proj(hidden_states), (bsz, tgt_len, self.num_heads, self.head_dim)) + + if past_key_value is not None: + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + past_key_value.is_updated[self.layer_idx] = True + past_key_value = past_key_value.cross_attention_cache + else: + past_key_value = past_key_value.self_attention_cache + + # use key_value_states if cross attention + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value and is_updated: + # reuse k,v, cross_attentions + key_states = past_key_value.key_cache[self.layer_idx] + value_states = past_key_value.value_cache[self.layer_idx] + else: + key_states = self._shape(self.k_proj(current_states), -1, bsz) + value_states = self._shape(self.v_proj(current_states), -1, bsz) + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + + key_states = key_states.swapaxes(1, 2) + value_states = value_states.swapaxes(1, 2) + + causal_mask = attention_mask + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, : key_states.shape[-2]] + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == mindspore.float32: + if hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + query_states = query_states.astype(target_dtype) + key_states = key_states.astype(target_dtype) + value_states = value_states.astype(target_dtype) + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + causal_mask, + tgt_len, + dropout=self.dropout if self.training else 0.0, + is_causal=self.is_causal, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + ) + + attn_output = attn_output.reshape(bsz, tgt_len, -1) + attn_output = self.out_proj(attn_output) + + return attn_output, None, past_key_value + + WHISPER_ATTENTION_CLASSES = { "eager": WhisperAttention, + "flash_attention_2": WhisperFlashAttention2, } @@ -794,6 +906,8 @@ def __init__(self, config: WhisperConfig): self.layers = nn.ModuleList( [WhisperDecoderLayer(config, layer_idx) for layer_idx in range(config.decoder_layers)] ) + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self.layer_norm = nn.LayerNorm(config.d_model) self.gradient_checkpointing = False @@ -1052,8 +1166,13 @@ def _update_causal_mask( input_tensor: mindspore.Tensor, cache_position: mindspore.Tensor, past_key_values: Cache, - output_attentions: bool, + output_attentions: bool = False, ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and (attention_mask == 0.0).any(): + return attention_mask + return None + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) diff --git a/mindnlp/utils/import_utils.py b/mindnlp/utils/import_utils.py index 547f2efd9..8a902abba 100644 --- a/mindnlp/utils/import_utils.py +++ b/mindnlp/utils/import_utils.py @@ -382,6 +382,11 @@ def is_essentia_available(): """ return _essentia_available +def is_mamba_2_ssm_available(): + return _is_package_available("mamba_ssm") + +def is_causal_conv1d_available(): + return _is_package_available("causal_conv1d") def is_pyctcdecode_available(): """ diff --git a/mindnlp/utils/testing_utils.py b/mindnlp/utils/testing_utils.py index 13d4b4f1d..466c4e4c3 100644 --- a/mindnlp/utils/testing_utils.py +++ b/mindnlp/utils/testing_utils.py @@ -262,6 +262,78 @@ def require_librosa(test_case): """ return unittest.skipUnless(is_librosa_available(), "test requires librosa")(test_case) +################################################################################ +### update_wrapper() and wraps() decorator +################################################################################ + +# update_wrapper() and wraps() are tools to help write +# wrapper functions that can handle naive introspection +# Note from mamba2 model porting: Original mamba2 code require python 3.13+ +# so we copy the codes from python 3.13+ +WRAPPER_ASSIGNMENTS = ('__module__', '__name__', '__qualname__', '__doc__', + '__annotations__', '__type_params__') +WRAPPER_UPDATES = ('__dict__',) +def update_wrapper(wrapper, + wrapped, + assigned = WRAPPER_ASSIGNMENTS, + updated = WRAPPER_UPDATES): + """Update a wrapper function to look like the wrapped function + + wrapper is the function to be updated + wrapped is the original function + assigned is a tuple naming the attributes assigned directly + from the wrapped function to the wrapper function (defaults to + functools.WRAPPER_ASSIGNMENTS) + updated is a tuple naming the attributes of the wrapper that + are updated with the corresponding attribute from the wrapped + function (defaults to functools.WRAPPER_UPDATES) + """ + for attr in assigned: + try: + value = getattr(wrapped, attr) + except AttributeError: + pass + else: + setattr(wrapper, attr, value) + for attr in updated: + getattr(wrapper, attr).update(getattr(wrapped, attr, {})) + # Issue #17482: set __wrapped__ last so we don't inadvertently copy it + # from the wrapped function when updating __dict__ + wrapper.__wrapped__ = wrapped + # Return the wrapper so this can be used as a decorator via partial() + return wrapper + +def wraps(wrapped, + assigned = WRAPPER_ASSIGNMENTS, + updated = WRAPPER_UPDATES): + """Decorator factory to apply update_wrapper() to a wrapper function + + Returns a decorator that invokes update_wrapper() with the decorated + function as the wrapper argument and the arguments to wraps() as the + remaining arguments. Default arguments are as for update_wrapper(). + This is a convenience function to simplify applying partial() to + update_wrapper(). + """ + return functools.partial(update_wrapper, wrapped=wrapped, + assigned=assigned, updated=updated) + +def require_read_token(fn): + """ + A decorator that loads the HF token for tests that require to load gated models. + """ + token = os.getenv("HF_HUB_READ_TOKEN") + + @wraps(fn) + def _inner(*args, **kwargs): + if token is not None: + with patch("huggingface_hub.utils._headers.get_token", return_value=token): + return fn(*args, **kwargs) + else: # Allow running locally with the default token env variable + return fn(*args, **kwargs) + + return _inner + + def require_essentia(test_case): """ Decorator marking a test that requires essentia diff --git a/tests/transformers/generation/test_utils.py b/tests/transformers/generation/test_utils.py index 35b486d6a..74ed50116 100644 --- a/tests/transformers/generation/test_utils.py +++ b/tests/transformers/generation/test_utils.py @@ -1630,16 +1630,16 @@ def test_generate_from_inputs_embeds_decoder_only(self): # Traditional way of generating text outputs_from_ids = model.generate( - input_ids, max_new_tokens=5, return_dict_in_generate=True, output_scores=True + input_ids, max_new_tokens=1, return_dict_in_generate=True, output_scores=True ) - self.assertEqual(outputs_from_ids.sequences.shape, (input_ids.shape[0], input_ids.shape[1] + 5)) + self.assertEqual(outputs_from_ids.sequences.shape, (input_ids.shape[0], input_ids.shape[1] + 1)) # Same thing, but from input embeddings (`input_ids` is passed so the prompt is present in the output) inputs_embeds = model.get_input_embeddings()(input_ids) outputs_from_embeds = model.generate( input_ids, inputs_embeds=inputs_embeds, - max_new_tokens=5, + max_new_tokens=1, return_dict_in_generate=True, output_scores=True, ) @@ -1651,7 +1651,7 @@ def test_generate_from_inputs_embeds_decoder_only(self): outputs_from_rand_embeds = model.generate( input_ids, inputs_embeds=random_embeds, - max_new_tokens=5, + max_new_tokens=1, return_dict_in_generate=True, output_scores=True, ) @@ -1660,7 +1660,7 @@ def test_generate_from_inputs_embeds_decoder_only(self): # input_ids is not a required input -- if we don't pass it, the newly generated tokens will be the same outputs_from_embeds_wo_ids = model.generate( - inputs_embeds=inputs_embeds, max_new_tokens=5, return_dict_in_generate=True, output_scores=True + inputs_embeds=inputs_embeds, max_new_tokens=1, return_dict_in_generate=True, output_scores=True ) self.assertListEqual( outputs_from_embeds.sequences[:, inputs_embeds.shape[1] :].tolist(), diff --git a/tests/transformers/models/mamba2/__init__.py b/tests/transformers/models/mamba2/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/transformers/models/mamba2/test_modeling_mamba2.py b/tests/transformers/models/mamba2/test_modeling_mamba2.py new file mode 100644 index 000000000..1a684efa4 --- /dev/null +++ b/tests/transformers/models/mamba2/test_modeling_mamba2.py @@ -0,0 +1,404 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest +from typing import Dict, List, Tuple + + +from mindnlp.transformers import AutoTokenizer, Mamba2Config, is_mindspore_available +from mindnlp.utils.testing_utils import require_read_token, slow, require_mindspore + +from ...generation.test_utils import GenerationTesterMixin +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, ids_tensor + + +if is_mindspore_available(): + import mindspore + from mindnlp.core import ops, nn, no_grad + + from mindnlp.transformers import ( + Mamba2ForCausalLM, + Mamba2Model, + ) + from mindnlp.transformers.models.mamba2.modeling_mamba2 import Mamba2Cache, Mamba2Mixer + + +class Mamba2ModelTester: + def __init__( + self, + parent, + batch_size=14, + num_heads=8, + n_groups=8, + state_size=2, + head_dim=8, + conv_kernel=4, + chunk_size=8, + seq_length=7, + is_training=True, + use_labels=True, + vocab_size=99, + hidden_size=32, + num_hidden_layers=2, + hidden_act="silu", + hidden_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=16, + type_sequence_label_size=2, + num_labels=3, + num_choices=4, + scope=None, + tie_word_embeddings=False, + ): + self.parent = parent + self.num_heads = num_heads + self.n_groups = n_groups + self.head_dim = head_dim + self.state_size = state_size + self.conv_kernel = conv_kernel + self.chunk_size = chunk_size + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.use_labels = use_labels + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.type_sequence_label_size = type_sequence_label_size + self.num_labels = num_labels + self.num_choices = num_choices + self.scope = scope + self.bos_token_id = vocab_size - 1 + self.eos_token_id = vocab_size - 1 + self.pad_token_id = vocab_size - 1 + self.tie_word_embeddings = tie_word_embeddings + + def get_large_model_config(self): + return Mamba2Config.from_pretrained("mistralai/Mamba-Codestral-7B-v0.1", from_pt=True) + + def prepare_config_and_inputs( + self, gradient_checkpointing=False, scale_attn_by_inverse_layer_idx=False, reorder_and_upcast_attn=False + ): + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + + # Only left padding is valid + attention_mask = ops.ones((self.batch_size, self.seq_length), mindspore.int64) + attention_mask[0, :1] = 0 + + sequence_labels = None + token_labels = None + choice_labels = None + if self.use_labels: + sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) + token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) + choice_labels = ids_tensor([self.batch_size], self.num_choices) + + config = self.get_config( + gradient_checkpointing=gradient_checkpointing, + ) + + return ( + config, + input_ids, + attention_mask, + sequence_labels, + token_labels, + choice_labels, + ) + + def get_config(self, gradient_checkpointing=False): + return Mamba2Config( + head_dim=self.head_dim, + num_heads=self.num_heads, + n_groups=self.n_groups, + state_size=self.state_size, + conv_kernel=self.conv_kernel, + chunk_size=self.chunk_size, + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + activation_function=self.hidden_act, + n_positions=self.max_position_embeddings, + type_vocab_size=self.type_vocab_size, + use_cache=True, + bos_token_id=self.bos_token_id, + eos_token_id=self.eos_token_id, + pad_token_id=self.pad_token_id, + gradient_checkpointing=gradient_checkpointing, + tie_word_embeddings=self.tie_word_embeddings, + ) + + def prepare_config_and_inputs_for_common(self): + ( + config, + input_ids, + _, + sequence_labels, + token_labels, + choice_labels, + ) = self.prepare_config_and_inputs() + inputs_dict = {"input_ids": input_ids} + return config, inputs_dict + + def create_and_check_mamba2_caching(self, config, input_ids, attention_mask, *args): + model = Mamba2Model(config=config) + model.eval() + + output_whole = model(input_ids, attention_mask=attention_mask).last_hidden_state + + outputs = model( + input_ids[:, :-1], + attention_mask=attention_mask[:, :-1], + use_cache=True, + cache_position=ops.arange(0, config.conv_kernel), + ) + output_one = outputs.last_hidden_state + + # Using the state computed on the first inputs, we will get the same output + outputs = model( + input_ids[:, -1:], + attention_mask=attention_mask[:, -1:], + use_cache=True, + cache_params=outputs.cache_params, + cache_position=ops.arange(config.conv_kernel, config.conv_kernel + 1), + ) + output_two = outputs.last_hidden_state + + self.parent.assertTrue( + ops.allclose(ops.cat([output_one, output_two], dim=1), output_whole, atol=1e-3, rtol=1e-3) + ) + +@require_mindspore +class Mamba2ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): + all_model_classes = (Mamba2Model, Mamba2ForCausalLM) if is_mindspore_available() else () + all_generative_model_classes = (Mamba2ForCausalLM,) if is_mindspore_available() else () + has_attentions = False # Mamba does not support attentions + fx_compatible = False # FIXME let's try to support this @molbap + test_missing_keys = False + test_model_parallel = False + test_pruning = False + test_head_masking = False # Mamba does not have attention heads + + pipeline_model_mapping = ( + {"feature-extraction": Mamba2Model, "text-generation": Mamba2ForCausalLM} if is_mindspore_available() else {} + ) + + def setUp(self): + self.model_tester = Mamba2ModelTester(self) + self.config_tester = ConfigTester( + self, config_class=Mamba2Config, n_embd=37, common_properties=["hidden_size", "num_hidden_layers"] + ) + + @unittest.skip(reason="Skipped in mamba") + def test_mamba2_caching(self): + pass + # config_and_inputs = self.model_tester.prepare_config_and_inputs() + # self.model_tester.create_and_check_mamba2_caching(*config_and_inputs) + + def test_initialization(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config=config) + for name, param in model.named_parameters(): + if "D" in name: + if param.requires_grad: + # check if it's a ones like + assert ops.allclose(param.data, ops.ones_like(param.data), rtol=1e-5, atol=1e-5) + + @unittest.skip(reason="Mamba 2 weights are not tied") + def test_tied_weights_keys(self): + pass + + @unittest.skip(reason="A large mamba2 would be necessary (and costly) for that") + def test_multi_gpu_data_parallel_forward(self): + pass + + def test_model_outputs_equivalence(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}): + with no_grad(): + tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs) + dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs).to_tuple() + + def recursive_check(tuple_object, dict_object): + if isinstance(tuple_object, Mamba2Cache): # MODIFIED PART START + recursive_check(tuple_object.conv_states, dict_object.conv_states) + recursive_check(tuple_object.ssm_states, dict_object.ssm_states) + elif isinstance(tuple_object, (List, Tuple)): # MODIFIED PART END + for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object): + recursive_check(tuple_iterable_value, dict_iterable_value) + elif isinstance(tuple_object, Dict): + for tuple_iterable_value, dict_iterable_value in zip( + tuple_object.values(), dict_object.values() + ): + recursive_check(tuple_iterable_value, dict_iterable_value) + elif tuple_object is None: + return + else: + self.assertTrue( + ops.allclose(tuple_object, dict_object, atol=1e-5), + msg=( + "Tuple and dict output are not equal. Difference:" + f" {ops.max(ops.abs(tuple_object - dict_object))}. Tuple has `nan`:" + f" {ops.isnan(tuple_object).any()} and `inf`: {ops.isinf(tuple_object)}. Dict has" + f" `nan`: {ops.isnan(dict_object).any()} and `inf`: {ops.isinf(dict_object)}." + ), + ) + + recursive_check(tuple_output, dict_output) + + for model_class in self.all_model_classes: + model = model_class(config) + model.eval() + + tuple_inputs = self._prepare_for_class(inputs_dict, model_class) + dict_inputs = self._prepare_for_class(inputs_dict, model_class) + check_equivalence(model, tuple_inputs, dict_inputs) + + tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + check_equivalence(model, tuple_inputs, dict_inputs) + + tuple_inputs = self._prepare_for_class(inputs_dict, model_class) + dict_inputs = self._prepare_for_class(inputs_dict, model_class) + check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True}) + + tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True}) + + +@require_mindspore +@slow +@require_read_token +class Mamba2IntegrationTest(unittest.TestCase): + def setUp(self): + self.model_id = "mistralai/Mamba-Codestral-7B-v0.1" + self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, from_slow=True, legacy=False, from_pt=True) + self.prompt = ("[INST]Write a hello world program in C++.",) + + @require_read_token + @slow + @require_mindspore + def test_simple_generate(self): + """ + Simple generate test to avoid regressions. + Note: state-spaces (cuda) implementation and pure torch implementation + have irreconciliable differences as of now, which will cause this test to fail + in an environment with state-spaces installed. + """ + tokenizer = self.tokenizer + tokenizer.pad_token_id = tokenizer.eos_token_id + + model = Mamba2ForCausalLM.from_pretrained(self.model_id, mindspore_dtype=mindspore.bfloat16, from_pt=True) + input_ids = tokenizer("[INST]Write a hello world program in C++.[/INST]", return_tensors="pt")["input_ids"] + + out = model.generate(input_ids, do_sample=False, use_cache=True, max_new_tokens=30) + output_sentence = tokenizer.decode(out[0]) + ground_truth_sentence = """[INST]Write a hello world program in C++.[/INST] Sure, here is a simple "Hello, World!" program in C++:\n\n```cpp\n#include \n\n""" + assert output_sentence == ground_truth_sentence + + @require_read_token + @slow + @require_mindspore + def test_batched_equivalence_with_cache(self): + """ + Verifies that batched generation matches individual generation. + Important because of the specific caching mechanism + statefulness of mamba model. + Depending on precision and devices, differences can be observed from generation to generation. + """ + tokenizer = self.tokenizer + prompt = [ + "[INST]Write C#.[/INST]", + "[INST]Write a hello world in C++.[/INST]", + "[INST] Write a simple Fibonacci number computation function in Rust that does memoization, with comments, in safe Rust.[/INST]", + ] + + model = Mamba2ForCausalLM.from_pretrained(self.model_id, torch_dtype=mindspore.bfloat16, from_pt=True) + tokenizer.pad_token_id = tokenizer.eos_token_id + # batched generation + tokenized_prompts = tokenizer(prompt, return_tensors="pt", padding="longest") + batched_gen = model.generate(**tokenized_prompts, max_new_tokens=30, use_cache=True) + batched_output = tokenizer.batch_decode(batched_gen, skip_special_tokens=True) + + # individual generation + + for index_gen, individual_prompt in enumerate(prompt): + inputs = tokenizer(individual_prompt, return_tensors="pt", padding="longest") + individual_gen = model.generate(**inputs, max_new_tokens=30, use_cache=True) + individual_output = tokenizer.batch_decode(individual_gen, skip_special_tokens=True)[0] + assert individual_output[:100] == batched_output[index_gen][:100] + + @require_read_token + @slow + def test_batched_equivalence_without_cache(self): + """ + Verifies that batched generation matches individual generation without cache. + Important because of the specific caching mechanism + statefulness of mamba model. + Depending on precision and devices, differences can be observed from generation to generation. + """ + tokenizer = self.tokenizer + prompt = [ + "[INST]Write C#.[/INST]", + "[INST]Write a hello world in C++.[/INST]", + "[INST] Write a simple Fibonacci number computation function in Rust that does memoization, with comments, in safe Rust.[/INST]", + ] + + model = Mamba2ForCausalLM.from_pretrained(self.model_id, mindspore_dtype=mindspore.bfloat16, from_pt=True) + tokenizer.pad_token_id = tokenizer.eos_token_id + # batched generation + tokenized_prompts = tokenizer(prompt, return_tensors="pt", padding="longest") + batched_gen = model.generate(**tokenized_prompts, max_new_tokens=30, use_cache=True) + batched_output = tokenizer.batch_decode(batched_gen, skip_special_tokens=True) + + # individual generation + + for index_gen, individual_prompt in enumerate(prompt): + inputs = tokenizer(individual_prompt, return_tensors="pt", padding="longest") + individual_gen = model.generate(**inputs, max_new_tokens=30, use_cache=True) + individual_output = tokenizer.batch_decode(individual_gen, skip_special_tokens=True)[0] + assert individual_output[:100] == batched_output[index_gen][:100] + + @slow + @require_mindspore + def test_mamba2_mixer_train_vs_eval_equivalence(self): + # Based on https://github.com/sustcsonglin/flash-linear-attention/issues/63 + # Credit to zhixuan-lin + + B, T, D = 4, 512, 768 + dtype = mindspore.bfloat16 + config = Mamba2Config(num_heads=24, head_dim=64, hidden_size=768, expand=2, n_groups=1) + + mindspore.set_seed(42) + with mindspore.amp.autocast(dtype=dtype): + with no_grad(): + mixer = Mamba2Mixer(config, layer_idx=0) + hidden_states = ops.rand(size=(B, T, D), dtype=dtype) + + mixer.train() + out_train = mixer(hidden_states) + + mixer.eval() + out_eval = mixer(hidden_states) + + assert ops.allclose(out_train, out_eval, rtol=1e-3, atol=1e-3) \ No newline at end of file