Skip to content

Commit

Permalink
More formatting (mlc-ai#1099)
Browse files Browse the repository at this point in the history
  • Loading branch information
junrushao authored Oct 21, 2023
1 parent cf39bf6 commit e9b85ce
Show file tree
Hide file tree
Showing 6 changed files with 92 additions and 99 deletions.
31 changes: 10 additions & 21 deletions tests/debug/compare_lib.py → tests/python/legacy/compare_lib.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
from typing import List

import argparse
import os
import json
import os
from typing import List

import tvm
from tvm import relax
from tvm import rpc
from tvm.relax.testing.lib_comparator import LibCompareVMInstrument
import numpy as np

import torch
import tvm
from transformers import AutoTokenizer, LlamaTokenizer
from tvm import relax, rpc
from tvm.relax.testing.lib_comparator import LibCompareVMInstrument

from mlc_llm import utils

Expand Down Expand Up @@ -53,7 +50,7 @@ def compare(

if self.time_eval and name not in self.time_eval_results:
res = self.mod.time_evaluator(
name, self.device, number=20, repeat=3#, cache_flush_bytes=256 * 10**6
name, self.device, number=20, repeat=3 # , cache_flush_bytes=256 * 10**6
)(*new_args)
self.time_eval_results[name] = (res.mean, 1)
print(f"Time-eval result {name} on {self.device}: {res}")
Expand Down Expand Up @@ -121,9 +118,7 @@ def __init__(self, args):
)
)
self.cmp_device = tvm.device(args.cmp_device)
self.const_params_dict = utils.load_params(
args.artifact_path, self.primary_device
)
self.const_params_dict = utils.load_params(args.artifact_path, self.primary_device)
self.cmp_instrument = LibCompare(
self.lib,
self.cmp_device,
Expand All @@ -134,9 +129,7 @@ def __init__(self, args):


def deploy_to_pipeline(args) -> None:
with open(
os.path.join(args.artifact_path, "params", "mlc-chat-config.json"), "r"
) as f:
with open(os.path.join(args.artifact_path, "params", "mlc-chat-config.json"), "r") as f:
config = json.load(f)

primary_device = tvm.device(args.primary_device)
Expand All @@ -157,18 +150,14 @@ def deploy_to_pipeline(args) -> None:
tokenizer(args.prompt, return_tensors="pt").input_ids.to(torch.int32).numpy(),
primary_device,
)
first_sampled_token = tvm.nd.array(
np.array([[6234]]).astype("int32"), primary_device
)
first_sampled_token = tvm.nd.array(np.array([[6234]]).astype("int32"), primary_device)
seq_len_shape = tvm.runtime.ShapeTuple([inputs.shape[1]])
second_seq_len_shape = tvm.runtime.ShapeTuple([inputs.shape[1] + 1])
kv_caches = state.vm["create_kv_cache"]()

print("Running inference...")
print("======================= Starts Encoding =======================")
logits, kv_caches = state.vm["prefill"](
inputs, seq_len_shape, kv_caches, const_params
)
logits, kv_caches = state.vm["prefill"](inputs, seq_len_shape, kv_caches, const_params)
print_as_table(
sorted(
state.cmp_instrument.time_eval_results.items(),
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import argparse
import os
import pickle

import numpy as np
import torch
import tvm
from transformers import AutoTokenizer
from tvm import relax
import pickle

from mlc_llm import utils

Expand Down Expand Up @@ -77,12 +77,8 @@ def deploy_to_pipeline(args) -> None:
)

print("Tokenizing...")
inputs = (
tokenizer(args.prompt, return_tensors="pt").input_ids.to(torch.int32).numpy()
)
first_sampled_token = tvm.nd.array(
np.array([[6234]]).astype("int32"), primary_device
)
inputs = tokenizer(args.prompt, return_tensors="pt").input_ids.to(torch.int32).numpy()
first_sampled_token = tvm.nd.array(np.array([[6234]]).astype("int32"), primary_device)
seq_len_shape = tvm.runtime.ShapeTuple([inputs.shape[1]])
second_seq_len_shape = tvm.runtime.ShapeTuple([inputs.shape[1] + 1])
kv_caches = state.vm["create_kv_cache"]()
Expand Down
8 changes: 2 additions & 6 deletions tests/evaluate.py → tests/python/legacy/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,7 @@ def compare(
repeat=3,
)(*new_args).mean
shapes = [arg.shape for arg in new_args]
total_bytes = sum(
arg.numpy().size * arg.numpy().itemsize for arg in new_args
)
total_bytes = sum(arg.numpy().size * arg.numpy().itemsize for arg in new_args)
self.time_eval_results[name] = (res, 1, shapes, total_bytes)
else:
record = self.time_eval_results[name]
Expand Down Expand Up @@ -177,9 +175,7 @@ def deploy_to_pipeline(args) -> None: # pylint: disable=too-many-locals
print("Profiling...")
kv_caches = vm["create_kv_cache"]()

logits, kv_caches = vm["prefill"](
inputs, seq_len_shape, kv_caches, const_params
)
logits, kv_caches = vm["prefill"](inputs, seq_len_shape, kv_caches, const_params)
print("======================= Encoding Profiling =======================")
print_as_table(
sorted(
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
import dataclasses
import unittest

from mlc_llm import BuildArgs, utils, core
from mlc_llm import BuildArgs, core, utils


def old_make_args():
"""The exact old way of creating `ArgumentParser`, used to test whether
`BuildArgs` is equivalent to this. """
`BuildArgs` is equivalent to this."""
args = argparse.ArgumentParser()
args.add_argument(
"--model",
Expand All @@ -17,7 +18,7 @@ def old_make_args():
'The name of the model to build. If it is "auto", we will '
'automatically set the model name according to "--model-path", '
'"hf-path" or the model folders under "--artifact-path/models"'
)
),
)
args.add_argument(
"--hf-path",
Expand All @@ -30,19 +31,16 @@ def old_make_args():
type=str,
choices=[*utils.quantization_schemes.keys()],
default=list(utils.quantization_schemes.keys())[0],
help="The quantization mode we use to compile."
help="The quantization mode we use to compile.",
)
args.add_argument(
"--max-seq-len",
type=int,
default=-1,
help="The maximum allowed sequence length for the model."
help="The maximum allowed sequence length for the model.",
)
args.add_argument(
"--target",
type=str,
default="auto",
help="The target platform to compile the model for."
"--target", type=str, default="auto", help="The target platform to compile the model for."
)
args.add_argument(
"--reuse-lib",
Expand All @@ -51,10 +49,7 @@ def old_make_args():
help="Whether to reuse a previously generated lib.",
)
args.add_argument(
"--artifact-path",
type=str,
default="dist",
help="Where to store the output."
"--artifact-path", type=str, default="dist", help="Where to store the output."
)
args.add_argument(
"--use-cache",
Expand All @@ -66,13 +61,13 @@ def old_make_args():
"--debug-dump",
action="store_true",
default=False,
help="Whether to dump debugging files during compilation."
help="Whether to dump debugging files during compilation.",
)
args.add_argument(
"--debug-load-script",
action="store_true",
default=False,
help="Whether to load the script for debugging."
help="Whether to load the script for debugging.",
)
args.add_argument(
"--llvm-mingw",
Expand All @@ -81,10 +76,7 @@ def old_make_args():
help="/path/to/llvm-mingw-root, use llvm-mingw to cross compile to windows.",
)
args.add_argument(
"--system-lib",
action="store_true",
default=False,
help="A parameter to `relax.build`."
"--system-lib", action="store_true", default=False, help="A parameter to `relax.build`."
)
args.add_argument(
"--sep-embed",
Expand All @@ -99,17 +91,20 @@ def old_make_args():

return args


# Referred to HfArgumentParserTest from https://github.com/huggingface/
# transformers/blob/e84bf1f734f87aa2bedc41b9b9933d00fc6add98/tests/utils
# /test_hf_argparser.py#L143
class BuildArgsTest(unittest.TestCase):
"""Tests whether BuildArgs reaches parity with regular ArgumentParser."""
def argparsers_equal(self, parse_a: argparse.ArgumentParser,
parse_b: argparse.ArgumentParser):

def argparsers_equal(self, parse_a: argparse.ArgumentParser, parse_b: argparse.ArgumentParser):
"""
Small helper to check pseudo-equality of parsed arguments on `ArgumentParser` instances.
"""
self.assertEqual(len(parse_a._actions), len(parse_b._actions)) # pylint: disable=protected-access
self.assertEqual(
len(parse_a._actions), len(parse_b._actions)
) # pylint: disable=protected-access
for x, y in zip(parse_a._actions, parse_b._actions): # pylint: disable=protected-access
xx = {k: v for k, v in vars(x).items() if k != "container"}
yy = {k: v for k, v in vars(y).items() if k != "container"}
Expand Down Expand Up @@ -175,5 +170,6 @@ def test_namespaces_are_equivalent_str_boolean_int(self):
build_args_namespace = argparse.Namespace(**build_args_as_dict)
self.assertNotEqual(build_args_namespace, parsed_args)

if __name__ == '__main__':

if __name__ == "__main__":
unittest.main()
Loading

0 comments on commit e9b85ce

Please sign in to comment.