diff --git a/llm/inference/janus_pro/generation.py b/llm/inference/janus_pro/generation.py index 11fd731b5..aaa5f5fd1 100644 --- a/llm/inference/janus_pro/generation.py +++ b/llm/inference/janus_pro/generation.py @@ -1,3 +1,4 @@ +import sys import os import PIL.Image import mindspore @@ -77,8 +78,8 @@ def generate( generated_tokens = ops.zeros(parallel_size, image_token_num_per_image, dtype=ms.int32) + print("Generating tokens: ") for i in range(image_token_num_per_image): - print(f"generating token {i}") outputs = mmgpt.language_model.model(inputs_embeds=inputs_embeds, use_cache=True, past_key_values=outputs.past_key_values if i != 0 else None) hidden_states = outputs.last_hidden_state # (parallel_size*2, len(input_ids), 2048) @@ -97,7 +98,8 @@ def generate( # print("img_embeds.shape:", img_embeds.shape) # print("img_embeds.dtype:", img_embeds.dtype) inputs_embeds = img_embeds.unsqueeze(dim=1) #(parallel_size*2, 2048) - print("generated one token") + sys.stdout.write('.'); sys.stdout.flush() + print(f"Generated {i+1} tokens.\n") if image_token_num_per_image==576: dec = mmgpt.gen_vision_model.decode_code(generated_tokens.astype(ms.int32), shape=[parallel_size, 8, img_size//patch_size, img_size//patch_size]) @@ -121,4 +123,4 @@ def generate( vl_gpt, vl_chat_processor, prompt, - ) \ No newline at end of file + )