Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Let mlas use session thread pool #1609

Merged
merged 8 commits into from
Aug 16, 2019
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions cmake/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ option(onnxruntime_USE_OPENVINO "Build with OpenVINO support" OFF)
option(onnxruntime_USE_NSYNC "Build with NSYNC support. This option only takes effect on Linux" OFF)
option(onnxruntime_USE_EIGEN_FOR_BLAS "Use eign for blas" ON)
option(onnxruntime_USE_NNAPI "Build with DNNLibrary for Android NNAPI support" OFF)
option(onnxruntime_USE_MLAS "Use optimized blas library for GEMM and 2D Convolution" ON)
option(onnxruntime_USE_MKLDNN "Build with MKL-DNN support" OFF)
option(onnxruntime_USE_MKLML "Build MKL-DNN with MKL-ML binary dependency" OFF)
option(onnxruntime_USE_AUTOML "Build AutoML support" ON)
Expand Down Expand Up @@ -368,10 +367,6 @@ if (onnxruntime_RUN_ONNX_TESTS)
add_definitions(-DORT_RUN_EXTERNAL_ONNX_TESTS)
endif()

if (onnxruntime_USE_MLAS)
add_definitions(-DUSE_MLAS)
endif()

#Adjust warning flags
if (WIN32)
add_definitions(-DPLATFORM_WINDOWS -DNOGDI -DNOMINMAX -D_USE_MATH_DEFINES)
Expand Down
2 changes: 1 addition & 1 deletion include/onnxruntime/core/framework/op_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ struct KernelCreateInfo {
: kernel_def(std::move(definition)),
kernel_create_func(create_func) {}

KernelCreateInfo(KernelCreateInfo&& other)
KernelCreateInfo(KernelCreateInfo&& other) noexcept
: kernel_def(std::move(other.kernel_def)),
kernel_create_func(std::move(other.kernel_create_func)) {}
};
Expand Down
4 changes: 2 additions & 2 deletions include/onnxruntime/core/framework/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,9 @@ class Tensor final {
//Move is allowed
ORT_DISALLOW_COPY_AND_ASSIGNMENT(Tensor);

Tensor(Tensor&& other);
Tensor(Tensor&& other) noexcept;

Tensor& operator=(Tensor&& other);
Tensor& operator=(Tensor&& other) noexcept;

/**
Returns the data type.
Expand Down
12 changes: 11 additions & 1 deletion include/onnxruntime/core/session/environment.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@

#include <atomic>
#include <memory>
#include <mutex>
#include <thread>
#include "core/common/common.h"
#include "core/common/status.h"
#include "core/platform/threadpool.h"

namespace onnxruntime {
/**
Expand All @@ -29,13 +32,20 @@ class Environment {
Returns whether any runtime environment instance has been initialized.
*/
static bool IsInitialized() { return is_initialized_; }
concurrency::ThreadPool* GetThreadPool() {
snnn marked this conversation as resolved.
Show resolved Hide resolved
std::call_once(tp_once_, [this]() {
tp_ = new concurrency::ThreadPool("default", std::max<int>(std::thread::hardware_concurrency() - 1, 1));
snnn marked this conversation as resolved.
Show resolved Hide resolved
});
return tp_;
}

private:
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Environment);

Environment() = default;
Status Initialize();

concurrency::ThreadPool* tp_ = nullptr;
std::once_flag tp_once_;
static std::atomic<bool> is_initialized_;
};
} // namespace onnxruntime
19 changes: 11 additions & 8 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ extern "C" {
#define _Inout_
#define _Inout_opt_
#define _Frees_ptr_opt_
#define _Ret_maybenull_
#define _Ret_notnull_
#define _Check_return_
#define _Success_(X)
#define ORT_ALL_ARGS_NONNULL __attribute__((nonnull))
#else
#include <specstrings.h>
Expand Down Expand Up @@ -127,11 +131,11 @@ typedef enum OrtErrorCode {
ORT_EXPORT RETURN_TYPE ORT_API_CALL NAME(__VA_ARGS__) NO_EXCEPTION

#define ORT_API_STATUS(NAME, ...) \
ORT_EXPORT OrtStatus* ORT_API_CALL NAME(__VA_ARGS__) NO_EXCEPTION ORT_MUST_USE_RESULT
ORT_EXPORT _Check_return_ _Success_(return == 0) _Ret_maybenull_ OrtStatus* ORT_API_CALL NAME(__VA_ARGS__) NO_EXCEPTION ORT_MUST_USE_RESULT

// Used in *.cc files. Almost as same as ORT_API_STATUS, except without ORT_MUST_USE_RESULT
#define ORT_API_STATUS_IMPL(NAME, ...) \
ORT_EXPORT OrtStatus* ORT_API_CALL NAME(__VA_ARGS__) NO_EXCEPTION
ORT_EXPORT _Check_return_ _Success_(return == 0) _Ret_maybenull_ OrtStatus* ORT_API_CALL NAME(__VA_ARGS__) NO_EXCEPTION

#define ORT_RUNTIME_CLASS(X) \
struct Ort##X; \
Expand All @@ -140,12 +144,11 @@ typedef enum OrtErrorCode {

// The actual types defined have an Ort prefix
ORT_RUNTIME_CLASS(Env);
ORT_RUNTIME_CLASS(Status); // nullptr for Status* indicates success
snnn marked this conversation as resolved.
Show resolved Hide resolved
ORT_RUNTIME_CLASS(Status);
ORT_RUNTIME_CLASS(Provider);
ORT_RUNTIME_CLASS(AllocatorInfo);
ORT_RUNTIME_CLASS(Session);
ORT_RUNTIME_CLASS(Session); //Don't call OrtReleaseSession from Dllmain (because session owns a thread pool)
ORT_RUNTIME_CLASS(Value);
ORT_RUNTIME_CLASS(ValueList);
ORT_RUNTIME_CLASS(RunOptions);
ORT_RUNTIME_CLASS(TypeInfo);
ORT_RUNTIME_CLASS(TensorTypeAndShapeInfo);
Expand Down Expand Up @@ -339,7 +342,7 @@ ORT_API_STATUS(OrtGetStringTensorDataLength, _In_ const OrtValue* value, _Out_ s
* \param value A tensor created from OrtCreateTensor... function.
* \param s_len total data length, get it from OrtGetStringTensorDataLength
*/
ORT_API_STATUS(OrtGetStringTensorContent, _In_ const OrtValue* value, _Out_ void* s, size_t s_len,
ORT_API_STATUS(OrtGetStringTensorContent, _In_ const OrtValue* value, _In_ void* s, size_t s_len,
_Out_ size_t* offsets, size_t offsets_len);

snnn marked this conversation as resolved.
Show resolved Hide resolved
/**
Expand Down Expand Up @@ -444,8 +447,8 @@ ORT_API(const char*, OrtGetVersionString);
/**
* \param msg A null-terminated string. Its content will be copied into the newly created OrtStatus
*/
ORT_API(OrtStatus*, OrtCreateStatus, OrtErrorCode code, _In_ const char* msg)
ORT_ALL_ARGS_NONNULL;
ORT_EXPORT _Check_return_ _Ret_notnull_ OrtStatus* ORT_API_CALL OrtCreateStatus(OrtErrorCode code, _In_ const char* msg) NO_EXCEPTION
ORT_ALL_ARGS_NONNULL;

ORT_API(OrtErrorCode, OrtGetErrorCode, _In_ const OrtStatus* status)
ORT_ALL_ARGS_NONNULL;
Expand Down
2 changes: 2 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ struct Env;
struct TypeInfo;
struct Value;

//Don't put such an object as a global(or thread local) variable in a DLL
struct Env : Base<OrtEnv> {
Env(nullptr_t) {}
Env(OrtLoggingLevel default_logging_level, _In_ const char* logid);
Expand Down Expand Up @@ -156,6 +157,7 @@ struct SessionOptions : Base<OrtSessionOptions> {
SessionOptions& Add(OrtCustomOpDomain* custom_op_domain);
};

//Don't put such an object as a global(or thread local) variable in a DLL
struct Session : Base<OrtSession> {
explicit Session(nullptr_t) {}
Session(Env& env, const ORTCHAR_T* model_path, const SessionOptions& options);
Expand Down
25 changes: 13 additions & 12 deletions onnxruntime/contrib_ops/cpu/attnlstm/attention_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,16 @@ template <typename T>
AttentionWrapper<T>::AttentionWrapper(AllocatorPtr alloc, const logging::Logger& logger,
int batch_size, int attn_context_depth, int attn_layer_depth,
int inner_cell_hidden_size, bool has_attn_layer,
const IAttentionMechanism<T>& attention_mechanism)
const IAttentionMechanism<T>& attention_mechanism, concurrency::ThreadPool* threadpool)
: allocator_(alloc),
logger_(logger),
batch_size_(batch_size),
attn_context_depth_(attn_context_depth),
attn_layer_depth_(attn_layer_depth),
inner_cell_hidden_size_(inner_cell_hidden_size),
has_attn_layer_(has_attn_layer),
attention_mechanism_(attention_mechanism) {
attention_mechanism_(attention_mechanism),
ttp_(threadpool) {
auto mem_max_steps = attention_mechanism_.GetMaxMemorySteps();
prev_alignments_ = Allocate(allocator_, batch_size_ * mem_max_steps, prev_alignments_ptr_, true);
alignments_ = Allocate(allocator_, batch_size_ * mem_max_steps, alignments_ptr_, true);
Expand All @@ -37,11 +38,11 @@ template <typename T>
void AttentionWrapper<T>::ProcessOutput(const gsl::span<const T>& rnn_cell_output) {
if (has_attn_layer_) {
// rnn_cell_output * cell_weights, (part of the attention layer above the attention mechanism).
math::GemmEx<T, CPUMathUtil>(CblasNoTrans, CblasNoTrans,
batch_size_, attn_layer_depth_, inner_cell_hidden_size_, T{1.0},
rnn_cell_output.data(), inner_cell_hidden_size_,
attn_layer_cell_weights_.data(), attn_layer_depth_, T{0.0},
attn_states_.data(), attn_layer_depth_, &CPUMathUtil::Instance());
math::GemmEx<T>(CblasNoTrans, CblasNoTrans,
batch_size_, attn_layer_depth_, inner_cell_hidden_size_, T{1.0},
rnn_cell_output.data(), inner_cell_hidden_size_,
attn_layer_cell_weights_.data(), attn_layer_depth_, T{0.0},
attn_states_.data(), attn_layer_depth_, ttp_);
}

// Get the context which is calculated within attention mechanism.
Expand All @@ -54,11 +55,11 @@ void AttentionWrapper<T>::ProcessOutput(const gsl::span<const T>& rnn_cell_outpu
//concat([p_cell_output, context]) * stack([attn_layer_cell_weights_, attn_layer_attn_weights_]) =
// p_cell_output * attn_layer_cell_weights_ + context * attn_layer_attn_weights_
// The first part is calulated above. Here just add the later.
math::GemmEx<T, CPUMathUtil>(CblasNoTrans, CblasNoTrans,
batch_size_, attn_layer_depth_, attn_context_depth_, T{1.0},
attn_context_.data(), attn_context_depth_,
attn_layer_attn_weights_.data(), attn_layer_depth_, T{1.0},
attn_states_.data(), attn_layer_depth_, &CPUMathUtil::Instance());
math::GemmEx<T>(CblasNoTrans, CblasNoTrans,
batch_size_, attn_layer_depth_, attn_context_depth_, T{1.0},
attn_context_.data(), attn_context_depth_,
attn_layer_attn_weights_.data(), attn_layer_depth_, T{1.0},
attn_states_.data(), attn_layer_depth_, ttp_);
}
}

Expand Down
4 changes: 3 additions & 1 deletion onnxruntime/contrib_ops/cpu/attnlstm/attention_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "core/common/common.h"
#include "core/common/logging/logging.h"
#include "core/framework/allocator.h"
#include "core/platform/threadpool.h"

namespace onnxruntime {
namespace contrib {
Expand All @@ -22,7 +23,7 @@ class AttentionWrapper {
int attn_layer_depth,
int inner_cell_hidden_size,
bool has_attn_layer,
const IAttentionMechanism<T>& attention_mechanism);
const IAttentionMechanism<T>& attention_mechanism, concurrency::ThreadPool* threadpool);

virtual ~AttentionWrapper() = default;

Expand Down Expand Up @@ -69,6 +70,7 @@ class AttentionWrapper {
bool has_attn_layer_;

const IAttentionMechanism<T>& attention_mechanism_;
concurrency::ThreadPool* ttp_;
};

} // namespace contrib
Expand Down
34 changes: 17 additions & 17 deletions onnxruntime/contrib_ops/cpu/attnlstm/bahdanau_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ namespace contrib {
template <typename T>
BahdanauAttention<T>::BahdanauAttention(AllocatorPtr allocator, const logging::Logger& logger,
int batch_size, int max_memory_step, int memory_depth,
int query_depth, int attn_depth, bool normalize)
: allocator_(allocator), logger_(logger), batch_size_(batch_size), max_memory_steps_(max_memory_step), memory_depth_(memory_depth), query_depth_(query_depth), attn_depth_(attn_depth), normalize_(normalize) {
int query_depth, int attn_depth, bool normalize, concurrency::ThreadPool* threadpool)
: allocator_(allocator), logger_(logger), batch_size_(batch_size), max_memory_steps_(max_memory_step), memory_depth_(memory_depth), query_depth_(query_depth), attn_depth_(attn_depth), normalize_(normalize), ttp_(threadpool) {
values_ = Allocate(allocator_, batch_size_ * max_memory_steps_ * memory_depth_, values_ptr_, true);
keys_ = Allocate(allocator_, batch_size_ * max_memory_steps_ * attn_depth_, keys_ptr_, true);
processed_query_ = Allocate(allocator_, batch_size_ * attn_depth_, processed_query_ptr_, true);
Expand Down Expand Up @@ -72,11 +72,11 @@ void BahdanauAttention<T>::PrepareMemory(
"Real memory steps ", mem_steps, " is not in (0, ", max_memory_steps_, "]");
}

math::GemmEx<T, CPUMathUtil>(CblasNoTrans, CblasNoTrans,
batch_size_ * max_memory_steps_, attn_depth_, memory_depth_, T{1.0},
memory.data(), memory_depth_,
memory_layer_weights_.data(), attn_depth_, T{0.0},
keys_.data(), attn_depth_, &CPUMathUtil::Instance());
math::GemmEx<T>(CblasNoTrans, CblasNoTrans,
batch_size_ * max_memory_steps_, attn_depth_, memory_depth_, T{1.0},
memory.data(), memory_depth_,
memory_layer_weights_.data(), attn_depth_, T{0.0},
keys_.data(), attn_depth_, ttp_);
}

template <typename T>
Expand Down Expand Up @@ -115,11 +115,11 @@ void BahdanauAttention<T>::Compute(
const gsl::span<T>& output,
const gsl::span<T>& aligns) const {
//process query in dense query layer without bias
math::GemmEx<T, CPUMathUtil>(CblasNoTrans, CblasNoTrans,
batch_size_, attn_depth_, query_depth_, T{1.0},
queries.data(), query_depth_,
query_layer_weights_.data(), attn_depth_, T{0.0},
processed_query_.data(), attn_depth_, &CPUMathUtil::Instance());
math::GemmEx<T>(CblasNoTrans, CblasNoTrans,
batch_size_, attn_depth_, query_depth_, T{1.0},
queries.data(), query_depth_,
query_layer_weights_.data(), attn_depth_, T{0.0},
processed_query_.data(), attn_depth_, ttp_);

std::fill(aligns.begin(), aligns.end(), T{});

Expand All @@ -146,11 +146,11 @@ void BahdanauAttention<T>::Compute(
// Calculate the context
auto outspan = output.subspan(b * memory_depth_);
auto values = values_.subspan(b * max_memory_steps_ * memory_depth_);
math::GemmEx<T, CPUMathUtil>(CblasNoTrans, CblasNoTrans,
1, memory_depth_, max_memory_steps_, T{1.0},
alignments, max_memory_steps_,
values.data(), memory_depth_, T{0.0},
outspan.data(), memory_depth_, &CPUMathUtil::Instance());
math::GemmEx<T>(CblasNoTrans, CblasNoTrans,
1, memory_depth_, max_memory_steps_, T{1.0},
alignments, max_memory_steps_,
values.data(), memory_depth_, T{0.0},
outspan.data(), memory_depth_, ttp_);
}
}

Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/contrib_ops/cpu/attnlstm/bahdanau_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class BahdanauAttention : public IAttentionMechanism<T> {
int memory_depth,
int query_depth,
int attn_depth,
bool normalize);
bool normalize, concurrency::ThreadPool* threadpool);

void SetWeights(
const gsl::span<const T>& attn_weights,
Expand Down Expand Up @@ -77,6 +77,7 @@ class BahdanauAttention : public IAttentionMechanism<T> {
gsl::span<int> mem_seq_lengths_;

bool normalize_;
concurrency::ThreadPool* ttp_;
};

} // namespace contrib
Expand Down
Loading