@@ -16,11 +16,13 @@ limitations under the License.
1616#include " tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h"
1717
1818#include < fstream>
19+ #include < future>
1920#include < map>
2021#include < memory>
2122#include < string>
2223#include < utility>
2324
25+ #include " absl/base/call_once.h"
2426#include " absl/memory/memory.h"
2527#include " absl/strings/str_cat.h"
2628#include " absl/strings/str_replace.h"
@@ -63,6 +65,7 @@ limitations under the License.
6365#include " tensorflow/core/platform/logging.h"
6466#include " tensorflow/core/platform/tracing.h"
6567#include " tensorflow/core/profiler/lib/traceme.h"
68+ #include " tensorflow/core/util/env_var.h"
6669
6770namespace xla {
6871namespace gpu {
@@ -524,6 +527,28 @@ StatusOr<string> CompileToPtx(llvm::Module* module, GpuVersion gpu_version,
524527} // namespace nvptx
525528
526529namespace {
530+ static std::string hsaco_cache_dir;
531+
532+ static void InitHsacoCacheDir () {
533+ static absl::once_flag init_once;
534+ absl::call_once (init_once, [] {
535+ auto env = tensorflow::Env::Default ();
536+ tensorflow::ReadStringFromEnvVar (" TF_XLA_HSACO_CACHE_DIR" , " /tmp" ,
537+ &hsaco_cache_dir);
538+ if (hsaco_cache_dir.empty ()) {
539+ LOG (INFO) << " Will not cache XLA HSACOs. "
540+ << " This line is logged at most "
541+ << " once for the lifetime of the process." ;
542+ } else {
543+ if (!env->IsDirectory (hsaco_cache_dir).ok ()){
544+ env->CreateDir (hsaco_cache_dir);
545+ }
546+ LOG (INFO) << " Cache XLA HSACOs in " << hsaco_cache_dir << " . "
547+ << " This line is logged at most "
548+ << " once for the lifetime of the process." ;
549+ }
550+ });
551+ }
527552
528553// Gets the ROCm-Device-Libs filenames for a particular AMDGPU version.
529554static std::vector<string> GetROCDLPaths (int amdgpu_version,
@@ -548,6 +573,18 @@ static std::vector<string> GetROCDLPaths(int amdgpu_version,
548573 return result;
549574}
550575
576+ Status ReadHsaco (std::string hsaco_path, std::vector<uint8>& hsaco){
577+ if (tensorflow::Env::Default ()->FileExists (hsaco_path).ok ()){
578+ std::ifstream hsaco_file (hsaco_path, std::ios::binary | std::ios::ate);
579+ std::ifstream::pos_type hsaco_file_size = hsaco_file.tellg ();
580+ hsaco = std::vector<uint8>(hsaco_file_size);
581+ hsaco_file.seekg (0 , std::ios::beg);
582+ hsaco_file.read (reinterpret_cast <char *>(&hsaco[0 ]), hsaco_file_size);
583+ return Status::OK ();
584+ }
585+ return xla::InternalErrorStrCat (" Can't find Hsaco: " , hsaco_path);
586+ }
587+
551588// Emits the given module to HSA Code Object. target_machine is an initialized
552589// TargetMachine for the AMDGPU target.
553590StatusOr<std::vector<uint8>> EmitModuleToHsaco (
@@ -609,12 +646,8 @@ StatusOr<std::vector<uint8>> EmitModuleToHsaco(
609646 }
610647
611648 // Read HSACO.
612- std::ifstream hsaco_file (hsaco_path, std::ios::binary | std::ios::ate);
613- std::ifstream::pos_type hsaco_file_size = hsaco_file.tellg ();
614-
615- std::vector<uint8> hsaco (hsaco_file_size);
616- hsaco_file.seekg (0 , std::ios::beg);
617- hsaco_file.read (reinterpret_cast <char *>(&hsaco[0 ]), hsaco_file_size);
649+ std::vector<uint8> hsaco;
650+ ReadHsaco (hsaco_path, hsaco);
618651 return hsaco;
619652}
620653
@@ -655,6 +688,7 @@ void AMDGPUBackendInit(const HloModuleConfig& hlo_module_config) {
655688 // its specific initialization functions instead of the catch-all
656689 // InitializeAll*.
657690#if TENSORFLOW_USE_ROCM
691+ InitHsacoCacheDir ();
658692 LLVMInitializeAMDGPUTarget ();
659693 LLVMInitializeAMDGPUTargetInfo ();
660694 LLVMInitializeAMDGPUTargetMC ();
@@ -669,62 +703,6 @@ void AMDGPUBackendInit(const HloModuleConfig& hlo_module_config) {
669703
670704namespace amdgpu {
671705
672-
673- struct HsacoCacheEntry {
674- uint64_t hash;
675- std::string ir;
676- std::string gfx;
677- std::vector<uint8_t > hsaco;
678- };
679-
680- struct HsacoCache {
681- protected:
682- std::vector<HsacoCacheEntry> cache;
683- std::mutex m_mutex;
684- int request_count = 0 ;
685- int hit_count = 0 ;
686-
687- public:
688- static bool Find (const std::string& ir, uint64_t & hash,
689- const std::string& gfx, std::vector<uint8_t >& hsaco);
690- static void Add (const std::string& ir, uint64_t hash, const std::string& gfx,
691- const std::vector<uint8_t >& hsaco);
692- };
693-
694- static HsacoCache g_hsacoCache; // NOLINT: static/global vars forbidden
695-
696- bool HsacoCache::Find (const std::string& ir, uint64_t & hash,
697- const std::string& gfx, std::vector<uint8_t >& hsaco) {
698- std::lock_guard<std::mutex> lg (g_hsacoCache.m_mutex );
699- hash = std::hash<std::string>{}(ir);
700- bool hit = false ;
701- for (auto & x : g_hsacoCache.cache ) {
702- if (x.hash != hash) continue ;
703- if (x.gfx != gfx) continue ;
704- if (x.ir != ir) continue ;
705- hsaco = x.hsaco ;
706- hit = true ;
707- break ;
708- }
709- g_hsacoCache.request_count ++;
710- if (hit) g_hsacoCache.hit_count ++;
711- if (!(g_hsacoCache.request_count % 50 ))
712- VLOG (1 ) << " HSACO cache: " << g_hsacoCache.request_count << " requests, "
713- << g_hsacoCache.hit_count << " hits" ;
714- return hit;
715- }
716-
717- void HsacoCache::Add (const std::string& ir, uint64_t hash,
718- const std::string& gfx,
719- const std::vector<uint8_t >& hsaco) {
720- std::lock_guard<std::mutex> lg (g_hsacoCache.m_mutex );
721- g_hsacoCache.cache .resize (g_hsacoCache.cache .size () + 1 );
722- g_hsacoCache.cache .back ().ir = ir;
723- g_hsacoCache.cache .back ().hash = hash;
724- g_hsacoCache.cache .back ().gfx = gfx;
725- g_hsacoCache.cache .back ().hsaco = hsaco;
726- }
727-
728706StatusOr<std::vector<uint8>> CompileToHsaco (
729707 llvm::Module* module , GpuVersion gpu_version,
730708 const HloModuleConfig& hlo_module_config, const string& rocdl_dir_path) {
@@ -760,9 +738,12 @@ StatusOr<std::vector<uint8>> CompileToHsaco(
760738 return xla::InternalError (
761739 " Incompatible AMD GCN ISA version was specified." );
762740 }
741+
742+ std::string hsaco_filename =
743+ absl::StrCat (module ->getModuleIdentifier (), " .hsaco" );
744+ std::string hsaco_path = tensorflow::io::JoinPath (hsaco_cache_dir, hsaco_filename);
763745
764- uint64_t hash;
765- if (HsacoCache::Find (str, hash, std::to_string (*amdgpu_version), hsaco)) {
746+ if (ReadHsaco (hsaco_path, hsaco).ok ()) {
766747 VLOG (1 ) << " HSACO cache hit" ;
767748 return hsaco;
768749 }
@@ -775,35 +756,22 @@ StatusOr<std::vector<uint8>> CompileToHsaco(
775756 hlo_module_config);
776757
777758 auto * env = tensorflow::Env::Default ();
778- std::vector<std::string> tempdir_vector;
779- env->GetLocalTempDirectories (&tempdir_vector);
780- if (tempdir_vector.empty ()) {
781- return xla::InternalError (
782- " Unable to locate a temporary directory for compile-time artifacts." );
783- }
784- std::string tempdir_name = tempdir_vector.front ();
785- VLOG (1 ) << " Compile-time artifacts located at: " << tempdir_name;
786-
787759 // Prepare filenames for all stages of compilation:
788760 // IR, binary ISA, and HSACO.
789- std::string ir_filename = absl::StrCat (module ->getModuleIdentifier (), " .ll" );
790- std::string ir_path = tensorflow::io::JoinPath (tempdir_name, ir_filename);
761+ std::string module_path;
762+ if (!env->LocalTempFilename (&module_path)) {
763+ return xla::InternalError (
764+ " Could not get temporary filenames for modules." );
765+ }
766+ std::string ir_path = absl::StrCat (module_path, " .ll" );
791767
792- std::string linked_ir_filename = absl::StrCat (module ->getModuleIdentifier (), " -linked.ll" );
793- std::string linked_ir_path = tensorflow::io::JoinPath (tempdir_name, linked_ir_filename);
768+ std::string linked_ir_path = absl::StrCat (module_path, " -linked.ll" );
794769
795- std::string optimized_ir_filename = absl::StrCat (module ->getModuleIdentifier (), " -opt.ll" );
796- std::string optimized_ir_path = tensorflow::io::JoinPath (tempdir_name, optimized_ir_filename);
770+ std::string optimized_ir_path = absl::StrCat (module_path, " -opt.ll" );
797771
798- std::string isabin_filename =
799- absl::StrCat (module ->getModuleIdentifier (), " .o" );
800772 std::string isabin_path =
801- tensorflow::io::JoinPath (tempdir_name, isabin_filename );
773+ absl::StrCat (module_path, " .o " );
802774
803- std::string hsaco_filename =
804- absl::StrCat (module ->getModuleIdentifier (), " .hsaco" );
805- std::string hsaco_path =
806- tensorflow::io::JoinPath (tempdir_name, hsaco_filename);
807775
808776 // Link with ROCm-Device-Libs, and optimize the LLVM module.
809777 TF_RETURN_IF_ERROR (LinkAndOptimizeModule (
@@ -813,7 +781,12 @@ StatusOr<std::vector<uint8>> CompileToHsaco(
813781
814782 // Lower optimized LLVM module to HSA code object.
815783 TF_ASSIGN_OR_RETURN (hsaco, EmitModuleToHsaco (module , target_machine.get (), optimized_ir_path, isabin_path, hsaco_path));
816- HsacoCache::Add (str, hash, std::to_string (*amdgpu_version), hsaco);
784+ std::async (std::launch::async, [](std::vector<std::string> files){
785+ for (auto & file : files){
786+ tensorflow::Env::Default ()->DeleteFile (file);
787+ }
788+ }, std::vector<std::string>{ir_path, linked_ir_path, optimized_ir_path, isabin_path});
789+
817790 }
818791 return hsaco;
819792}
0 commit comments