Skip to content

Commit

Permalink
Merge branch 'gt/master' into gt/kedeng/imatmul
Browse files Browse the repository at this point in the history
  • Loading branch information
KeDengMS committed Jul 17, 2019
2 parents a78ff58 + c2aa205 commit 04dc1b7
Show file tree
Hide file tree
Showing 77 changed files with 3,528 additions and 634 deletions.
28 changes: 14 additions & 14 deletions cmake/onnxruntime_mlas.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,7 @@ set(mlas_common_srcs

if(MSVC)

if(CMAKE_GENERATOR_PLATFORM STREQUAL "ARM")

set(mlas_platform_srcs
${ONNXRUNTIME_ROOT}/core/mlas/lib/arm/sgemmc.cpp
)

elseif(CMAKE_GENERATOR_PLATFORM STREQUAL "ARM64")
if(CMAKE_GENERATOR_PLATFORM STREQUAL "ARM64")

set(asm_filename ${ONNXRUNTIME_ROOT}/core/mlas/lib/arm64/sgemma.asm)
set(pre_filename ${CMAKE_CURRENT_BINARY_DIR}/sgemma.i)
Expand All @@ -45,17 +39,13 @@ if(MSVC)

set(mlas_platform_srcs ${obj_filename})

elseif(CMAKE_GENERATOR_PLATFORM STREQUAL "Win32")

enable_language(ASM_MASM)

set(CMAKE_ASM_MASM_FLAGS "${CMAKE_ASM_MASM_FLAGS} /safeseh")
elseif(CMAKE_GENERATOR_PLATFORM STREQUAL "ARM" OR CMAKE_GENERATOR MATCHES "ARM")

set(mlas_platform_srcs
${ONNXRUNTIME_ROOT}/core/mlas/lib/i386/sgemma.asm
${ONNXRUNTIME_ROOT}/core/mlas/lib/arm/sgemmc.cpp
)

elseif(CMAKE_GENERATOR_PLATFORM STREQUAL "x64")
elseif(CMAKE_GENERATOR_PLATFORM STREQUAL "x64" OR CMAKE_GENERATOR MATCHES "Win64")

enable_language(ASM_MASM)

Expand All @@ -78,6 +68,16 @@ if(MSVC)
${ONNXRUNTIME_ROOT}/core/mlas/lib/amd64/ErfKernelFma3.asm
)

else()

enable_language(ASM_MASM)

set(CMAKE_ASM_MASM_FLAGS "${CMAKE_ASM_MASM_FLAGS} /safeseh")

set(mlas_platform_srcs
${ONNXRUNTIME_ROOT}/core/mlas/lib/i386/sgemma.asm
)

endif()

elseif(CMAKE_SYSTEM_NAME STREQUAL "Android")
Expand Down
9 changes: 0 additions & 9 deletions cmake/onnxruntime_unittests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -623,15 +623,6 @@ if (onnxruntime_BUILD_SHARED_LIB)
protobuf::libprotobuf
DEPENDS ${all_dependencies}
)
#demo
message("PNG Lib Dir = ${PNG_LIBRARIES}")
message("PNG Include Dir = ${PNG_INCLUDE_DIRS}")
if(PNG_FOUND AND NOT WIN32) # for some reason some symbols are not found in Win32 PNG module
add_executable(fns_candy_style_transfer "${ONNXRUNTIME_ROOT}/test/shared_lib/fns_candy_style_transfer.c")
target_include_directories(fns_candy_style_transfer PRIVATE "${TEST_SRC_DIR}/util/include" ${PNG_INCLUDE_DIRS})
target_link_libraries(fns_candy_style_transfer PRIVATE onnxruntime ${PNG_LIBRARIES})
set_target_properties(fns_candy_style_transfer PROPERTIES FOLDER "ONNXRuntimeTest")
endif()
endif()

if (onnxruntime_BUILD_SERVER)
Expand Down
10 changes: 10 additions & 0 deletions include/onnxruntime/core/graph/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,10 @@ class Node {
return !attr_to_subgraph_map_.empty();
}

/** Get the const subgraphs from a node.
@remarks Creates a new vector so calling ContainsSubgraphs first is preferred. */
std::vector<gsl::not_null<const Graph*>> GetSubgraphs() const;

/** Gets a map of attribute name to the mutable Graph instances for all subgraphs of the Node.
@returns Map of the attribute name that defines the subgraph to the subgraph's Graph instance.
nullptr if the Node has no subgraphs.
Expand Down Expand Up @@ -756,6 +760,9 @@ class Graph {
/** Returns the parent graph if this is a subgraph */
const Graph* ParentGraph() const { return parent_graph_; }

/** Returns the mutable parent graph if this is a subgraph */
Graph* MutableParentGraph() { return parent_graph_; }

/** Construct a Graph instance for a subgraph that is created from a GraphProto attribute in a Node.
Inherits some properties from the parent graph.
@param parent_graph The Graph containing the Node which has a GraphProto attribute.
Expand Down Expand Up @@ -980,6 +987,9 @@ class Graph {
// NodeArgs that come from outer scope. Used when building a graph so that
// these don't get recorded as graph inputs in the GraphProto.
std::unordered_set<std::string> outer_scope_node_arg_names_;

// number of times Resolve has run.
int num_resolves_ = 0;
};

} // namespace onnxruntime
31 changes: 31 additions & 0 deletions onnxruntime/contrib_ops/cpu/fused_activation.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "contrib_ops/cpu/fused_activation.h"

namespace onnxruntime {

common::Status GetFusedActivationAttr(const OpKernelInfo& info, MLAS_ACTIVATION& activation) {
// Convert the activation parameters from the node into a MLAS_ACTIVATION.
activation.ActivationKind = MlasIdentityActivation;

std::string activation_type;
if (info.GetAttr<std::string>("activation", &activation_type).IsOK()) {
if (activation_type == "Relu") {
activation.ActivationKind = MlasReluActivation;
} else if (activation_type == "LeakyRelu") {
activation.ActivationKind = MlasLeakyReluActivation;
activation.alpha = info.GetAttrOrDefault<float>("alpha", 0.01f);
} else if (activation_type == "Tanh") {
activation.ActivationKind = MlasTanhActivation;
} else if (activation_type == "Sigmoid") {
activation.ActivationKind = MlasLogisticActivation;
} else {
return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Unimplemented activation: " + activation_type);
}
}

return Status::OK();
}

} // namespace onnxruntime
14 changes: 14 additions & 0 deletions onnxruntime/contrib_ops/cpu/fused_activation.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once
#include "core/common/common.h"
#include "core/framework/op_kernel.h"
#include "core/util/math.h"
#include "core/mlas/inc/mlas.h"

namespace onnxruntime {

common::Status GetFusedActivationAttr(const OpKernelInfo& info, MLAS_ACTIVATION& activation);

} // namespace onnxruntime
14 changes: 12 additions & 2 deletions onnxruntime/contrib_ops/cpu/fused_conv.cc
Original file line number Diff line number Diff line change
@@ -1,16 +1,26 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "fused_conv.h"
#include "core/providers/cpu/nn/conv.h"
#include "contrib_ops/cpu/fused_activation.h"

namespace onnxruntime {
namespace contrib {

class FusedConvFloat final : public Conv<float> {
public:
FusedConvFloat(const OpKernelInfo& info) : Conv<float>(info) {
ORT_ENFORCE(GetFusedActivationAttr(info, activation_).IsOK());
}
};

ONNX_CPU_OPERATOR_TYPED_MS_KERNEL(
FusedConv,
1,
float,
KernelDefBuilder()
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
FusedConv<float>);
FusedConvFloat);

} // namespace contrib
} // namespace onnxruntime
24 changes: 0 additions & 24 deletions onnxruntime/contrib_ops/cpu/fused_conv.h

This file was deleted.

15 changes: 13 additions & 2 deletions onnxruntime/contrib_ops/cpu/fused_gemm.cc
Original file line number Diff line number Diff line change
@@ -1,15 +1,26 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "fused_gemm.h"
#include "core/providers/cpu/math/gemm.h"

namespace onnxruntime {
namespace contrib {

template <typename T>
class FusedGemm final : public Gemm<T> {
public:
FusedGemm(const OpKernelInfo& info) : Gemm<T>(info) {
Gemm<T>::activation_ = info.GetAttrOrDefault<std::string>("activation", "");
Gemm<T>::leaky_relu_alpha_ = info.GetAttrOrDefault("leaky_relu_alpha", 0.01f);
}
};

ONNX_CPU_OPERATOR_TYPED_MS_KERNEL(
FusedGemm,
1,
float,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
FusedGemm<float, float, float, float>);
FusedGemm<float>);

} // namespace contrib
} // namespace onnxruntime
26 changes: 0 additions & 26 deletions onnxruntime/contrib_ops/cpu/fused_gemm.h

This file was deleted.

61 changes: 22 additions & 39 deletions onnxruntime/contrib_ops/cpu/nchwc_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ ONNX_CPU_OPERATOR_TYPED_NCHWC_KERNEL(
KernelDefBuilder()
.MayInplace(3, 0)
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
NchwcConv<float>);
NchwcConv);

ONNX_CPU_OPERATOR_TYPED_NCHWC_KERNEL(
MaxPool,
Expand Down Expand Up @@ -70,39 +70,38 @@ ONNX_CPU_OPERATOR_TYPED_NCHWC_KERNEL(

template <typename T>
Status ReorderInput<T>::Compute(OpKernelContext* context) const {
const Tensor* X = context->Input<Tensor>(0);
const TensorShape& X_shape = X->Shape();
const auto* X = context->Input<Tensor>(0);
const auto& X_shape = X->Shape();
ORT_ENFORCE(X_shape.NumDimensions() == 4);
ORT_ENFORCE((X_shape[1] % MlasNchwcGetBlockSize()) == 0);
Tensor* Y = context->Output(0, X_shape);
auto* Y = context->Output(0, X_shape);
MlasReorderInput(X_shape.GetDims().data(), X->template Data<T>(), Y->template MutableData<T>());
return Status::OK();
}

template <typename T>
Status ReorderOutput<T>::Compute(OpKernelContext* context) const {
const Tensor* X = context->Input<Tensor>(0);
const TensorShape& X_shape = X->Shape();
const auto* X = context->Input<Tensor>(0);
const auto& X_shape = X->Shape();
ORT_ENFORCE(X_shape.NumDimensions() == 4);
std::vector<int64_t> Y_shape(X_shape.GetDims());
ORT_ENFORCE(channels_ <= Y_shape[1]);
Y_shape[1] = channels_;
Tensor* Y = context->Output(0, Y_shape);
auto* Y = context->Output(0, Y_shape);
MlasReorderOutput(Y_shape.data(), X->template Data<T>(), Y->template MutableData<T>());
return Status::OK();
}

template <typename T>
Status NchwcConv<T>::Compute(OpKernelContext* context) const {
const Tensor* X = context->Input<Tensor>(0);
const Tensor* W = context->Input<Tensor>(1);
const Tensor* B = context->Input<Tensor>(2);
const Tensor* Sum = context->Input<Tensor>(3);
Status NchwcConv::Compute(OpKernelContext* context) const {
const auto* X = context->Input<Tensor>(0);
const auto* W = context->Input<Tensor>(1);
const auto* B = context->Input<Tensor>(2);
const auto* Sum = context->Input<Tensor>(3);

ORT_RETURN_IF_ERROR(ConvBase::ValidateInputShape(X, W));

const TensorShape& X_shape = X->Shape();
const TensorShape& W_shape = W->Shape();
const auto& X_shape = X->Shape();
const auto& W_shape = W->Shape();
ORT_ENFORCE(X_shape.NumDimensions() == 4);

const size_t nchwc_block_size = MlasNchwcGetBlockSize();
Expand Down Expand Up @@ -131,36 +130,20 @@ Status NchwcConv<T>::Compute(OpKernelContext* context) const {
Y_dims.insert(Y_dims.begin(), {X_shape[0], W_shape[0]});
TensorShape input_shape = X->Shape().Slice(2);
ORT_RETURN_IF_ERROR(ConvBase::InferOutputShape(input_shape, kernel_shape, strides, dilations, &pads, &Y_dims));
Tensor* Y = context->Output(0, Y_dims);
T* y_data = Y->template MutableData<T>();
auto* Y = context->Output(0, Y_dims);
auto* y_data = Y->template MutableData<float>();

// Check for the optional Conv/Sum fusion.
if (Sum != nullptr) {
const auto& sum_shape = Sum->Shape();
ORT_RETURN_IF_NOT(Y->Shape() == sum_shape, "output and sum shape must match");
// If the output was not allocated inplace with the sum tensor, then copy here.
const float* sum_data = Sum->template Data<T>();
const auto* sum_data = Sum->template Data<float>();
if (y_data != sum_data) {
memcpy(y_data, sum_data, sum_shape.Size() * sizeof(T));
memcpy(y_data, sum_data, sum_shape.Size() * sizeof(float));
}
}

MLAS_ACTIVATION Activation;
if (ConvBase::activation_.empty()) {
Activation.ActivationKind = MlasIdentityActivation;
} else if (ConvBase::activation_ == "Relu") {
Activation.ActivationKind = MlasReluActivation;
} else if (ConvBase::activation_ == "LeakyRelu") {
Activation.ActivationKind = MlasLeakyReluActivation;
Activation.alpha = ConvBase::alpha_;
} else if (ConvBase::activation_ == "Tanh") {
Activation.ActivationKind = MlasTanhActivation;
} else if (ConvBase::activation_ == "Sigmoid") {
Activation.ActivationKind = MlasLogisticActivation;
} else {
ORT_NOT_IMPLEMENTED("Not implemented fused activation: ", ConvBase::activation_);
}

MlasNchwcConv(kernel_shape.size(),
X_shape.GetDims().data(),
kernel_shape.data(),
Expand All @@ -173,17 +156,17 @@ Status NchwcConv<T>::Compute(OpKernelContext* context) const {
W->template Data<float>(),
B != nullptr ? B->template Data<float>() : nullptr,
y_data,
&Activation,
&activation_,
Sum == nullptr,
const_cast<concurrency::ThreadPool*>(static_cast<OpKernelContextInternal*>(context)->GetOperatorThreadPool()));

return Status::OK();
}

Status NchwcPoolBase::NchwcPool(OpKernelContext* context, MLAS_POOLING_KIND kind) const {
const Tensor* X = context->Input<Tensor>(0);
const auto* X = context->Input<Tensor>(0);

const TensorShape& X_shape = X->Shape();
const auto& X_shape = X->Shape();
ORT_ENFORCE(X_shape.NumDimensions() == 4);
ORT_ENFORCE((X_shape[1] % MlasNchwcGetBlockSize()) == 0);

Expand All @@ -193,7 +176,7 @@ Status NchwcPoolBase::NchwcPool(OpKernelContext* context, MLAS_POOLING_KIND kind

std::vector<int64_t> pads = pads_;
std::vector<int64_t> output_dims = PoolBase::SetOutputSize(X_shape, X_shape[1], &pads, dilations_, ceil_mode_);
Tensor* Y = context->Output(0, output_dims);
auto* Y = context->Output(0, output_dims);

MlasNchwcPool(kind,
2,
Expand Down
Loading

0 comments on commit 04dc1b7

Please sign in to comment.