diff --git a/cmake/onnxruntime_providers_vitisai.cmake b/cmake/onnxruntime_providers_vitisai.cmake index 764cde9491da8..561a323533f48 100644 --- a/cmake/onnxruntime_providers_vitisai.cmake +++ b/cmake/onnxruntime_providers_vitisai.cmake @@ -12,6 +12,7 @@ file(GLOB onnxruntime_providers_vitisai_cc_srcs CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/core/providers/vitisai/*.cc" "${ONNXRUNTIME_ROOT}/core/providers/vitisai/*.h" + "${ONNXRUNTIME_ROOT}/core/providers/vitisai/include/vaip/*.h" "${ONNXRUNTIME_ROOT}/core/providers/vitisai/imp/*.cc" "${ONNXRUNTIME_ROOT}/core/providers/vitisai/imp/*.h" "${ONNXRUNTIME_ROOT}/core/providers/shared_library/*.h" diff --git a/onnxruntime/core/providers/shared_library/provider_interfaces.h b/onnxruntime/core/providers/shared_library/provider_interfaces.h index 34319287a80fd..5681dd77a6ab5 100644 --- a/onnxruntime/core/providers/shared_library/provider_interfaces.h +++ b/onnxruntime/core/providers/shared_library/provider_interfaces.h @@ -994,6 +994,7 @@ struct ProviderHost { bool include_outer_scope_args, int execution_order) noexcept = 0; virtual const Node* GraphViewer__GetProducerNode(const GraphViewer* p, const std::string& node_arg_name) const = 0; + virtual IOnnxRuntimeOpSchemaCollectionPtr GraphViewer__GetSchemaRegistry(const GraphViewer* p) const = 0; // OpKernel virtual const Node& OpKernel__Node(const OpKernel* p) = 0; diff --git a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h index 4644f703dcb5d..bd1c289587177 100644 --- a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h +++ b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h @@ -1064,6 +1064,7 @@ class GraphViewer final { g_host->GraphViewer__ToProto(this, graph_proto, include_initializers, include_outer_scope_args, execution_order); } const Node* GetProducerNode(const std::string& node_arg_name) const { return g_host->GraphViewer__GetProducerNode(this, node_arg_name); } + IOnnxRuntimeOpSchemaCollectionPtr GetSchemaRegistry() const { return g_host->GraphViewer__GetSchemaRegistry(this); } GraphViewer() = delete; GraphViewer(const GraphViewer&) = delete; diff --git a/onnxruntime/core/providers/vitisai/imp/global_api.cc b/onnxruntime/core/providers/vitisai/imp/global_api.cc index 41885721e7b9a..9d6c036aa53c5 100644 --- a/onnxruntime/core/providers/vitisai/imp/global_api.cc +++ b/onnxruntime/core/providers/vitisai/imp/global_api.cc @@ -428,6 +428,11 @@ vaip_core::OrtApiForVaip* create_org_api_hook() { } }; the_global_api.node_arg_external_location = vaip::node_arg_external_location; + the_global_api.model_to_proto = [](onnxruntime::Model& model) { return model.ToProto().release(); }; + the_global_api.model_proto_serialize_as_string = [](ONNX_NAMESPACE::ModelProto& model_proto) { + return vaip_core::DllSafe(model_proto.SerializeAsString()); + }; + the_global_api.model_proto_delete = [](ONNX_NAMESPACE::ModelProto* p) { delete p; }; if (!s_library_vitisaiep.vaip_get_version) { return reinterpret_cast(&(the_global_api.host_)); } else { diff --git a/onnxruntime/core/providers/vitisai/include/vaip/custom_op.h b/onnxruntime/core/providers/vitisai/include/vaip/custom_op.h index 5d020e00ff5b7..64cf52ec0a404 100644 --- a/onnxruntime/core/providers/vitisai/include/vaip/custom_op.h +++ b/onnxruntime/core/providers/vitisai/include/vaip/custom_op.h @@ -25,18 +25,18 @@ class ExecutionProvider { virtual DllSafe> get_meta_def_nodes() const = 0; virtual DllSafe> get_meta_def_constant_initializer() const = 0; + virtual bool get_meta_def_fallback_CPU() const { return false; }; virtual std::unique_ptr compile() const = 0; public: - inline void set_fused_node(const onnxruntime::Node* fused_node) { - fused_node_ = fused_node; - } - inline const onnxruntime::Node* get_fused_node() const { - return fused_node_; - } + inline void set_fused_node(const onnxruntime::Node* fused_node) { fused_node_ = fused_node; } + inline const onnxruntime::Node* get_fused_node() const { return fused_node_; } + inline void set_model(onnxruntime::Model* model) { model_ = model; } + inline onnxruntime::Model* get_model() const { return model_; } private: const onnxruntime::Node* fused_node_ = nullptr; + onnxruntime::Model* model_ = nullptr; }; class CustomOp { diff --git a/onnxruntime/core/providers/vitisai/include/vaip/my_ort.h b/onnxruntime/core/providers/vitisai/include/vaip/my_ort.h index 74482d8e9ee0e..7628e45d2b933 100644 --- a/onnxruntime/core/providers/vitisai/include/vaip/my_ort.h +++ b/onnxruntime/core/providers/vitisai/include/vaip/my_ort.h @@ -20,6 +20,7 @@ struct NodeAttributes; namespace ONNX_NAMESPACE { struct AttributeProto; struct TensorProto; +struct ModelProto; #ifndef USE_VITISAI enum TensorProto_DataType : int { TensorProto_DataType_UNDEFINED = 0, @@ -70,6 +71,7 @@ enum AttributeProto_AttributeType : int { namespace vaip_core { class GraphHolder; using ONNX_NAMESPACE::AttributeProto; +using ONNX_NAMESPACE::ModelProto; using ONNX_NAMESPACE::TensorProto; using onnxruntime::Graph; using onnxruntime::GraphViewer; diff --git a/onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h b/onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h index bbe8b6e6e4934..288cfd6850d06 100644 --- a/onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h +++ b/onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h @@ -13,7 +13,7 @@ struct OrtApi; namespace vaip_core { -#define VAIP_ORT_API_MAJOR (10u) +#define VAIP_ORT_API_MAJOR (11u) #define VAIP_ORT_API_MINOR (0u) #define VAIP_ORT_API_PATCH (0u) struct OrtApiForVaip { @@ -231,6 +231,9 @@ struct OrtApiForVaip { gsl::span inputs); // [92] int (*node_arg_external_location)(const Graph& graph, const NodeArg& node_arg, std::string& file, size_t& offset, size_t& size, size_t& checksum); // [93] void (*session_option_configuration)(void* mmap, void* session_option, void (*push)(void* mmap, const char* name, const char* value)); // [94] + ModelProto* (*model_to_proto)(Model& model); // [95] + DllSafe (*model_proto_serialize_as_string)(ModelProto& model_proto); // [96] + void (*model_proto_delete)(ModelProto* p); // [97] }; #ifndef USE_VITISAI diff --git a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc index 09b115b4a57fc..4e0d1e1df9753 100644 --- a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc +++ b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc @@ -76,7 +76,17 @@ common::Status VitisAIExecutionProvider::Compile(const std::vectorexecution_providers_)[index]->set_fused_node(&fused_node_graph.fused_node.get()); + auto& ep = (**this->execution_providers_)[index]; + ep->set_fused_node(&fused_node_graph.fused_node.get()); + if (ep->get_meta_def_fallback_CPU()) { + auto& subgraph = fused_node_graph.filtered_graph.get(); + auto& logger = logging::LoggingManager::DefaultLogger(); + auto model_proto = subgraph.CreateModel(logger)->ToProto(); + subgraph.ToProto(*model_proto->mutable_graph(), true, true); + auto local_registries = IOnnxRuntimeOpSchemaRegistryList{subgraph.GetSchemaRegistry()}; + auto model = Model::Create(std::move(*model_proto), subgraph.ModelPath(), &local_registries, logger); + ep->set_model(model.release()); + } compute_info.create_state_func = [this, index](ComputeContext* context, FunctionState* state) { auto* p = (**this->execution_providers_)[index]->compile().release(); *state = p; diff --git a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.h b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.h index 05d2a976815b9..8e70c59280aa1 100644 --- a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.h +++ b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.h @@ -48,7 +48,6 @@ class VitisAIExecutionProvider : public IExecutionProvider { ProviderOptions info_; std::vector custom_op_domains_; std::shared_ptr registry_; - std::set vitisai_optypes_; // EP context related. bool ep_ctx_enabled_ = false; bool ep_ctx_embed_mode_ = true; diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 85079ef78c8d3..7ed69cf5dcad1 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -1205,6 +1205,7 @@ struct ProviderHostImpl : ProviderHost { GraphViewerToProto(*p, graph_proto, include_initializers, include_outer_scope_args, static_cast(execution_order)); } const Node* GraphViewer__GetProducerNode(const GraphViewer* p, const std::string& node_arg_name) const override { return p->GetProducerNode(node_arg_name); } + IOnnxRuntimeOpSchemaCollectionPtr GraphViewer__GetSchemaRegistry(const GraphViewer* p) const override { return p->GetSchemaRegistry(); } // OpKernel (direct) const Node& OpKernel__Node(const OpKernel* p) override { return p->OpKernel::Node(); }