diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index 19baeaff..b0bb2f91 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -54,22 +54,21 @@ """ -import tensorflow as tf -import tensorflow_text as tftxt - import argparse import asyncio - from dataclasses import dataclass from datetime import datetime import json import random import time from typing import Any, AsyncGenerator, List, Optional + import grpc from jetstream.core.proto import jetstream_pb2 from jetstream.core.proto import jetstream_pb2_grpc import numpy as np +import tensorflow as tf +import tensorflow_text as tftxt from tqdm.asyncio import tqdm @@ -96,6 +95,7 @@ class InputRequest: output: str = "" output_len: int = 0 + @dataclass class RequestFuncOutput: input_request: InputRequest = None @@ -109,12 +109,12 @@ class RequestFuncOutput: # Flatten the structure and return only the necessary results def to_dict(self): return { - "prompt": self.input_request.prompt, - "original_output": self.input_request.output, - "generated_text": self.generated_text, - "success": self.success, - "latency": self.latency, - "prompt_len": self.prompt_len + "prompt": self.input_request.prompt, + "original_output": self.input_request.output, + "generated_text": self.generated_text, + "success": self.success, + "latency": self.latency, + "prompt_len": self.prompt_len, } @@ -123,12 +123,14 @@ def get_tokenizer(tokenizer_name: str) -> Any: if tokenizer_name == "test": return "test" else: - with tf.io.gfile.GFile(tokenizer_name, 'rb') as model_fp: + with tf.io.gfile.GFile(tokenizer_name, "rb") as model_fp: sp_model = model_fp.read() sp_tokenizer = tftxt.SentencepieceTokenizer( - model=sp_model, add_bos=True, add_eos=False, reverse=False) + model=sp_model, add_bos=True, add_eos=False, reverse=False + ) return sp_tokenizer + def load_sharegpt_dataset( dataset_path: str, conversation_starter: str, @@ -141,7 +143,11 @@ def load_sharegpt_dataset( # Filter based on conversation starter if conversation_starter != "both": - dataset = [data for data in dataset if data["conversations"][0]["from"] == conversation_starter] + dataset = [ + data + for data in dataset + if data["conversations"][0]["from"] == conversation_starter + ] # Only keep the first two turns of each conversation. dataset = [ (data["conversations"][0]["value"], data["conversations"][1]["value"]) @@ -151,9 +157,7 @@ def load_sharegpt_dataset( return dataset -def load_openorca_dataset( - dataset_path: str -) -> List[tuple[str]]: +def load_openorca_dataset(dataset_path: str) -> List[tuple[str]]: # Load the dataset. with open(dataset_path) as f: dataset = json.load(f) @@ -187,23 +191,31 @@ def tokenize_dataset( prompt_len = len(prompt_token_ids[i]) output_len = len(outputs_token_ids[i]) tokenized_dataset.append( - (prompts[i], prompt_token_ids[i], outputs[i], prompt_len, output_len) + (prompts[i], prompt_token_ids[i], outputs[i], prompt_len, output_len) ) return tokenized_dataset def filter_dataset( - tokenized_dataset: List[tuple[Any]], - max_output_length: Optional[int] = None + tokenized_dataset: List[tuple[Any]], max_output_length: Optional[int] = None ) -> List[InputRequest]: if max_output_length is None: print("In InputRequest, pass in actual output_length for each sample") else: - print(f"In InputRequest, pass in max_output_length: {max_output_length} for each sample") + print( + f"In InputRequest, pass in max_output_length: {max_output_length} for" + " each sample" + ) # Filter out too long sequences. filtered_dataset: List[InputRequest] = [] - for prompt, prompt_token_ids, output, prompt_len, output_len in tokenized_dataset: + for ( + prompt, + prompt_token_ids, + output, + prompt_len, + output_len, + ) in tokenized_dataset: if prompt_len < 4 or output_len < 4: # Prune too short sequences. # This is because TGI causes errors when the input or output length @@ -212,7 +224,9 @@ def filter_dataset( if prompt_len > 1024 or prompt_len + output_len > 2048: # Prune too long sequences. continue - request = InputRequest(prompt, prompt_len, output, max_output_length or output_len) + request = InputRequest( + prompt, prompt_len, output, max_output_length or output_len + ) filtered_dataset.append(request) print(f"The dataset contains {len(tokenized_dataset)} samples.") @@ -226,20 +240,26 @@ def sample_requests( tokenizer: Any, num_requests: int, max_output_length: Optional[int] = None, - oversample_multiplier: float=1.2, - ) -> List[InputRequest]: + oversample_multiplier: float = 1.2, +) -> List[InputRequest]: # Original dataset size n = len(dataset) # Create necessary number of requests even if bigger than dataset size sampled_indices = random.sample( - range(n), min(int(num_requests * oversample_multiplier), n)) + range(n), min(int(num_requests * oversample_multiplier), n) + ) if num_requests > len(sampled_indices): - print(f"Number of requests {num_requests} is larger than size of dataset {n}.\n", - f"Repeating data to meet number of requests.\n") - sampled_indices = sampled_indices * int(np.ceil(num_requests / len(sampled_indices))) + print( + f"Number of requests {num_requests} is larger than size of dataset" + f" {n}.\n", + f"Repeating data to meet number of requests.\n", + ) + sampled_indices = sampled_indices * int( + np.ceil(num_requests / len(sampled_indices)) + ) print(f"{len(sampled_indices)=}") # some of these will be filtered out, so sample more than we need @@ -315,7 +335,9 @@ def calculate_metrics( return metrics -async def grpc_async_request(api_url: str, request: Any) -> tuple[list[str], float, float]: +async def grpc_async_request( + api_url: str, request: Any +) -> tuple[list[str], float, float]: """Send grpc synchronous request since the current grpc server is sync.""" options = [("grpc.keepalive_timeout_ms", 10000)] async with grpc.aio.insecure_channel(api_url, options=options) as channel: @@ -351,7 +373,9 @@ async def send_request( output = RequestFuncOutput() output.input_request = input_request output.prompt_len = input_request.prompt_len - generated_token_list, ttft, latency = await grpc_async_request(api_url, request) + generated_token_list, ttft, latency = await grpc_async_request( + api_url, request + ) output.ttft = ttft output.latency = latency output.generated_token_list = generated_token_list @@ -453,14 +477,15 @@ def mock_requests(total_mock_requests: int): def sample_warmup_requests(requests): interesting_buckets = [ - 0, - 16, - 32, - 64, - 128, - 256, - 512, - 1024,] + 0, + 16, + 32, + 64, + 128, + 256, + 512, + 1024, + ] for start, end in zip(interesting_buckets[:-1], interesting_buckets[1:]): for request in requests: @@ -481,28 +506,30 @@ def main(args: argparse.Namespace): tokenizer = get_tokenizer(tokenizer_id) if tokenizer == "test" or args.dataset == "test": - input_requests = mock_requests(args.total_mock_requests) # e.g. [("AB", 2, "AB", 3)] + input_requests = mock_requests( + args.total_mock_requests + ) # e.g. [("AB", 2, "AB", 3)] else: if args.dataset == "openorca": dataset = load_openorca_dataset(args.dataset_path) elif args.dataset == "sharegpt": dataset = load_sharegpt_dataset( - args.dataset_path, - args.conversation_starter, + args.dataset_path, + args.conversation_starter, ) # A given args.max_output_length value is the max generation step, # when the args.max_output_length is default to None, the sample's golden output length # will be used to decide the generation step input_requests = sample_requests( - dataset=dataset, - tokenizer=tokenizer, - num_requests=args.num_prompts, - max_output_length=args.max_output_length + dataset=dataset, + tokenizer=tokenizer, + num_requests=args.num_prompts, + max_output_length=args.max_output_length, ) if args.warmup_first: - print('Warm up start:' ) + print("Warm up start:") warmup_requests = list(sample_warmup_requests(input_requests)) * 2 benchmark_result, request_outputs = asyncio.run( benchmark( @@ -516,7 +543,7 @@ def main(args: argparse.Namespace): threads=args.threads, ) ) - print('Warm up done') + print("Warm up done") benchmark_result, request_outputs = asyncio.run( benchmark( @@ -561,7 +588,11 @@ def main(args: argparse.Namespace): if args.save_request_outputs: file_path = args.request_outputs_file_path with open(file_path, "w") as output_file: - json.dump([output.to_dict() for output in request_outputs], output_file, indent=4) + json.dump( + [output.to_dict() for output in request_outputs], + output_file, + indent=4, + ) if __name__ == "__main__": @@ -576,11 +607,13 @@ def main(args: argparse.Namespace): ) parser.add_argument("--port", type=str, default=9000) parser.add_argument( - "--dataset", type=str, default="test", choices=["test", "sharegpt", "openorca"], help="The dataset name." - ) - parser.add_argument( - "--dataset-path", type=str, help="Path to the dataset." + "--dataset", + type=str, + default="test", + choices=["test", "sharegpt", "openorca"], + help="The dataset name.", ) + parser.add_argument("--dataset-path", type=str, help="Path to the dataset.") parser.add_argument( "--model", type=str, @@ -637,7 +670,16 @@ def main(args: argparse.Namespace): "--max-output-length", type=int, default=None, - help="The maximum output length for reference request.", + help=( + "The maximum output length for reference request. It would be passed" + " to `max_tokens` parameter of the JetStream's DecodeRequest proto," + " and used in JetStream to control the output/decode length of a" + " sequence. It would not be used in the engine. We should always set" + " max_tokens <= (max_target_length - max_prefill_predict_length)." + " max_target_length is the maximum length of a sequence;" + " max_prefill_predict_length is the maximum length of the" + " input/prefill of a sequence." + ), ) parser.add_argument("--seed", type=int, default=0) @@ -678,26 +720,20 @@ def main(args: argparse.Namespace): "--request-outputs-file-path", type=str, default="/tmp/request-outputs.json", - help=( - "File path to store request outputs" - ), + help="File path to store request outputs", ) parser.add_argument( "--warmup-first", type=bool, default=False, - help=( - "Whether to send warmup req first" - ), + help="Whether to send warmup req first", ) parser.add_argument( "--conversation-starter", type=str, default="human", choices=["human", "gpt", "both"], - help=( - "What entity should be the one starting the conversations." - ), + help="What entity should be the one starting the conversations.", ) args = parser.parse_args() diff --git a/jetstream/core/proto/jetstream.proto b/jetstream/core/proto/jetstream.proto index 1736f7a6..228ddd05 100644 --- a/jetstream/core/proto/jetstream.proto +++ b/jetstream/core/proto/jetstream.proto @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. - syntax = "proto3"; package jetstream_proto; @@ -29,6 +28,12 @@ message DecodeRequest { // New text from a user or tool. string additional_text = 2; int32 priority = 3; + // The maximum output length of a sequence. It's used in JetStream to control + // the output/decode length of a sequence. It would not be used in the engine. + // We should always set max_tokens <= (max_target_length - + // max_prefill_predict_length). max_target_length is the maximum length of a + // sequence; max_prefill_predict_length is the maximum length of the + // input/prefill of a sequence. int32 max_tokens = 4; } message DecodeResponse { diff --git a/jetstream/tools/load_tester.py b/jetstream/tools/load_tester.py index 5ee041a05..965c6419 100644 --- a/jetstream/tools/load_tester.py +++ b/jetstream/tools/load_tester.py @@ -17,12 +17,11 @@ import concurrent.futures import functools import time -from typing import Sequence, Iterator +from typing import Iterator, Sequence from absl import app from absl import flags import grpc - from jetstream.core.proto import jetstream_pb2 from jetstream.core.proto import jetstream_pb2_grpc @@ -31,7 +30,7 @@ _PORT = flags.DEFINE_string('port', '9000', 'port to ping') _TEXT = flags.DEFINE_string('text', 'AB', 'The message') _MAX_TOKENS = flags.DEFINE_integer( - 'max_tokens', 100, 'Maximum number of tokens' + 'max_tokens', 100, 'Maximum number of output/decode tokens of a sequence' ) @@ -88,8 +87,7 @@ def load_test( number = list(range(len(text))) start = time.time() ping_partial = functools.partial(ping, stub) - with concurrent.futures.ThreadPoolExecutor( - max_workers=queries) as executor: + with concurrent.futures.ThreadPoolExecutor(max_workers=queries) as executor: responses = list(executor.map(ping_partial, text, number)) time_taken = time.time() - start print(f'Time taken: {time_taken}') @@ -100,10 +98,9 @@ def load_test( def main(argv: Sequence[str]): del argv address = f'{_SERVER.value}:{_PORT.value}' - # Note: Uses insecure_channel only for local testing. Please add grpc credentials for Production. - with grpc.insecure_channel( - address - ) as channel: + # Note: Uses insecure_channel only for local testing. Please add grpc + # credentials for Production. + with grpc.insecure_channel(address) as channel: grpc.channel_ready_future(channel).result() stub = jetstream_pb2_grpc.OrchestratorStub(channel) _ = load_test(stub, text=[_TEXT.value], queries=64) diff --git a/jetstream/tools/maxtext/model_ckpt_conversion.sh b/jetstream/tools/maxtext/model_ckpt_conversion.sh index 8a872b0d..c81ac6a1 100644 --- a/jetstream/tools/maxtext/model_ckpt_conversion.sh +++ b/jetstream/tools/maxtext/model_ckpt_conversion.sh @@ -29,7 +29,7 @@ export MODEL_VARIATION=$2 export MODEL_NAME=${MODEL}-${MODEL_VARIATION} # After downloading checkpoints, copy them to GCS bucket at $CHKPT_BUCKET \ -# Please use seperate GCS paths for uploading open source model weights ($CHKPT_BUCKET) and MaxText compatible weights ($MODEL_BUCKET). +# Please use separate GCS paths for uploading open source model weights ($CHKPT_BUCKET) and MaxText compatible weights ($MODEL_BUCKET). # Point these variables to a GCS bucket that you created. # An example of CHKPT_BUCKET could be: gs://${USER}-maxtext/chkpt/${MODEL}/${MODEL_VARIATION} export CHKPT_BUCKET=$3 @@ -38,7 +38,7 @@ export MODEL_BUCKET=gs://${USER}-maxtext # Point `BASE_OUTPUT_DIRECTORY` to a GCS bucket that you created, this bucket will store all the files generated by MaxText during a run. export BASE_OUTPUT_DIRECTORY=gs://${USER}-runner-maxtext-logs -# Point `DATASET_PATH` to the GCS bucket where you have your training data +# Point `DATASET_PATH` to the GCS bucket where you have your training data. export DATASET_PATH=gs://${USER}-maxtext-dataset export BUCKET_LOCATION=US @@ -56,13 +56,18 @@ if [ "$MODEL" == "gemma" ]; then --maxtext_model_path ${MODEL_BUCKET}/${MODEL}/${MODEL_VARIATION}/${idx} \ --model_size ${MODEL_VARIATION} else - # We install torch CPU because the checkpoint conversion script MaxText/llama_or_mistral_ckpt.py does not need a TPU/GPU + # We install torch CPU because the checkpoint conversion script MaxText/llama_or_mistral_ckpt.py does not need a TPU/GPU. pip install torch --index-url https://download.pytorch.org/whl/cpu + # llama_or_mistral_ckpt.py requires local path, so we need to copy the checkpoint from CHKPT_BUCKET to local. + tmp_ckpt_path="/tmp/" + gcloud storage cp -r ${CHKPT_BUCKET} ${tmp_ckpt_path} + path_parts=(${CHKPT_BUCKET//\// }) + directory_substring=${path_parts[-1]} CONVERT_CKPT_SCRIPT="llama_or_mistral_ckpt.py" JAX_PLATFORMS=cpu python MaxText/${CONVERT_CKPT_SCRIPT} \ - --base-model-path ${CHKPT_BUCKET} \ + --base-model-path ${tmp_ckpt_path}${directory_substring} \ --maxtext-model-path ${MODEL_BUCKET}/${MODEL}/${MODEL_VARIATION}/${idx} \ - --model-size ${MODEL_VARIATION} + --model-size ${MODEL_NAME} fi echo "Written MaxText compatible checkpoint to ${MODEL_BUCKET}/${MODEL}/${MODEL_VARIATION}/${idx}" diff --git a/jetstream/tools/maxtext/model_ckpt_finetune_with_aqt.sh b/jetstream/tools/maxtext/model_ckpt_finetune_with_aqt.sh index 8404c285..7e6ff1f5 100644 --- a/jetstream/tools/maxtext/model_ckpt_finetune_with_aqt.sh +++ b/jetstream/tools/maxtext/model_ckpt_finetune_with_aqt.sh @@ -29,7 +29,7 @@ export MODEL_VARIATION=$2 export MODEL_NAME=${MODEL}-${MODEL_VARIATION} # After downloading checkpoints, copy them to GCS bucket at $CHKPT_BUCKET \ -# Please use seperate GCS paths for uploading open source model weights ($CHKPT_BUCKET) and MaxText compatible weights ($MODEL_BUCKET). +# Please use separate GCS paths for uploading open source model weights ($CHKPT_BUCKET) and MaxText compatible weights ($MODEL_BUCKET). # Point these variables to a GCS bucket that you created. # An example of CHKPT_BUCKET could be: gs://${USER}-maxtext/chkpt/${MODEL}/${MODEL_VARIATION} export CHKPT_BUCKET=$3 diff --git a/jetstream/tools/requester.py b/jetstream/tools/requester.py index d99d0ee9..f6cf30b9 100644 --- a/jetstream/tools/requester.py +++ b/jetstream/tools/requester.py @@ -19,7 +19,6 @@ from absl import app from absl import flags import grpc - from jetstream.core.proto import jetstream_pb2 from jetstream.core.proto import jetstream_pb2_grpc @@ -31,7 +30,9 @@ ) _TEXT = flags.DEFINE_string('text', 'Today is a good day', 'The message') _PRIORITY = flags.DEFINE_integer('priority', 0, 'Message priority') -_MAX_TOKENS = flags.DEFINE_integer('max_tokens', 3, 'Maximum number of tokens') +_MAX_TOKENS = flags.DEFINE_integer( + 'max_tokens', 3, 'Maximum number of output/decode tokens of a sequence' +) def _GetResponseAsync( @@ -41,7 +42,7 @@ def _GetResponseAsync( """Gets an async response.""" response = stub.Decode(request) - output = "" + output = '' for token_list in response: output += token_list.response[0] print(f'Prompt: {_TEXT.value}') @@ -50,11 +51,10 @@ def _GetResponseAsync( def main(argv: Sequence[str]) -> None: del argv - # Note: Uses insecure_channel only for local testing. Please add grpc credentials for Production. + # Note: Uses insecure_channel only for local testing. Please add grpc + # credentials for Production. address = f'{_SERVER.value}:{_PORT.value}' - with grpc.insecure_channel( - address - ) as channel: + with grpc.insecure_channel(address) as channel: grpc.channel_ready_future(channel).result() stub = jetstream_pb2_grpc.OrchestratorStub(channel) print(f'Sending request to: {address}')