Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[aot] [llvm] LLVM AOT Field #2: Updated LLVM AOTModuleLoader & AOTModuleBuilder to support Fields #5120

Merged
merged 8 commits into from
Jun 13, 2022
5 changes: 5 additions & 0 deletions taichi/backends/cpu/aot_module_builder_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@ namespace lang {
namespace cpu {

class AotModuleBuilderImpl : public LlvmAotModuleBuilder {
public:
explicit AotModuleBuilderImpl(LlvmProgramImpl *prog)
: LlvmAotModuleBuilder(prog) {
}

private:
CodeGenLLVM::CompiledData compile_kernel(Kernel *kernel) override;
};
Expand Down
5 changes: 0 additions & 5 deletions taichi/backends/cpu/aot_module_loader_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,6 @@ class AotModuleImpl : public LlvmAotModule {
TI_NOT_IMPLEMENTED;
return nullptr;
}

std::unique_ptr<aot::Field> make_new_field(const std::string &name) override {
TI_NOT_IMPLEMENTED;
return nullptr;
}
};

} // namespace
Expand Down
5 changes: 5 additions & 0 deletions taichi/backends/cuda/aot_module_builder_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@ namespace lang {
namespace cuda {

class AotModuleBuilderImpl : public LlvmAotModuleBuilder {
public:
explicit AotModuleBuilderImpl(LlvmProgramImpl *prog)
: LlvmAotModuleBuilder(prog) {
}

private:
CodeGenLLVM::CompiledData compile_kernel(Kernel *kernel) override;
};
Expand Down
5 changes: 0 additions & 5 deletions taichi/backends/cuda/aot_module_loader_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,6 @@ class AotModuleImpl : public LlvmAotModule {
TI_NOT_IMPLEMENTED;
return nullptr;
}

std::unique_ptr<aot::Field> make_new_field(const std::string &name) override {
TI_NOT_IMPLEMENTED;
return nullptr;
}
};

} // namespace
Expand Down
2 changes: 1 addition & 1 deletion taichi/ir/snode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ void SNode::set_snode_tree_id(int id) {
snode_tree_id_ = id;
}

int SNode::get_snode_tree_id() {
int SNode::get_snode_tree_id() const {
return snode_tree_id_;
}

Expand Down
2 changes: 1 addition & 1 deletion taichi/ir/snode.h
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ class SNode {

void set_snode_tree_id(int id);

int get_snode_tree_id();
int get_snode_tree_id() const;

static void reset_counter() {
counter = 0;
Expand Down
33 changes: 33 additions & 0 deletions taichi/llvm/llvm_aot_module_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <algorithm>
#include "taichi/llvm/launch_arg_info.h"
#include "taichi/llvm/llvm_program.h"

namespace taichi {
namespace lang {
Expand Down Expand Up @@ -34,5 +35,37 @@ void LlvmAotModuleBuilder::add_per_backend(const std::string &identifier,
cache_.kernels[identifier] = std::move(kcache);
}

void LlvmAotModuleBuilder::add_field_per_backend(const std::string &identifier,
const SNode *rep_snode,
bool is_scalar,
DataType dt,
std::vector<int> shape,
int row_num,
int column_num) {
// Field refers to a leaf node(Place SNode) in a SNodeTree.
// It makes no sense to just serialize the leaf node or its corresponding
// branch. Instead, the minimal unit we have to serialize is the entire
// SNodeTree. Note that SNodeTree's uses snode_tree_id as its identifier,
// rather than the field's name. (multiple fields may end up referring to the
// same SNodeTree)

// 1. Find snode_tree_id
int snode_tree_id = rep_snode->get_snode_tree_id();

// 2. Fetch Cache from the Program
// Kernel compilation is not allowed until all the Fields are finalized,
// so we finished SNodeTree compilation during AOTModuleBuilder construction.
//
// By the time "add_field_per_backend()" is called,
// SNodeTrees should have already been finalized,
// with compiled info stored in LlvmProgramImpl::cache_data_.
TI_ASSERT(prog_ != nullptr);
LlvmOfflineCache::FieldCacheData field_cache =
prog_->get_cached_field(snode_tree_id);

// 3. Update AOT Cache
cache_.fields[snode_tree_id] = std::move(field_cache);
}

} // namespace lang
} // namespace taichi
12 changes: 12 additions & 0 deletions taichi/llvm/llvm_aot_module_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,27 @@ namespace lang {

class LlvmAotModuleBuilder : public AotModuleBuilder {
public:
explicit LlvmAotModuleBuilder(LlvmProgramImpl *prog) : prog_(prog) {
}

void dump(const std::string &output_dir,
const std::string &filename) const override;

protected:
void add_per_backend(const std::string &identifier, Kernel *kernel) override;
virtual CodeGenLLVM::CompiledData compile_kernel(Kernel *kernel) = 0;

void add_field_per_backend(const std::string &identifier,
const SNode *rep_snode,
bool is_scalar,
DataType dt,
std::vector<int> shape,
int row_num,
int column_num) override;

private:
mutable LlvmOfflineCache cache_;
LlvmProgramImpl *prog_ = nullptr;
};

} // namespace lang
Expand Down
55 changes: 55 additions & 0 deletions taichi/llvm/llvm_aot_module_loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,24 @@ class KernelImpl : public aot::Kernel {
FunctionType fn_;
};

class FieldImpl : public aot::Field {
public:
explicit FieldImpl(const LlvmOfflineCache::FieldCacheData &field)
: field_(field) {
}

explicit FieldImpl(LlvmOfflineCache::FieldCacheData &&field)
: field_(std::move(field)) {
}

LlvmOfflineCache::FieldCacheData get_field() const {
return field_;
}

private:
LlvmOfflineCache::FieldCacheData field_;
};

} // namespace

LlvmOfflineCache::KernelCacheData LlvmAotModule::load_kernel_from_cache(
Expand All @@ -37,5 +55,42 @@ std::unique_ptr<aot::Kernel> LlvmAotModule::make_new_kernel(
return std::make_unique<KernelImpl>(fn);
}

std::unique_ptr<aot::Field> LlvmAotModule::make_new_field(
const std::string &name) {
// Check if "name" represents snode_tree_id.
// Avoid using std::atoi due to its poor error handling.
char *end;
int snode_tree_id = static_cast<int>(strtol(name.c_str(), &end, 10 /*base*/));

TI_ASSERT(end != name.c_str());
TI_ASSERT(*end == '\0');

// Load FieldCache
LlvmOfflineCache::FieldCacheData loaded;
auto ok = cache_reader_->get_field_cache(loaded, snode_tree_id);
TI_ERROR_IF(!ok, "Failed to load field with id={}", snode_tree_id);

return std::make_unique<FieldImpl>(std::move(loaded));
}

void finalize_aot_field(aot::Module *aot_module,
aot::Field *aot_field,
uint64 *result_buffer) {
auto *llvm_aot_module = dynamic_cast<LlvmAotModule *>(aot_module);
auto *aot_field_impl = dynamic_cast<FieldImpl *>(aot_field);

TI_ASSERT(llvm_aot_module != nullptr);
TI_ASSERT(aot_field_impl != nullptr);

auto *llvm_prog = llvm_aot_module->get_program();
const auto &field_cache = aot_field_impl->get_field();

int snode_tree_id = field_cache.tree_id;
if (!llvm_aot_module->is_snode_tree_initialized(snode_tree_id)) {
llvm_prog->initialize_llvm_runtime_snodes(field_cache, result_buffer);
llvm_aot_module->set_initialized_snode_tree(snode_tree_id);
}
}

} // namespace lang
} // namespace taichi
21 changes: 21 additions & 0 deletions taichi/llvm/llvm_aot_module_loader.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
namespace taichi {
namespace lang {

TI_DLL_EXPORT void finalize_aot_field(aot::Module *aot_module,
aot::Field *aot_field,
uint64 *result_buffer);

class LlvmAotModule : public aot::Module {
public:
explicit LlvmAotModule(const std::string &module_path,
Expand All @@ -27,6 +31,18 @@ class LlvmAotModule : public aot::Module {
return 0;
}

LlvmProgramImpl *const get_program() {
return program_;
}

void set_initialized_snode_tree(int snode_tree_id) {
initialized_snode_tree_ids.insert(snode_tree_id);
}

bool is_snode_tree_initialized(int snode_tree_id) {
return initialized_snode_tree_ids.count(snode_tree_id);
}

protected:
virtual FunctionType convert_module_to_function(
const std::string &name,
Expand All @@ -38,8 +54,13 @@ class LlvmAotModule : public aot::Module {
std::unique_ptr<aot::Kernel> make_new_kernel(
const std::string &name) override;

std::unique_ptr<aot::Field> make_new_field(const std::string &name) override;

LlvmProgramImpl *const program_{nullptr};
std::unique_ptr<LlvmOfflineCacheFileReader> cache_reader_{nullptr};

// To prevent repeated SNodeTree initialization
std::unordered_set<int> initialized_snode_tree_ids;
};

} // namespace lang
Expand Down
2 changes: 1 addition & 1 deletion taichi/llvm/llvm_offline_cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ struct LlvmOfflineCache {
std::unordered_map<std::string, KernelCacheData>
kernels; // key = kernel_name

TI_IO_DEF(kernels);
TI_IO_DEF(fields, kernels);
};

class LlvmOfflineCacheFileReader {
Expand Down
66 changes: 39 additions & 27 deletions taichi/llvm/llvm_program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -273,37 +273,22 @@ std::unique_ptr<StructCompiler> LlvmProgramImpl::compile_snode_tree_types_impl(
}

void LlvmProgramImpl::compile_snode_tree_types(SNodeTree *tree) {
compile_snode_tree_types_impl(tree);
}

static LlvmOfflineCache::FieldCacheData construct_filed_cache_data(
const SNodeTree &tree,
const StructCompiler &struct_compiler) {
LlvmOfflineCache::FieldCacheData ret;
ret.tree_id = tree.id();
ret.root_id = tree.root()->id;
ret.root_size = struct_compiler.root_size;

const auto &snodes = struct_compiler.snodes;
for (size_t i = 0; i < snodes.size(); i++) {
LlvmOfflineCache::FieldCacheData::SNodeCacheData snode_cache_data;
snode_cache_data.id = snodes[i]->id;
snode_cache_data.type = snodes[i]->type;
snode_cache_data.cell_size_bytes = snodes[i]->cell_size_bytes;
snode_cache_data.chunk_size = snodes[i]->chunk_size;

ret.snode_metas.emplace_back(std::move(snode_cache_data));
}
auto struct_compiler = compile_snode_tree_types_impl(tree);
int snode_tree_id = tree->id();
int root_id = tree->root()->id;

return ret;
// Add compiled result to Cache
cache_field(snode_tree_id, root_id, *struct_compiler);
}

void LlvmProgramImpl::materialize_snode_tree(SNodeTree *tree,
uint64 *result_buffer) {
auto struct_compiler = compile_snode_tree_types_impl(tree);
compile_snode_tree_types(tree);
int snode_tree_id = tree->id();

auto field_cache_data = construct_filed_cache_data(*tree, *struct_compiler);
initialize_llvm_runtime_snodes(field_cache_data, result_buffer);
TI_ASSERT(cache_data_.fields.find(snode_tree_id) != cache_data_.fields.end());
initialize_llvm_runtime_snodes(cache_data_.fields.at(snode_tree_id),
result_buffer);
}

uint64 LlvmProgramImpl::fetch_result_uint64(int i, uint64 *result_buffer) {
Expand Down Expand Up @@ -365,12 +350,12 @@ void LlvmProgramImpl::print_list_manager_info(void *list_manager,

std::unique_ptr<AotModuleBuilder> LlvmProgramImpl::make_aot_module_builder() {
if (config->arch == Arch::x64 || config->arch == Arch::arm64) {
return std::make_unique<cpu::AotModuleBuilderImpl>();
return std::make_unique<cpu::AotModuleBuilderImpl>(this);
}

#if defined(TI_WITH_CUDA)
if (config->arch == Arch::cuda) {
return std::make_unique<cuda::AotModuleBuilderImpl>();
return std::make_unique<cuda::AotModuleBuilderImpl>(this);
}
#endif

Expand Down Expand Up @@ -701,6 +686,33 @@ void LlvmProgramImpl::cache_kernel(
kernel_cache.offloaded_task_list = std::move(offloaded_task_list);
}

void LlvmProgramImpl::cache_field(int snode_tree_id,
int root_id,
const StructCompiler &struct_compiler) {
if (cache_data_.fields.find(snode_tree_id) != cache_data_.fields.end()) {
// [TODO] check and update the Cache, instead of simply return.
return;
}

LlvmOfflineCache::FieldCacheData ret;
ret.tree_id = snode_tree_id;
ret.root_id = root_id;
ret.root_size = struct_compiler.root_size;

const auto &snodes = struct_compiler.snodes;
for (size_t i = 0; i < snodes.size(); i++) {
LlvmOfflineCache::FieldCacheData::SNodeCacheData snode_cache_data;
snode_cache_data.id = snodes[i]->id;
snode_cache_data.type = snodes[i]->type;
snode_cache_data.cell_size_bytes = snodes[i]->cell_size_bytes;
snode_cache_data.chunk_size = snodes[i]->chunk_size;

ret.snode_metas.emplace_back(std::move(snode_cache_data));
}

cache_data_.fields[snode_tree_id] = std::move(ret);
}

void LlvmProgramImpl::dump_cache_data_to_disk() {
if (config->offline_cache && !cache_data_.kernels.empty()) {
LlvmOfflineCacheFileWriter writer{};
Expand Down
23 changes: 17 additions & 6 deletions taichi/llvm/llvm_program.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,23 +118,34 @@ class LlvmProgramImpl : public ProgramImpl {
std::vector<LlvmOfflineCache::OffloadedTaskCacheData>
&&offloaded_task_list);

void cache_field(int snode_tree_id,
int root_id,
const StructCompiler &struct_compiler);

LlvmOfflineCache::FieldCacheData get_cached_field(int snode_tree_id) const {
TI_ASSERT(cache_data_.fields.find(snode_tree_id) !=
cache_data_.fields.end());
return cache_data_.fields.at(snode_tree_id);
}

Device *get_compute_device() override {
return device_.get();
}

/**
* Initializes the SNodes for LLVM based backends.
*/
void initialize_llvm_runtime_snodes(
const LlvmOfflineCache::FieldCacheData &field_cache_data,
uint64 *result_buffer);

private:
std::unique_ptr<llvm::Module> clone_struct_compiler_initial_context(
bool has_multiple_snode_trees,
TaichiLLVMContext *tlctx);

std::unique_ptr<StructCompiler> compile_snode_tree_types_impl(
SNodeTree *tree);
/**
* Initializes the SNodes for LLVM based backends.
*/
void initialize_llvm_runtime_snodes(
const LlvmOfflineCache::FieldCacheData &field_cache_data,
uint64 *result_buffer);

uint64 fetch_result_uint64(int i, uint64 *result_buffer);

Expand Down