Skip to content

Commit

Permalink
Make replace_all() have linear complexity
Browse files Browse the repository at this point in the history
  • Loading branch information
jart committed Aug 25, 2024
1 parent fa4c4e7 commit 2c940da
Show file tree
Hide file tree
Showing 13 changed files with 59 additions and 89 deletions.
22 changes: 14 additions & 8 deletions llama.cpp/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "json-schema-to-grammar.h"
#include "llama.h"
#include "llamafile/debug.h"
#include "string.h"

#include <cosmo.h>
#include <algorithm>
Expand Down Expand Up @@ -1731,15 +1732,20 @@ std::string string_get_sortable_timestamp() {
return std::string(timestamp_no_ns) + "." + std::string(timestamp_ns);
}

void string_replace_all(std::string & s, const std::string & search, const std::string & replace) {
if (search.empty()) {
return; // Avoid infinite loop if 'search' is an empty string
}
std::string replace_all(const std::string& s, const std::string& search, const std::string& replace) {
if (search.empty())
return s;
std::string builder;
builder.reserve(s.length());
size_t pos = 0;
while ((pos = s.find(search, pos)) != std::string::npos) {
s.replace(pos, search.length(), replace);
pos += replace.length();
}
size_t last_pos = 0;
while ((pos = s.find(search, last_pos)) != std::string::npos) {
builder.append(s, last_pos, pos - last_pos);
builder.append(replace);
last_pos = pos + search.length();
}
builder.append(s, last_pos, std::string::npos);
return builder;
}

void string_process_escapes(std::string & input) {
Expand Down
16 changes: 0 additions & 16 deletions llama.cpp/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -289,11 +289,6 @@ std::string gpt_params_get_system_info(const gpt_params & params);

std::vector<std::string> string_split(std::string input, char separator);

std::string string_strip(const std::string & str);
std::string string_get_sortable_timestamp();

void string_replace_all(std::string & s, const std::string & search, const std::string & replace);

template<class T>
static std::vector<T> string_split(const std::string & str, char delim) {
std::vector<T> values;
Expand All @@ -309,17 +304,6 @@ static std::vector<T> string_split(const std::string & str, char delim) {
}

bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides);
void string_process_escapes(std::string & input);

//
// Filesystem utils
//

bool fs_validate_filename(const std::string & filename);
bool fs_create_directory_with_parents(const std::string & path);

std::string fs_get_cache_directory();
std::string fs_get_cache_file(const std::string & filename);

//
// Model utils
Expand Down
18 changes: 4 additions & 14 deletions llama.cpp/llama-bench/llama-bench.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,16 +91,6 @@ static T stdev(const std::vector<T> & v) {
return stdev;
}

static std::string replaceAll(std::string str, const std::string& from, const std::string& to) {
size_t start_pos = 0;
while ((start_pos = str.find(from, start_pos)) != std::string::npos) {
str.replace(start_pos, from.length(), to);
start_pos += to.length(); // Handles case where 'to' is a substring of 'from'
}
return str;
}


#ifdef __x86_64__
static void cpuid(unsigned leaf, unsigned subleaf, unsigned *info) {
asm("movq\t%%rbx,%%rsi\n\t"
Expand Down Expand Up @@ -159,9 +149,9 @@ static std::string get_cpu_info() { // [jart]
}
}
#endif
id = replaceAll(id, " 96-Cores", "");
id = replaceAll(id, "(TM)", "");
id = replaceAll(id, "(R)", "");
id = replace_all(id, " 96-Cores", "");
id = replace_all(id, "(TM)", "");
id = replace_all(id, "(R)", "");

std::string march;
#ifdef __x86_64__
Expand Down Expand Up @@ -1257,7 +1247,7 @@ struct markdown_printer : public printer {
snprintf(buf, sizeof(buf), "%.2f", t.avg_ts());
value = buf;
} else if (vmap.find(field) != vmap.end()) {
value = replaceAll(replaceAll(vmap.at(field), ".gguf", ""), ".llamafile", ""); // [jart]
value = replace_all(replace_all(vmap.at(field), ".gguf", ""), ".llamafile", ""); // [jart]
} else {
assert(false);
exit(1);
Expand Down
15 changes: 0 additions & 15 deletions llama.cpp/llama-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,3 @@ void llama_log_callback_default(ggml_log_level level, const char * text, void *
#define LLAMA_LOG_INFO(...) llama_log_internal(GGML_LOG_LEVEL_INFO , __VA_ARGS__)
#define LLAMA_LOG_WARN(...) llama_log_internal(GGML_LOG_LEVEL_WARN , __VA_ARGS__)
#define LLAMA_LOG_ERROR(...) llama_log_internal(GGML_LOG_LEVEL_ERROR, __VA_ARGS__)

//
// helpers
//

static void replace_all(std::string & s, const std::string & search, const std::string & replace) {
if (search.empty()) {
return; // Avoid infinite loop if 'search' is an empty string
}
size_t pos = 0;
while ((pos = s.find(search, pos)) != std::string::npos) {
s.replace(pos, search.length(), replace);
pos += replace.length();
}
}
5 changes: 3 additions & 2 deletions llama.cpp/llama-vocab.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "llama-vocab.h"

#include "unicode.h"
#include "string.h"

#include <algorithm>
#include <cassert>
Expand Down Expand Up @@ -152,11 +153,11 @@ static uint8_t llama_token_to_byte(const llama_vocab & vocab, llama_token id) {
}

static void llama_escape_whitespace(std::string & text) {
replace_all(text, " ", "\xe2\x96\x81");
text = replace_all(text, " ", "\xe2\x96\x81");
}

static void llama_unescape_whitespace(std::string & word) {
replace_all(word, "\xe2\x96\x81", " ");
word = replace_all(word, "\xe2\x96\x81", " ");
}

struct llm_symbol {
Expand Down
11 changes: 6 additions & 5 deletions llama.cpp/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "llama-sampling.h"

#include "unicode.h"
#include "string.h"

#include "ggml.h"
#include "ggml-alloc.h"
Expand Down Expand Up @@ -1424,8 +1425,8 @@ static std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i) {
if (arr_type == GGUF_TYPE_STRING) {
std::string val = gguf_get_arr_str(ctx_gguf, i, j);
// escape quotes
replace_all(val, "\\", "\\\\");
replace_all(val, "\"", "\\\"");
val = replace_all(val, "\\", "\\\\");
val = replace_all(val, "\"", "\\\"");
ss << '"' << val << '"';
} else if (arr_type == GGUF_TYPE_ARRAY) {
ss << "???";
Expand Down Expand Up @@ -3563,7 +3564,7 @@ struct llama_model_loader {
if (value.size() > MAX_VALUE_LEN) {
value = format("%s...", value.substr(0, MAX_VALUE_LEN - 3).c_str());
}
replace_all(value, "\n", "\\n");
value = replace_all(value, "\n", "\\n");

LLAMA_LOG_INFO("%s: - kv %3d: %42s %-16s = %s\n", __func__, i, name, type_name.c_str(), value.c_str());
}
Expand Down Expand Up @@ -16397,14 +16398,14 @@ static void llama_lora_adapter_init_internal(struct llama_model * model, const c
for (ggml_tensor * cur = ggml_get_first_tensor(ctx); cur; cur = ggml_get_next_tensor(ctx, cur)) {
std::string name(cur->name);
if (str_endswith(name, ".lora_a")) {
replace_all(name, ".lora_a", "");
name = replace_all(name, ".lora_a", "");
if (ab_map.find(name) == ab_map.end()) {
ab_map[name] = llama_lora_weight(cur, nullptr);
} else {
ab_map[name].a = cur;
}
} else if (str_endswith(name, ".lora_b")) {
replace_all(name, ".lora_b", "");
name = replace_all(name, ".lora_b", "");
if (ab_map.find(name) == ab_map.end()) {
ab_map[name] = llama_lora_weight(nullptr, cur);
} else {
Expand Down
18 changes: 4 additions & 14 deletions llama.cpp/llava/clip.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "llama.cpp/ggml-backend.h"
#include "llama.cpp/ggml-cuda.h"
#include "llama.cpp/ggml-metal.h"
#include "llama.cpp/string.h"

#include "stb/stb_image.h"

Expand Down Expand Up @@ -202,17 +203,6 @@ static std::string gguf_data_to_str(enum gguf_type type, const void * data, int
}
}

static void replace_all(std::string & s, const std::string & search, const std::string & replace) {
if (search.empty()) {
return; // Avoid infinite loop if 'search' is an empty string
}
size_t pos = 0;
while ((pos = s.find(search, pos)) != std::string::npos) {
s.replace(pos, search.length(), replace);
pos += replace.length();
}
}

static std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i) {
const enum gguf_type type = gguf_get_kv_type(ctx_gguf, i);

Expand All @@ -230,8 +220,8 @@ static std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i) {
if (arr_type == GGUF_TYPE_STRING) {
std::string val = gguf_get_arr_str(ctx_gguf, i, j);
// escape quotes
replace_all(val, "\\", "\\\\");
replace_all(val, "\"", "\\\"");
val = replace_all(val, "\\", "\\\\");
val = replace_all(val, "\"", "\\\"");
ss << '"' << val << '"';
} else if (arr_type == GGUF_TYPE_ARRAY) {
ss << "???";
Expand Down Expand Up @@ -1073,7 +1063,7 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
if (value.size() > MAX_VALUE_LEN) {
value = format("%s...", value.substr(0, MAX_VALUE_LEN - 3).c_str());
}
replace_all(value, "\n", "\\n");
value = replace_all(value, "\n", "\\n");

LOG_TEE("%s: - kv %3d: %42s %-16s = %s\n", __func__, i, name, type_name.c_str(), value.c_str());
}
Expand Down
4 changes: 3 additions & 1 deletion llama.cpp/llava/llava.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
// vi: set et ft=cpp ts=4 sts=4 sw=4 fenc=utf-8 :vi

#include "clip.h"
#include "llama.cpp/common.h"
#include "llama.cpp/string.h"
#include "llama.cpp/llama.h"
#include "llama.cpp/log.h"
#include "llama.cpp/common.h"
#include "llava.h"

#include <cstdio>
Expand Down
1 change: 1 addition & 0 deletions llama.cpp/llava/llava.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ LLAVA_API void llava_image_embed_free(struct llava_image_embed * embed);
/** write the image represented by embed into the llama context with batch size n_batch, starting at context pos n_past. on completion, n_past points to the next position in the context after the image embed. */
LLAVA_API bool llava_eval_image_embed(struct llama_context * ctx_llama, const struct llava_image_embed * embed, int n_batch, int * n_past);

struct gpt_params;
LLAVA_API int llava_cli(int argc, char ** argv, gpt_params & params); // [jart]

#ifdef __cplusplus
Expand Down
1 change: 1 addition & 0 deletions llama.cpp/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

#include "llamafile/version.h"
#include "llama.cpp/llama.h"
#include "llama.cpp/string.h"
#include "llama.cpp/common.h"
#include "llama.cpp/console.h"
#include "llama.cpp/ggml-cuda.h"
Expand Down
1 change: 1 addition & 0 deletions llama.cpp/perplexity/perplexity.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// vi: set et ft=cpp ts=4 sts=4 sw=4 fenc=utf-8 :vi
#include "llama.cpp/common.h"
#include "llama.cpp/llama.h"
#include "llama.cpp/string.h"

#include <cmath>
#include <cstdio>
Expand Down
17 changes: 17 additions & 0 deletions llama.cpp/string.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;tab-width:8;coding:utf-8 -*-
// vi: set et ft=cpp ts=4 sts=4 sw=4 fenc=utf-8 :vi

#pragma once
#include <string>

std::string string_strip(const std::string & str);
std::string string_get_sortable_timestamp();
std::string replace_all(const std::string & s, const std::string & search, const std::string & replace);

void string_process_escapes(std::string & input);

bool fs_validate_filename(const std::string & filename);
bool fs_create_directory_with_parents(const std::string & path);

std::string fs_get_cache_directory();
std::string fs_get_cache_file(const std::string & filename);
19 changes: 5 additions & 14 deletions whisper.cpp/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,12 @@

#include "llamafile/llamafile.h"
#include "llamafile/debug.h"
#include "llama.cpp/string.h"

#if defined(_MSC_VER)
#pragma warning(disable: 4244 4267) // possible loss of data
#endif

// helper function to replace substrings
static void replace_all(std::string & s, const std::string & search, const std::string & replace) {
for (size_t pos = 0; ; pos += replace.length()) {
pos = s.find(search, pos);
if (pos == std::string::npos) break;
s.erase(pos, search.length());
s.insert(pos, replace);
}
}

int cpu_get_num_math();

// command-line parameters
Expand Down Expand Up @@ -870,10 +861,10 @@ static bool output_wts(struct whisper_context * ctx, const char * fname, const c
}
}

::replace_all(txt_bg, "'", "\u2019");
::replace_all(txt_bg, "\"", "\\\"");
::replace_all(txt_fg, "'", "\u2019");
::replace_all(txt_fg, "\"", "\\\"");
txt_bg = replace_all(txt_bg, "'", "\u2019");
txt_bg = replace_all(txt_bg, "\"", "\\\"");
txt_fg = replace_all(txt_fg, "'", "\u2019");
txt_fg = replace_all(txt_fg, "\"", "\\\"");
}

if (is_first) {
Expand Down

0 comments on commit 2c940da

Please sign in to comment.