Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 0 additions & 22 deletions onnxruntime/core/providers/openvino/backend_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -389,18 +389,6 @@ static bool IsQDQGraph(const onnxruntime::GraphViewer& graph_viewer) {
return false;
}

static bool IsModelBF16(const onnxruntime::GraphViewer& graph_viewer) {
const auto& node_indices = graph_viewer.GetNodesInTopologicalOrder();
for (std::size_t i = 0; i < node_indices.size(); i++) {
gsl::not_null<const onnxruntime::Node*> node(graph_viewer.GetNode(node_indices[i]));
for (auto& output : node->OutputDefs()) {
if (output->ToProto().type().tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16)
return true;
}
}
return false;
}

static bool Is16BitTensor(const onnxruntime::NodeArg* node_arg) {
const auto* type_proto = node_arg ? node_arg->TypeAsProto() : nullptr;
return type_proto && type_proto->has_tensor_type() &&
Expand Down Expand Up @@ -598,16 +586,6 @@ BackendManager::GetModelProtoFromFusedNode(const onnxruntime::Node& fused_node,
DumpOpenVINOEPModel(onnx_model_path_name, model_proto.get(), fused_node);
ORT_ENFORCE(status.IsOK(), status.ErrorMessage());
return model_proto;
} else if (IsModelBF16(subgraph)) {
LOGS_DEFAULT(INFO) << "[OpenVINO-EP] OVEP bfloat16->float16 optimization pass is enabled";
std::unique_ptr<onnxruntime::Model> model;
Status status = bfloat16_fix::Transform(subgraph, logger, model);
auto model_proto = model->ToProto();
model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION);
print_model_proto_duration();
DumpOpenVINOEPModel(onnx_model_path_name, model_proto.get(), fused_node);
ORT_ENFORCE(status.IsOK(), status.ErrorMessage());
return model_proto;
} else {
LOGS_DEFAULT(INFO) << "[OpenVINO-EP] OVEP QDQ optimization pass is disabled";

Expand Down
4 changes: 1 addition & 3 deletions onnxruntime/core/providers/openvino/ov_versions/data_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -561,9 +561,7 @@ bool DataOps::type_is_supported(const NodeArg* node_arg, bool is_initializer) {
}

auto dtype = type_proto->tensor_type().elem_type();
// Enable bfloat16 -> float16 on-the-fly conversion
if (dtype == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BFLOAT16 ||
dtype == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT16 ||
if (dtype == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT16 ||
dtype == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT16)
return true;
if (is_initializer) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
#include "qdq_scales_fix.h"
#include "core/providers/openvino/ov_protobuf_utils.h"
#include "core/framework/ort_value.h"
#include "core/common/float16.h"

#include <fstream>
#include <list>
Expand Down Expand Up @@ -955,60 +954,5 @@ Status Transform(const GraphViewer& src_graph_viewer,
return status;
}
} // namespace qdq_scales_fix

namespace bfloat16_fix {
void replace_bf16_with_fp16(qdq_scales_fix::CustomGraph& gen_graph) {
for (auto& const_node : gen_graph.original_graph.Nodes()) {
auto node = const_cast<ONNX_NAMESPACE::Node*>(const_node);
if (node->OpType() == "Cast") {
for (auto& [name, const_attribute] : node->GetAttributes()) {
auto& attribute = const_cast<ONNX_NAMESPACE::AttributeProto&>(const_attribute);
if (name == "to" && attribute.type() == ONNX_NAMESPACE::AttributeProto_AttributeType_INT)
if (attribute.i() == ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16)
attribute.set_i(ONNX_NAMESPACE::TensorProto_DataType_FLOAT16);
}
}
for (auto& output : node->OutputDefs()) {
auto& output_proto = const_cast<ONNX_NAMESPACE::TypeProto&>(output->ToProto().type());
if (output_proto.mutable_tensor_type()->elem_type() == ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16)
output_proto.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT16);
}
}

for (auto& node : gen_graph.original_graph.Nodes()) {
for (auto& input_def : node->InputDefs()) {
ORT_THROW_IF_ERROR(graph_utils::ConvertInMemoryDataToInline(gen_graph.original_graph, input_def->Name()));
}
}

const auto& init_set = gen_graph.original_graph.GetAllInitializedTensors();
for (auto& [key, const_tensor_proto] : init_set) {
auto tensor_proto = const_cast<ONNX_NAMESPACE::TensorProto*>(const_tensor_proto);
auto dt = tensor_proto->data_type();
if (dt == ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16) {
auto raw_data = tensor_proto->has_raw_data() ? reinterpret_cast<std::uint16_t*>(tensor_proto->mutable_raw_data()->data()) : nullptr;
if (raw_data) {
tensor_proto->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT16);
std::int64_t size = 1;
for (int i = 0; i < tensor_proto->dims_size(); ++i)
size *= tensor_proto->dims()[i];
for (std::int64_t i = 0; i < size; ++i) {
raw_data[i] = onnxruntime::MLFloat16(onnxruntime::BFloat16::FromBits(raw_data[i])).val;
}
}
}
}
}

Status Transform(const GraphViewer& src_graph_viewer,
const logging::Logger& logger,
/*out*/ std::unique_ptr<onnxruntime::Model>& model) {
auto status = qdq_scales_fix::copy_model(src_graph_viewer, logger, model);
auto g = qdq_scales_fix::generate_graph_from_onnx(model->MainGraph());

replace_bf16_with_fp16(g);
return status;
}
} // namespace bfloat16_fix
} // namespace openvino_ep
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,5 @@ Status Transform(const GraphViewer& src_graph,
const logging::Logger& logger,
/*out*/ std::unique_ptr<onnxruntime::Model>& model);
}
namespace bfloat16_fix {
Status Transform(const GraphViewer& src_graph,
const logging::Logger& logger,
/*out*/ std::unique_ptr<onnxruntime::Model>& model);
}
} // namespace openvino_ep
} // namespace onnxruntime
116 changes: 0 additions & 116 deletions onnxruntime/test/providers/openvino/openvino_ep_bfloat16_pass_test.cc

This file was deleted.