Skip to content

Commit f0ef163

Browse files
joker-ephtensorflower-gardener
authored andcommitted
Add an MLIR tracing implementation to the C unified API
This is plumbing just enough to pass all the unit-tests. The conversion to the function library is quite inefficient, but it isn't clear if we want to optimize this or just focus on TFRT moving forward. PiperOrigin-RevId: 313356850 Change-Id: I83815317d4958786d0103168b5d88498f89511ed
1 parent 7738c18 commit f0ef163

File tree

10 files changed

+632
-4
lines changed

10 files changed

+632
-4
lines changed

tensorflow/c/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,7 @@ tf_cuda_library(
216216
],
217217
visibility = [
218218
"//tensorflow/c:__subpackages__",
219+
"//tensorflow/compiler/mlir/tensorflow/c:__subpackages__",
219220
],
220221
deps = select({
221222
"//tensorflow:android": [

tensorflow/c/eager/BUILD

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,24 @@ cc_library(
144144
],
145145
)
146146

147+
cc_library(
148+
name = "c_api_unified_internal",
149+
hdrs = [
150+
"c_api_unified_experimental_internal.h",
151+
],
152+
visibility = [
153+
"//tensorflow:internal",
154+
],
155+
deps = [
156+
":c_api",
157+
":c_api_experimental",
158+
"//tensorflow/c:c_api_internal",
159+
"//tensorflow/c:tf_status",
160+
"//tensorflow/core/platform:casts",
161+
"//tensorflow/core/platform:types",
162+
],
163+
)
164+
147165
cc_library(
148166
name = "tensor_handle_interface",
149167
hdrs = ["tensor_handle_interface.h"],
@@ -514,6 +532,7 @@ tf_cuda_cc_test(
514532
"//tensorflow/c:c_api",
515533
"//tensorflow/c:c_test_util",
516534
"//tensorflow/cc/profiler",
535+
"//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration",
517536
"//tensorflow/core:lib",
518537
"//tensorflow/core:protos_all_cc",
519538
"//tensorflow/core:test",

tensorflow/c/eager/c_api_unified_experimental_internal.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ T* dyncast(S source) {
5858
// GraphContext and vice-versa).
5959
class AbstractTensor {
6060
protected:
61-
enum AbstractTensorKind { kGraphTensor, kEagerTensor, kMLIRTensor };
61+
enum AbstractTensorKind { kMlirTensor, kGraphTensor, kEagerTensor };
6262
explicit AbstractTensor(AbstractTensorKind kind) : kind_(kind) {}
6363

6464
public:
@@ -101,7 +101,7 @@ class AbstractFunction {
101101
// on a given context, with the same or different input tensors.
102102
class AbstractOp {
103103
protected:
104-
enum AbstractOpKind { kGraphOp, kEagerOp };
104+
enum AbstractOpKind { kMlirOp, kGraphOp, kEagerOp };
105105
explicit AbstractOp(AbstractOpKind kind) : kind_(kind) {}
106106

107107
public:
@@ -129,7 +129,7 @@ class AbstractOp {
129129
// eager implementation or to a graph implementation.
130130
struct ExecutionContext {
131131
protected:
132-
enum ExecutionContextKind { kGraphContext, kEagerContext };
132+
enum ExecutionContextKind { kMlirContext, kGraphContext, kEagerContext };
133133
explicit ExecutionContext(ExecutionContextKind kind) : k(kind) {}
134134

135135
public:

tensorflow/c/eager/c_api_unified_experimental_test.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -477,7 +477,8 @@ TEST_P(UnifiedCAPI, TestExecutingGraphOpInEagerModeRaises) {
477477
TF_DeleteExecutionContext(eager_execution_ctx);
478478
}
479479

480-
INSTANTIATE_TEST_SUITE_P(Tracing, UnifiedCAPI, ::testing::Values("graphdef"));
480+
INSTANTIATE_TEST_SUITE_P(Tracing, UnifiedCAPI,
481+
::testing::Values("graphdef", "mlir"));
481482

482483
} // namespace
483484
} // namespace tensorflow

tensorflow/compiler/mlir/tensorflow/BUILD

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -788,6 +788,9 @@ cc_library(
788788
name = "convert_type",
789789
srcs = ["utils/convert_type.cc"],
790790
hdrs = ["utils/convert_type.h"],
791+
visibility = [
792+
"//visibility:public",
793+
],
791794
deps = [
792795
":tensorflow_types",
793796
"//tensorflow/core:framework",
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
load(
2+
"//tensorflow:tensorflow.bzl",
3+
"tf_copts",
4+
"tf_cuda_library",
5+
"tfe_xla_copts",
6+
)
7+
8+
package(
9+
default_visibility = [":friends"],
10+
licenses = ["notice"], # Apache 2.0
11+
)
12+
13+
package_group(
14+
name = "friends",
15+
packages = ["//tensorflow/..."],
16+
)
17+
18+
tf_cuda_library(
19+
name = "mlir_c_api",
20+
srcs = [
21+
"c_api_unified_experimental_mlir.cc",
22+
],
23+
copts = tf_copts() + tfe_xla_copts(),
24+
deps = [
25+
"//tensorflow/c:c_api",
26+
"//tensorflow/c:tf_status_helper",
27+
"//tensorflow/c:tf_status_internal",
28+
"//tensorflow/c/eager:c_api",
29+
"//tensorflow/c/eager:c_api_internal",
30+
"//tensorflow/c/eager:c_api_unified_internal",
31+
"//tensorflow/compiler/mlir/tensorflow",
32+
"//tensorflow/compiler/mlir/tensorflow:convert_graphdef",
33+
"//tensorflow/compiler/mlir/tensorflow:convert_type",
34+
"//tensorflow/compiler/mlir/tensorflow:error_util",
35+
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
36+
"//tensorflow/core:framework",
37+
"//tensorflow/core:protos_all_cc",
38+
"//tensorflow/core/platform:casts",
39+
"@llvm-project//llvm:support",
40+
"@llvm-project//mlir:IR",
41+
"@llvm-project//mlir:Pass",
42+
"@llvm-project//mlir:StandardOps",
43+
"@llvm-project//mlir:Support",
44+
],
45+
)
46+
47+
cc_library(
48+
name = "mlir_c_api_registration",
49+
srcs = ["c_api_unified_experimental_mlir_registration.cc"],
50+
deps = [
51+
":mlir_c_api",
52+
"//tensorflow/c/eager:c_api_unified_internal",
53+
],
54+
alwayslink = 1,
55+
)

0 commit comments

Comments
 (0)