Skip to content

Commit

Permalink
use chat template, set attention max and properly adjust length of ma…
Browse files Browse the repository at this point in the history
…x new tokens
  • Loading branch information
kyriediculous committed Jul 31, 2024
1 parent 5070365 commit 145437a
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 9 deletions.
43 changes: 34 additions & 9 deletions runner/app/pipelines/llm_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,44 +74,69 @@ async def __call__(self, prompt: str, history: Optional[List[tuple]] = None, sys
"role": "assistant", "content": assistant}])
conversation.append({"role": "user", "content": prompt})

logger.info(f"Conversation prepared: {conversation}")

logger.info("Applying chat template")
input_ids = self.tokenizer.apply_chat_template(
conversation, return_tensors="pt").to(self.model.device)
logger.info(f"Inputs after apply_chat_template: {input_ids}")
attention_mask = torch.ones_like(input_ids)

max_new_tokens = kwargs.get("max_tokens", 256)
temperature = kwargs.get("temperature", 0.7)

streamer = TextIteratorStreamer(
self.tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)

generate_kwargs = {
# Start with the generation config
generate_kwargs = self.generation_config.to_dict()
logger.info(f"Generation config: {generate_kwargs}")
# Update with our specific parameters
generate_kwargs.update({
"input_ids": input_ids,
"attention_mask": attention_mask,
"streamer": streamer,
"max_new_tokens": max_new_tokens,
"do_sample": True,
"temperature": temperature,
"eos_token_id": self.terminators,
}

# Use the generation config
generate_kwargs.update(self.generation_config.to_dict())
"eos_token_id": self.tokenizer.eos_token_id,
"pad_token_id": self.tokenizer.eos_token_id,
})

# This will enforce greedy generation (do_sample=False) when the temperature is passed 0, avoiding the crash.
if temperature == 0:
generate_kwargs['do_sample'] = False

logger.info(f"Final generate kwargs: {generate_kwargs}")

# Start generation in a separate thread
thread = Thread(target=self.model.generate, kwargs=generate_kwargs)
thread = Thread(target=self.model_generate_wrapper, kwargs=generate_kwargs)
thread.start()

total_tokens = 0
try:
for text in streamer:
total_tokens += 1
yield text
await asyncio.sleep(0) # Allow other tasks to run
except Exception as e:
logger.error(f"Error during streaming: {str(e)}")
raise

# Yield the final result
yield {"tokens_used": len(input_ids[0]) + max_new_tokens}
input_length = input_ids.size(1)
logger.info(f"Total tokens generated: {total_tokens}")
logger.info(f"Input length: {input_length}")
yield {"tokens_used": input_length + total_tokens}

def __str__(self):
return f"LLMGeneratePipeline(model_id={self.model_id})"

def model_generate_wrapper(self, **kwargs):
try:
logger.debug("Entering model.generate")
with torch.cuda.amp.autocast(): # Use automatic mixed precision
self.model.generate(**kwargs)
logger.debug("Exiting model.generate")
except Exception as e:
logger.error(f"Error in model.generate: {str(e)}", exc_info=True)
raise
18 changes: 18 additions & 0 deletions runner/check_torch_cuda.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import torch
import subprocess

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
print(f"CUDA version: {torch.version.cuda}")

# Check system CUDA version
try:
nvcc_output = subprocess.check_output(["nvcc", "--version"]).decode("utf-8")
cuda_version = nvcc_output.split("release ")[-1].split(",")[0]
print(f"System CUDA version: {cuda_version}")
except:
print("Unable to check system CUDA version")

# Print the current device
print(f"Current device: {torch.cuda.get_device_name(0)}")

0 comments on commit 145437a

Please sign in to comment.