Skip to content

Commit

Permalink
speed up the snapshot downloader
Browse files Browse the repository at this point in the history
  • Loading branch information
Rexicon226 committed Sep 30, 2024
1 parent 23f6b1d commit 7b5058d
Showing 1 changed file with 102 additions and 47 deletions.
149 changes: 102 additions & 47 deletions src/accountsdb/download.zig
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ const ThreadSafeContactInfo = sig.gossip.data.ThreadSafeContactInfo;
const GossipService = sig.gossip.GossipService;
const Logger = sig.trace.Logger;

const DOWNLOAD_PROGRESS_UPDATES_NS = 30 * std.time.ns_per_s;
const assert = std.debug.assert;

const DOWNLOAD_PROGRESS_UPDATES_NS = 6 * std.time.ns_per_s;

/// Analogous to [PeerSnapshotHash](https://github.com/anza-xyz/agave/blob/f868aa38097094e4fb78a885b6fb27ce0e43f5c7/validator/src/bootstrap.rs#L342)
const PeerSnapshotHash = struct {
Expand Down Expand Up @@ -53,7 +55,7 @@ pub fn findPeersToDownloadFromAssumeCapacity(
) !PeerSearchResult {
// clear the list
valid_peers.clearRetainingCapacity();
std.debug.assert(valid_peers.capacity >= contact_infos.len);
assert(valid_peers.capacity >= contact_infos.len);

const TrustedMapType = std.AutoHashMap(
SlotAndHash, // full snapshot hash
Expand Down Expand Up @@ -295,81 +297,91 @@ pub fn downloadSnapshotsFromGossip(
}

const DownloadProgress = struct {
mmap: []align(std.mem.page_size) u8,
download_size: usize,
file: std.fs.File,
min_mb_per_second: ?usize,
logger: Logger,

mb_timer: std.time.Timer,
bytes_read: usize = 0,
file_memory_index: usize = 0,
bytes_read: u64 = 0,
total_read: u64 = 0,
has_checked_speed: bool = false,

const Self = @This();

pub fn init(
fn init(
logger: Logger,
output_dir: std.fs.Dir,
filename: []const u8,
download_size: usize,
min_mb_per_second: ?usize,
) !Self {
const file = try output_dir.createFile(filename, .{ .read = true });
defer file.close();
const file = try output_dir.createFile(filename, .{});

// resize the file
try file.seekTo(download_size - 1);
_ = try file.write(&[_]u8{1});
try file.seekTo(0);

const file_memory = try std.posix.mmap(
null,
download_size,
std.posix.PROT.READ | std.posix.PROT.WRITE,
std.posix.MAP{ .TYPE = .SHARED },
file.handle,
0,
);

return .{
.logger = logger,
.mmap = file_memory,
.download_size = download_size,
.file = file,
.min_mb_per_second = min_mb_per_second,
.mb_timer = try std.time.Timer.start(),
.mb_timer = undefined,
};
}

pub fn deinit(self: *Self) void {
std.posix.munmap(self.mmap);
fn deinit(self: *Self) void {
self.file.close();
}

pub fn bufferWriteCallback(ptr: [*c]c_char, size: c_uint, nmemb: c_uint, user_data: *anyopaque) callconv(.C) c_uint {
fn writeCallback(
ptr: ?[*:0]c_char,
size: c_uint,
nmemb: c_uint,
user_data: *anyopaque,
) callconv(.C) c_uint {
assert(size == 1); // size will always be 1
const len = size * nmemb;
var self: *Self = @alignCast(@ptrCast(user_data));
var typed_data: [*]u8 = @ptrCast(ptr);
const self: *Self = @alignCast(@ptrCast(user_data));
var typed_data: [*]u8 = @ptrCast(ptr.?);
const buf = typed_data[0..len];

@memcpy(self.mmap[self.file_memory_index..][0..len], buf);
self.file_memory_index += len;
self.file.writeAll(buf) catch |err|
std.debug.panic("failed to write to file: {s}", .{@errorName(err)});
self.bytes_read += len;
self.total_read += len;

return len;
}

fn progressCallback(
user_data: *anyopaque,
download_total: c_ulong,
download_now: c_ulong,
upload_total: c_ulong,
upload_now: c_ulong,
) callconv(.C) c_uint {
const self: *Self = @alignCast(@ptrCast(user_data));

// we're only downloading
assert(upload_total == 0);
assert(upload_now == 0);
const elapsed_ns = self.mb_timer.read();
if (elapsed_ns > DOWNLOAD_PROGRESS_UPDATES_NS) {
// each MB
defer {
self.bytes_read = 0;
self.mb_timer.reset();
}

const mb_read = self.bytes_read / 1024 / 1024;
if (mb_read == 0) {
self.logger.infof("download speed is too slow (<1MB/s) -- disconnecting", .{});
return 0;
return 1; // abort from callback
}

defer {
self.bytes_read = 0;
self.mb_timer.reset();
}
const elapsed_sec = elapsed_ns / std.time.ns_per_s;
const ns_per_mb = elapsed_ns / mb_read;
const mb_left = (self.download_size - self.file_memory_index) / 1024 / 1024;
const mb_left = (download_total - download_now) / 1024 / 1024;
const time_left_ns = mb_left * ns_per_mb;
const mb_per_second = mb_read / elapsed_sec;

Expand All @@ -379,23 +391,26 @@ const DownloadProgress = struct {
self.has_checked_speed = true;
if (mb_per_second < self.min_mb_per_second.?) {
// not fast enough => abort
self.logger.infof("[download progress]: speed is too slow ({d} MB/s) -- disconnecting", .{mb_per_second});
return 0;
self.logger.infof(
"[download progress]: speed is too slow ({}/s) -- disconnecting",
.{std.fmt.fmtIntSizeDec(download_now / elapsed_sec)},
);
return 1; // abort from callback
} else {
self.logger.infof("[download progress]: speed is ok ({d} MB/s) -- maintaining", .{mb_per_second});
}
}

self.logger.infof("[download progress]: {d}% done ({d} MB/s - {d}/{d}) (time left: {d})", .{
self.file_memory_index * 100 / self.download_size,
mb_per_second,
self.file_memory_index,
self.download_size,
self.logger.infof("[download progress]: {d}% done ({:.4}/s - {:.4}/{:.4}) (time left: {d})", .{
self.total_read * 100 / download_total,
std.fmt.fmtIntSizeDec(self.bytes_read / elapsed_sec),
std.fmt.fmtIntSizeDec(download_now),
std.fmt.fmtIntSizeDec(download_total),
std.fmt.fmtDuration(time_left_ns),
});
}

return len;
return 0;
}
};

Expand All @@ -408,8 +423,44 @@ fn checkCode(code: curl.libcurl.CURLcode) !void {
return error.Unexpected;
}

pub fn setNoBody(self: curl.Easy, no_body: bool) !void {
try checkCode(curl.libcurl.curl_easy_setopt(self.handle, curl.libcurl.CURLOPT_NOBODY, @as(c_long, @intFromBool(no_body))));
fn setNoBody(self: curl.Easy, no_body: bool) !void {
try checkCode(curl.libcurl.curl_easy_setopt(
self.handle,
curl.libcurl.CURLOPT_NOBODY,
@as(c_long, @intFromBool(no_body)),
));
}

fn setProgressFunction(
self: curl.Easy,
func: *const fn (*anyopaque, c_ulong, c_ulong, c_ulong, c_ulong) callconv(.C) c_uint,
) !void {
try checkCode(curl.libcurl.curl_easy_setopt(
self.handle,
curl.libcurl.CURLOPT_XFERINFOFUNCTION,
func,
));
}

fn setProgressData(
self: curl.Easy,
data: *const anyopaque,
) !void {
try checkCode(curl.libcurl.curl_easy_setopt(
self.handle,
curl.libcurl.CURLOPT_XFERINFODATA,
data,
));
}

fn enableProgress(
self: curl.Easy,
) !void {
try checkCode(curl.libcurl.curl_easy_setopt(
self.handle,
curl.libcurl.CURLOPT_NOPROGRESS,
@as(c_long, 0),
));
}

/// downloads a file from a url into output_dir/filename
Expand Down Expand Up @@ -457,12 +508,16 @@ pub fn downloadFile(
try easy.setUrl(url);
try easy.setMethod(.GET);
try easy.setWritedata(&download_progress);
try easy.setWritefunction(DownloadProgress.bufferWriteCallback);
try easy.setWritefunction(DownloadProgress.writeCallback);
try setProgressData(easy, &download_progress);
try setProgressFunction(easy, DownloadProgress.progressCallback);
try enableProgress(easy);

download_progress.mb_timer = try std.time.Timer.start();
var resp = try easy.perform();
defer resp.deinit();

const full_download = download_progress.file_memory_index == download_size;
const full_download = download_progress.total_read == download_size;
// this if block should only be hit if the download was too slow
if (!full_download) {
return error.TooSlow;
Expand Down

0 comments on commit 7b5058d

Please sign in to comment.