Skip to content

Commit

Permalink
Clean up logging for generate subcommand and setup log levels
Browse files Browse the repository at this point in the history
  • Loading branch information
George Hong committed Apr 17, 2024
1 parent 3623645 commit 45b68c3
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 12 deletions.
27 changes: 15 additions & 12 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
from cli import add_arguments_for_generate, arg_init, check_args
from quantize import set_precision

import logging
logger = logging.getLogger(__name__)

B_INST, E_INST = "[INST]", "[/INST]"

@dataclass
Expand Down Expand Up @@ -66,7 +69,7 @@ def device_sync(device):
elif ("cpu" in device) or ("mps" in device):
pass
else:
print(f"device={ device } is not yet suppported")
logging.error(f"device={ device } is not yet suppported")


torch._inductor.config.coordinate_descent_tuning = True
Expand Down Expand Up @@ -106,15 +109,15 @@ def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None):
def prefill(
model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs
) -> torch.Tensor:
print(f"x: {x}, input_pos: {input_pos}")
logging.debug(f"x: {x}, input_pos: {input_pos}")
width = x.size(1)
assert input_pos.size(0) == width
sequential_prefill = True

if sequential_prefill:
for i in range(width):
x_sliced, ip_sliced = x[:, i].view(-1, 1), input_pos[i].view(-1)
print(f"<sliced> x: {x_sliced}, input_pos: {ip_sliced}")
logging.debug(f"<sliced> x: {x_sliced}, input_pos: {ip_sliced}")
logits = model(x_sliced, ip_sliced) # (x[:, i], input_pos[i])
else:
# input_pos: [B, S]
Expand Down Expand Up @@ -340,12 +343,12 @@ def _main(
# # only print on rank 0
# print = lambda *args, **kwargs: None

print(f"Using device={builder_args.device}")
logging.info(f"Using device={builder_args.device}")
set_precision(builder_args.precision)
is_speculative = speculative_builder_args.checkpoint_path is not None

if generator_args.chat_mode and not builder_args.is_chat_model:
print("""
logging.warning("""
*******************************************************
This model is not known to support the chat function.
We will enable chat mode based on your instructions.
Expand Down Expand Up @@ -374,7 +377,7 @@ def _main(
encoded = encode_tokens(
tokenizer, generator_args.prompt, bos=True, device=builder_args.device
)
print(encoded)
logging.debug(encoded)
prompt_length = encoded.size(0)

model_size = sum(
Expand Down Expand Up @@ -469,7 +472,7 @@ def callback(x):
)
aggregate_metrics["accept_counts"].append(metrics["accept_counts"])
if i == -1:
print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
logging.info(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
continue
if hasattr(prof, "export_chrome_trace"):
if use_tp:
Expand All @@ -486,23 +489,23 @@ def callback(x):
tokens_generated = y.size(0) - prompt_length
tokens_sec = tokens_generated / t
aggregate_metrics["tokens_per_sec"].append(tokens_sec)
print(
logging.info(
f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_sec:.02f} tokens/sec"
)
print(f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s")
print("==========")
if is_speculative:
counts_aggregated = [sum(i) for i in zip(*aggregate_metrics["accept_counts"])]
acceptance_probs = [i / sum(counts_aggregated) for i in counts_aggregated]
print(f"Acceptance probs: {acceptance_probs}")
print(
logging.info(f"Acceptance probs: {acceptance_probs}")
logging.info(
f"Mean Accepted: {sum([idx * i for idx, i in enumerate(counts_aggregated)])/sum(counts_aggregated)}"
)

print(
logging.info(
f"Average tokens/sec: {torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item():.2f}"
)
print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB")
logging.info(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB")


def main(args):
Expand Down
3 changes: 3 additions & 0 deletions torchchat.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
check_args,
)

import logging

default_device = "cpu" # 'cuda' if torch.cuda.is_available() else 'cpu'


Expand All @@ -35,6 +37,7 @@

args = parser.parse_args()
args = arg_init(args)
logging.basicConfig(format='%(message)s', level=logging.DEBUG if args.verbose else logging.INFO)

if args.subcommand == "generate":
check_args(args, "generate")
Expand Down

0 comments on commit 45b68c3

Please sign in to comment.