Skip to content

Commit

Permalink
fix launcher
Browse files Browse the repository at this point in the history
  • Loading branch information
OlivierDehaene committed Oct 8, 2024
1 parent 9c6aed1 commit 70fe6ce
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 53 deletions.
12 changes: 0 additions & 12 deletions benchmark/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,18 +179,6 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
.await
.expect("Unable to clear cache");

// Warmup shard
let max_batch_size = batch_size.iter().max().unwrap();
sharded_client
.warmup(
sequence_length,
sequence_length * max_batch_size,
(sequence_length + decode_length) * max_batch_size,
Some(*max_batch_size as usize),
)
.await
.expect("Unable to warmup");

tracing::info!("Connected");

// Run app
Expand Down
12 changes: 0 additions & 12 deletions launcher/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1727,12 +1727,6 @@ fn main() -> Result<(), LauncherError> {
"`max_input_tokens must be < `max_total_tokens`".to_string(),
));
}
if max_input_tokens as u32 > max_batch_prefill_tokens {
return Err(LauncherError::ArgumentValidation(format!(
"`max_batch_prefill_tokens` must be >= `max_input_tokens`. Given: {} and {}",
max_batch_prefill_tokens, max_input_tokens
)));
}

if matches!(args.quantize, Some(Quantization::Bitsandbytes)) {
tracing::warn!("Bitsandbytes is deprecated, use `eetq` instead, which provides better latencies overall and is drop-in in most cases.");
Expand Down Expand Up @@ -1786,12 +1780,6 @@ fn main() -> Result<(), LauncherError> {
}

if let Some(ref max_batch_total_tokens) = args.max_batch_total_tokens {
if max_batch_prefill_tokens > *max_batch_total_tokens {
return Err(LauncherError::ArgumentValidation(format!(
"`max_batch_prefill_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
max_batch_prefill_tokens, max_batch_total_tokens
)));
}
if max_total_tokens as u32 > *max_batch_total_tokens {
return Err(LauncherError::ArgumentValidation(format!(
"`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {} and {}",
Expand Down
66 changes: 37 additions & 29 deletions server/text_generation_server/models/flash_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,6 @@ class FlashCausalLMBatch(Batch):
# Will be set by `generate_token` and reset after each prefill forward
prefill_logprob_tokens: List[Optional[Tokens]]

# Prefixes
prefix_ids: List[List[int]]

# All tokens
all_input_ids: List[List[int]]
all_input_ids_tensor: torch.Tensor
Expand Down Expand Up @@ -259,7 +256,6 @@ def from_tokenized(
read_offsets = []
all_input_ids = []
all_postfix_ids = []
prefix_ids = []
requests_idx_mapping = {}

next_token_chooser_parameters = []
Expand Down Expand Up @@ -297,7 +293,6 @@ def from_tokenized(
assert get_support_chunking()
assert input_length > 0

prefix_ids.append(tokenized_input[:cache_length])
postfix_ids = tokenized_input[cache_length : cache_length + input_length]

assert (
Expand Down Expand Up @@ -400,7 +395,6 @@ def from_tokenized(
read_offsets=read_offsets,
all_input_ids=all_input_ids,
all_input_ids_tensor=all_input_ids_tensor,
prefix_ids=prefix_ids,
next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
Expand Down Expand Up @@ -464,7 +458,6 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch":
requests = []
block_tables = []
all_input_ids = []
prefix_ids = []
input_ids = []

prompt_lengths = []
Expand Down Expand Up @@ -505,7 +498,6 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch":
)

all_input_ids.append(self.all_input_ids[idx])
prefix_ids.append(self.prefix_ids[idx])

prompt_lengths.append(self.prompt_lengths[idx])
input_lengths.append(request_input_length)
Expand Down Expand Up @@ -621,7 +613,6 @@ def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch":
read_offsets=read_offsets,
all_input_ids=all_input_ids,
all_input_ids_tensor=all_input_ids_tensor,
prefix_ids=prefix_ids,
next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
Expand Down Expand Up @@ -718,7 +709,6 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch
block_tables = []
cache_lengths = []
all_input_ids = []
prefix_ids = []

prompt_lengths = []
input_lengths = []
Expand Down Expand Up @@ -802,7 +792,6 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch
block_tables.extend(batch.block_tables)
cache_lengths.extend(batch.cache_lengths)
all_input_ids.extend(batch.all_input_ids)
prefix_ids.extend(batch.prefix_ids)

prompt_lengths.extend(batch.prompt_lengths)
input_lengths.extend(batch.input_lengths)
Expand Down Expand Up @@ -873,7 +862,6 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch
read_offsets=read_offsets,
all_input_ids=all_input_ids,
all_input_ids_tensor=all_input_ids_tensor,
prefix_ids=prefix_ids,
next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias,
top_n_tokens=top_n_tokens,
Expand Down Expand Up @@ -1839,6 +1827,8 @@ def generate_token(
batch.input_lengths,
batch.all_input_ids,
accepted_ids,
current_prefilling_mask,
batch.prefilling_mask,
)

# We do two for loops as the first one can run completely asynchronously from the GPU while for the second
Expand All @@ -1855,6 +1845,8 @@ def generate_token(
input_length,
all_input_ids,
n_accepted_ids,
request_was_prefilling,
request_is_prefilling,
) in enumerate(iterator):
# Indexing metadata
start_index = cumulative_length
Expand All @@ -1864,7 +1856,6 @@ def generate_token(
# Indexing metadata
out_start_index = batch.prefill_cu_outlens[i]
out_end_index = batch.prefill_cu_outlens[i + 1]
out_length = out_end_index - out_start_index

if finished_prefilling:
# Initialize position_ids
Expand All @@ -1880,21 +1871,25 @@ def generate_token(
# Used to gather prefill logprobs
# Copy batch.input_ids to prefill_token_indices
if prefill_logprobs:
# If the request was prefilling and cache_length == 0, the first token is a bogus token
# and needs to be removed. We do so by incrementing the start_index
if request_was_prefilling and cache_length == 0:
start_index += 1

# If the request was prefilling, and it is done prefilling, the last token was generated and is
# therefore not part of the prefill. We remove it by decrementing out_end_index
if request_was_prefilling and not request_is_prefilling:
out_end_index -= 1

if len(batch) > 1:
prefill_tokens_indices[out_start_index : out_end_index - 1] = (
batch.input_ids[start_index + 1 : start_index + out_length]
prefill_tokens_indices[out_start_index:out_end_index] = (
batch.input_ids[start_index:end_index]
)
else:
# Set prefill_tokens_indices to the correct slice
prefill_tokens_indices = batch.input_ids[
start_index + 1 : start_index + out_length
]
prefill_tokens_indices = batch.input_ids[start_index:end_index]

# Represent whether this request is still prefilling
# If it is, the tokens we decoded should be ignored
accept_tokens = cache_length + input_length >= prompt_length

if accept_tokens:
if not request_is_prefilling:
# Only save tokens if we are done prefilling for this request
for j in range(n_accepted_ids):
batch.all_input_ids_tensor[i, cache_length + input_length + j] = (
Expand Down Expand Up @@ -1995,7 +1990,6 @@ def generate_token(
batch.read_offsets,
batch.stopping_criterias,
batch.all_input_ids,
batch.prefix_ids,
batch.next_token_chooser.do_sample,
batch.next_token_chooser.seeds,
batch.top_n_tokens,
Expand All @@ -2019,7 +2013,6 @@ def generate_token(
read_offset,
stopping_criteria,
all_input_ids,
prefix_ids,
do_sample,
seed,
top_n_tokens,
Expand All @@ -2039,26 +2032,41 @@ def generate_token(
out_start_index = batch.prefill_cu_outlens[i]
out_end_index = batch.prefill_cu_outlens[i + 1]

# log_master(logger.info, f"{prefill_logprobs}")

if not request_is_prefilling:
# If the request is done prefilling, then the last logprob is a generated token
# We need to remove it
out_end_index -= 1

request_prefill_logprobs = prefill_logprobs[
out_start_index : out_end_index - 1
out_start_index:out_end_index
]
prefill_token_ids = all_input_ids[
cache_length : cache_length + input_length
]
prefill_token_ids = all_input_ids[:-1]

past_prefill_logprob_tokens = batch.prefill_logprob_tokens[i]

if past_prefill_logprob_tokens is None:
# Remove generated token to only have prefill and add nan for first prompt token
request_prefill_logprobs = [float("nan")] * (
len(prefix_ids) + 1
cache_length + 1
) + request_prefill_logprobs
prefill_token_ids = prefix_ids + prefill_token_ids
prefill_token_ids = (
all_input_ids[:cache_length] + prefill_token_ids
)

prefill_texts = self.tokenizer.batch_decode(
prefill_token_ids,
clean_up_tokenization_spaces=False,
skip_special_tokens=False,
)

# log_master(logger.info, f"{prefill_token_ids}")
# log_master(logger.info, f"{request_prefill_logprobs}")
# log_master(logger.info, f"{prefill_texts}")

prefill_logprob_tokens = Tokens(
prefill_token_ids,
request_prefill_logprobs,
Expand Down

0 comments on commit 70fe6ce

Please sign in to comment.