Skip to content

Commit

Permalink
unify thread pool
Browse files Browse the repository at this point in the history
  • Loading branch information
snnn committed Aug 14, 2019
1 parent a6a5ace commit f2476a8
Show file tree
Hide file tree
Showing 68 changed files with 520 additions and 501 deletions.
5 changes: 0 additions & 5 deletions cmake/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ option(onnxruntime_USE_OPENVINO "Build with OpenVINO support" OFF)
option(onnxruntime_USE_NSYNC "Build with NSYNC support. This option only takes effect on Linux" OFF)
option(onnxruntime_USE_EIGEN_FOR_BLAS "Use eign for blas" ON)
option(onnxruntime_USE_NNAPI "Build with DNNLibrary for Android NNAPI support" OFF)
option(onnxruntime_USE_MLAS "Use optimized blas library for GEMM and 2D Convolution" ON)
option(onnxruntime_USE_MKLDNN "Build with MKL-DNN support" OFF)
option(onnxruntime_USE_MKLML "Build MKL-DNN with MKL-ML binary dependency" OFF)
option(onnxruntime_USE_NGRAPH "Build with nGraph support" OFF)
Expand Down Expand Up @@ -367,10 +366,6 @@ if (onnxruntime_RUN_ONNX_TESTS)
add_definitions(-DORT_RUN_EXTERNAL_ONNX_TESTS)
endif()

if (onnxruntime_USE_MLAS)
add_definitions(-DUSE_MLAS)
endif()

#Adjust warning flags
if (WIN32)
add_definitions(-DPLATFORM_WINDOWS -DNOGDI -DNOMINMAX -D_USE_MATH_DEFINES)
Expand Down
2 changes: 1 addition & 1 deletion cmake/onnxruntime_mlas.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -192,4 +192,4 @@ endif()

add_library(onnxruntime_mlas STATIC ${mlas_common_srcs} ${mlas_platform_srcs})
target_include_directories(onnxruntime_mlas PRIVATE ${ONNXRUNTIME_ROOT}/core/mlas/inc ${ONNXRUNTIME_ROOT}/core/mlas/lib)
set_target_properties(onnxruntime_mlas PROPERTIES FOLDER "ONNXRuntime")
set_target_properties(onnxruntime_mlas PROPERTIES FOLDER "ONNXRuntime")
3 changes: 1 addition & 2 deletions cmake/onnxruntime_unittests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -468,8 +468,7 @@ set(onnx_test_runner_common_srcs
${onnx_test_runner_src_dir}/runner.h
${onnx_test_runner_src_dir}/runner.cc
${onnx_test_runner_src_dir}/TestCase.cc
${onnx_test_runner_src_dir}/TestCase.h
${onnx_test_runner_src_dir}/onnxruntime_event.h
${onnx_test_runner_src_dir}/TestCase.h
${onnx_test_runner_src_dir}/sync_api.h
${onnx_test_runner_src_dir}/sync_api.cc
${onnx_test_runner_src_dir}/callback.h
Expand Down
2 changes: 1 addition & 1 deletion include/onnxruntime/core/framework/op_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ struct KernelCreateInfo {
: kernel_def(std::move(definition)),
kernel_create_func(create_func) {}

KernelCreateInfo(KernelCreateInfo&& other)
KernelCreateInfo(KernelCreateInfo&& other) noexcept
: kernel_def(std::move(other.kernel_def)),
kernel_create_func(std::move(other.kernel_create_func)) {}
};
Expand Down
4 changes: 2 additions & 2 deletions include/onnxruntime/core/framework/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,9 @@ class Tensor final {
//Move is allowed
ORT_DISALLOW_COPY_AND_ASSIGNMENT(Tensor);

Tensor(Tensor&& other);
Tensor(Tensor&& other) noexcept;

Tensor& operator=(Tensor&& other);
Tensor& operator=(Tensor&& other) noexcept;

/**
Returns the data type.
Expand Down
12 changes: 11 additions & 1 deletion include/onnxruntime/core/session/environment.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@

#include <atomic>
#include <memory>
#include <mutex>
#include <thread>
#include "core/common/common.h"
#include "core/common/status.h"
#include "core/platform/threadpool.h"

namespace onnxruntime {
/**
Expand All @@ -29,13 +32,20 @@ class Environment {
Returns whether any runtime environment instance has been initialized.
*/
static bool IsInitialized() { return is_initialized_; }
concurrency::ThreadPool* GetDefaultThreadPool() {
std::call_once(tp_once_, [this]() {
tp_ = new concurrency::ThreadPool("default", std::max<int>(std::thread::hardware_concurrency() - 1, 1));
});
return tp_;
}

private:
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Environment);

Environment() = default;
Status Initialize();

concurrency::ThreadPool* tp_ = nullptr;
std::once_flag tp_once_;
static std::atomic<bool> is_initialized_;
};
} // namespace onnxruntime
19 changes: 11 additions & 8 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ extern "C" {
#define _Inout_
#define _Inout_opt_
#define _Frees_ptr_opt_
#define _Ret_maybenull_
#define _Ret_notnull_
#define _Check_return_
#define _Success_(X)
#define ORT_ALL_ARGS_NONNULL __attribute__((nonnull))
#else
#include <specstrings.h>
Expand Down Expand Up @@ -127,11 +131,11 @@ typedef enum OrtErrorCode {
ORT_EXPORT RETURN_TYPE ORT_API_CALL NAME(__VA_ARGS__) NO_EXCEPTION

#define ORT_API_STATUS(NAME, ...) \
ORT_EXPORT OrtStatus* ORT_API_CALL NAME(__VA_ARGS__) NO_EXCEPTION ORT_MUST_USE_RESULT
ORT_EXPORT _Check_return_ _Success_(return == 0) _Ret_maybenull_ OrtStatus* ORT_API_CALL NAME(__VA_ARGS__) NO_EXCEPTION ORT_MUST_USE_RESULT

// Used in *.cc files. Almost as same as ORT_API_STATUS, except without ORT_MUST_USE_RESULT
#define ORT_API_STATUS_IMPL(NAME, ...) \
ORT_EXPORT OrtStatus* ORT_API_CALL NAME(__VA_ARGS__) NO_EXCEPTION
ORT_EXPORT _Check_return_ _Success_(return == 0) _Ret_maybenull_ OrtStatus* ORT_API_CALL NAME(__VA_ARGS__) NO_EXCEPTION

#define ORT_RUNTIME_CLASS(X) \
struct Ort##X; \
Expand All @@ -140,12 +144,11 @@ typedef enum OrtErrorCode {

// The actual types defined have an Ort prefix
ORT_RUNTIME_CLASS(Env);
ORT_RUNTIME_CLASS(Status); // nullptr for Status* indicates success
ORT_RUNTIME_CLASS(Status);
ORT_RUNTIME_CLASS(Provider);
ORT_RUNTIME_CLASS(AllocatorInfo);
ORT_RUNTIME_CLASS(Session);
ORT_RUNTIME_CLASS(Value);
ORT_RUNTIME_CLASS(ValueList);
ORT_RUNTIME_CLASS(RunOptions);
ORT_RUNTIME_CLASS(TypeInfo);
ORT_RUNTIME_CLASS(TensorTypeAndShapeInfo);
Expand Down Expand Up @@ -328,14 +331,14 @@ ORT_API_STATUS(OrtFillStringTensor, _Inout_ OrtValue* value, _In_ const char* co
* \param value A tensor created from OrtCreateTensor... function.
* \param len total data length, not including the trailing '\0' chars.
*/
ORT_API_STATUS(OrtGetStringTensorDataLength, _In_ const OrtValue* value, _Out_ size_t* len);
ORT_API_STATUS(OrtGetStringTensorDataLength, _Out_ const OrtValue* value, _Out_ size_t* len);

/**
* \param s string contents. Each string is NOT null-terminated.
* \param value A tensor created from OrtCreateTensor... function.
* \param s_len total data length, get it from OrtGetStringTensorDataLength
*/
ORT_API_STATUS(OrtGetStringTensorContent, _In_ const OrtValue* value, _Out_ void* s, size_t s_len,
ORT_API_STATUS(OrtGetStringTensorContent, _In_ const OrtValue* value, _In_ void* s, size_t s_len,
_Out_ size_t* offsets, size_t offsets_len);

/**
Expand Down Expand Up @@ -440,8 +443,8 @@ ORT_API(const char*, OrtGetVersionString);
/**
* \param msg A null-terminated string. Its content will be copied into the newly created OrtStatus
*/
ORT_API(OrtStatus*, OrtCreateStatus, OrtErrorCode code, _In_ const char* msg)
ORT_ALL_ARGS_NONNULL;
ORT_EXPORT _Check_return_ _Ret_notnull_ OrtStatus* ORT_API_CALL OrtCreateStatus(OrtErrorCode code, _In_ const char* msg) NO_EXCEPTION
ORT_ALL_ARGS_NONNULL;

ORT_API(OrtErrorCode, OrtGetErrorCode, _In_ const OrtStatus* status)
ORT_ALL_ARGS_NONNULL;
Expand Down
25 changes: 13 additions & 12 deletions onnxruntime/contrib_ops/cpu/attnlstm/attention_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,16 @@ template <typename T>
AttentionWrapper<T>::AttentionWrapper(AllocatorPtr alloc, const logging::Logger& logger,
int batch_size, int attn_context_depth, int attn_layer_depth,
int inner_cell_hidden_size, bool has_attn_layer,
const IAttentionMechanism<T>& attention_mechanism)
const IAttentionMechanism<T>& attention_mechanism, concurrency::ThreadPool* threadpool)
: allocator_(alloc),
logger_(logger),
batch_size_(batch_size),
attn_context_depth_(attn_context_depth),
attn_layer_depth_(attn_layer_depth),
inner_cell_hidden_size_(inner_cell_hidden_size),
has_attn_layer_(has_attn_layer),
attention_mechanism_(attention_mechanism) {
attention_mechanism_(attention_mechanism),
ttp_(threadpool) {
auto mem_max_steps = attention_mechanism_.GetMaxMemorySteps();
prev_alignments_ = Allocate(allocator_, batch_size_ * mem_max_steps, prev_alignments_ptr_, true);
alignments_ = Allocate(allocator_, batch_size_ * mem_max_steps, alignments_ptr_, true);
Expand All @@ -37,11 +38,11 @@ template <typename T>
void AttentionWrapper<T>::ProcessOutput(const gsl::span<const T>& rnn_cell_output) {
if (has_attn_layer_) {
// rnn_cell_output * cell_weights, (part of the attention layer above the attention mechanism).
math::GemmEx<T, CPUMathUtil>(CblasNoTrans, CblasNoTrans,
batch_size_, attn_layer_depth_, inner_cell_hidden_size_, T{1.0},
rnn_cell_output.data(), inner_cell_hidden_size_,
attn_layer_cell_weights_.data(), attn_layer_depth_, T{0.0},
attn_states_.data(), attn_layer_depth_, &CPUMathUtil::Instance());
math::GemmEx<T>(CblasNoTrans, CblasNoTrans,
batch_size_, attn_layer_depth_, inner_cell_hidden_size_, T{1.0},
rnn_cell_output.data(), inner_cell_hidden_size_,
attn_layer_cell_weights_.data(), attn_layer_depth_, T{0.0},
attn_states_.data(), attn_layer_depth_, ttp_);
}

// Get the context which is calculated within attention mechanism.
Expand All @@ -54,11 +55,11 @@ void AttentionWrapper<T>::ProcessOutput(const gsl::span<const T>& rnn_cell_outpu
//concat([p_cell_output, context]) * stack([attn_layer_cell_weights_, attn_layer_attn_weights_]) =
// p_cell_output * attn_layer_cell_weights_ + context * attn_layer_attn_weights_
// The first part is calulated above. Here just add the later.
math::GemmEx<T, CPUMathUtil>(CblasNoTrans, CblasNoTrans,
batch_size_, attn_layer_depth_, attn_context_depth_, T{1.0},
attn_context_.data(), attn_context_depth_,
attn_layer_attn_weights_.data(), attn_layer_depth_, T{1.0},
attn_states_.data(), attn_layer_depth_, &CPUMathUtil::Instance());
math::GemmEx<T>(CblasNoTrans, CblasNoTrans,
batch_size_, attn_layer_depth_, attn_context_depth_, T{1.0},
attn_context_.data(), attn_context_depth_,
attn_layer_attn_weights_.data(), attn_layer_depth_, T{1.0},
attn_states_.data(), attn_layer_depth_, ttp_);
}
}

Expand Down
4 changes: 3 additions & 1 deletion onnxruntime/contrib_ops/cpu/attnlstm/attention_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "core/common/common.h"
#include "core/common/logging/logging.h"
#include "core/framework/allocator.h"
#include "core/platform/threadpool.h"

namespace onnxruntime {
namespace contrib {
Expand All @@ -22,7 +23,7 @@ class AttentionWrapper {
int attn_layer_depth,
int inner_cell_hidden_size,
bool has_attn_layer,
const IAttentionMechanism<T>& attention_mechanism);
const IAttentionMechanism<T>& attention_mechanism, concurrency::ThreadPool* threadpool);

virtual ~AttentionWrapper() = default;

Expand Down Expand Up @@ -69,6 +70,7 @@ class AttentionWrapper {
bool has_attn_layer_;

const IAttentionMechanism<T>& attention_mechanism_;
concurrency::ThreadPool* ttp_;
};

} // namespace contrib
Expand Down
34 changes: 17 additions & 17 deletions onnxruntime/contrib_ops/cpu/attnlstm/bahdanau_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ namespace contrib {
template <typename T>
BahdanauAttention<T>::BahdanauAttention(AllocatorPtr allocator, const logging::Logger& logger,
int batch_size, int max_memory_step, int memory_depth,
int query_depth, int attn_depth, bool normalize)
: allocator_(allocator), logger_(logger), batch_size_(batch_size), max_memory_steps_(max_memory_step), memory_depth_(memory_depth), query_depth_(query_depth), attn_depth_(attn_depth), normalize_(normalize) {
int query_depth, int attn_depth, bool normalize, concurrency::ThreadPool* threadpool)
: allocator_(allocator), logger_(logger), batch_size_(batch_size), max_memory_steps_(max_memory_step), memory_depth_(memory_depth), query_depth_(query_depth), attn_depth_(attn_depth), normalize_(normalize), ttp_(threadpool) {
values_ = Allocate(allocator_, batch_size_ * max_memory_steps_ * memory_depth_, values_ptr_, true);
keys_ = Allocate(allocator_, batch_size_ * max_memory_steps_ * attn_depth_, keys_ptr_, true);
processed_query_ = Allocate(allocator_, batch_size_ * attn_depth_, processed_query_ptr_, true);
Expand Down Expand Up @@ -72,11 +72,11 @@ void BahdanauAttention<T>::PrepareMemory(
"Real memory steps ", mem_steps, " is not in (0, ", max_memory_steps_, "]");
}

math::GemmEx<T, CPUMathUtil>(CblasNoTrans, CblasNoTrans,
batch_size_ * max_memory_steps_, attn_depth_, memory_depth_, T{1.0},
memory.data(), memory_depth_,
memory_layer_weights_.data(), attn_depth_, T{0.0},
keys_.data(), attn_depth_, &CPUMathUtil::Instance());
math::GemmEx<T>(CblasNoTrans, CblasNoTrans,
batch_size_ * max_memory_steps_, attn_depth_, memory_depth_, T{1.0},
memory.data(), memory_depth_,
memory_layer_weights_.data(), attn_depth_, T{0.0},
keys_.data(), attn_depth_, ttp_);
}

template <typename T>
Expand Down Expand Up @@ -115,11 +115,11 @@ void BahdanauAttention<T>::Compute(
const gsl::span<T>& output,
const gsl::span<T>& aligns) const {
//process query in dense query layer without bias
math::GemmEx<T, CPUMathUtil>(CblasNoTrans, CblasNoTrans,
batch_size_, attn_depth_, query_depth_, T{1.0},
queries.data(), query_depth_,
query_layer_weights_.data(), attn_depth_, T{0.0},
processed_query_.data(), attn_depth_, &CPUMathUtil::Instance());
math::GemmEx<T>(CblasNoTrans, CblasNoTrans,
batch_size_, attn_depth_, query_depth_, T{1.0},
queries.data(), query_depth_,
query_layer_weights_.data(), attn_depth_, T{0.0},
processed_query_.data(), attn_depth_, ttp_);

std::fill(aligns.begin(), aligns.end(), T{});

Expand All @@ -146,11 +146,11 @@ void BahdanauAttention<T>::Compute(
// Calculate the context
auto outspan = output.subspan(b * memory_depth_);
auto values = values_.subspan(b * max_memory_steps_ * memory_depth_);
math::GemmEx<T, CPUMathUtil>(CblasNoTrans, CblasNoTrans,
1, memory_depth_, max_memory_steps_, T{1.0},
alignments, max_memory_steps_,
values.data(), memory_depth_, T{0.0},
outspan.data(), memory_depth_, &CPUMathUtil::Instance());
math::GemmEx<T>(CblasNoTrans, CblasNoTrans,
1, memory_depth_, max_memory_steps_, T{1.0},
alignments, max_memory_steps_,
values.data(), memory_depth_, T{0.0},
outspan.data(), memory_depth_, ttp_);
}
}

Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/contrib_ops/cpu/attnlstm/bahdanau_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class BahdanauAttention : public IAttentionMechanism<T> {
int memory_depth,
int query_depth,
int attn_depth,
bool normalize);
bool normalize, concurrency::ThreadPool* threadpool);

void SetWeights(
const gsl::span<const T>& attn_weights,
Expand Down Expand Up @@ -77,6 +77,7 @@ class BahdanauAttention : public IAttentionMechanism<T> {
gsl::span<int> mem_seq_lengths_;

bool normalize_;
concurrency::ThreadPool* ttp_;
};

} // namespace contrib
Expand Down
Loading

0 comments on commit f2476a8

Please sign in to comment.