From 2b685f2d17870da6fd83c3f224d1a862b20c1cff Mon Sep 17 00:00:00 2001 From: Pavel Esir Date: Thu, 31 Oct 2024 12:15:29 +0100 Subject: [PATCH 01/11] exclude prefill from calculating TPOT statistics --- src/cpp/src/perf_metrics.cpp | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/cpp/src/perf_metrics.cpp b/src/cpp/src/perf_metrics.cpp index 822246f01..07f8a34d5 100644 --- a/src/cpp/src/perf_metrics.cpp +++ b/src/cpp/src/perf_metrics.cpp @@ -29,7 +29,6 @@ ov::genai::MeanStdPair calc_mean_and_std(const std::vector start_time) { auto start_time_val = *start_time; auto& tok_times = raw_metrics.m_new_token_times; auto& batch_sizes = raw_metrics.m_batch_sizes; - raw_metrics.m_durations = std::vector(tok_times.size()); + raw_metrics.m_durations = std::vector(tok_times.size() - 1); auto ttft = tok_times[0] - start_time_val; raw_metrics.m_times_to_first_token = std::vector(); raw_metrics.m_times_to_first_token.emplace_back(ttft / batch_sizes[0]); num_generated_tokens = 0; - for (size_t i = 0; i < tok_times.size(); ++i) { + + // Exclude prefill from calculating TPOT. + // The very first duration used to calcualte TPOT is from the first token to the second token, + // not from the start time to the first token. + start_time_val = tok_times[0]; + for (size_t i = 1; i < tok_times.size(); ++i) { raw_metrics.m_durations[i] = tok_times[i] - start_time_val; // If in 10 ms a batch of 5 new tokens is generated then TPOT is 10 / 5 = 2 tok/ms. From b906f2d516afaf0358b1aacbdb6f60b667123579 Mon Sep 17 00:00:00 2001 From: Pavel Esir Date: Thu, 31 Oct 2024 15:11:02 +0100 Subject: [PATCH 02/11] add docs/example of usag of raw perf metrics --- src/README.md | 32 ++++++++++++++++++++++++++++++++ src/python/py_perf_metrics.cpp | 21 +++++++++++++++++++-- 2 files changed, 51 insertions(+), 2 deletions(-) diff --git a/src/README.md b/src/README.md index 71d1a230c..dfb753d33 100644 --- a/src/README.md +++ b/src/README.md @@ -304,6 +304,7 @@ mean_tpot 3.80 >**Note**: If the input prompt is just a string, the generate function returns only a string without perf_metrics. To obtain perf_metrics, provide the prompt as a list with at least one element or call generate with encoded inputs. +#### Accumulating metrics Several `perf_metrics` can be added to each other. In that case `raw_metrics` are concatenated and mean/std values are recalculated. This accumulates statistics from several `generate()` calls ```cpp @@ -338,6 +339,37 @@ print(f'TPOT: {perf_metrics.get_tpot().mean:.2f} ms/token') print(f'Throughput: {perf_metrics.get_throughput().mean:.2f} tokens/s') ``` +#### Using raw performance metrics +Additionally to mean and std values, `perf_metrics` object has a `raw_metrics` field which stored raw numbers with timesteps when batch of tokens was generated, with batch sizes for each timestamp, with tokenization duration and so on. + +Getting timestamps for each generated token: +```python +import openvino_genai as ov_genai +pipe = ov_genai.LLMPipeline(models_path, "CPU") +result = pipe.generate(["The Sun is yellow because"], max_new_tokens=20) +perf_metrics = result.perf_metrics +raw_metrics = perf_metrics.raw_metrics + +print(f'Generate duration: {perf_metrics.get_generate_duration().mean:.2f}') +print(f'Throughput: {perf_metrics.get_throughput().mean:.2f} tokens/s') +print(f'Timestamps: {" ms, ".join(f"{i:.2f}" for i in raw_metrics.m_new_token_times[1:])}') +``` + +Example of using raw metrics to calculate median value of generate duration: +```python +import openvino_genai as ov_genai +import numpy as np +pipe = ov_genai.LLMPipeline(models_path, "CPU") +result = pipe.generate(["The Sun is yellow because"], max_new_tokens=20) +perf_metrics = result.perf_metrics +raw_metrics = perf_metrics.raw_metrics + +print(f'Generate duration: {perf_metrics.get_generate_duration().mean:.2f}') +print(f'Throughput: {perf_metrics.get_throughput().mean:.2f} tokens/s') +durations = np.array(raw_metrics.m_new_token_times[1:]) - np.array(raw_metrics.m_new_token_times[:-1]) +print(f'Median from token to token duration: {np.median(durations):.2f} ms') +``` + For more examples of how metrics are used, please refer to the Python [benchmark_genai.py](../samples/python/benchmark_genai/README.md) and C++ [benchmark_genai](../samples/cpp/benchmark_genai/README.md) samples. ## How It Works diff --git a/src/python/py_perf_metrics.cpp b/src/python/py_perf_metrics.cpp index 500ce0abe..3dc68140e 100644 --- a/src/python/py_perf_metrics.cpp +++ b/src/python/py_perf_metrics.cpp @@ -32,8 +32,8 @@ auto raw_perf_metrics_docstring = R"( :param m_times_to_first_token: Times to the first token for each call in microseconds. :type m_times_to_first_token: List[MicroSeconds] - :param m_new_token_times: Time points for each new token generated. - :type m_new_token_times: List[TimePoint] + :param m_new_token_times: Timestamps of generation every token or batch of tokens in milliseconds. + :type m_new_token_times: List[MilliSeconds] :param m_batch_sizes: Batch sizes for each generate call. :type m_batch_sizes: List[int] @@ -109,6 +109,20 @@ std::vector get_ms(const T& instance, U T::*member) { return res; } +template +std::vector timestamp_to_ms(const T& instance, U T::*member) { + // Converts c++ duration to double so that it can be used in Python. + // Use double instead of float to store more than 7 signficant digits. + std::vector res; + const auto& timestamps = instance.*member; + res.reserve(timestamps.size()); + std::transform(timestamps.begin(), timestamps.end(), std::back_inserter(res), + [](const auto& timestamp) { + return std::chrono::duration(timestamp.time_since_epoch()).count(); + }); + return res; +} + } // namespace void init_perf_metrics(py::module_& m) { @@ -126,6 +140,9 @@ void init_perf_metrics(py::module_& m) { .def_property_readonly("m_times_to_first_token", [](const RawPerfMetrics &rw) { return get_ms(rw, &RawPerfMetrics::m_times_to_first_token); }) + .def_property_readonly("m_new_token_times", [](const RawPerfMetrics &rw) { + return timestamp_to_ms(rw, &RawPerfMetrics::m_new_token_times); + }) .def_property_readonly("m_durations", [](const RawPerfMetrics &rw) { return get_ms(rw, &RawPerfMetrics::m_durations); }) From 6f8f41a3757b8b2f3d51fddf648efefcf8951523 Mon Sep 17 00:00:00 2001 From: Pavel Esir Date: Fri, 1 Nov 2024 18:28:28 +0100 Subject: [PATCH 03/11] Microseconds -> Microsecond --- src/cpp/include/openvino/genai/perf_metrics.hpp | 16 ++++++++-------- src/cpp/src/llm_pipeline.cpp | 2 +- src/cpp/src/llm_pipeline_static.cpp | 2 +- src/cpp/src/lm_encoding.cpp | 6 +++--- src/cpp/src/perf_metrics.cpp | 10 +++++----- src/cpp/src/whisper/whisper.cpp | 8 ++++---- src/cpp/src/whisper_pipeline.cpp | 2 +- src/python/py_perf_metrics.cpp | 10 +++++----- 8 files changed, 28 insertions(+), 28 deletions(-) diff --git a/src/cpp/include/openvino/genai/perf_metrics.hpp b/src/cpp/include/openvino/genai/perf_metrics.hpp index 0a880c4a4..f730fcc27 100644 --- a/src/cpp/include/openvino/genai/perf_metrics.hpp +++ b/src/cpp/include/openvino/genai/perf_metrics.hpp @@ -13,7 +13,7 @@ namespace ov { namespace genai { using TimePoint = std::chrono::steady_clock::time_point; -using MicroSeconds = std::chrono::duration>; +using MicroSecond = std::chrono::duration>; /** * @brief Structure with raw performance metrics for each generation before any statistics are calculated. @@ -31,16 +31,16 @@ using MicroSeconds = std::chrono::duration>; * @param num_input_tokens Total number of tokens in the input prompt. */ struct OPENVINO_GENAI_EXPORTS RawPerfMetrics { - std::vector generate_durations; - std::vector tokenization_durations; - std::vector detokenization_durations; + std::vector generate_durations; + std::vector tokenization_durations; + std::vector detokenization_durations; - std::vector m_times_to_first_token; + std::vector m_times_to_first_token; std::vector m_new_token_times; - std::vector m_token_infer_durations; + std::vector m_token_infer_durations; std::vector m_batch_sizes; - std::vector m_durations; - std::vector m_inference_durations; + std::vector m_durations; + std::vector m_inference_durations; }; /** diff --git a/src/cpp/src/llm_pipeline.cpp b/src/cpp/src/llm_pipeline.cpp index 26221fd5c..d1bf5e65d 100644 --- a/src/cpp/src/llm_pipeline.cpp +++ b/src/cpp/src/llm_pipeline.cpp @@ -152,7 +152,7 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { auto& raw_counters = decoded_results.perf_metrics.raw_metrics; auto stop_time = std::chrono::steady_clock::now(); - raw_counters.generate_durations = std::vector(); + raw_counters.generate_durations = std::vector(); raw_counters.generate_durations.emplace_back(PerfMetrics::get_microsec(stop_time - start_time)); raw_counters.tokenization_durations.emplace_back(PerfMetrics::get_microsec(encode_stop_time - start_time)); raw_counters.detokenization_durations.emplace_back(PerfMetrics::get_microsec(decode_stop_time - decode_start_time)); diff --git a/src/cpp/src/llm_pipeline_static.cpp b/src/cpp/src/llm_pipeline_static.cpp index 12d6add5e..0f93ab86b 100644 --- a/src/cpp/src/llm_pipeline_static.cpp +++ b/src/cpp/src/llm_pipeline_static.cpp @@ -571,7 +571,7 @@ DecodedResults StaticLLMPipeline::generate( decoded_results.perf_metrics = encoded_results.perf_metrics; auto& raw_counters = decoded_results.perf_metrics.raw_metrics; auto stop_time = std::chrono::steady_clock::now(); - raw_counters.generate_durations = std::vector(); + raw_counters.generate_durations = std::vector(); raw_counters.generate_durations.emplace_back(PerfMetrics::get_microsec(stop_time - start_time)); raw_counters.tokenization_durations.emplace_back(PerfMetrics::get_microsec(encode_stop_time - start_time)); raw_counters.detokenization_durations.emplace_back(PerfMetrics::get_microsec(decode_stop_time - decode_start_time)); diff --git a/src/cpp/src/lm_encoding.cpp b/src/cpp/src/lm_encoding.cpp index a2ed15252..d8246b599 100644 --- a/src/cpp/src/lm_encoding.cpp +++ b/src/cpp/src/lm_encoding.cpp @@ -79,7 +79,7 @@ std::pair get_lm_encoded_results( raw_perf_counters.m_new_token_times.reserve(max_new_tokens); raw_perf_counters.m_batch_sizes.reserve(max_new_tokens); raw_perf_counters.m_token_infer_durations.reserve(max_new_tokens); - raw_perf_counters.m_inference_durations = {{ MicroSeconds(0.0f) }}; + raw_perf_counters.m_inference_durations = {{ MicroSecond(0.0f) }}; // Initialize inputs if (m_embedding.has_value()) @@ -105,7 +105,7 @@ std::pair get_lm_encoded_results( m_llm.infer(); const auto infer_end = std::chrono::steady_clock::now(); const auto infer_ms = PerfMetrics::get_microsec(infer_end - infer_start); - raw_perf_counters.m_inference_durations[0] += MicroSeconds(infer_ms); + raw_perf_counters.m_inference_durations[0] += MicroSecond(infer_ms); raw_perf_counters.m_token_infer_durations.emplace_back(infer_ms); raw_perf_counters.m_new_token_times.emplace_back(infer_end); raw_perf_counters.m_batch_sizes.emplace_back(batch_size); @@ -201,7 +201,7 @@ std::pair get_lm_encoded_results( m_llm.infer(); const auto infer_end = std::chrono::steady_clock::now(); const auto infer_ms = PerfMetrics::get_microsec(infer_end - infer_start); - raw_perf_counters.m_inference_durations[0] += MicroSeconds(infer_ms); + raw_perf_counters.m_inference_durations[0] += MicroSecond(infer_ms); raw_perf_counters.m_token_infer_durations.emplace_back(infer_ms); raw_perf_counters.m_new_token_times.emplace_back(infer_end); raw_perf_counters.m_batch_sizes.emplace_back(batch_size); diff --git a/src/cpp/src/perf_metrics.cpp b/src/cpp/src/perf_metrics.cpp index 07f8a34d5..570f87a64 100644 --- a/src/cpp/src/perf_metrics.cpp +++ b/src/cpp/src/perf_metrics.cpp @@ -9,19 +9,19 @@ namespace { -ov::genai::MeanStdPair calc_mean_and_std(const std::vector& durations) { +ov::genai::MeanStdPair calc_mean_and_std(const std::vector& durations) { if (durations.size() == 0) { return {-1, -1}; } // Accepts time durations in microseconds and returns standard deviation and mean in milliseconds. float mean = std::accumulate(durations.begin(), durations.end(), 0.0f, - [](const float& acc, const ov::genai::MicroSeconds& duration) -> float { + [](const float& acc, const ov::genai::MicroSecond& duration) -> float { return acc + duration.count() / 1000.0f; }); mean /= durations.size(); float sum_square_durations = std::accumulate(durations.begin(), durations.end(), 0.0f, - [](const float& acc, const ov::genai::MicroSeconds& duration) -> float { + [](const float& acc, const ov::genai::MicroSecond& duration) -> float { auto d = duration.count() / 1000.0f; return acc + d * d; }); @@ -101,10 +101,10 @@ void PerfMetrics::evaluate_statistics(std::optional start_time) { auto start_time_val = *start_time; auto& tok_times = raw_metrics.m_new_token_times; auto& batch_sizes = raw_metrics.m_batch_sizes; - raw_metrics.m_durations = std::vector(tok_times.size() - 1); + raw_metrics.m_durations = std::vector(tok_times.size() - 1); auto ttft = tok_times[0] - start_time_val; - raw_metrics.m_times_to_first_token = std::vector(); + raw_metrics.m_times_to_first_token = std::vector(); raw_metrics.m_times_to_first_token.emplace_back(ttft / batch_sizes[0]); num_generated_tokens = 0; diff --git a/src/cpp/src/whisper/whisper.cpp b/src/cpp/src/whisper/whisper.cpp index dacbabd9a..0f42c7f8d 100644 --- a/src/cpp/src/whisper/whisper.cpp +++ b/src/cpp/src/whisper/whisper.cpp @@ -19,7 +19,7 @@ #include "whisper_feature_extractor.hpp" #include "whisper_models.hpp" -using ov::genai::MicroSeconds; +using ov::genai::MicroSecond; namespace { @@ -44,7 +44,7 @@ ov::Tensor encode(ov::InferRequest& request, const auto infer_start = std::chrono::steady_clock::now(); request.infer(); const auto infer_ms = ov::genai::PerfMetrics::get_microsec(std::chrono::steady_clock::now() - infer_start); - raw_metrics.m_inference_durations[0] += MicroSeconds(infer_ms); + raw_metrics.m_inference_durations[0] += MicroSecond(infer_ms); // reset input tensor request.set_tensor("input_features", ov::Tensor(ov::element::f32, {0, feature_size, nb_max_frames})); @@ -84,7 +84,7 @@ void infer_with_perf_metrics(ov::InferRequest& request, ov::genai::RawPerfMetric request.infer(); const auto infer_end = std::chrono::steady_clock::now(); const auto infer_ms = ov::genai::PerfMetrics::get_microsec(infer_end - infer_start); - raw_metrics.m_inference_durations[0] += MicroSeconds(infer_ms); + raw_metrics.m_inference_durations[0] += MicroSecond(infer_ms); raw_metrics.m_token_infer_durations.emplace_back(infer_ms); raw_metrics.m_new_token_times.emplace_back(infer_end); raw_metrics.m_batch_sizes.emplace_back(1); @@ -293,7 +293,7 @@ WhisperGenerateResult whisper_generate(const ov::genai::WhisperGenerationConfig& raw_metrics.m_new_token_times.reserve(max_new_tokens); raw_metrics.m_batch_sizes.reserve(max_new_tokens); raw_metrics.m_token_infer_durations.reserve(max_new_tokens); - raw_metrics.m_inference_durations = {{MicroSeconds(0.0f)}}; + raw_metrics.m_inference_durations = {{MicroSecond(0.0f)}}; std::vector init_ids; std::vector& output_tokens = result.output_tokens; diff --git a/src/cpp/src/whisper_pipeline.cpp b/src/cpp/src/whisper_pipeline.cpp index a8e34b995..e27a69e81 100644 --- a/src/cpp/src/whisper_pipeline.cpp +++ b/src/cpp/src/whisper_pipeline.cpp @@ -107,7 +107,7 @@ class WhisperPipeline::WhisperPipelineStatefulImpl : public WhisperPipeline::Whi metrics.load_time = this->m_load_time_ms; auto stop_time = std::chrono::steady_clock::now(); metrics.raw_metrics.generate_durations.emplace_back(PerfMetrics::get_microsec(stop_time - start_time)); - result.perf_metrics.raw_metrics.tokenization_durations.emplace_back(MicroSeconds(0.0f)); + result.perf_metrics.raw_metrics.tokenization_durations.emplace_back(MicroSecond(0.0f)); metrics.evaluate_statistics(start_time); return result; diff --git a/src/python/py_perf_metrics.cpp b/src/python/py_perf_metrics.cpp index 3dc68140e..486a3d75e 100644 --- a/src/python/py_perf_metrics.cpp +++ b/src/python/py_perf_metrics.cpp @@ -21,16 +21,16 @@ auto raw_perf_metrics_docstring = R"( Structure with raw performance metrics for each generation before any statistics are calculated. :param generate_durations: Durations for each generate call in microseconds. - :type generate_durations: List[MicroSeconds] + :type generate_durations: List[MicroSecond] :param tokenization_durations: Durations for the tokenization process in microseconds. - :type tokenization_durations: List[MicroSeconds] + :type tokenization_durations: List[MicroSecond] :param detokenization_durations: Durations for the detokenization process in microseconds. - :type detokenization_durations: List[MicroSeconds] + :type detokenization_durations: List[MicroSecond] :param m_times_to_first_token: Times to the first token for each call in microseconds. - :type m_times_to_first_token: List[MicroSeconds] + :type m_times_to_first_token: List[MicroSecond] :param m_new_token_times: Timestamps of generation every token or batch of tokens in milliseconds. :type m_new_token_times: List[MilliSeconds] @@ -39,7 +39,7 @@ auto raw_perf_metrics_docstring = R"( :type m_batch_sizes: List[int] :param m_durations: Total durations for each generate call in microseconds. - :type m_durations: List[MicroSeconds] + :type m_durations: List[MicroSecond] :param num_generated_tokens: Total number of tokens generated. :type num_generated_tokens: int From 8e22fc61fb95947708c5bbcb121194fa218b75ad Mon Sep 17 00:00:00 2001 From: Pavel Esir Date: Fri, 1 Nov 2024 18:32:51 +0100 Subject: [PATCH 04/11] Revert "Microseconds -> Microsecond" This reverts commit 6f8f41a3757b8b2f3d51fddf648efefcf8951523. --- src/cpp/include/openvino/genai/perf_metrics.hpp | 16 ++++++++-------- src/cpp/src/llm_pipeline.cpp | 2 +- src/cpp/src/llm_pipeline_static.cpp | 2 +- src/cpp/src/lm_encoding.cpp | 6 +++--- src/cpp/src/perf_metrics.cpp | 10 +++++----- src/cpp/src/whisper/whisper.cpp | 8 ++++---- src/cpp/src/whisper_pipeline.cpp | 2 +- src/python/py_perf_metrics.cpp | 10 +++++----- 8 files changed, 28 insertions(+), 28 deletions(-) diff --git a/src/cpp/include/openvino/genai/perf_metrics.hpp b/src/cpp/include/openvino/genai/perf_metrics.hpp index f730fcc27..0a880c4a4 100644 --- a/src/cpp/include/openvino/genai/perf_metrics.hpp +++ b/src/cpp/include/openvino/genai/perf_metrics.hpp @@ -13,7 +13,7 @@ namespace ov { namespace genai { using TimePoint = std::chrono::steady_clock::time_point; -using MicroSecond = std::chrono::duration>; +using MicroSeconds = std::chrono::duration>; /** * @brief Structure with raw performance metrics for each generation before any statistics are calculated. @@ -31,16 +31,16 @@ using MicroSecond = std::chrono::duration>; * @param num_input_tokens Total number of tokens in the input prompt. */ struct OPENVINO_GENAI_EXPORTS RawPerfMetrics { - std::vector generate_durations; - std::vector tokenization_durations; - std::vector detokenization_durations; + std::vector generate_durations; + std::vector tokenization_durations; + std::vector detokenization_durations; - std::vector m_times_to_first_token; + std::vector m_times_to_first_token; std::vector m_new_token_times; - std::vector m_token_infer_durations; + std::vector m_token_infer_durations; std::vector m_batch_sizes; - std::vector m_durations; - std::vector m_inference_durations; + std::vector m_durations; + std::vector m_inference_durations; }; /** diff --git a/src/cpp/src/llm_pipeline.cpp b/src/cpp/src/llm_pipeline.cpp index d1bf5e65d..26221fd5c 100644 --- a/src/cpp/src/llm_pipeline.cpp +++ b/src/cpp/src/llm_pipeline.cpp @@ -152,7 +152,7 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { auto& raw_counters = decoded_results.perf_metrics.raw_metrics; auto stop_time = std::chrono::steady_clock::now(); - raw_counters.generate_durations = std::vector(); + raw_counters.generate_durations = std::vector(); raw_counters.generate_durations.emplace_back(PerfMetrics::get_microsec(stop_time - start_time)); raw_counters.tokenization_durations.emplace_back(PerfMetrics::get_microsec(encode_stop_time - start_time)); raw_counters.detokenization_durations.emplace_back(PerfMetrics::get_microsec(decode_stop_time - decode_start_time)); diff --git a/src/cpp/src/llm_pipeline_static.cpp b/src/cpp/src/llm_pipeline_static.cpp index 0f93ab86b..12d6add5e 100644 --- a/src/cpp/src/llm_pipeline_static.cpp +++ b/src/cpp/src/llm_pipeline_static.cpp @@ -571,7 +571,7 @@ DecodedResults StaticLLMPipeline::generate( decoded_results.perf_metrics = encoded_results.perf_metrics; auto& raw_counters = decoded_results.perf_metrics.raw_metrics; auto stop_time = std::chrono::steady_clock::now(); - raw_counters.generate_durations = std::vector(); + raw_counters.generate_durations = std::vector(); raw_counters.generate_durations.emplace_back(PerfMetrics::get_microsec(stop_time - start_time)); raw_counters.tokenization_durations.emplace_back(PerfMetrics::get_microsec(encode_stop_time - start_time)); raw_counters.detokenization_durations.emplace_back(PerfMetrics::get_microsec(decode_stop_time - decode_start_time)); diff --git a/src/cpp/src/lm_encoding.cpp b/src/cpp/src/lm_encoding.cpp index d8246b599..a2ed15252 100644 --- a/src/cpp/src/lm_encoding.cpp +++ b/src/cpp/src/lm_encoding.cpp @@ -79,7 +79,7 @@ std::pair get_lm_encoded_results( raw_perf_counters.m_new_token_times.reserve(max_new_tokens); raw_perf_counters.m_batch_sizes.reserve(max_new_tokens); raw_perf_counters.m_token_infer_durations.reserve(max_new_tokens); - raw_perf_counters.m_inference_durations = {{ MicroSecond(0.0f) }}; + raw_perf_counters.m_inference_durations = {{ MicroSeconds(0.0f) }}; // Initialize inputs if (m_embedding.has_value()) @@ -105,7 +105,7 @@ std::pair get_lm_encoded_results( m_llm.infer(); const auto infer_end = std::chrono::steady_clock::now(); const auto infer_ms = PerfMetrics::get_microsec(infer_end - infer_start); - raw_perf_counters.m_inference_durations[0] += MicroSecond(infer_ms); + raw_perf_counters.m_inference_durations[0] += MicroSeconds(infer_ms); raw_perf_counters.m_token_infer_durations.emplace_back(infer_ms); raw_perf_counters.m_new_token_times.emplace_back(infer_end); raw_perf_counters.m_batch_sizes.emplace_back(batch_size); @@ -201,7 +201,7 @@ std::pair get_lm_encoded_results( m_llm.infer(); const auto infer_end = std::chrono::steady_clock::now(); const auto infer_ms = PerfMetrics::get_microsec(infer_end - infer_start); - raw_perf_counters.m_inference_durations[0] += MicroSecond(infer_ms); + raw_perf_counters.m_inference_durations[0] += MicroSeconds(infer_ms); raw_perf_counters.m_token_infer_durations.emplace_back(infer_ms); raw_perf_counters.m_new_token_times.emplace_back(infer_end); raw_perf_counters.m_batch_sizes.emplace_back(batch_size); diff --git a/src/cpp/src/perf_metrics.cpp b/src/cpp/src/perf_metrics.cpp index 570f87a64..07f8a34d5 100644 --- a/src/cpp/src/perf_metrics.cpp +++ b/src/cpp/src/perf_metrics.cpp @@ -9,19 +9,19 @@ namespace { -ov::genai::MeanStdPair calc_mean_and_std(const std::vector& durations) { +ov::genai::MeanStdPair calc_mean_and_std(const std::vector& durations) { if (durations.size() == 0) { return {-1, -1}; } // Accepts time durations in microseconds and returns standard deviation and mean in milliseconds. float mean = std::accumulate(durations.begin(), durations.end(), 0.0f, - [](const float& acc, const ov::genai::MicroSecond& duration) -> float { + [](const float& acc, const ov::genai::MicroSeconds& duration) -> float { return acc + duration.count() / 1000.0f; }); mean /= durations.size(); float sum_square_durations = std::accumulate(durations.begin(), durations.end(), 0.0f, - [](const float& acc, const ov::genai::MicroSecond& duration) -> float { + [](const float& acc, const ov::genai::MicroSeconds& duration) -> float { auto d = duration.count() / 1000.0f; return acc + d * d; }); @@ -101,10 +101,10 @@ void PerfMetrics::evaluate_statistics(std::optional start_time) { auto start_time_val = *start_time; auto& tok_times = raw_metrics.m_new_token_times; auto& batch_sizes = raw_metrics.m_batch_sizes; - raw_metrics.m_durations = std::vector(tok_times.size() - 1); + raw_metrics.m_durations = std::vector(tok_times.size() - 1); auto ttft = tok_times[0] - start_time_val; - raw_metrics.m_times_to_first_token = std::vector(); + raw_metrics.m_times_to_first_token = std::vector(); raw_metrics.m_times_to_first_token.emplace_back(ttft / batch_sizes[0]); num_generated_tokens = 0; diff --git a/src/cpp/src/whisper/whisper.cpp b/src/cpp/src/whisper/whisper.cpp index 0f42c7f8d..dacbabd9a 100644 --- a/src/cpp/src/whisper/whisper.cpp +++ b/src/cpp/src/whisper/whisper.cpp @@ -19,7 +19,7 @@ #include "whisper_feature_extractor.hpp" #include "whisper_models.hpp" -using ov::genai::MicroSecond; +using ov::genai::MicroSeconds; namespace { @@ -44,7 +44,7 @@ ov::Tensor encode(ov::InferRequest& request, const auto infer_start = std::chrono::steady_clock::now(); request.infer(); const auto infer_ms = ov::genai::PerfMetrics::get_microsec(std::chrono::steady_clock::now() - infer_start); - raw_metrics.m_inference_durations[0] += MicroSecond(infer_ms); + raw_metrics.m_inference_durations[0] += MicroSeconds(infer_ms); // reset input tensor request.set_tensor("input_features", ov::Tensor(ov::element::f32, {0, feature_size, nb_max_frames})); @@ -84,7 +84,7 @@ void infer_with_perf_metrics(ov::InferRequest& request, ov::genai::RawPerfMetric request.infer(); const auto infer_end = std::chrono::steady_clock::now(); const auto infer_ms = ov::genai::PerfMetrics::get_microsec(infer_end - infer_start); - raw_metrics.m_inference_durations[0] += MicroSecond(infer_ms); + raw_metrics.m_inference_durations[0] += MicroSeconds(infer_ms); raw_metrics.m_token_infer_durations.emplace_back(infer_ms); raw_metrics.m_new_token_times.emplace_back(infer_end); raw_metrics.m_batch_sizes.emplace_back(1); @@ -293,7 +293,7 @@ WhisperGenerateResult whisper_generate(const ov::genai::WhisperGenerationConfig& raw_metrics.m_new_token_times.reserve(max_new_tokens); raw_metrics.m_batch_sizes.reserve(max_new_tokens); raw_metrics.m_token_infer_durations.reserve(max_new_tokens); - raw_metrics.m_inference_durations = {{MicroSecond(0.0f)}}; + raw_metrics.m_inference_durations = {{MicroSeconds(0.0f)}}; std::vector init_ids; std::vector& output_tokens = result.output_tokens; diff --git a/src/cpp/src/whisper_pipeline.cpp b/src/cpp/src/whisper_pipeline.cpp index e27a69e81..a8e34b995 100644 --- a/src/cpp/src/whisper_pipeline.cpp +++ b/src/cpp/src/whisper_pipeline.cpp @@ -107,7 +107,7 @@ class WhisperPipeline::WhisperPipelineStatefulImpl : public WhisperPipeline::Whi metrics.load_time = this->m_load_time_ms; auto stop_time = std::chrono::steady_clock::now(); metrics.raw_metrics.generate_durations.emplace_back(PerfMetrics::get_microsec(stop_time - start_time)); - result.perf_metrics.raw_metrics.tokenization_durations.emplace_back(MicroSecond(0.0f)); + result.perf_metrics.raw_metrics.tokenization_durations.emplace_back(MicroSeconds(0.0f)); metrics.evaluate_statistics(start_time); return result; diff --git a/src/python/py_perf_metrics.cpp b/src/python/py_perf_metrics.cpp index 486a3d75e..3dc68140e 100644 --- a/src/python/py_perf_metrics.cpp +++ b/src/python/py_perf_metrics.cpp @@ -21,16 +21,16 @@ auto raw_perf_metrics_docstring = R"( Structure with raw performance metrics for each generation before any statistics are calculated. :param generate_durations: Durations for each generate call in microseconds. - :type generate_durations: List[MicroSecond] + :type generate_durations: List[MicroSeconds] :param tokenization_durations: Durations for the tokenization process in microseconds. - :type tokenization_durations: List[MicroSecond] + :type tokenization_durations: List[MicroSeconds] :param detokenization_durations: Durations for the detokenization process in microseconds. - :type detokenization_durations: List[MicroSecond] + :type detokenization_durations: List[MicroSeconds] :param m_times_to_first_token: Times to the first token for each call in microseconds. - :type m_times_to_first_token: List[MicroSecond] + :type m_times_to_first_token: List[MicroSeconds] :param m_new_token_times: Timestamps of generation every token or batch of tokens in milliseconds. :type m_new_token_times: List[MilliSeconds] @@ -39,7 +39,7 @@ auto raw_perf_metrics_docstring = R"( :type m_batch_sizes: List[int] :param m_durations: Total durations for each generate call in microseconds. - :type m_durations: List[MicroSecond] + :type m_durations: List[MicroSeconds] :param num_generated_tokens: Total number of tokens generated. :type num_generated_tokens: int From 20a147846c2999672c3db7080d853eefcbca5c7a Mon Sep 17 00:00:00 2001 From: Pavel Esir Date: Fri, 1 Nov 2024 19:06:52 +0100 Subject: [PATCH 05/11] some corrections --- src/README.md | 29 +++++++++++++++++++++++++++-- src/cpp/src/perf_metrics.cpp | 7 ++++--- src/python/py_perf_metrics.cpp | 3 ++- 3 files changed, 33 insertions(+), 6 deletions(-) diff --git a/src/README.md b/src/README.md index dfb753d33..35312dddc 100644 --- a/src/README.md +++ b/src/README.md @@ -340,7 +340,15 @@ print(f'Throughput: {perf_metrics.get_throughput().mean:.2f} tokens/s') ``` #### Using raw performance metrics -Additionally to mean and std values, `perf_metrics` object has a `raw_metrics` field which stored raw numbers with timesteps when batch of tokens was generated, with batch sizes for each timestamp, with tokenization duration and so on. +In addition to mean and standard deviation values, the `perf_metrics` object has a `raw_metrics` field. This field stores raw data, including: + +- Timestamps for each batch of generated tokens +- Batch sizes for each timestamp +- Tokenization durations +- Detokenization durations +- Other relevant metrics + +These metrics can be use for more fine grained analysis, such as getting exact calculating median values, percentiles, etc. Below are a few examples of how to use raw metrics. Getting timestamps for each generated token: ```python @@ -352,7 +360,24 @@ raw_metrics = perf_metrics.raw_metrics print(f'Generate duration: {perf_metrics.get_generate_duration().mean:.2f}') print(f'Throughput: {perf_metrics.get_throughput().mean:.2f} tokens/s') -print(f'Timestamps: {" ms, ".join(f"{i:.2f}" for i in raw_metrics.m_new_token_times[1:])}') +print(f'Timestamps: {" ms, ".join(f"{i:.2f}" for i in raw_metrics.m_new_token_times)}') +``` + +Getting pure inference time without tokenizatin and detokenization duration: +```python +import openvino_genai as ov_genai +import openvino_genai as ov_genai +import numpy as np +pipe = ov_genai.LLMPipeline(models_path, "CPU") +result = pipe.generate(["The Sun is yellow because"], max_new_tokens=20) +perf_metrics = result.perf_metrics +print(f'Generate duration: {perf_metrics.get_generate_duration().mean:.2f}') + +raw_metrics = perf_metrics.raw_metrics +generate_duration = np.array(raw_metrics.generate_durations) +tok_detok_duration = np.array(raw_metrics.tokenization_durations) - np.array(raw_metrics.detokenization_durations) +pure_inference_duration = np.mean(generate_duration - tok_detok_duration) / 1000 # in seconds +print(f'Pure Inference duration: {pure_inference_duration:.2f} ms') ``` Example of using raw metrics to calculate median value of generate duration: diff --git a/src/cpp/src/perf_metrics.cpp b/src/cpp/src/perf_metrics.cpp index 07f8a34d5..47da2014a 100644 --- a/src/cpp/src/perf_metrics.cpp +++ b/src/cpp/src/perf_metrics.cpp @@ -108,9 +108,10 @@ void PerfMetrics::evaluate_statistics(std::optional start_time) { raw_metrics.m_times_to_first_token.emplace_back(ttft / batch_sizes[0]); num_generated_tokens = 0; - // Exclude prefill from calculating TPOT. - // The very first duration used to calcualte TPOT is from the first token to the second token, - // not from the start time to the first token. + // The very first infer request (prefill stage) is slower than sunsequent ones since we process a sequence of tokens. + // To have a clearer TPOT number, the time taken to generate the very first token at the prefill stage + // must not be included in the TPOT calculation. The first duration used for TPOT is from the first token + // to the second token, not from the start time to the first token. start_time_val = tok_times[0]; for (size_t i = 1; i < tok_times.size(); ++i) { raw_metrics.m_durations[i] = tok_times[i] - start_time_val; diff --git a/src/python/py_perf_metrics.cpp b/src/python/py_perf_metrics.cpp index 3dc68140e..5e2fb1221 100644 --- a/src/python/py_perf_metrics.cpp +++ b/src/python/py_perf_metrics.cpp @@ -112,7 +112,8 @@ std::vector get_ms(const T& instance, U T::*member) { template std::vector timestamp_to_ms(const T& instance, U T::*member) { // Converts c++ duration to double so that it can be used in Python. - // Use double instead of float to store more than 7 signficant digits. + // Use double instead of float bacuse timestamp in ms contains 14 digits + // while float only allows to store ~7 significant digits. std::vector res; const auto& timestamps = instance.*member; res.reserve(timestamps.size()); From 450c1a58364ea38c1032853e24bdee08c33144d4 Mon Sep 17 00:00:00 2001 From: Pavel Esir Date: Sat, 2 Nov 2024 10:33:40 +0100 Subject: [PATCH 06/11] fix segfault --- src/cpp/src/perf_metrics.cpp | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/src/cpp/src/perf_metrics.cpp b/src/cpp/src/perf_metrics.cpp index 47da2014a..84a5b1707 100644 --- a/src/cpp/src/perf_metrics.cpp +++ b/src/cpp/src/perf_metrics.cpp @@ -101,25 +101,21 @@ void PerfMetrics::evaluate_statistics(std::optional start_time) { auto start_time_val = *start_time; auto& tok_times = raw_metrics.m_new_token_times; auto& batch_sizes = raw_metrics.m_batch_sizes; - raw_metrics.m_durations = std::vector(tok_times.size() - 1); + raw_metrics.m_durations.reserve(tok_times.size()); auto ttft = tok_times[0] - start_time_val; raw_metrics.m_times_to_first_token = std::vector(); raw_metrics.m_times_to_first_token.emplace_back(ttft / batch_sizes[0]); - num_generated_tokens = 0; + num_generated_tokens = batch_sizes[0]; // The very first infer request (prefill stage) is slower than sunsequent ones since we process a sequence of tokens. // To have a clearer TPOT number, the time taken to generate the very first token at the prefill stage // must not be included in the TPOT calculation. The first duration used for TPOT is from the first token // to the second token, not from the start time to the first token. - start_time_val = tok_times[0]; for (size_t i = 1; i < tok_times.size(); ++i) { - raw_metrics.m_durations[i] = tok_times[i] - start_time_val; - // If in 10 ms a batch of 5 new tokens is generated then TPOT is 10 / 5 = 2 tok/ms. - raw_metrics.m_durations[i] /= batch_sizes[i]; + raw_metrics.m_durations.emplace_back((tok_times[i] - tok_times[i - 1]) / batch_sizes[i]); num_generated_tokens += batch_sizes[i]; - start_time_val = tok_times[i]; } } From fd95bfb923babeb74bfd74898bf3487e2d1a0e01 Mon Sep 17 00:00:00 2001 From: Pavel Esir Date: Sun, 3 Nov 2024 11:57:27 +0400 Subject: [PATCH 07/11] Update src/README.md --- src/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/README.md b/src/README.md index 35312dddc..d6d26574b 100644 --- a/src/README.md +++ b/src/README.md @@ -376,7 +376,7 @@ print(f'Generate duration: {perf_metrics.get_generate_duration().mean:.2f}') raw_metrics = perf_metrics.raw_metrics generate_duration = np.array(raw_metrics.generate_durations) tok_detok_duration = np.array(raw_metrics.tokenization_durations) - np.array(raw_metrics.detokenization_durations) -pure_inference_duration = np.mean(generate_duration - tok_detok_duration) / 1000 # in seconds +pure_inference_duration = np.sum(generate_duration - tok_detok_duration) / 1000 # in seconds print(f'Pure Inference duration: {pure_inference_duration:.2f} ms') ``` From efb19f99965e33cc2250fe52a9edd994389cbd89 Mon Sep 17 00:00:00 2001 From: Andrei Kochin Date: Mon, 4 Nov 2024 12:31:25 +0400 Subject: [PATCH 08/11] Update src/README.md Co-authored-by: Vladimir Zlobin --- src/README.md | 1 - 1 file changed, 1 deletion(-) diff --git a/src/README.md b/src/README.md index d6d26574b..fb5f9caae 100644 --- a/src/README.md +++ b/src/README.md @@ -366,7 +366,6 @@ print(f'Timestamps: {" ms, ".join(f"{i:.2f}" for i in raw_metrics.m_new_token_ti Getting pure inference time without tokenizatin and detokenization duration: ```python import openvino_genai as ov_genai -import openvino_genai as ov_genai import numpy as np pipe = ov_genai.LLMPipeline(models_path, "CPU") result = pipe.generate(["The Sun is yellow because"], max_new_tokens=20) From 8100497f79e077ad91c88dbd0d8adcffa2787e70 Mon Sep 17 00:00:00 2001 From: Andrei Kochin Date: Mon, 4 Nov 2024 12:31:39 +0400 Subject: [PATCH 09/11] Update src/cpp/src/perf_metrics.cpp Co-authored-by: Vladimir Zlobin --- src/cpp/src/perf_metrics.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cpp/src/perf_metrics.cpp b/src/cpp/src/perf_metrics.cpp index 84a5b1707..6ec88fa31 100644 --- a/src/cpp/src/perf_metrics.cpp +++ b/src/cpp/src/perf_metrics.cpp @@ -108,7 +108,7 @@ void PerfMetrics::evaluate_statistics(std::optional start_time) { raw_metrics.m_times_to_first_token.emplace_back(ttft / batch_sizes[0]); num_generated_tokens = batch_sizes[0]; - // The very first infer request (prefill stage) is slower than sunsequent ones since we process a sequence of tokens. + // The very first infer request (prefill stage) is slower than subsequent ones since we process a sequence of tokens. // To have a clearer TPOT number, the time taken to generate the very first token at the prefill stage // must not be included in the TPOT calculation. The first duration used for TPOT is from the first token // to the second token, not from the start time to the first token. From 7947aa13c7bbc0b56ef2ccfc770a9c9d34cde91a Mon Sep 17 00:00:00 2001 From: Andrei Kochin Date: Mon, 4 Nov 2024 12:31:48 +0400 Subject: [PATCH 10/11] Update src/python/py_perf_metrics.cpp Co-authored-by: Vladimir Zlobin --- src/python/py_perf_metrics.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/src/python/py_perf_metrics.cpp b/src/python/py_perf_metrics.cpp index 5e2fb1221..679acc2b9 100644 --- a/src/python/py_perf_metrics.cpp +++ b/src/python/py_perf_metrics.cpp @@ -114,6 +114,7 @@ std::vector timestamp_to_ms(const T& instance, U T::*member) { // Converts c++ duration to double so that it can be used in Python. // Use double instead of float bacuse timestamp in ms contains 14 digits // while float only allows to store ~7 significant digits. + // And the current timestamp (number of secs from 1970) is already 11 digits. std::vector res; const auto& timestamps = instance.*member; res.reserve(timestamps.size()); From 8cf0685170fa28acda545b082c038bdb67ae1187 Mon Sep 17 00:00:00 2001 From: Pavel Esir Date: Mon, 4 Nov 2024 12:53:55 +0400 Subject: [PATCH 11/11] Apply suggestions from code review --- src/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/README.md b/src/README.md index fb5f9caae..9a96daa9d 100644 --- a/src/README.md +++ b/src/README.md @@ -370,12 +370,12 @@ import numpy as np pipe = ov_genai.LLMPipeline(models_path, "CPU") result = pipe.generate(["The Sun is yellow because"], max_new_tokens=20) perf_metrics = result.perf_metrics -print(f'Generate duration: {perf_metrics.get_generate_duration().mean:.2f}') +print(f'Generate duration: {perf_metrics.get_generate_duration().mean:.2f} ms') raw_metrics = perf_metrics.raw_metrics generate_duration = np.array(raw_metrics.generate_durations) tok_detok_duration = np.array(raw_metrics.tokenization_durations) - np.array(raw_metrics.detokenization_durations) -pure_inference_duration = np.sum(generate_duration - tok_detok_duration) / 1000 # in seconds +pure_inference_duration = np.sum(generate_duration - tok_detok_duration) / 1000 # in milliseconds print(f'Pure Inference duration: {pure_inference_duration:.2f} ms') ```