@@ -171,13 +171,96 @@ cc_library(
171171 ],
172172)
173173
174+ cc_library (
175+ name = "gradients" ,
176+ srcs = [
177+ "gradients.cc" ,
178+ "gradients_internal.h" ,
179+ ],
180+ hdrs = [
181+ "gradients.h" ,
182+ ],
183+ visibility = [
184+ "//tensorflow:internal" ,
185+ ],
186+ deps = [
187+ ":abstract_context" ,
188+ ":abstract_operation" ,
189+ ":abstract_tensor_handle" ,
190+ ":c_api_unified_internal" ,
191+ ":tape" ,
192+ "//tensorflow/core/common_runtime/eager:attr_builder" ,
193+ "//tensorflow/core/lib/llvm_rtti" ,
194+ "@com_google_absl//absl/container:flat_hash_map" ,
195+ "@com_google_absl//absl/strings" ,
196+ ],
197+ )
198+
199+ cc_library (
200+ name = "gradients_internal" ,
201+ srcs = [
202+ "gradients.cc" ,
203+ ],
204+ hdrs = [
205+ "gradients.h" ,
206+ "gradients_internal.h" ,
207+ ],
208+ visibility = [
209+ "//tensorflow:internal" ,
210+ ],
211+ deps = [
212+ ":abstract_context" ,
213+ ":abstract_operation" ,
214+ ":abstract_tensor_handle" ,
215+ ":c_api_unified_internal" ,
216+ ":tape" ,
217+ "//tensorflow/core/common_runtime/eager:attr_builder" ,
218+ "//tensorflow/core/lib/llvm_rtti" ,
219+ "@com_google_absl//absl/container:flat_hash_map" ,
220+ "@com_google_absl//absl/strings" ,
221+ ],
222+ )
223+
224+ tf_cuda_cc_test (
225+ name = "gradients_test" ,
226+ size = "small" ,
227+ srcs = [
228+ "gradients_test.cc" ,
229+ ],
230+ args = ["--heap_check=local" ],
231+ extra_copts = tfe_xla_copts (),
232+ linkstatic = tf_kernel_tests_linkstatic (),
233+ tags = tf_cuda_tests_tags () + ["nomac" ],
234+ deps = [
235+ ":abstract_tensor_handle" ,
236+ ":c_api_experimental" ,
237+ ":c_api_test_util" ,
238+ ":c_api_unified_internal" ,
239+ ":gradients_internal" ,
240+ "//tensorflow/c:c_api" ,
241+ "//tensorflow/c:c_test_util" ,
242+ "//tensorflow/c:tf_status_helper" ,
243+ "//tensorflow/cc/profiler" ,
244+ "//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration" ,
245+ "//tensorflow/core:lib" ,
246+ "//tensorflow/core:protos_all_cc" ,
247+ "//tensorflow/core:test" ,
248+ "//tensorflow/core:test_main" ,
249+ "//tensorflow/core/lib/llvm_rtti" ,
250+ "@com_google_absl//absl/strings" ,
251+ "@com_google_absl//absl/types:span" ,
252+ ],
253+ )
254+
174255cc_library (
175256 name = "abstract_tensor_handle" ,
176257 hdrs = ["abstract_tensor_handle.h" ],
177258 visibility = [
178259 "//tensorflow:internal" ,
179260 ],
180- deps = [],
261+ deps = [
262+ "//tensorflow/core:protos_all_cc" ,
263+ ],
181264)
182265
183266cc_library (
@@ -480,6 +563,7 @@ tf_cuda_cc_test(
480563 "//tensorflow/core:test" ,
481564 "//tensorflow/core:test_main" ,
482565 "//tensorflow/core/common_runtime:function_optimization_registry" ,
566+ "//tensorflow/core/common_runtime:optimization_registry" ,
483567 "//tensorflow/core/common_runtime/eager:eager_operation" ,
484568 "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib" ,
485569 "@com_google_absl//absl/strings" ,
@@ -744,6 +828,7 @@ filegroup(
744828 "c_api_unified_experimental_eager.cc" ,
745829 "c_api_unified_experimental_graph.cc" ,
746830 "c_api_unified_experimental_internal.h" ,
831+ "gradients.cc" , # Uses RTTI.
747832 "*test*" ,
748833 "*dlpack*" ,
749834 ],
0 commit comments