Skip to content

Commit

Permalink
Apply similar changes to AZURE
Browse files Browse the repository at this point in the history
  • Loading branch information
oandreeva-nv committed Sep 13, 2023
1 parent 2bd9a25 commit 6903e1c
Showing 1 changed file with 34 additions and 22 deletions.
56 changes: 34 additions & 22 deletions src/filesystem/implementations/as.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ class ASFileSystem : public FileSystem {

Status DownloadFolder(
const std::string& container, const std::string& path,
const std::string& dest);
const std::string& dest, const bool recursive);

std::shared_ptr<asb::BlobServiceClient> client_;
re2::RE2 as_regex_;
Expand Down Expand Up @@ -392,7 +392,7 @@ ASFileSystem::FileExists(const std::string& path, bool* exists)
Status
ASFileSystem::DownloadFolder(
const std::string& container, const std::string& path,
const std::string& dest)
const std::string& dest, const bool recursive)
{
auto container_client = client_->GetBlobContainerClient(container);
auto func = [&](const std::vector<asb::Models::BlobItem>& blobs,
Expand All @@ -408,17 +408,20 @@ ASFileSystem::DownloadFolder(
"Failed to download file at " + blob_item.Name + ":" + ex.what());
}
}
for (const auto& directory_item : blob_prefixes) {
const auto& local_path = JoinPath({dest, BaseName(directory_item)});
int status = mkdir(
const_cast<char*>(local_path.c_str()), S_IRUSR | S_IWUSR | S_IXUSR);
if (status == -1) {
return Status(
Status::Code::INTERNAL,
"Failed to create local folder: " + local_path +
", errno:" + strerror(errno));
if (recursive) {
for (const auto& directory_item : blob_prefixes) {
const auto& local_path = JoinPath({dest, BaseName(directory_item)});
int status = mkdir(
const_cast<char*>(local_path.c_str()), S_IRUSR | S_IWUSR | S_IXUSR);
if (status == -1 && errno != EEXIST) {
return Status(
Status::Code::INTERNAL,
"Failed to create local folder: " + local_path +
", errno:" + strerror(errno));
}
RETURN_IF_ERROR(
DownloadFolder(container, directory_item, local_path, recursive));
}
RETURN_IF_ERROR(DownloadFolder(container, directory_item, local_path));
}
return Status::Success;
};
Expand All @@ -445,21 +448,30 @@ ASFileSystem::LocalizePath(
"AS file localization not yet implemented " + path);
}

std::string folder_template = "/tmp/folderXXXXXX";
char* tmp_folder = mkdtemp(const_cast<char*>(folder_template.c_str()));
if (tmp_folder == nullptr) {
return Status(
Status::Code::INTERNAL,
"Failed to create local temp folder: " + folder_template +
", errno:" + strerror(errno));
// Create a local directory for s3 model store.
// If `mount_dir` or ENV variable are not set,
// creates a temporary directory under `/tmp` with the format: "folderXXXXXX".
// Otherwise, will create a folder under specified directory with the name
// indicated in path (i.e. everything after the last encounter of `/`).
const char* env_mount_dir = std::getenv("TRITON_AZURE_MOUNT_DIRECTORY");
std::string tmp_folder;
if (mount_dir.empty() && env_mount_dir == nullptr) {
RETURN_IF_ERROR(triton::core::MakeTemporaryDirectory(
FileSystemType::LOCAL, &tmp_folder));
} else {
tmp_folder = mount_dir.empty() ? std::string(env_mount_dir) : mount_dir;
tmp_folder =
JoinPath({tmp_folder, path.substr(path.find_last_of('/') + 1)});
RETURN_IF_ERROR(triton::core::MakeDirectory(
tmp_folder, true /*recursive*/, true /*allow_dir_exist*/));
}
localized->reset(new LocalizedPath(path, tmp_folder));

std::string dest(folder_template);
localized->reset(new LocalizedPath(path, tmp_folder));

std::string dest(tmp_folder);
std::string container, blob;
RETURN_IF_ERROR(ParsePath(path, &container, &blob));
return DownloadFolder(container, blob, dest);
return DownloadFolder(container, blob, dest, recursive);
}

Status
Expand Down

0 comments on commit 6903e1c

Please sign in to comment.