Skip to content

Commit

Permalink
debug eval script times
Browse files Browse the repository at this point in the history
Signed-off-by: adithyare <adithyare@nvidia.com>
  • Loading branch information
arendu committed Nov 19, 2024
1 parent 15fdf8a commit ca902fd
Showing 1 changed file with 17 additions and 9 deletions.
26 changes: 17 additions & 9 deletions examples/nlp/language_modeling/megatron_mamba_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,23 +352,31 @@ def main(cfg) -> None:
prompts = load_prompts(cfg)

# First method of running text generation, call model.generate method
response = model.generate(inputs=prompts, length_params=length_params, sampling_params=sampling_params)
for i in range(3):
st = time.perf_counter()
response = model.generate(inputs=prompts, length_params=length_params, sampling_params=sampling_params)
tdiff = time.perf_counter() - st
print(f"[Try{i} model.generate took {tdiff} seconds...")

print("***************************")
print(response)
print("***************************")
#print("***************************")
#print(response)
#print("***************************")

# Second method of running text generation, call trainer.predict [recommended]
bs = 2
ds = RequestDataSet(prompts)
request_dl = DataLoader(dataset=ds, batch_size=bs)
config = OmegaConf.to_container(cfg.inference)
model.set_inference_config(config)
response = trainer.predict(model, request_dl)

print("***************************")
print(response)
print("***************************")
for i in range(3):
st = time.perf_counter()
response = trainer.predict(model, request_dl)
tdiff = time.perf_counter() - st
print(f"[Try{i} trainer.predict took {tdiff} seconds...")

#print("***************************")
#print(response)
#print("***************************")

# Third method of running text generation, use inference server
if cfg.server:
Expand Down

0 comments on commit ca902fd

Please sign in to comment.