Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

No difference in performance or speed between different adapter configs #670

Open
3 of 4 tasks
Tai-Mai opened this issue Apr 7, 2024 · 0 comments
Open
3 of 4 tasks
Labels
bug Something isn't working

Comments

@Tai-Mai
Copy link

Tai-Mai commented Apr 7, 2024

Environment info

  • adapters version: 0.1.2
  • transformers version: 4.36.2
  • Platform: Linux-5.8.0-63-generic-x86_64-with-glibc2.29
  • Python version: 3.8.10
  • Huggingface_hub version: 0.22.0
  • Safetensors version: 0.4.2
  • Accelerate version: 0.28.0
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.0.0+cu117 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: Yes
  • Using distributed or parallel set-up in script?: No

Information

Model I am using (Bert, XLNet ...): XLM-RoBERTa

Language I am using the model on (English, Chinese ...): ["en", "et", "ht", "id", "it", "qu", "sw", "ta", "th", "tr", "vi", "zh"]

Adapter setup I am using (if any): ["lora", "seq_bn", "seq_bn_inv", "double_seq_bn", "double_seq_bn_inv"]

The problem arises when using:

  • the official example scripts: (give details below)
  • my own modified scripts: (give details below)

The tasks I am working on is:

  • an official GLUE/SQUaD task: COPA + XCOPA
  • my own task or dataset: (give details below)

To reproduce

Steps to reproduce the behavior:

Please run the following code. It's a modification of the 04_Cross_Lingual_Transfer notebook. When I run this script, I get the exact same stats and performance on a language, no matter which adapter configuration I choose for the language adapter. As far as I can tell, everything is the same; accuracy, training and eval losses, training and eval samples per second. I thought it might be because there's some caching going on the background or something but I've added del model, garbage collected, and deleted the torch cache at the end of the loop but it doesn't help.

These are the main modifications I've made:

  • Convert the notebook to a .py script
  • Iterate through all combinations of languages and adapter configs listed above
  • Modify the compute_accuracy() function so that it also records the accuracy on individual test set items. (This was done for a uni project)
#-*- coding: utf-8 -*-
from transformers import TrainingArguments, AutoTokenizer, AutoConfig, enable_full_determinism, EvalPrediction
from datasets import load_dataset, concatenate_datasets
from adapters.composition import Stack
from adapters import AutoAdapterModel, AdapterConfig, AdapterTrainer
import csv
import numpy as np
from itertools import product
from functools import partial
import os
import gc
import torch


def encode_batch(examples):
  """Encodes a batch of input data using the model tokenizer."""
  all_encoded = {"input_ids": [], "attention_mask": []}
  # Iterate through all examples in this batch
  for premise, question, choice1, choice2 in zip(examples["premise"], examples["question"], examples["choice1"], examples["choice2"]):
    sentences_a = [premise + " " + question for _ in range(2)]
    # Both answer choices are passed in an array according to the format needed for the multiple-choice prediction head
    sentences_b = [choice1, choice2]
    encoded = tokenizer(
        sentences_a,
        sentences_b,
        max_length=60,
        truncation=True,
        padding="max_length",
    )
    all_encoded["input_ids"].append(encoded["input_ids"])
    all_encoded["attention_mask"].append(encoded["attention_mask"])
  return all_encoded


def preprocess_dataset(dataset):
  # Encode the input data
  dataset = dataset.map(encode_batch, batched=True)
  # The transformers model expects the target class column to be named "labels"
  dataset = dataset.rename_column("label", "labels")
  # Transform to pytorch tensors and only output the required columns
  dataset.set_format(columns=["input_ids", "attention_mask", "labels"])
  return dataset


def compute_accuracy(p: EvalPrediction, stats):
    preds = np.argmax(p.predictions, axis=1)
    for idx, (pred, label) in enumerate(zip(preds, p.label_ids)):
        stats["index"].append(idx)
        stats["prediction"].append(pred)
        stats["label"].append(label)
        stats["accuracy"].append(int(pred == label))
    return {"acc": (preds == p.label_ids).mean()}


if __name__ == "__main__":

    model_id = "xlm-roberta-base"
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    model_config = AutoConfig.from_pretrained(model_id)

    dataset_en = load_dataset("super_glue", "copa")
    dataset_en = preprocess_dataset(dataset_en)
    train_dataset = concatenate_datasets([dataset_en["train"], dataset_en["validation"]])

    seeds = [1, 2]
    epochs = [8, 16]
    langs = ["et", "ht", "id", "it", "qu", "sw", "ta", "th", "tr", "vi", "zh"]
    adapter_configs = []
    for inverse, double in product([True, False], [True, False]):
        adapter_configs.append(f"{'double_' if double else ''}seq_bn{'_inv' if inverse else ''}")
    adapter_configs.append("lora")

    for seed, num_epochs, trg_lang, adapter_config in product(seeds, epochs, langs, adapter_configs):
        print("\n\n\n\n===============")
        print("seed:", seed)
        print("num_epochs:", num_epochs)
        print("trg_lang:", trg_lang)
        print("adapter_config:", adapter_config)
        lora = adapter_config == "lora"

        inverse = "inv" in adapter_config
        double = "double" in adapter_config

        enable_full_determinism(seed=seed)

        model = AutoAdapterModel.from_pretrained(
            model_id,
            config=model_config,
            device_map="auto",
        )

        # Load the language adapters
        lang_adapter_config = AdapterConfig.load(adapter_config, reduction_factor=2)
        model.load_adapter("en/wiki@ukp", config=lang_adapter_config)
        model.load_adapter(f"{trg_lang}/wiki@ukp", config=lang_adapter_config)

        # Add a new task adapter
        if lora:
            # model.add_adapter("copa", config="lora")
            # model.add_adapter("copa", config=lang_adapter_config)
            model.add_adapter("copa")     # this is the only one of the three that works. the others throw the following error: ValueError: Invalid adapter setup: str is not a valid adapter name or composition block.
        else:
            # standard Houlsby architecture (this is the default config)
            model.add_adapter("copa")

        # Add a classification head for our target task
        model.add_multiple_choice_head("copa", num_choices=2)

        model.train_adapter(["copa"])

        # Unfreeze and activate stack setup
        model.active_adapters = Stack("en", "copa")

        training_args = TrainingArguments(
            learning_rate=1e-4,
            num_train_epochs=num_epochs,
            per_device_train_batch_size=32,
            per_device_eval_batch_size=32,
            logging_steps=100,
            output_dir="./training_output/adapters",
            overwrite_output_dir=True,
            # The next line is important to ensure the dataset labels are properly passed to the model
            remove_unused_columns=False,
        )
        trainer = AdapterTrainer(
            model=model,
            args=training_args,
            train_dataset=train_dataset,
        )

        train_output = trainer.train()
        print(train_output)
        train_samples_per_second = train_output.metrics["train_samples_per_second"]

        dataset_trg_lang = load_dataset("xcopa", trg_lang, verification_mode="no_checks")
        dataset_trg_lang = preprocess_dataset(dataset_trg_lang)
        # print(dataset_trg_lang["test"][0])

        model.active_adapters = Stack(trg_lang, "copa")

        stats = {
            "dataset": [],
            "system": [],
            "num_epochs": [],
            "seed": [],
            "index": [],
            "target_language": [],
            "lang_double_adapter": [],
            "lang_inverse_adapter": [],
            "lang_lora_adapter": [],
            # "task_lora_adapter": [],
            "prediction": [],
            "label": [],
            "accuracy": [],
            "overall_eval_accuracy": [],
            "train_samples_per_second": [],
            "eval_samples_per_second": [],
        }

        eval_trainer = AdapterTrainer(
            model=model,
            args=TrainingArguments(output_dir="./eval_output/adapters", remove_unused_columns=False,),
            eval_dataset=dataset_trg_lang["test"],
            compute_metrics=partial(compute_accuracy, stats=stats),
        )
        eval_stats = eval_trainer.evaluate()
        print(eval_stats)

        num_datapoints = len(stats["index"])
        stats["dataset"] = ["xcopa"] * num_datapoints
        stats["system"] = ["adapters"] * num_datapoints
        stats["num_epochs"] = [num_epochs] * num_datapoints
        stats["seed"] = [seed] * num_datapoints
        stats["target_language"] = [trg_lang] * num_datapoints
        stats["lang_double_adapter"] = [int(double)] * num_datapoints
        stats["lang_inverse_adapter"] = [int(inverse)] * num_datapoints
        stats["lang_lora_adapter"] = [int(lora)] * num_datapoints
        # stats["task_lora_adapter"] = ["lora" if lora else "houlsby"] * num_datapoints
        stats["overall_eval_accuracy"] = [eval_stats["eval_acc"]] * num_datapoints
        stats["train_samples_per_second"] = [train_samples_per_second] * num_datapoints
        stats["eval_samples_per_second"] = [eval_stats["eval_samples_per_second"]] * num_datapoints
        
        for k, v in stats.items():
            assert len(v) == num_datapoints, f"\nThis column doesn't have the right amount of data:\nk:{k}\nv:\n{v}"

        filename = f"data/xcopa/{trg_lang}_{adapter_config}_{num_epochs}eps_seed{seed}.csv"
        directory = os.path.dirname(filename)
        if not os.path.exists(os.path.dirname(filename)):
            os.makedirs(directory)
        with open(filename, "w") as f:
           writer = csv.writer(f)
           writer.writerow(stats.keys())
           writer.writerows(zip(*stats.values()))

        del model
        gc.collect()
        torch.cuda.empty_cache()

Expected behavior

I expected to observe differences in performance or speed between the different adapter configs.

@Tai-Mai Tai-Mai added the bug Something isn't working label Apr 7, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant