Skip to content

Commit 9deaf74

Browse files
committed
[Executorch] Make module constructors uniform across
Existing constructors dont compose well such that if you want data loader or data files constructor then you cannot get to override memory allocator. Fix that. Differential Revision: [D86120037](https://our.internmc.facebook.com/intern/diff/D86120037/) [ghstack-poisoned]
1 parent f16910d commit 9deaf74

File tree

2 files changed

+30
-6
lines changed

2 files changed

+30
-6
lines changed

extension/module/module.cpp

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -78,11 +78,17 @@ runtime::Result<std::unique_ptr<runtime::DataLoader>> make_data_loader(
7878
Module::Module(
7979
const std::string& file_path,
8080
const LoadMode load_mode,
81+
std::unique_ptr<runtime::MemoryAllocator> memory_allocator,
82+
std::unique_ptr<runtime::MemoryAllocator> temp_allocator,
8183
std::unique_ptr<runtime::EventTracer> event_tracer)
8284
: file_path_(file_path),
8385
load_mode_(load_mode),
84-
memory_allocator_(std::make_unique<MallocMemoryAllocator>()),
85-
temp_allocator_(std::make_unique<MallocMemoryAllocator>()),
86+
memory_allocator_(
87+
memory_allocator ? std::move(memory_allocator)
88+
: std::make_unique<MallocMemoryAllocator>()),
89+
temp_allocator_(
90+
temp_allocator ? std::move(temp_allocator)
91+
: std::make_unique<MallocMemoryAllocator>()),
8692
event_tracer_(std::move(event_tracer)) {
8793
runtime::runtime_init();
8894
}
@@ -91,11 +97,17 @@ Module::Module(
9197
const std::string& file_path,
9298
const std::string& data_map_path,
9399
const LoadMode load_mode,
100+
std::unique_ptr<runtime::MemoryAllocator> memory_allocator,
101+
std::unique_ptr<runtime::MemoryAllocator> temp_allocator,
94102
std::unique_ptr<runtime::EventTracer> event_tracer)
95103
: file_path_(file_path),
96104
load_mode_(load_mode),
97-
memory_allocator_(std::make_unique<MallocMemoryAllocator>()),
98-
temp_allocator_(std::make_unique<MallocMemoryAllocator>()),
105+
memory_allocator_(
106+
memory_allocator ? std::move(memory_allocator)
107+
: std::make_unique<MallocMemoryAllocator>()),
108+
temp_allocator_(
109+
temp_allocator ? std::move(temp_allocator)
110+
: std::make_unique<MallocMemoryAllocator>()),
99111
event_tracer_(std::move(event_tracer)) {
100112
if (!data_map_path.empty()) {
101113
data_files_.push_back(data_map_path);
@@ -107,12 +119,18 @@ Module::Module(
107119
const std::string& file_path,
108120
std::vector<std::string> data_files,
109121
const LoadMode load_mode,
122+
std::unique_ptr<runtime::MemoryAllocator> memory_allocator,
123+
std::unique_ptr<runtime::MemoryAllocator> temp_allocator,
110124
std::unique_ptr<runtime::EventTracer> event_tracer)
111125
: file_path_(file_path),
112126
data_files_(std::move(data_files)),
113127
load_mode_(load_mode),
114-
memory_allocator_(std::make_unique<MallocMemoryAllocator>()),
115-
temp_allocator_(std::make_unique<MallocMemoryAllocator>()),
128+
memory_allocator_(
129+
memory_allocator ? std::move(memory_allocator)
130+
: std::make_unique<MallocMemoryAllocator>()),
131+
temp_allocator_(
132+
temp_allocator ? std::move(temp_allocator)
133+
: std::make_unique<MallocMemoryAllocator>()),
116134
event_tracer_(std::move(event_tracer)) {
117135
runtime::runtime_init();
118136
}

extension/module/module.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ class Module {
6363
explicit Module(
6464
const std::string& file_path,
6565
const LoadMode load_mode = LoadMode::File,
66+
std::unique_ptr<runtime::MemoryAllocator> memory_allocator = nullptr,
67+
std::unique_ptr<runtime::MemoryAllocator> temp_allocator = nullptr,
6668
std::unique_ptr<runtime::EventTracer> event_tracer = nullptr);
6769

6870
/**
@@ -78,6 +80,8 @@ class Module {
7880
const std::string& file_path,
7981
const std::string& data_map_path,
8082
const LoadMode load_mode = LoadMode::File,
83+
std::unique_ptr<runtime::MemoryAllocator> memory_allocator = nullptr,
84+
std::unique_ptr<runtime::MemoryAllocator> temp_allocator = nullptr,
8185
std::unique_ptr<runtime::EventTracer> event_tracer = nullptr);
8286

8387
/**
@@ -93,6 +97,8 @@ class Module {
9397
const std::string& file_path,
9498
std::vector<std::string> data_files,
9599
const LoadMode load_mode = LoadMode::File,
100+
std::unique_ptr<runtime::MemoryAllocator> memory_allocator = nullptr,
101+
std::unique_ptr<runtime::MemoryAllocator> temp_allocator = nullptr,
96102
std::unique_ptr<runtime::EventTracer> event_tracer = nullptr);
97103

98104
/**

0 commit comments

Comments
 (0)