Skip to content

Commit 9229898

Browse files
committed
gguf: split gguf writer into base and buf impl
1 parent bbbf5ec commit 9229898

File tree

1 file changed

+43
-22
lines changed

1 file changed

+43
-22
lines changed

ggml/src/gguf.cpp

Lines changed: 43 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1166,50 +1166,51 @@ void gguf_set_tensor_data(struct gguf_context * ctx, const char * name, const vo
11661166
ctx->info[tensor_id].t.data = (void *)(uintptr_t)data; // double cast suppresses warning about casting away const
11671167
}
11681168

1169-
struct gguf_writer {
1170-
std::vector<int8_t> & buf;
1169+
struct gguf_writer_base {
1170+
size_t written_bytes {0u};
1171+
1172+
~gguf_writer_base(void) {}
11711173

1172-
gguf_writer(std::vector<int8_t> & buf) : buf(buf) {}
1174+
// we bet on devirtualization
1175+
virtual void write(int8_t val) = 0;
1176+
virtual void write(const std::vector<int8_t> & val) = 0;
1177+
virtual void write_tensor_data(const struct gguf_tensor_info & info, size_t offset_data, size_t alignment) = 0;
11731178

11741179
template <typename T>
1175-
void write(const T & val) const {
1180+
void write(const T & val) {
11761181
for (size_t i = 0; i < sizeof(val); ++i) {
1177-
buf.push_back(reinterpret_cast<const int8_t *>(&val)[i]);
1182+
write(reinterpret_cast<const int8_t *>(&val)[i]);
11781183
}
11791184
}
11801185

1181-
void write(const std::vector<int8_t> & val) const {
1182-
buf.insert(buf.end(), val.begin(), val.end());
1183-
}
1184-
1185-
void write(const bool & val) const {
1186+
void write(const bool & val) {
11861187
const int8_t val8 = val ? 1 : 0;
11871188
write(val8);
11881189
}
11891190

1190-
void write(const std::string & val) const {
1191+
void write(const std::string & val) {
11911192
{
11921193
const uint64_t n = val.length();
11931194
write(n);
11941195
}
11951196
for (size_t i = 0; i < val.length(); ++i) {
1196-
buf.push_back(reinterpret_cast<const int8_t *>(val.data())[i]);
1197+
write((val.data())[i]);
11971198
}
11981199
}
11991200

1200-
void write(const char * val) const {
1201+
void write(const char * val) {
12011202
write(std::string(val));
12021203
}
12031204

1204-
void write(const enum ggml_type & val) const {
1205+
void write(const enum ggml_type & val) {
12051206
write(int32_t(val));
12061207
}
12071208

1208-
void write(const enum gguf_type & val) const {
1209+
void write(const enum gguf_type & val) {
12091210
write(int32_t(val));
12101211
}
12111212

1212-
void write(const struct gguf_kv & kv) const {
1213+
void write(const struct gguf_kv & kv) {
12131214
const uint64_t ne = kv.get_ne();
12141215

12151216
write(kv.get_key());
@@ -1250,7 +1251,7 @@ struct gguf_writer {
12501251
}
12511252
}
12521253

1253-
void write_tensor_meta(const struct gguf_tensor_info & info) const {
1254+
void write_tensor_meta(const struct gguf_tensor_info & info) {
12541255
write(info.t.name);
12551256

12561257
const uint32_t n_dims = ggml_n_dims(&info.t);
@@ -1263,14 +1264,33 @@ struct gguf_writer {
12631264
write(info.offset);
12641265
}
12651266

1266-
void pad(const size_t alignment) const {
1267-
while (buf.size() % alignment != 0) {
1267+
void pad(const size_t alignment) {
1268+
while (written_bytes % alignment != 0) {
12681269
const int8_t zero = 0;
12691270
write(zero);
12701271
}
12711272
}
1273+
};
1274+
1275+
// vector buffer based writer
1276+
struct gguf_writer_buf final : public gguf_writer_base {
1277+
std::vector<int8_t> & buf;
1278+
1279+
gguf_writer_buf(std::vector<int8_t> & buf) : buf(buf) {}
1280+
1281+
using gguf_writer_base::write;
1282+
1283+
void write(const int8_t val) override {
1284+
buf.push_back(val);
1285+
written_bytes++;
1286+
}
1287+
1288+
void write(const std::vector<int8_t> & val) override {
1289+
buf.insert(buf.end(), val.begin(), val.end());
1290+
written_bytes += val.size();
1291+
}
12721292

1273-
void write_tensor_data(const struct gguf_tensor_info & info, const size_t offset_data, const size_t alignment) const {
1293+
void write_tensor_data(const struct gguf_tensor_info & info, const size_t offset_data, const size_t alignment) override {
12741294
GGML_ASSERT(buf.size() - offset_data == info.offset);
12751295

12761296
GGML_ASSERT(ggml_is_contiguous(&info.t));
@@ -1284,13 +1304,14 @@ struct gguf_writer {
12841304
GGML_ASSERT(info.t.data);
12851305
memcpy(buf.data() + offset, info.t.data, nbytes);
12861306
}
1307+
written_bytes += nbytes;
12871308

12881309
pad(alignment);
12891310
}
12901311
};
12911312

12921313
void gguf_write_to_buf(const struct gguf_context * ctx, std::vector<int8_t> & buf, bool only_meta) {
1293-
const struct gguf_writer gw(buf);
1314+
gguf_writer_buf gw(buf);
12941315

12951316
const int64_t n_kv = gguf_get_n_kv(ctx);
12961317
const int64_t n_tensors = gguf_get_n_tensors(ctx);
@@ -1321,7 +1342,7 @@ void gguf_write_to_buf(const struct gguf_context * ctx, std::vector<int8_t> & bu
13211342
return;
13221343
}
13231344

1324-
const size_t offset_data = gw.buf.size();
1345+
const size_t offset_data = gw.written_bytes;
13251346

13261347
// write tensor data
13271348
for (int64_t i = 0; i < n_tensors; ++i) {

0 commit comments

Comments
 (0)