Skip to content

Commit 463c305

Browse files
An implementation of a multithreaded runtime for receiving outfeed data
from devices and pushing it back to Python. Design and implementation notes in outfeed_received.cc PiperOrigin-RevId: 313373094 Change-Id: I4278d6ebf4e204b0e91c536d1de8c3f49dca6a34
1 parent 9b7b8f1 commit 463c305

File tree

7 files changed

+1071
-0
lines changed

7 files changed

+1071
-0
lines changed

tensorflow/compiler/xla/python/BUILD

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
load("//tensorflow/core/platform:build_config.bzl", "pyx_library")
22
load("//tensorflow/compiler/xla:xla.bzl", "xla_py_test_deps")
3+
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
34

45
# buildifier: disable=same-origin-load
56
load("//tensorflow:tensorflow.bzl", "pybind_extension")
@@ -212,6 +213,63 @@ cc_library(
212213
],
213214
)
214215

216+
cc_library(
217+
name = "outfeed_receiver",
218+
srcs = ["outfeed_receiver.cc"],
219+
hdrs = ["outfeed_receiver.h"],
220+
deps = [
221+
"//tensorflow/compiler/xla:literal",
222+
"//tensorflow/compiler/xla:shape_util",
223+
"//tensorflow/compiler/xla:statusor",
224+
"//tensorflow/compiler/xla:util",
225+
"//tensorflow/compiler/xla/client:xla_builder",
226+
"//tensorflow/compiler/xla/client:xla_computation",
227+
"//tensorflow/compiler/xla/pjrt:pjrt_client",
228+
"//tensorflow/core/profiler/lib:traceme",
229+
"@com_google_absl//absl/container:flat_hash_map",
230+
"@com_google_absl//absl/memory",
231+
"@com_google_absl//absl/strings:str_format",
232+
],
233+
)
234+
235+
tf_cc_test(
236+
name = "cpu_outfeed_receiver_test",
237+
size = "small",
238+
srcs = ["outfeed_receiver_test.cc"],
239+
deps = [
240+
":outfeed_receiver",
241+
"//tensorflow/compiler/jit:xla_cpu_jit",
242+
"//tensorflow/compiler/xla:test",
243+
"//tensorflow/compiler/xla/client:executable_build_options",
244+
"//tensorflow/compiler/xla/client:xla_builder",
245+
"//tensorflow/compiler/xla/pjrt:cpu_device",
246+
"//tensorflow/compiler/xla/pjrt:pjrt_client",
247+
"//tensorflow/core:test",
248+
"//tensorflow/core:test_main",
249+
"@com_google_absl//absl/synchronization",
250+
],
251+
)
252+
253+
cc_library(
254+
name = "outfeed_receiver_py",
255+
srcs = ["outfeed_receiver_py.cc"],
256+
hdrs = ["outfeed_receiver_py.h"],
257+
copts = [
258+
"-fexceptions",
259+
"-fno-strict-aliasing",
260+
],
261+
features = ["-use_header_modules"],
262+
deps = [
263+
":outfeed_receiver",
264+
":types",
265+
"//tensorflow/compiler/xla/client:xla_builder",
266+
"//tensorflow/compiler/xla/pjrt:pjrt_client",
267+
"@com_google_absl//absl/memory",
268+
"@com_google_absl//absl/synchronization",
269+
"@pybind11",
270+
],
271+
)
272+
215273
config_setting(
216274
name = "enable_gpu",
217275
values = {"define": "xla_python_enable_gpu=true"},
@@ -233,6 +291,7 @@ pybind_extension(
233291
":dlpack",
234292
":ops",
235293
":python_ref_manager",
294+
":outfeed_receiver_py",
236295
":types",
237296
"@com_google_absl//absl/base",
238297
"@com_google_absl//absl/hash",

0 commit comments

Comments
 (0)