diff --git a/runtime/hsa-runtime/core/common/hsa_table_interface.cpp b/runtime/hsa-runtime/core/common/hsa_table_interface.cpp index fc42ac8a1..3129d8874 100644 --- a/runtime/hsa-runtime/core/common/hsa_table_interface.cpp +++ b/runtime/hsa-runtime/core/common/hsa_table_interface.cpp @@ -909,6 +909,16 @@ hsa_status_t HSA_API return amdExtTable->hsa_amd_async_function_fn(callback, arg); } +// Mirrors Amd Extension Apis +uint32_t HSA_API hsa_amd_signal_wait_all( + uint32_t signal_count, hsa_signal_t *signals, hsa_signal_condition_t *conds, + hsa_signal_value_t *values, uint64_t timeout_hint, + hsa_wait_state_t wait_hint, hsa_signal_value_t *satisfying_values) { + return amdExtTable->hsa_amd_signal_wait_all_fn(signal_count, signals, conds, + values, timeout_hint, + wait_hint, satisfying_values); +} + // Mirrors Amd Extension Apis uint32_t HSA_API hsa_amd_signal_wait_any(uint32_t signal_count, hsa_signal_t* signals, diff --git a/runtime/hsa-runtime/core/inc/hsa_ext_amd_impl.h b/runtime/hsa-runtime/core/inc/hsa_ext_amd_impl.h index 5109d3976..4b6b0920a 100644 --- a/runtime/hsa-runtime/core/inc/hsa_ext_amd_impl.h +++ b/runtime/hsa-runtime/core/inc/hsa_ext_amd_impl.h @@ -100,6 +100,14 @@ hsa_status_t hsa_amd_signal_create(hsa_signal_value_t initial_value, uint32_t nu const hsa_agent_t* consumers, uint64_t attributes, hsa_signal_t* signal); +// Mirrors Amd Extension Apis +uint32_t hsa_amd_signal_wait_all(uint32_t signal_count, hsa_signal_t *signals, + hsa_signal_condition_t *conds, + hsa_signal_value_t *values, + uint64_t timeout_hint, + hsa_wait_state_t wait_hint, + hsa_signal_value_t *satisfying_values); + // Mirrors Amd Extension Apis uint32_t hsa_amd_signal_wait_any(uint32_t signal_count, hsa_signal_t* signals, diff --git a/runtime/hsa-runtime/core/inc/signal.h b/runtime/hsa-runtime/core/inc/signal.h index f960bc34b..33c4fc340 100644 --- a/runtime/hsa-runtime/core/inc/signal.h +++ b/runtime/hsa-runtime/core/inc/signal.h @@ -347,14 +347,29 @@ class Signal { /// Returns NULL for DefaultEvent Type. virtual HsaEvent* EopEvent() = 0; - /// @brief Waits until any signal in the list satisfies its condition or - /// timeout is reached. - /// Returns the index of a satisfied signal. Returns -1 on timeout and - /// errors. - static uint32_t WaitAny(uint32_t signal_count, const hsa_signal_t* hsa_signals, - const hsa_signal_condition_t* conds, const hsa_signal_value_t* values, - uint64_t timeout_hint, hsa_wait_state_t wait_hint, - hsa_signal_value_t* satisfying_value); + /// @brief Waits until multiple signals in the list satisfy their conditions + /// or a timeout is reached. + /// @param signal_count Number of hsa_signals in the list. + /// @param hsa_signals Pointer to array of HSA signals. + /// @param conds Pointer to array of signal conditions. + /// @param values Pointer to array of signal values. + /// @param timeout Timeout hint value. + /// @param wait_hint Hint about wait state. + /// @param satisfying_values Vector of satisfying values. If \p wait_on_all + /// is false (then we are waiting on any signal in the list) this will contain + /// only the first satisfying value. + /// @param wait_on_all Wait on all signals in the list to satisfy their + /// conditions if true, else wait on any signal in the list to satisfy its + /// condition. + /// @return Return the index of the first signal in the list that satisfies + /// its condition or -1 on a timeout. Note that if \p wait_on_all is true, + /// then all signals in the list satisfy their conditions, thus the index will + /// always be 0. + static uint32_t WaitMultiple( + uint32_t signal_count, const hsa_signal_t *hsa_signals, + const hsa_signal_condition_t *conds, const hsa_signal_value_t *values, + uint64_t timeout, hsa_wait_state_t wait_hint, + std::vector &satisfying_values, bool wait_on_all); /// @brief Dedicated funtion to wait on signals that are not of type HSA_EVENTTYPE_SIGNAL /// these events can only be received by calling the underlying driver (i.e via the hsaKmtWaitOnMultipleEvents_Ext diff --git a/runtime/hsa-runtime/core/runtime/hsa_api_trace.cpp b/runtime/hsa-runtime/core/runtime/hsa_api_trace.cpp index 8e42cae8b..c696ff23d 100644 --- a/runtime/hsa-runtime/core/runtime/hsa_api_trace.cpp +++ b/runtime/hsa-runtime/core/runtime/hsa_api_trace.cpp @@ -87,7 +87,7 @@ void HsaApiTable::Init() { // they can add preprocessor macros on the new functions constexpr size_t expected_core_api_table_size = 1016; - constexpr size_t expected_amd_ext_table_size = 584; + constexpr size_t expected_amd_ext_table_size = 592; constexpr size_t expected_image_ext_table_size = 120; constexpr size_t expected_finalizer_ext_table_size = 64; constexpr size_t expected_tools_table_size = 64; @@ -412,6 +412,7 @@ void HsaApiTable::UpdateAmdExts() { amd_ext_api.hsa_amd_profiling_convert_tick_to_system_domain_fn = AMD::hsa_amd_profiling_convert_tick_to_system_domain; amd_ext_api.hsa_amd_signal_async_handler_fn = AMD::hsa_amd_signal_async_handler; amd_ext_api.hsa_amd_async_function_fn = AMD::hsa_amd_async_function; + amd_ext_api.hsa_amd_signal_wait_all_fn = AMD::hsa_amd_signal_wait_all; amd_ext_api.hsa_amd_signal_wait_any_fn = AMD::hsa_amd_signal_wait_any; amd_ext_api.hsa_amd_queue_cu_set_mask_fn = AMD::hsa_amd_queue_cu_set_mask; amd_ext_api.hsa_amd_queue_cu_get_mask_fn = AMD::hsa_amd_queue_cu_get_mask; diff --git a/runtime/hsa-runtime/core/runtime/hsa_ext_amd.cpp b/runtime/hsa-runtime/core/runtime/hsa_ext_amd.cpp index 615eee7f6..cc5d6b150 100644 --- a/runtime/hsa-runtime/core/runtime/hsa_ext_amd.cpp +++ b/runtime/hsa-runtime/core/runtime/hsa_ext_amd.cpp @@ -40,13 +40,14 @@ // //////////////////////////////////////////////////////////////////////////////// -#include -#include +#include #include +#include +#include +#include #include +#include #include -#include -#include #include #include "core/inc/agent.h" @@ -570,6 +571,40 @@ hsa_status_t hsa_amd_signal_value_pointer(hsa_signal_t hsa_signal, CATCH; } +uint32_t hsa_amd_signal_wait_all(uint32_t signal_count, + hsa_signal_t *hsa_signals, + hsa_signal_condition_t *conds, + hsa_signal_value_t *values, + uint64_t timeout_hint, + hsa_wait_state_t wait_hint, + hsa_signal_value_t *satisfying_values) { + TRY; + if (!core::Runtime::runtime_singleton_->IsOpen()) { + assert(false && "hsa_amd_signal_wait_all called while not initialized."); + return 0; + } + // Do not check for signal invalidation. Invalidation may occur during async + // signal handler loop and is not an error. + for (int i = 0; i < signal_count; ++i) + assert(hsa_signals[i].handle != 0 && + core::SharedSignal::Convert(hsa_signals[i])->IsValid() && + "Invalid signal."); + + std::vector satisfying_values_vec; + satisfying_values_vec.resize(signal_count); + uint32_t first_satysifying_signal_id = core::Signal::WaitMultiple( + signal_count, hsa_signals, conds, values, timeout_hint, wait_hint, + satisfying_values_vec, true); + + if (satisfying_values) { + std::copy(satisfying_values_vec.begin(), satisfying_values_vec.end(), + satisfying_values); + } + + return first_satysifying_signal_id; + CATCHRET(uint32_t); +} + uint32_t hsa_amd_signal_wait_any(uint32_t signal_count, hsa_signal_t* hsa_signals, hsa_signal_condition_t* conds, hsa_signal_value_t* values, uint64_t timeout_hint, hsa_wait_state_t wait_hint, @@ -585,8 +620,15 @@ uint32_t hsa_amd_signal_wait_any(uint32_t signal_count, hsa_signal_t* hsa_signal assert(hsa_signals[i].handle != 0 && core::SharedSignal::Convert(hsa_signals[i])->IsValid() && "Invalid signal."); - return core::Signal::WaitAny(signal_count, hsa_signals, conds, values, - timeout_hint, wait_hint, satisfying_value); + std::vector satisfying_value_vec(1); + uint32_t satisfying_signal_id = core::Signal::WaitMultiple( + signal_count, hsa_signals, conds, values, timeout_hint, wait_hint, + satisfying_value_vec, false); + + if (satisfying_value) + *satisfying_value = satisfying_value_vec.at(0); + + return satisfying_signal_id; CATCHRET(uint32_t); } diff --git a/runtime/hsa-runtime/core/runtime/signal.cpp b/runtime/hsa-runtime/core/runtime/signal.cpp index 8a60e2599..3f6f22a73 100644 --- a/runtime/hsa-runtime/core/runtime/signal.cpp +++ b/runtime/hsa-runtime/core/runtime/signal.cpp @@ -2,24 +2,24 @@ // // The University of Illinois/NCSA // Open Source License (NCSA) -// +// // Copyright (c) 2014-2020, Advanced Micro Devices, Inc. All rights reserved. -// +// // Developed by: -// +// // AMD Research and AMD HSA Software Development -// +// // Advanced Micro Devices, Inc. -// +// // www.amd.com -// +// // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to // deal with the Software without restriction, including without limitation // the rights to use, copy, modify, merge, publish, distribute, sublicense, // and/or sell copies of the Software, and to permit persons to whom the // Software is furnished to do so, subject to the following conditions: -// +// // - Redistributions of source code must retain the above copyright notice, // this list of conditions and the following disclaimers. // - Redistributions in binary form must reproduce the above copyright @@ -29,7 +29,7 @@ // nor the names of its contributors may be used to endorse or promote // products derived from this Software without specific prior written // permission. -// +// // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL @@ -46,6 +46,9 @@ #include "core/inc/signal.h" #include +#include +#include + #include "core/util/timer.h" #include "core/inc/runtime.h" @@ -177,10 +180,13 @@ Signal::~Signal() { } } -uint32_t Signal::WaitAny(uint32_t signal_count, const hsa_signal_t* hsa_signals, - const hsa_signal_condition_t* conds, const hsa_signal_value_t* values, - uint64_t timeout, hsa_wait_state_t wait_hint, - hsa_signal_value_t* satisfying_value) { +uint32_t Signal::WaitMultiple(uint32_t signal_count, + const hsa_signal_t *hsa_signals, + const hsa_signal_condition_t *conds, + const hsa_signal_value_t *values, + uint64_t timeout, hsa_wait_state_t wait_hint, + std::vector &satisfying_value, + bool wait_on_all) { hsa_signal_handle* signals = reinterpret_cast(const_cast(hsa_signals)); @@ -251,10 +257,15 @@ uint32_t Signal::WaitAny(uint32_t signal_count, const hsa_signal_t* hsa_signals, timer::duration_from_seconds( double(timeout) / double(hsa_freq)); - bool condition_met = false; + std::vector unmet_condition_ids(signal_count); + std::iota(unmet_condition_ids.begin(), unmet_condition_ids.end(), 0); + while (true) { // Cannot mwaitx - polling multiple signals - for (uint32_t i = 0; i < signal_count; i++) { + for (auto it = unmet_condition_ids.begin(); + it != unmet_condition_ids.end();) { + auto i = *it; + bool condition_met = false; if (!signals[i]->IsValid()) return uint32_t(-1); @@ -282,8 +293,14 @@ uint32_t Signal::WaitAny(uint32_t signal_count, const hsa_signal_t* hsa_signals, return uint32_t(-1); } if (condition_met) { - if (satisfying_value != NULL) *satisfying_value = value; - return i; + it = unmet_condition_ids.erase(it); + satisfying_value[i] = value; + if (!wait_on_all) + return i; + else if (unmet_condition_ids.empty()) + return 0; + } else { + ++it; } } @@ -306,7 +323,8 @@ uint32_t Signal::WaitAny(uint32_t signal_count, const hsa_signal_t* hsa_signals, uint64_t ct=timer::duration_cast( time_remaining).count(); wait_ms = (ct>0xFFFFFFFEu) ? 0xFFFFFFFEu : ct; - hsaKmtWaitOnMultipleEvents_Ext(evts, unique_evts, false, wait_ms, event_age); + hsaKmtWaitOnMultipleEvents_Ext(evts, unique_evts, wait_on_all, wait_ms, + event_age); } } diff --git a/runtime/hsa-runtime/hsacore.so.def b/runtime/hsa-runtime/hsacore.so.def index 6f1a019fb..27f65019f 100644 --- a/runtime/hsa-runtime/hsacore.so.def +++ b/runtime/hsa-runtime/hsacore.so.def @@ -174,6 +174,7 @@ global: hsa_amd_profiling_get_async_copy_time; hsa_amd_profiling_convert_tick_to_system_domain; hsa_amd_signal_create; + hsa_amd_signal_wait_all; hsa_amd_signal_wait_any; hsa_amd_signal_async_handler; hsa_amd_async_function; diff --git a/runtime/hsa-runtime/inc/hsa_api_trace.h b/runtime/hsa-runtime/inc/hsa_api_trace.h index e0063e6da..171db3954 100644 --- a/runtime/hsa-runtime/inc/hsa_api_trace.h +++ b/runtime/hsa-runtime/inc/hsa_api_trace.h @@ -203,6 +203,7 @@ struct AmdExtTable { decltype(hsa_amd_profiling_convert_tick_to_system_domain)* hsa_amd_profiling_convert_tick_to_system_domain_fn; decltype(hsa_amd_signal_async_handler)* hsa_amd_signal_async_handler_fn; decltype(hsa_amd_async_function)* hsa_amd_async_function_fn; + decltype(hsa_amd_signal_wait_all) *hsa_amd_signal_wait_all_fn; decltype(hsa_amd_signal_wait_any)* hsa_amd_signal_wait_any_fn; decltype(hsa_amd_queue_cu_set_mask)* hsa_amd_queue_cu_set_mask_fn; decltype(hsa_amd_memory_pool_get_info)* hsa_amd_memory_pool_get_info_fn; diff --git a/runtime/hsa-runtime/inc/hsa_ext_amd.h b/runtime/hsa-runtime/inc/hsa_ext_amd.h index 2f3b48c1d..e7f93371a 100644 --- a/runtime/hsa-runtime/inc/hsa_ext_amd.h +++ b/runtime/hsa-runtime/inc/hsa_ext_amd.h @@ -1183,6 +1183,20 @@ hsa_status_t HSA_API hsa_signal_value_t value, hsa_amd_signal_handler handler, void* arg); +/** + * @brief Wait for all signal-condition pairs to be satisfied. + * + * @details Allows waiting for all of several signal and condition pairs to be + * satisfied. The function returns 0 if all signals met their conditions and -1 + * on a timeout. The value of each signal's satisfying value is returned in + * satisfying_value unless satisfying_value is nullptr. This function provides + * only relaxed memory semantics. + */ +uint32_t HSA_API hsa_amd_signal_wait_all( + uint32_t signal_count, hsa_signal_t *signals, hsa_signal_condition_t *conds, + hsa_signal_value_t *values, uint64_t timeout_hint, + hsa_wait_state_t wait_hint, hsa_signal_value_t *satisfying_values); + /** * @brief Wait for any signal-condition pair to be satisfied. * @@ -1433,7 +1447,7 @@ typedef enum { * following its memory access model. The actual placement may vary or migrate * due to the system's NUMA policy and state, which is beyond the scope of * HSA APIs. - */ + */ typedef struct hsa_amd_memory_pool_s { /** * Opaque handle. @@ -2976,7 +2990,7 @@ typedef enum hsa_amd_svm_attribute_s { HSA_AMD_SVM_ATTRIB_ACCESS_QUERY = 0x203, } hsa_amd_svm_attribute_t; -// List type for hsa_amd_svm_attributes_set/get. +// List type for hsa_amd_svm_attributes_set/get. typedef struct hsa_amd_svm_attribute_pair_s { // hsa_amd_svm_attribute_t value. uint64_t attribute;