Skip to content

Commit

Permalink
[Wandb Logger] add models, and args to wandb tables. (EvolvingLMMs-La…
Browse files Browse the repository at this point in the history
…b#55)

* Refactor logging in lmms_eval package

* Refactor variable names in lmms_eval package
  • Loading branch information
Luodian authored Feb 27, 2024
1 parent 9c039a7 commit 521ece2
Show file tree
Hide file tree
Showing 6 changed files with 21 additions and 10 deletions.
6 changes: 4 additions & 2 deletions lmms_eval/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,12 +201,14 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
results_list.append(results)

accelerator.wait_for_everyone()
if is_main_process:
if is_main_process and args.wandb_args:
wandb_logger.post_init(results)
wandb_logger.log_eval_result()
if args.wandb_log_samples and samples is not None:
wandb_logger.log_eval_samples(samples)

wandb_logger.finish()

except Exception as e:
traceback.print_exc()
eval_logger.error(f"Error during evaluation: {e}")
Expand Down Expand Up @@ -312,7 +314,7 @@ def cli_evaluate_single(args: Union[argparse.Namespace, None] = None) -> None:
for task_name, config in results["configs"].items():
filename = args.output_path.joinpath(f"{task_name}.json")
# Structure the data with 'args' and 'logs' keys
data_to_dump = {"args": vars(args), "config": config, "logs": sorted(samples[task_name], key=lambda x: x["doc_id"])} # Convert Namespace to dict
data_to_dump = {"args": vars(args), "model_configs": config, "logs": sorted(samples[task_name], key=lambda x: x["doc_id"])} # Convert Namespace to dict
samples_dumped = json.dumps(data_to_dump, indent=4, default=_handle_non_serializable)
filename.open("w").write(samples_dumped)
eval_logger.info(f"Saved samples to {filename}")
Expand Down
2 changes: 1 addition & 1 deletion lmms_eval/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def simple_evaluate(

if lm.rank == 0:
# add info about the model and few shot config
results["config"] = {
results["model_configs"] = {
"model": model if isinstance(model, str) else model.model.config._name_or_path,
"model_args": model_args,
"batch_size": batch_size,
Expand Down
17 changes: 13 additions & 4 deletions lmms_eval/logging_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ def __init__(self, args):
os.environ["WANDB_MODE"] = "offline"
self.init_run()

def finish(self):
self.run.finish()

@tenacity.retry(wait=tenacity.wait_fixed(5), stop=tenacity.stop_after_attempt(5))
def init_run(self):
if "name" not in self.wandb_args:
Expand Down Expand Up @@ -152,6 +155,9 @@ def _sanitize_results_dict(self) -> Tuple[Dict[str, str], Dict[str, Any]]:
def _log_results_as_table(self) -> None:
"""Generate and log evaluation results as a table to W&B."""
columns = [
"Model",
"Args",
"Tasks",
"Version",
"Filter",
"num_fewshot",
Expand All @@ -164,6 +170,9 @@ def make_table(columns: List[str], key: str = "results"):
table = wandb.Table(columns=columns)
results = copy.deepcopy(self.results)

model_name = results.get("model_configs").get("model")
model_args = results.get("model_configs").get("model_args")

for k, dic in results.get(key).items():
if k in self.group_names and not key == "groups":
continue
Expand All @@ -183,14 +192,14 @@ def make_table(columns: List[str], key: str = "results"):
se = dic[m + "_stderr" + "," + f]
if se != "N/A":
se = "%.4f" % se
table.add_data(*[k, version, f, n, m, str(v), str(se)])
table.add_data(*[model_name, model_args, k, version, f, n, m, str(v), str(se)])
else:
table.add_data(*[k, version, f, n, m, str(v), ""])
table.add_data(*[model_name, model_args, k, version, f, n, m, str(v), ""])

return table

# log the complete eval result to W&B Table
table = make_table(["Tasks"] + columns, "results")
table = make_table(columns, "results")
self.run.log({"evaluation/eval_results": table})

if "groups" in self.results.keys():
Expand All @@ -209,7 +218,7 @@ def log_eval_result(self) -> None:
"""Log evaluation results to W&B."""
# Log configs to wandb
configs = self._get_config()
self.run.config.update(configs)
self.run.config.update(configs, allow_val_change=True)

wandb_summary, self.wandb_results = self._sanitize_results_dict()
# update wandb.run.summary with items that were removed
Expand Down
2 changes: 1 addition & 1 deletion lmms_eval/tasks/internal_eval/d170_cn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ dataset_kwargs:
task: "d170_cn"
test_split: test
output_type: generate_until
doc_to_visual: !function utils.doc_to_visual
doc_to_text: !function utils.doc_to_text # Such that {{prompt}} will be replaced by doc["question"]
doc_to_visual: !function d170_cn_utils.doc_to_visual
doc_to_target: "{{annotation}}"
generation_kwargs:
until:
Expand Down
2 changes: 1 addition & 1 deletion lmms_eval/tasks/mme/mme_test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ generation_kwargs:
num_beams: 1
do_sample: false
# The return value of process_results will be used by metrics
process_results: !function utils.mme_process_result
process_results: !function utils.mme_process_results
# Note that the metric name can be either a registed metric function (such as the case for GQA) or a key name returned by process_results
metric_list:
- metric: mme_percetion_score
Expand Down
2 changes: 1 addition & 1 deletion lmms_eval/tasks/mme/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def parse_pred_ans(pred_ans):
return pred_label


def mme_process_result(doc, results):
def mme_process_results(doc, results):
"""
Args:
doc: a instance of the eval dataset
Expand Down

0 comments on commit 521ece2

Please sign in to comment.