diff --git a/README.md b/README.md
index fc05233b..3e0ba563 100644
--- a/README.md
+++ b/README.md
@@ -22,18 +22,17 @@ You can evaluate the models on multiple datasets with a single command. No model
### Accelerator support and Tasks grouping.
We support the usage of `accelerate` to wrap the model for distributed evaluation, supporting multi-gpu and tensor parallelism. With **Task Grouping**, all instances from all tasks are grouped and evaluated in parallel, which significantly improves the throughput of the evaluation.
-### Efficiency benchmark
Below are the total runtime on different datasets using 4 x A100 40G.
-|Dataset|LLaVA-v1.5-7b|LLaVA-v1.5-13b|
+|Dataset (#num)|LLaVA-v1.5-7b|LLaVA-v1.5-13b|
|-------|-------------|--------------|
-|mme | 2 mins 43 seconds | 3 mins 27 seconds |
-|gqa | 10 mins 43 seconds | 14 mins 23 seconds |
-|scienceqa_img| 1 mins 58 seconds | 2 mins 52 seconds |
-|ai2d | 3 mins 17 seconds | 4 mins 12 seconds |
-|coco2017_cap_val| 14 mins 13 seconds | 19 mins 58 seconds |
+|mme (2374) | 2 mins 43 seconds | 3 mins 27 seconds |
+|gqa (12578) | 10 mins 43 seconds | 14 mins 23 seconds |
+|scienceqa_img (2017) | 1 mins 58 seconds | 2 mins 52 seconds |
+|ai2d (3088) | 3 mins 17 seconds | 4 mins 12 seconds |
+|coco2017_cap_val (5000) | 14 mins 13 seconds | 19 mins 58 seconds |
### Prepared HF datasets.
-We are hosting more than 40 (and it's increasing) datasets on [huggingface/lmms-lab](https://huggingface.co/lmms-lab), we carefully converted these datasets from original sources and included all variants, versions and splits. Now they can be directly accessed without any burden of data preprocessing. They also serve for the purpose of visualizing the data and grasping the sense of evaluation tasks distribution.
+We are hosting more than 40 (and increasing) datasets on [huggingface/lmms-lab](https://huggingface.co/lmms-lab), we carefully converted these datasets from original sources and included all variants, versions and splits. Now they can be directly accessed without any burden of data preprocessing. They also serve for the purpose of visualizing the data and grasping the sense of evaluation tasks distribution.
@@ -45,6 +44,8 @@ Including prompt pre-processing, output post-processing, answer extraction, mode
### Reproducible results (for LLaVA series models) and Logging Utilites.
We provide a set of pre-defined configurations & environments for llava-1.5, which can be directly used to reproduce the results in the paper.
+You can refer to the [repr_scripts.sh](https://github.com/EvolvingLMMs-Lab/lmms-eval/blob/dev/readme/miscs/repr_scripts.sh) we provide to see how to build and set-up the enviroments to reproduce the results from the paper. However, this environment is not recommended when you try to evaluating your own model or other models since it only install packages necessary to run llava and has a lower pytorch version that may results in a lower speed.
+
With `lmms-eval`, all evaluation details will be recorded including log samples and results, generating report tables to terminal output and to Weights & Biases Runs/Tables.
> Development will be continuing on the main branch, and we encourage you to give us feedback on what features are desired and how to improve the library further, or ask questions, either in issues or PRs on GitHub.
@@ -70,6 +71,8 @@ cd LLaVA
pip install -e .
```
+You can check the [environment install script](miscs/repr_scripts.sh) and [torch environment info](miscs/repr_torch_envs.txt) to reproduce LLaVA-1.5's paper results. We found torch/cuda versions difference would cause small variations in the results, we provide the [results check](miscs/llava_result_check.md) with different environments.
+
If you want to test on caption dataset such as `coco`, `refcoco`, and `nocaps`, you will need to have `java==1.8.0 ` to let pycocoeval api to work. If you don't have it, you can install by using conda
```
conda install openjdk=8
@@ -209,10 +212,10 @@ Please refer to our [documentation](docs/README.md).
# Acknowledgement
-The API, togegher with many code blocks of this project come from [lm-eval-harness](https://github.com/EleutherAI/lm-evaluation-harness). We recommend you to read through the [docs of lm-eval-harness](https://github.com/EleutherAI/lm-evaluation-harness/tree/main/docs) for relevant informations.
+lmms_eval is a fork of [lm-eval-harness](https://github.com/EleutherAI/lm-evaluation-harness). We recommend you to read through the [docs of lm-eval-harness](https://github.com/EleutherAI/lm-evaluation-harness/tree/main/docs) for relevant information.
Below are the changes we made to the original API:
-- Build context now only pass in idx and process image and doc during the model responding phase. This is due to the fact that dataset now contains lots of images and we can't store them in the doc like the original lm-eval-harness other wise the memory would explode.
+- Build context now only pass in idx and process image and doc during the model responding phase. This is due to the fact that dataset now contains lots of images and we can't store them in the doc like the original lm-eval-harness other wise the cpu memory would explode.
- Instance.args (lmms_eval/api/instance.py) now contains a list of images to be inputted to lmms.
- lm-eval-harness supports all HF language models as single model class. Currently this is not possible of lmms because the input/output format of lmms in HF are not yet unified. Thererfore, we have to create a new class for each lmms model. This is not ideal and we will try to unify them in the future.
diff --git a/demo.tape b/demo.tape
deleted file mode 100644
index ceb2c982..00000000
--- a/demo.tape
+++ /dev/null
@@ -1,17 +0,0 @@
-# Where should we write the GIF?
-Output demo.gif
-
-# Set up a 1200x600 terminal with 46px font.
-Set FontSize 24
-Set Width 1440
-Set Height 2560
-Set WindowBar Colorful
-Set LoopOffset 5 # Start the GIF at the 5th frame
-Set Framerate 6
-Set TypingSpeed 15ms
-
-# Type a command in the terminal.
-Type "python -m accelerate.commands.launch --main_process_port=12350 --num_processes=8 lmms_eval --model=llava --model_args=pretrained=liuhaotian/llava-v1.5-7b --tasks=mme --limit=8 --batch_size=1 --log_samples --log_samples_suffix=demo --output_path=./logs/"
-Enter
-# Admire the output for a bit.
-Sleep 30
diff --git a/docs/README.md b/docs/README.md
index 522407bc..020e351b 100644
--- a/docs/README.md
+++ b/docs/README.md
@@ -6,7 +6,6 @@ Majority of this documentation is adapted from [lm-eval-harness](https://github.
## Table of Contents
-* To learn about the command line flags, see the [commands](https://github.com/EvolvingLMMs-Lab/lmms-eval/tree/main/docs/commands.md)
-* To learn how to add a new moddel, see the [Model Guide](https://github.com/EvolvingLMMs-Lab/lmms-eval/tree/main/docs/model_guide.md).
-* For a crash course on adding new tasks to the library, see our [New Task Guide](https://github.com/EvolvingLMMs-Lab/lmms-eval/tree/main/docs/new_task_guide.md).
-* To learn more about pushing the limits of task configuration that the Eval Harness supports, see the [Task Configuration Guide](https://github.com/EvolvingLMMs-Lab/lmms-eval/tree/main/docs/task_guide.md).
+* To learn about the command line flags, see the [commands](commands.md)
+* To learn how to add a new moddel, see the [Model Guide](model_guide.md).
+* For a crash course on adding new tasks to the library, see our [Task Guide](task_guide.md).
\ No newline at end of file
diff --git a/docs/commands.md b/docs/commands.md
index f5ebf0b6..4f8c7a7d 100644
--- a/docs/commands.md
+++ b/docs/commands.md
@@ -12,7 +12,7 @@ This mode supports a number of command-line arguments, the details of which can
* `--model_args` : Controls parameters passed to the model constructor. Accepts a string containing comma-separated keyword arguments to the model class of the format `"arg1=val1,arg2=val2,..."`, such as, for example `--model_args pretrained=liuhaotian/llava-v1.5-7b,batch_size=1`. For a full list of what keyword arguments, see the initialization of the corresponding model class in `lmms_eval/models/`.
-* `--tasks` : Determines which tasks or task groups are evaluated. Accepts a comma-separated list of task names or task group names. Must be solely comprised of valid tasks/groups.
+* `--tasks` : Determines which tasks or task groups are evaluated. Accepts a comma-separated list of task names or task group names. Must be solely comprised of valid tasks/groups. You can use `--tasks list` to see all the available tasks. If you add your own tasks but not shown on the list, you can try to set `--verbosity=DEBUG` to view the error message. You can also use `--tasks list_with_num` to check every tasks and the number of question each task contains. However, `list_with_num` will download all the available datasets and may require lots of memory and time.
* `--batch_size` : Sets the batch size used for evaluation. Can be a positive integer or `"auto"` to automatically select the largest batch size that will fit in memory, speeding up evaluation. One can pass `--batch_size auto:N` to re-select the maximum batch size `N` times during evaluation. This can help accelerate evaluation further, since `lm-eval` sorts documents in descending order of context length.
diff --git a/docs/model_guide.md b/docs/model_guide.md
index 13ae8caf..0a4e4fee 100644
--- a/docs/model_guide.md
+++ b/docs/model_guide.md
@@ -19,9 +19,7 @@ Now, we'll create a new file where we'll be adding our model:
touch lmms_eval/models/.py
```
-As a rule of thumb, we recommend you to use `lmms_eval/models/qwen_vl.py` and `lmms_eval/models/instructblip.py` as reference implementations for your model. You can copy and paste the contents of one of these files into your new file to get started.
-
-**Tip: this filename should not shadow package names! For example, naming your file `anthropic.py` is disallowed since the API's name on pypi is `anthropic`, but naming it `anthropic_llms.py` works with no problems.**
+**As a rule of thumb, we recommend you to use `lmms_eval/models/qwen_vl.py` and `lmms_eval/models/instructblip.py` as reference implementations for your model. You can copy and paste the contents of one of these files into your new file to get started.**
## Interface
@@ -35,11 +33,6 @@ class MyCustomLM(lmms):
def loglikelihood(self, requests: list[Instance]) -> list[tuple[float, bool]]:
#...
-
- def loglikelihood_rolling(self, requests: list[Instance]) -> list[tuple[float, bool]]:
- #...
-
-
def generate_until(self, requests: list[Instance]) -> list[str]:
#...
#...
@@ -61,11 +54,6 @@ All three request types take as input `requests` of type `list[Instance]` that h
- In each `Instance.args` there will be 6 elements which are ` contexts, doc_to_target, doc_to_visual, doc_id, task, split`. `contexts` refers to the formatted question and is the text input for the LMM. Sometimes it might contains image token and need to address differently for different models. `doc_to_target` is a function reference that get the get the answer from the doc. This will be the continuation of the answer and only tokens belong to this part should be calculated for the loglikelihood.
- Each request will have, as result, `(ll, is_greedy): Tuple[float, int]` returned, where `ll` is a floating point number representing the log probability of generating the target string conditioned on the input, and `is_greedy` being either the value `0` or `1`, with it being `1` if and only if the target string *would be generated by greedy sampling from the LM* (that is, if the target string is the *most likely* N-token string to be output by the LM given the input. )
-- `loglikelihood_rolling`
- - Each request contains `Instance.args : Tuple[str]`, which is an input string to the model whose *entire* loglikelihood, conditioned on purely the EOT token, will be calculated.
- - This is used to evaluate *perplexity* on a data distribution.
- - It should return `(ll,) : Tuple[float]` , a.k.a. solely the *loglikelihood* of producing each piece of text given no starting input.
-
diff --git a/docs/task_guide.md b/docs/task_guide.md
index e69de29b..31fb443d 100644
--- a/docs/task_guide.md
+++ b/docs/task_guide.md
@@ -0,0 +1,113 @@
+# Task Configuration
+
+The `lmms_eval` is meant to be an extensible and flexible framework within which many different evaluation tasks can be defined. All tasks in the new version of the harness are built around a YAML configuration file format.
+
+These YAML configuration files, along with the current codebase commit hash, are intended to be shareable such that providing the YAML config enables another researcher to precisely replicate the evaluation setup used by another, in the case that the prompt or setup differs from standard `lmms_eval` task implementations.
+
+While adding a standard evaluation task on a new dataset can be occasionally as simple as swapping out a Hugging Face dataset path in an existing file, more specialized evaluation setups also exist. Here we'll provide a crash course on the more advanced logic implementable in YAML form available to users.
+
+## Good Reference Tasks
+
+Contributing a new task can be daunting! Luckily, much of the work has often been done for you in a different, similarly evaluated task. Good examples of task implementations to study include:
+
+Generation-based tasks:
+
+- MME (`lmms_eval/tasks/mme/mme.yaml`)
+
+```yaml
+dataset_path: lmms-lab/MME
+dataset_kwargs:
+ token: True
+task: "mme"
+test_split: test
+output_type: generate_until
+doc_to_visual: !function utils.mme_doc_to_visual
+doc_to_text: !function utils.mme_doc_to_text
+doc_to_target: "answer"
+generation_kwargs:
+ max_new_tokens: 16
+ temperature: 0
+ top_p: 0
+ num_beams: 1
+ do_sample: false
+# The return value of process_results will be used by metrics
+process_results: !function utils.mme_process_results
+# Note that the metric name can be either a registed metric function (such as the case for GQA) or a key name returned by process_results
+metric_list:
+ - metric: mme_percetion_score
+ aggregation: !function utils.mme_aggregate_results
+ higher_is_better: true
+ - metric: mme_cognition_score
+ aggregation: !function utils.mme_aggregate_results
+ higher_is_better: true
+model_specific_prompt_kwargs:
+ default:
+ pre_prompt: ""
+ post_prompt: "\nAnswer the question using a single word or phrase."
+ qwen_vl:
+ pre_prompt: ""
+ post_prompt: " Answer:"
+metadata:
+ - version: 0.0
+```
+
+You can pay special attention to the `process_results` and `metric_list` fields, which are used to define how the model output is post-processed and scored.
+Also, the `model_specific_prompt_kwargs` field is used to define model-specific prompt configurations. The default is set to follow Llava.
+
+PPL-based tasks:
+- Seedbench (`lmms_eval/tasks/seedbench/seedbench_ppl.yaml`)
+
+```yaml
+dataset_path: lmms-lab/SEED-Bench
+dataset_kwargs:
+ token: True
+task: "seedbench_ppl"
+test_split: test
+output_type: multiple_choice
+doc_to_visual: !function utils.seed_doc_to_visual
+doc_to_text: !function utils.seed_doc_to_text_mc
+doc_to_choice : !function utils.seed_doc_to_choice
+doc_to_target: !function utils.seed_doc_to_mc_target
+# Note that the metric name can be either a registed metric function (such as the case for GQA) or a key name returned by process_results
+metric_list:
+ - metric: acc
+metadata:
+ - version: 0.0
+```
+
+## Configurations
+
+Tasks are configured via the `TaskConfig` object. Below, we describe all fields usable within the object, and their role in defining a task.
+
+### Parameters
+
+Task naming + registration:
+- **task** (`str`, defaults to None) — name of the task.
+- **group** (`str`, *optional*) — name of the task group(s) a task belongs to. Enables one to run all tasks with a specified tag or group name at once.
+
+Dataset configuration options:
+- **dataset_path** (`str`) — The name of the dataset as listed by HF in the datasets Hub.
+- **dataset_name** (`str`, *optional*, defaults to None) — The name of what HF calls a “config” or sub-task of the benchmark. If your task does not contain any data instances, just leave this to default to None. (If you're familiar with the HF `datasets.load_dataset` function, these are just the first 2 arguments to it.)
+- **dataset_kwargs** (`dict`, *optional*) — Auxiliary arguments that `datasets.load_dataset` accepts. This can be used to specify arguments such as `data_files` or `data_dir` if you want to use local datafiles such as json or csv.
+- **training_split** (`str`, *optional*) — Split in the dataset to use as the training split.
+- **validation_split** (`str`, *optional*) — Split in the dataset to use as the validation split.
+- **test_split** (`str`, *optional*) — Split in the dataset to use as the test split.
+- **fewshot_split** (`str`, *optional*) — Split in the dataset to draw few-shot exemplars from. assert that this not None if num_fewshot > 0. **This function is not well tested so far**
+- **process_docs** (`Callable`, *optional*) — Optionally define a function to apply to each HF dataset split, to preprocess all documents before being fed into prompt template rendering or other evaluation steps. Can be used to rename dataset columns, or to process documents into a format closer to the expected format expected by a prompt template.
+
+Prompting / in-context formatting options:
+- **doc_to_text** (`Union[Callable, str]`, *optional*) — Column name or function to process a sample into the appropriate input for the model
+- **doc_to_visial** (`Union[Callable, str]`, *optional*) — Function to process a sample into the appropriate input images for the model.
+- **doc_to_target** (`Union[Callable, str]`, *optional*) — Column name or or function to process a sample into the appropriate target output for the model. For multiple choice tasks, this should return an index into
+- **doc_to_choice** (`Union[Callable, str]`, *optional*) — Column name or or function to process a sample into a list of possible string choices for `multiple_choice` tasks. Left undefined for `generate_until` tasks.
+
+Runtime configuration options:
+- **num_fewshot** (`int`, *optional*, defaults to 0) — Number of few-shot examples before the input. **This function is not well tested so far**
+- **batch_size** (`int`, *optional*, defaults to 1) — Batch size.
+
+**So far some models (such as qwen) may not support batch size > 1. Some models (such as llava) will generate different scores for different batch sizes. We recommend setting batch size to 1 for final benchmarking runs.**
+
+Scoring details:
+- **metric_list** (`str`, *optional*, defaults to None) — A list of metrics to use for evaluation.
+- **output_type** (`str`, *optional*, defaults to "generate_until") — Selects the type of model output for the given task. Options are `generate_until`, `loglikelihood`, and `multiple_choice`.
+- **generation_kwargs** (`dict`, *optional*) — Auxiliary arguments for the `generate` function from HF transformers library. Advanced keyword arguments may not be supported for non-HF LM classes.
diff --git a/llava_repr_requirements.txt b/llava_repr_requirements.txt
new file mode 100644
index 00000000..f1f6dcf8
--- /dev/null
+++ b/llava_repr_requirements.txt
@@ -0,0 +1,33 @@
+llava@git+https://github.com/haotian-liu/LLaVA@v1.1.3
+accelerate>=0.21.0
+black==24.1.0
+datasets==2.16.1
+evaluate>=0.4.0
+jsonlines
+numexpr
+peft>=0.2.0
+pybind11>=2.6.2
+pytablewriter
+rouge-score>=0.0.4
+sacrebleu>=1.5.0
+scikit-learn>=0.24.1
+sqlitedict
+torch==2.0.1
+openai>=1.0.0
+pycocoevalcap
+tqdm-multiprocess
+transformers>=4.36.2
+zstandard
+pillow
+pyyaml
+sympy
+mpmath
+Jinja2
+openpyxl
+Levenshtein
+hf_transfer
+tenacity
+wandb>=0.16.0
+transformers-stream-generator
+tiktoken
+pre-commit
\ No newline at end of file
diff --git a/lmms_eval/__main__.py b/lmms_eval/__main__.py
index 9f005109..1f45a85e 100644
--- a/lmms_eval/__main__.py
+++ b/lmms_eval/__main__.py
@@ -298,6 +298,8 @@ def cli_evaluate_single(args: Union[argparse.Namespace, None] = None) -> None:
if results is not None:
if args.log_samples:
samples = results.pop("samples")
+ else:
+ samples = None
dumped = json.dumps(results, indent=4, default=_handle_non_serializable)
if args.show_config:
print(dumped)
diff --git a/lmms_eval/api/instance.py b/lmms_eval/api/instance.py
index 7324dae4..41875358 100644
--- a/lmms_eval/api/instance.py
+++ b/lmms_eval/api/instance.py
@@ -4,7 +4,7 @@
@dataclass
class Instance:
- request_type: Literal["loglikelihood", "loglikelihood_rolling", "generate_until"]
+ request_type: Literal["loglikelihood", "generate_until"]
arguments: tuple
idx: int
metadata: Tuple[str, int, int] = field(default_factory=lambda: (None, None, None)) # TODO: better typehints here
diff --git a/lmms_eval/api/metrics.py b/lmms_eval/api/metrics.py
index 1417d595..56e269a7 100644
--- a/lmms_eval/api/metrics.py
+++ b/lmms_eval/api/metrics.py
@@ -166,25 +166,6 @@ def perplexity_fn(items): # This is a passthrough function
return items
-@register_metric(
- metric="word_perplexity",
- higher_is_better=False,
- output_type="loglikelihood_rolling",
- aggregation="weighted_perplexity",
-)
-def word_perplexity_fn(items): # This is a passthrough function
- return items
-
-
-@register_metric(
- metric="byte_perplexity",
- higher_is_better=False,
- output_type="loglikelihood_rolling",
- aggregation="weighted_perplexity",
-)
-def byte_perplexity_fn(items): # This is a passthrough function
- return items
-
def levenshtein_distance(s1, s2):
if len(s1) > len(s2):
@@ -232,16 +213,6 @@ def anls(
return {"anls": question_result}
-@register_metric(
- metric="bits_per_byte",
- higher_is_better=False,
- output_type="loglikelihood_rolling",
- aggregation="bits_per_byte",
-)
-def bits_per_byte_fn(items): # This is a passthrough function
- return items
-
-
def pop_stddev(arr):
mu = mean(arr)
return math.sqrt(sum([(x - mu) ** 2 for x in arr]) / len(arr))
diff --git a/lmms_eval/api/model.py b/lmms_eval/api/model.py
index d956e85e..9afed21d 100644
--- a/lmms_eval/api/model.py
+++ b/lmms_eval/api/model.py
@@ -54,49 +54,6 @@ def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
"""
pass
- @abc.abstractmethod
- def loglikelihood_rolling(self, requests) -> List[Tuple[float, bool]]:
- """Compute full log-likelihood of a string, with no truncation, for perplexity computation
- - We will use the full max context length of the model.
- - For inputs that exceed the max context length, we divide the tokenized string into chunks of up to
- the max context length.
- - IMPORTANT: Each document's loglikelihood/perplexity is computed *separately*, unlike other implementations
- which may simply concatenate multiple documents together.
- - IMPORTANT: We maximize the amount of context for each prediction. Specifically, for inputs that we break into
- multiple chunks, the last input will still a full-sized context.
- Example:
- Input tokens: [ 0 1 2 3 4 5 6 7 8 9 ]
- Prefix: EOT
- Max context length: 4
- Resulting input/prediction pairs:
-
- INPUT: EOT 0 1 2
- PRED: 0 1 2 3
-
- INPUT: 3 4 5 6
- PRED: 4 5 6 7
-
- INPUT: 5 6 7 8
- PRED: 8 9
-
- Observe that:
- 1. Each token is predicted exactly once
- 2. For the last pair, we provide the full context, but only score the last two tokens
-
- :param requests: list[Instance]
- A list of Instance objects with property `args` which returns a tuple (context, continuation).
- string: str
- String for which we are computing per-token loglikelihood
- 'visual_list: list[dict]'
- Visual input to the model. Can be None.
- :return: list[tuple[float, bool]]
- A list of pairs (logprob, isgreedy)
- logprob: float
- The log probability of `continuation`
- isgreedy:
- Whether `continuation` would be generated by greedy sampling from `context`
- """
- pass
# TODO: Add an optional max length
@abc.abstractmethod
diff --git a/lmms_eval/api/registry.py b/lmms_eval/api/registry.py
index 288b9368..0728b86d 100644
--- a/lmms_eval/api/registry.py
+++ b/lmms_eval/api/registry.py
@@ -72,7 +72,6 @@ def decorate(fn):
"perplexity",
"acc",
],
- "loglikelihood_rolling": ["word_perplexity", "byte_perplexity", "bits_per_byte"],
"multiple_choice": ["acc", "acc_norm"],
"generate_until": ["exact_match"],
}
diff --git a/lmms_eval/api/task.py b/lmms_eval/api/task.py
index db3cfd1e..f50549d0 100644
--- a/lmms_eval/api/task.py
+++ b/lmms_eval/api/task.py
@@ -37,7 +37,6 @@
ALL_OUTPUT_TYPES = [
"loglikelihood",
"multiple_choice",
- "loglikelihood_rolling",
"generate_until",
]
@@ -440,11 +439,6 @@ def count_bytes(cls, doc):
"""Used for byte-level perplexity metrics in rolling loglikelihood"""
return len(doc.encode("utf-8"))
- @classmethod
- def count_words(cls, doc):
- """Downstream loglikelihood_rolling perplexity tasks with custom word boundaries should override this!"""
- return len(re.split(r"\s+", doc))
-
@utils.positional_deprecated
def fewshot_context(
self,
@@ -931,8 +925,6 @@ def construct_requests(self, doc_id: int, ctx: str, **kwargs) -> Union[List[Inst
kwargs.pop("split")
if self.OUTPUT_TYPE == "loglikelihood":
arguments = (ctx, self.doc_to_target, self.doc_to_visual, doc_id, self.config.task, split)
- elif self.OUTPUT_TYPE == "loglikelihood_rolling":
- arguments = (self.doc_to_target,)
elif self.OUTPUT_TYPE == "multiple_choice":
doc = self.dataset[split][doc_id]
choices = self.doc_to_choice(doc)
@@ -993,15 +985,6 @@ def process_results(self, doc, results):
**({"perplexity": ll} if "perplexity" in use_metric else {}),
**({"acc": int(is_greedy)} if "acc" in use_metric else {}),
}
- elif self.OUTPUT_TYPE == "loglikelihood_rolling":
- (loglikelihood,) = results
- _words = self.count_words(self.doc_to_target(doc))
- _bytes = self.count_bytes(self.doc_to_target(doc))
- return {
- **({"word_perplexity": (loglikelihood, _words)} if "word_perplexity" in use_metric else {}),
- **({"byte_perplexity": (loglikelihood, _bytes)} if "byte_perplexity" in use_metric else {}),
- **({"bits_per_byte": (loglikelihood, _bytes)} if "bits_per_byte" in use_metric else {}),
- }
elif self.OUTPUT_TYPE == "multiple_choice":
lls, is_greedy = zip(*results)
@@ -1123,7 +1106,7 @@ def process_results(self, doc, results):
else:
raise ValueError(
f"Passed invalid output_type '{self.OUTPUT_TYPE}' ! Please use one of ",
- "'loglikelihood', 'loglikelihood_rolling', 'generate_until' or 'multiple_choice'",
+ "'loglikelihood','generate_until' or 'multiple_choice'",
)
return result_dict
diff --git a/lmms_eval/evaluator.py b/lmms_eval/evaluator.py
index c3100dca..a97edff0 100644
--- a/lmms_eval/evaluator.py
+++ b/lmms_eval/evaluator.py
@@ -318,7 +318,7 @@ def evaluate(
# hack: remove image columns to speed avoid loading images and speed up postprocessing
# reason: doc_iterator will actually load image if it's in the doc.
docs = task.test_docs() if task.has_test_docs() else task.validation_docs()
- if "d170" not in task_name or "dc100" not in task_name or "dc200" not in task_name:
+ if "d170" not in task_name and "dc100" not in task_name and "dc200" not in task_name:
remove_cols = []
features = docs.features
# If it is an Image instance or a Sequence of Image instance. Remove it
diff --git a/lmms_eval/logging_utils.py b/lmms_eval/logging_utils.py
index 800dfcd1..21a2ee04 100644
--- a/lmms_eval/logging_utils.py
+++ b/lmms_eval/logging_utils.py
@@ -276,10 +276,6 @@ def _generate_dataset(self, data: List[Dict[str, Any]], config: Dict[str, Any])
choices = ["\n".join([f"{idx}. {y[1]}" for idx, y in enumerate(x["arguments"])]) for x in data]
resps = [np.argmax([n[0][0] for n in x["resps"]]) for x in data]
filtered_resps = [np.argmax([n[0] for n in x["filtered_resps"]]) for x in data]
- elif config["output_type"] == "loglikelihood_rolling":
- instance = [x["arguments"][0][0] for x in data]
- resps = [x["resps"][0][0] for x in data]
- filtered_resps = [x["filtered_resps"][0] for x in data]
elif config["output_type"] == "generate_until":
instance = [x["arguments"][0][0] for x in data]
resps = [x["resps"][0][0] for x in data]
diff --git a/lmms_eval/models/fuyu.py b/lmms_eval/models/fuyu.py
index 9ab39bf7..32566063 100644
--- a/lmms_eval/models/fuyu.py
+++ b/lmms_eval/models/fuyu.py
@@ -253,9 +253,7 @@ def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
pbar.close()
return res
- def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]:
- # TODO
- assert False, "We have not implemented this function for llava yet"
+
def tok_encode(self, string: str, left_truncate_len=None, add_special_tokens=None) -> List[int]:
""" """
diff --git a/lmms_eval/models/gpt4v.py b/lmms_eval/models/gpt4v.py
index 46c851f6..2ab489ae 100644
--- a/lmms_eval/models/gpt4v.py
+++ b/lmms_eval/models/gpt4v.py
@@ -107,7 +107,7 @@ def generate_until(self, requests) -> List[str]:
for attempt in range(5):
try:
- response = url_requests.post(API_URL, headers=headers, json=payload)
+ response = url_requests.post(API_URL, headers=headers, json=payload, timeout=20)
response_data = response.json()
content = response_data["choices"][0]["message"]["content"].strip()
@@ -128,6 +128,4 @@ def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
# TODO
assert False, "GPT4V not support"
- def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]:
- # TODO
- assert False, "GPT4V not support"
+
diff --git a/lmms_eval/models/instructblip.py b/lmms_eval/models/instructblip.py
index 7086f346..ebbf7bec 100644
--- a/lmms_eval/models/instructblip.py
+++ b/lmms_eval/models/instructblip.py
@@ -138,10 +138,6 @@ def tok_decode(self, tokens):
def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
# TODO
assert False, "We have not implemented this function for InstructBLIP yet"
-
- def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]:
- # TODO
- assert False, "We have not implemented this function for InstructBLIP yet"
def flatten(self, input):
new_list = []
diff --git a/lmms_eval/models/llava.py b/lmms_eval/models/llava.py
index 961fafea..f7d9184b 100644
--- a/lmms_eval/models/llava.py
+++ b/lmms_eval/models/llava.py
@@ -232,10 +232,6 @@ def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
pbar.close()
return res
- def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]:
- # TODO
- assert False, "We have not implemented this function for llava yet"
-
def flatten(self, input):
new_list = []
for i in input:
diff --git a/lmms_eval/models/minicpm_v.py b/lmms_eval/models/minicpm_v.py
index 1838b56f..ad7c5ac7 100644
--- a/lmms_eval/models/minicpm_v.py
+++ b/lmms_eval/models/minicpm_v.py
@@ -135,10 +135,6 @@ def tok_decode(self, tokens):
def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
# TODO
assert False, "We have not implemented this function for MiniCPM_V yet"
-
- def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]:
- # TODO
- assert False, "We have not implemented this function for MiniCPM_V yet"
def flatten(self, input):
new_list = []
diff --git a/lmms_eval/models/qwen_vl.py b/lmms_eval/models/qwen_vl.py
index 5201f79f..503d091b 100644
--- a/lmms_eval/models/qwen_vl.py
+++ b/lmms_eval/models/qwen_vl.py
@@ -174,11 +174,7 @@ def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
pbar.close()
return res
- assert False, "We have not implemented this function for Qwen VL yet"
- def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]:
- # TODO
- assert False, "We have not implemented this function for Qwen VL yet"
def flatten(self, input):
new_list = []
diff --git a/lmms_eval/tasks/coco_cap/utils.py b/lmms_eval/tasks/coco_cap/utils.py
index 58fcb141..4e0551f6 100644
--- a/lmms_eval/tasks/coco_cap/utils.py
+++ b/lmms_eval/tasks/coco_cap/utils.py
@@ -42,7 +42,7 @@ def coco_process_result(doc, result):
return {f"coco_{metric}": data_dict for metric in COCO_METRICS}
-def coco_aggregation_result(results, metric):
+def coco_aggregation_result(results, metric, args):
scorers = [(Bleu(4), "Bleu_1"), (Bleu(4), "Bleu_2"), (Bleu(4), "Bleu_3"), (Bleu(4), "Bleu_4"), (Meteor(), "METEOR"), (Rouge(), "ROUGE_L"), (Cider(), "CIDEr"), (Spice(), "SPICE")]
scorers_dict = {s[1]: s for s in scorers}
@@ -89,45 +89,45 @@ def coco_aggregation_result(results, metric):
n = int(metric.split("_")[-1])
score = score[n - 1]
- os.makedirs("./submissions", exist_ok=True)
- if not os.path.exists("./submissions/coco_captions_val2014_alg_results.json"):
+ path = generate_submission_file("coco_captions_val2014_alg_results.json", args)
+ if not os.path.exists(path):
eval_logger.info("Storing prediction that can be submitted to the server ...")
- with open("./submissions/coco_captions_val2014_alg_results.json", "w") as f:
+ with open(path, "w") as f:
json.dump(stored_results, f, indent=4)
return score
-def coco_bleu4(results):
- return coco_aggregation_result(results, "Bleu_4")
+def coco_bleu4(results, args):
+ return coco_aggregation_result(results, "Bleu_4", args)
-def coco_bleu3(results):
- return coco_aggregation_result(results, "Bleu_3")
+def coco_bleu3(results, args):
+ return coco_aggregation_result(results, "Bleu_3", args)
-def coco_bleu2(results):
- return coco_aggregation_result(results, "Bleu_2")
+def coco_bleu2(results, args):
+ return coco_aggregation_result(results, "Bleu_2", args)
-def coco_bleu1(results):
- return coco_aggregation_result(results, "Bleu_1")
+def coco_bleu1(results, args):
+ return coco_aggregation_result(results, "Bleu_1", args)
-def coco_meteor(results):
- return coco_aggregation_result(results, "METEOR")
+def coco_meteor(results, args):
+ return coco_aggregation_result(results, "METEOR", args)
-def coco_rougel(results):
- return coco_aggregation_result(results, "ROUGE_L")
+def coco_rougel(results, args):
+ return coco_aggregation_result(results, "ROUGE_L", args)
-def coco_cider(results):
- return coco_aggregation_result(results, "CIDEr")
+def coco_cider(results, args):
+ return coco_aggregation_result(results, "CIDEr", args)
-def coco_spice(results):
- return coco_aggregation_result(results, "SPICE")
+def coco_spice(results, args):
+ return coco_aggregation_result(results, "SPICE", args)
def coco_test_process_result(doc, result):
diff --git a/lmms_eval/tasks/llava-bench-coco/utils.py b/lmms_eval/tasks/llava-bench-coco/utils.py
index b7bea9c4..8858637f 100644
--- a/lmms_eval/tasks/llava-bench-coco/utils.py
+++ b/lmms_eval/tasks/llava-bench-coco/utils.py
@@ -13,7 +13,7 @@
eval_logger = logging.getLogger("lmms-eval")
NUM_SECONDS_TO_SLEEP = 0.5
-LLAVA_W_METRICS = ["gpt_eval_llava_conv", "gpt_eval_llava_detail", "gpt_eval_llava_conv"]
+LLAVA_W_METRICS = ["gpt_eval_llava_conv", "gpt_eval_llava_detail", "gpt_eval_llava_complex"]
rule_dict = json.load(open(os.path.join(os.path.dirname(os.path.abspath(__file__)), "rule.json"), "r"))
@@ -69,7 +69,7 @@ def get_eval(content: str, max_tokens: int, retries: int = 3):
for attempt in range(retries):
try:
- response = requests.post(API_URL, headers=headers, json=payload)
+ response = requests.post(API_URL, headers=headers, json=payload, timeout=60)
response.raise_for_status()
response_data = response.json()
diff --git a/lmms_eval/tasks/llava-in-the-wild/utils.py b/lmms_eval/tasks/llava-in-the-wild/utils.py
index c788cb24..ac86ee99 100644
--- a/lmms_eval/tasks/llava-in-the-wild/utils.py
+++ b/lmms_eval/tasks/llava-in-the-wild/utils.py
@@ -11,9 +11,9 @@
from copy import deepcopy
eval_logger = logging.getLogger("lmms-eval")
-NUM_SECONDS_TO_SLEEP = 0.5
+NUM_SECONDS_TO_SLEEP = 5
-LLAVA_W_METRICS = ["gpt_eval_llava_conv", "gpt_eval_llava_detail", "gpt_eval_llava_conv"]
+LLAVA_W_METRICS = ["gpt_eval_llava_conv", "gpt_eval_llava_detail", "gpt_eval_llava_complex"]
rule_dict = json.load(open(os.path.join(os.path.dirname(os.path.abspath(__file__)), "rule.json"), "r"))
@@ -47,7 +47,7 @@
}
-def get_eval(content: str, max_tokens: int, retries: int = 3):
+def get_eval(content: str, max_tokens: int, retries: int = 5):
global headers
messages = [
@@ -67,7 +67,7 @@ def get_eval(content: str, max_tokens: int, retries: int = 3):
for attempt in range(retries):
try:
- response = requests.post(API_URL, headers=headers, json=payload)
+ response = requests.post(API_URL, headers=headers, json=payload, timeout=60)
response.raise_for_status()
response_data = response.json()
@@ -78,7 +78,7 @@ def get_eval(content: str, max_tokens: int, retries: int = 3):
except Exception as e:
eval_logger.info(f"Attempt {attempt + 1} failed with error: {e}")
- if attempt < retries - 1: # If we have retries left, sleep and then continue to next attempt
+ if attempt < retries: # If we have retries left, sleep and then continue to next attempt
time.sleep(NUM_SECONDS_TO_SLEEP)
else: # If this was the last attempt, log and return empty
eval_logger.error(f"All {retries} attempts failed. Last error message: {e}")
diff --git a/lmms_eval/tasks/mathvista/utils.py b/lmms_eval/tasks/mathvista/utils.py
index 471e9727..620e3f28 100644
--- a/lmms_eval/tasks/mathvista/utils.py
+++ b/lmms_eval/tasks/mathvista/utils.py
@@ -47,7 +47,7 @@ def mathvista_process_results(doc, results):
problem = {
"question_type": doc["question_type"],
"answer_type": doc["answer_type"],
- "query": doc["question"],
+ "query": doc["query"],
"choices": doc["choices"],
"answer": doc["answer"] if "answer" in doc else None,
"precision": doc["precision"] if "precision" in doc else 0,
@@ -60,7 +60,7 @@ def mathvista_process_results(doc, results):
result = {
"question_id": doc["pid"],
- "query": doc["question"],
+ "query": doc["query"],
"choices": doc["choices"],
"answer": doc["answer"] if "answer" in doc else None,
"extraction": extraction,
diff --git a/lmms_eval/tasks/mme/mme_test.yaml b/lmms_eval/tasks/mme/mme_test.yaml
deleted file mode 100644
index c529cf83..00000000
--- a/lmms_eval/tasks/mme/mme_test.yaml
+++ /dev/null
@@ -1,31 +0,0 @@
-dataset_path: lmms-lab/MME
-dataset_kwargs:
- token: True
-task: "mme_test"
-test_split: test
-output_type: generate_until
-doc_to_visual: !function utils.mme_doc_to_visual
-doc_to_text: !function utils.mme_doc_to_text
-doc_to_target: "answer"
-generation_kwargs:
- max_new_tokens: 16
- temperature: 0
- top_p: 0
- num_beams: 1
- do_sample: false
-# The return value of process_results will be used by metrics
-process_results: !function utils.mme_process_results
-# Note that the metric name can be either a registed metric function (such as the case for GQA) or a key name returned by process_results
-metric_list:
- - metric: mme_percetion_score
- aggregation: !function utils.mme_aggregate_results
- higher_is_better: true
- - metric: mme_cognition_score
- aggregation: !function utils.mme_aggregate_results
- higher_is_better: true
-model_specific_prompt_kwargs:
- default:
- pre_prompt: ""
- post_prompt: "\nAnswer the question using a single word or phrase."
-metadata:
- - version: 0.0
\ No newline at end of file
diff --git a/lmms_eval/tasks/mmvet/utils.py b/lmms_eval/tasks/mmvet/utils.py
index b54f19f1..5caaba46 100644
--- a/lmms_eval/tasks/mmvet/utils.py
+++ b/lmms_eval/tasks/mmvet/utils.py
@@ -58,6 +58,7 @@ def get_chat_response(prompt, model=GPT_EVAL_MODEL_NAME, temperature=0.0, max_to
API_URL,
headers=headers,
json=payload,
+ timeout=60,
)
response.raise_for_status()
response_data = response.json()
@@ -67,7 +68,7 @@ def get_chat_response(prompt, model=GPT_EVAL_MODEL_NAME, temperature=0.0, max_to
return content, response_data["model"]
except Exception as e:
- eval_logger.info(f"Error in response: {response.json()['error']['message']}")
+ eval_logger.error(f"Error: {e}")
if "Rate limit" in str(e):
eval_logger.info("Sleeping due to rate limit...")
time.sleep(sleep_time)
diff --git a/lmms_eval/tasks/seedbench/utils.py b/lmms_eval/tasks/seedbench/utils.py
index 4d9334ee..c2938f13 100644
--- a/lmms_eval/tasks/seedbench/utils.py
+++ b/lmms_eval/tasks/seedbench/utils.py
@@ -16,7 +16,8 @@ def seed_doc_to_text(doc):
def seed_process_result(doc, result):
pred = result[0].strip()
- pred = pred[0]
+ if len(pred) > 1:
+ pred = pred[0]
answer = doc["answer"]
data_type = doc["data_type"]
diff --git a/lmms_eval/tasks/seedbench_2/utils.py b/lmms_eval/tasks/seedbench_2/utils.py
index af8d8571..f88ded9c 100644
--- a/lmms_eval/tasks/seedbench_2/utils.py
+++ b/lmms_eval/tasks/seedbench_2/utils.py
@@ -27,7 +27,8 @@ def seed_doc_to_text(doc, model_specific_kwargs=None):
def seed_process_result(doc, result):
pred = result[0].strip()
- pred = pred[0]
+ if len(pred) > 1:
+ pred = pred[0]
answer = doc["answer"]
data_type = doc["data_type"].split(" ")
data_type = "_".join(data_type)
diff --git a/miscs/llava_result_check.md b/miscs/llava_result_check.md
new file mode 100644
index 00000000..e69de29b
diff --git a/miscs/repr_scripts.sh b/miscs/repr_scripts.sh
new file mode 100644
index 00000000..f5a74309
--- /dev/null
+++ b/miscs/repr_scripts.sh
@@ -0,0 +1,10 @@
+# install lmms_eval without building dependencies
+cd lmms_eval;
+pip install --no-deps -U -e .
+
+# install all the requirements that require for reproduce llava results
+pip install -r llava_repr_requirements.txt
+
+# Run and exactly reproduce llava_v1.5 results!
+# mme as an example
+accelerate launch --num_processes=1 -m lmms_eval --model llava --model_args pretrained="liuhaotian/llava-v1.5-7b" --tasks mme --batch_size 1 --log_samples --log_samples_sufix reproduce --output_path ./logs/
\ No newline at end of file
diff --git a/miscs/repr_torch_envs.txt b/miscs/repr_torch_envs.txt
new file mode 100644
index 00000000..6a7f22ae
--- /dev/null
+++ b/miscs/repr_torch_envs.txt
@@ -0,0 +1,69 @@
+Collecting environment information...
+PyTorch version: 2.0.1+cu117
+Is debug build: False
+CUDA used to build PyTorch: 11.7
+ROCM used to build PyTorch: N/A
+
+OS: Ubuntu 22.04.2 LTS (x86_64)
+GCC version: (Ubuntu 11.3.0-1ubuntu1~22.04.1) 11.3.0
+Clang version: Could not collect
+CMake version: version 3.28.3
+Libc version: glibc-2.35
+
+Python version: 3.9.18 (main, Sep 11 2023, 13:41:44) [GCC 11.2.0] (64-bit runtime)
+Python platform: Linux-5.15.0-76-generic-x86_64-with-glibc2.35
+Is CUDA available: False
+CUDA runtime version: 11.8.89
+CUDA_MODULE_LOADING set to: N/A
+GPU models and configuration: Could not collect
+Nvidia driver version: Could not collect
+cuDNN version: Could not collect
+HIP runtime version: N/A
+MIOpen runtime version: N/A
+Is XNNPACK available: True
+
+CPU:
+Architecture: x86_64
+CPU op-mode(s): 32-bit, 64-bit
+Address sizes: 42 bits physical, 48 bits virtual
+Byte Order: Little Endian
+CPU(s): 16
+On-line CPU(s) list: 0-15
+Vendor ID: GenuineIntel
+Model name: Intel(R) Xeon(R) Gold 6348 CPU @ 2.60GHz
+CPU family: 6
+Model: 106
+Thread(s) per core: 2
+Core(s) per socket: 8
+Socket(s): 1
+Stepping: 6
+BogoMIPS: 5200.01
+Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch cpuid_fault invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves wbnoinvd arat avx512vbmi umip avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq fsrm md_clear arch_capabilities
+Hypervisor vendor: KVM
+Virtualization type: full
+L1d cache: 384 KiB (8 instances)
+L1i cache: 256 KiB (8 instances)
+L2 cache: 10 MiB (8 instances)
+L3 cache: 42 MiB (1 instance)
+NUMA node(s): 1
+NUMA node0 CPU(s): 0-15
+Vulnerability Itlb multihit: Not affected
+Vulnerability L1tf: Not affected
+Vulnerability Mds: Not affected
+Vulnerability Meltdown: Not affected
+Vulnerability Mmio stale data: Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
+Vulnerability Retbleed: Not affected
+Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
+Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
+Vulnerability Spectre v2: Mitigation; Enhanced IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence
+Vulnerability Srbds: Not affected
+Vulnerability Tsx async abort: Not affected
+
+Versions of relevant libraries:
+[pip3] mypy-extensions==1.0.0
+[pip3] numpy==1.26.4
+[pip3] torch==2.0.1
+[pip3] torchvision==0.16.2
+[conda] numpy 1.26.4 pypi_0 pypi
+[conda] torch 2.0.1 pypi_0 pypi
+[conda] torchvision 0.16.2 pypi_0 pypi
\ No newline at end of file
diff --git a/pyproject.toml b/pyproject.toml
index 05043fd1..4483b3ab 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -19,7 +19,7 @@ classifiers = [
"Operating System :: OS Independent",
]
requires-python = ">=3.8"
-license = { "text" = "MIT" }
+license = { text = "MIT" }
dependencies = [
"accelerate>=0.21.0",
"black==24.1.0",
@@ -34,11 +34,11 @@ dependencies = [
"sacrebleu>=1.5.0",
"scikit-learn>=0.24.1",
"sqlitedict",
- "torch>=1.8", # Note the version specification here for torch
+ "torch>=1.8",
"openai>=1.0.0",
"pycocoevalcap",
"tqdm-multiprocess",
- "transformers>=4.36.2",
+ "transformers>=4.31.0",
"zstandard",
"pillow",
"pyyaml",
@@ -53,50 +53,11 @@ dependencies = [
"transformers-stream-generator",
"tiktoken",
"pre-commit",
- "llava@git+https://github.com/haotian-liu/LLaVA",
-]
-
-[project.optional-dependencies]
-llava_repr = [
- "accelerate>=0.21.0",
- "black==24.1.0",
- "datasets==2.16.1",
- "evaluate>=0.4.0",
- "jsonlines",
- "numexpr",
- "peft>=0.2.0",
- "pybind11>=2.6.2",
- "pytablewriter",
- "rouge-score>=0.0.4",
- "sacrebleu>=1.5.0",
- "scikit-learn>=0.24.1",
- "sqlitedict",
- "openai>=1.0.0",
- "pycocoevalcap",
- "tqdm-multiprocess",
- "transformers>=4.36.2",
- "zstandard",
- "pillow",
- "pyyaml",
- "sympy",
- "mpmath",
- "Jinja2",
- "openpyxl",
- "Levenshtein",
- "hf_transfer",
- "tenacity",
- "wandb>=0.16.0",
- "transformers-stream-generator",
- "tiktoken",
- "pre-commit",
- "torch==2.0.1", # Specific version for llava_repr
- "llava@git+https://github.com/haotian-liu/LLaVA",
]
[tool.setuptools.packages.find]
include = ["lmms_eval*"]
-# required to include yaml files in pip installation
[tool.setuptools.package-data]
lmms_eval = ["**/*.yaml", "tasks/**/*"]
diff --git a/ttyd b/ttyd
new file mode 160000
index 00000000..68521f5b
--- /dev/null
+++ b/ttyd
@@ -0,0 +1 @@
+Subproject commit 68521f5b029f3faba7b693e59cf4c175ad06a0db