Skip to content

Efficient preloading for mmap() #869

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 124 additions & 1 deletion llama_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@

#ifndef LLAMA_UTIL_H
#define LLAMA_UTIL_H

#define _PRELOAD_MMAP_FILE 1 // when using mmap, preload the entire file to prevent loading during first token inference
#include <cstdio>
#include <cstdint>
#include <cerrno>
#include <cstring>
#include <cstdarg>
#include <cstdlib>
#include <climits>
#include <thread>

#include <string>
#include <vector>
Expand All @@ -30,6 +31,34 @@
#include <windows.h>
#include <io.h>
#include <stdio.h> // for _fseeki64
typedef volatile LONG atomic_int;
typedef atomic_int atomic_bool;

typedef HANDLE pthread_t;
typedef DWORD thread_ret_t;

static int pthread_create(pthread_t *out, void *unused, thread_ret_t (*func)(void *), void *arg)
{
HANDLE handle = CreateThread(NULL, 0, (LPTHREAD_START_ROUTINE)func, arg, 0, NULL);
if (handle == NULL)
{
return EAGAIN;
}

*out = handle;
return 0;
}

static int pthread_join(pthread_t thread, void *unused)
{
return (int)WaitForSingleObject(thread, INFINITE);
}
#else
#include <unistd.h>
#include <pthread.h>
#include <stdatomic.h>

typedef void *thread_ret_t;
#endif

#define LLAMA_ASSERT(x) \
Expand Down Expand Up @@ -156,7 +185,96 @@ static std::string llama_format_win_err(DWORD err) {
struct llama_mmap {
void * addr;
size_t size;
typedef struct
{
size_t start;
size_t end;
void *addr;
int n_threads;
int n_thread;
int page_size;
} thread_data_t;
static thread_ret_t worker_preload_memory(void *arg)
{
thread_data_t *data = (thread_data_t *)arg;
volatile char buffer;
for (size_t offset = data->start + data->n_thread * data->page_size; offset <= data->end; offset += data->n_threads * data->page_size)
{
volatile void *buffer_ptr = &buffer;
memcpy((void *)buffer_ptr, (char *)data->addr + offset, sizeof(buffer));
if (data->n_threads < data->n_thread && buffer==0) exit(-1); // to avoid compiler optimization - the previous simple access method did not work in thread workers
}
return NULL;
}
void preload_mmap_file(void *addr, size_t length, int n_threads)
{
#ifndef _PRELOAD_MMAP_FILE
return;
#endif
// Get the page size of the system
#if defined(_WIN32)
SYSTEM_INFO si;
GetSystemInfo(&si);
long page_size = si.dwPageSize;
#else
long page_size = sysconf(_SC_PAGE_SIZE); // in windows we can use GetSystemInfo:
#endif

if (page_size == -1)
{
perror("sysconf");
return;
}
#ifdef _WIN32
HANDLE hProcess = GetCurrentProcess();
WIN32_MEMORY_RANGE_ENTRY range;
range.VirtualAddress = addr;
range.NumberOfBytes = length;
// if (!VirtualLock(addr, length)) { }; // no benefit. for systems with too little RAM we should lock a part and restrict the preload to that new length
if (!PrefetchVirtualMemory(hProcess, 1, &range, 0)) { }; // Prefetches part of the data and signals readahead to the file system
#else
// todo
//if (posix_madvise(addr, length, POSIX_MADV_WILLNEED) == -1) { };
// readahead() should be the equivalent method for Linux. I don't think madvise will cause a full fetch
// the multi threaded read below is pseudo sequential, it also needs a test without OS level readahead in place (worst case set threads to 1 in linux or return)
#endif

if (n_threads > 32)
n_threads = 32;
pthread_t threads[32];
thread_data_t thread_data[32];

// we split the pages between the threads - that was the only reliable solution I could find
size_t num_pages_per_thread = (length / page_size) / n_threads;
int pages = ceil(length / page_size);
for (int page_start = 0; page_start < pages; page_start += n_threads * num_pages_per_thread)
{
size_t chunk_start = page_start * page_size;
size_t chunk_end = chunk_start + page_size * n_threads * num_pages_per_thread;
for (int i = 0; i < n_threads; ++i)
{
thread_data[i].start = chunk_start;
thread_data[i].end = chunk_end;
if (thread_data[i].end > length)
{
thread_data[i].end = length;
}
thread_data[i].addr = addr;
thread_data[i].page_size = page_size;
thread_data[i].n_threads = n_threads;
thread_data[i].n_thread = i;
pthread_create(&threads[i], NULL, worker_preload_memory, &thread_data[i]);
if (thread_data[i].end == length)
break;
}

for (int i = 0; i < n_threads; ++i)
{
pthread_join(threads[i], NULL);
}

}
}
llama_mmap(const llama_mmap &) = delete;

#ifdef _POSIX_MAPPED_FILES
Expand All @@ -180,6 +298,8 @@ struct llama_mmap {
fprintf(stderr, "warning: madvise(.., MADV_WILLNEED) failed: %s\n",
strerror(errno));
}
// if _PRELOAD_MMAP_FILE is define, this will preload the file into the page cache efficiently
preload_mmap_file(addr, file->size);
}

~llama_mmap() {
Expand Down Expand Up @@ -217,6 +337,9 @@ struct llama_mmap {
fprintf(stderr, "warning: PrefetchVirtualMemory failed: %s\n",
llama_format_win_err(GetLastError()).c_str());
}

// if _PRELOAD_MMAP_FILE is define, this will preload the file into the page cache efficiently
preload_mmap_file(addr, file->size, std::thread::hardware_concurrency()/2);
}

~llama_mmap() {
Expand Down