Skip to content

Commit

Permalink
easier log prob
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexPiche committed Dec 23, 2024
1 parent dbc0c0b commit f74016d
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 16 deletions.
19 changes: 9 additions & 10 deletions examples/rl_gsm8k/orchestrate_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from tqdm import tqdm

import wandb

wandb.require("core")
from .cot_math_agent import (
CoTMathAgent,
Expand Down Expand Up @@ -153,16 +154,13 @@ def extract_tape_training_samples(
continue
llm_call = step.metadata.other["llm_call"]
trace = agent.llm.make_training_text(llm_call.prompt, llm_call.output)

input_ids = [lp["token_id"] for lp in llm_call.logprobs]
labels = [lp["token_id"] for lp in llm_call.logprobs if lp["generated"]]
labels = [-100] * (len(input_ids) - len(labels)) + labels

trace.input_ids = input_ids
trace.labels = labels
#TODO: token log probs look a bit off
ref_logprobs = agent.llm.new_get_logprobs(trace.input_ids)


trace.reward = reward
trace.logprobs = [lp["logprob"] for lp in llm_call.logprobs if lp["generated"]]
Expand Down Expand Up @@ -348,7 +346,7 @@ def main(cfg: DictConfig):
parameters=cfg.llm.parameters,
use_cache=False,
collect_logprobs=True,
observe_llm_calls=False
observe_llm_calls=False,
)

test_llm = TrainableLLM(
Expand All @@ -357,7 +355,7 @@ def main(cfg: DictConfig):
tokenizer_name=str(assistant_model_path),
parameters=cfg.test_llm.parameters,
use_cache=False,
observe_llm_calls=False
observe_llm_calls=False,
)

train_agent = CoTMathAgent.create(llm=llm)
Expand All @@ -378,7 +376,8 @@ def main(cfg: DictConfig):
more_llm_stats = {
"make_data_output_tokens/s": llm_stats["total_prompt_tokens"] / make_data_took,
"make_data_prompt_tokens/s": llm_stats["total_output_tokens"] / make_data_took,
"make_data_tokens/s": (llm_stats["total_output_tokens"] + llm_stats["total_prompt_tokens"]) / make_data_took,
"make_data_tokens/s": (llm_stats["total_output_tokens"] + llm_stats["total_prompt_tokens"])
/ make_data_took,
}
for k, v in llm_stats.items():
if "/s" in k:
Expand Down Expand Up @@ -490,10 +489,10 @@ def main(cfg: DictConfig):

start_finetune = time.time()
launch_training(
str(conf_dir),
str(state["iteration"]),
str(conf_dir),
str(state["iteration"]),
cfg.accelerate_cfg_path,
use_deepspeed=cfg.use_deepspeed # defaults to False
use_deepspeed=cfg.use_deepspeed, # defaults to False
)
time_finetune = time.time() - start_finetune
time_iteration = time.time() - start_iteration
Expand Down
14 changes: 8 additions & 6 deletions tapeagents/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -688,7 +688,7 @@ def new_get_logprobs(self, prompt_token_ids: list[int]):
"prompt": prompt_token_ids,
"temperature": 0.0,
"max_tokens": 0,
"logprobs": 1,
"logprobs": 0,
"echo": True,
"include_stop_str_in_output": True, # self.include_stop_str_in_output,
"skip_special_tokens": False,
Expand All @@ -708,13 +708,15 @@ def new_get_logprobs(self, prompt_token_ids: list[int]):
raise RuntimeError(f"Generation API wrong response: {r.text}", e)
logprobs = [
{
"logprob": lp,
"top_logprobs": [],
"token": t,
"token_id": self.tokenizer.encode(t, add_special_tokens=False)[0],
"logprob": None,
"token": self.tokenizer.bos_token,
}
for lp, t in zip(logprobs, tokens) if t != ""
]
for lp in response["choices"][0]["prompt_logprobs"]:
if lp:
for k, v in lp.items():
v.update({"generated": 0, "token_id": k})
logprobs.append(v)
return {"content": logprobs}

def get_logprobs_complete(self, prompt: str, output: str) -> dict[str, Any]:
Expand Down

0 comments on commit f74016d

Please sign in to comment.