Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

apply timeouts to rewards tree downloads #446

Merged
merged 1 commit into from
Feb 15, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 62 additions & 52 deletions shared/services/rewards/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -273,73 +273,83 @@ func (i *IntervalInfo) DownloadRewardsFile(cfg *config.RocketPoolConfig, isDaemo

// Attempt downloads
errBuilder := strings.Builder{}
for _, url := range urls {
resp, err := http.Get(url)
if err != nil {
errBuilder.WriteString(fmt.Sprintf("Downloading %s failed (%s)\n", url, err.Error()))
continue
// ipfs http services are very unreliable and like to hold the connection open for several
// minutes before returning a 504. Force a short timeout, but if all sources fail,
// gradually increase the timeout to be unreasonably long.
for _, timeout := range []time.Duration{200 * time.Millisecond, 2 * time.Second, 60 * time.Second} {
client := http.Client{
Timeout: timeout,
}
defer resp.Body.Close()
for _, url := range urls {
resp, err := client.Get(url)
if err != nil {
errBuilder.WriteString(fmt.Sprintf("Downloading %s failed (%s)\n", url, err.Error()))
continue
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
errBuilder.WriteString(fmt.Sprintf("Downloading %s failed with status %s\n", url, resp.Status))
continue
}
// If we got here, we have a successful download
bytes, err := io.ReadAll(resp.Body)
if err != nil {
errBuilder.WriteString(fmt.Sprintf("Error reading response bytes from %s: %s\n", url, err.Error()))
continue
}
writeBytes := bytes
if strings.HasSuffix(url, config.RewardsTreeIpfsExtension) {
// Decompress it
writeBytes, err = decompressFile(bytes)
if resp.StatusCode != http.StatusOK {
errBuilder.WriteString(fmt.Sprintf("Downloading %s failed with status %s\n", url, resp.Status))
continue
}
// If we got here, we have a successful download
bytes, err := io.ReadAll(resp.Body)
if err != nil {
errBuilder.WriteString(fmt.Sprintf("Error decompressing %s: %s\n", url, err.Error()))
errBuilder.WriteString(fmt.Sprintf("Error reading response bytes from %s: %s\n", url, err.Error()))
continue
}
}
writeBytes := bytes
if strings.HasSuffix(url, config.RewardsTreeIpfsExtension) {
// Decompress it
writeBytes, err = decompressFile(bytes)
if err != nil {
errBuilder.WriteString(fmt.Sprintf("Error decompressing %s: %s\n", url, err.Error()))
continue
}
}

deserializedRewardsFile, err := DeserializeRewardsFile(writeBytes)
if err != nil {
return fmt.Errorf("Error deserializing file %s: %w", rewardsTreePath, err)
}
deserializedRewardsFile, err := DeserializeRewardsFile(writeBytes)
if err != nil {
return fmt.Errorf("Error deserializing file %s: %w", rewardsTreePath, err)
}

// Get the original merkle root
downloadedRoot := deserializedRewardsFile.GetHeader().MerkleRoot
// Get the original merkle root
downloadedRoot := deserializedRewardsFile.GetHeader().MerkleRoot

// Clear the merkle root so we have a safer comparison after calculating it again
deserializedRewardsFile.GetHeader().MerkleRoot = ""
// Clear the merkle root so we have a safer comparison after calculating it again
deserializedRewardsFile.GetHeader().MerkleRoot = ""

// Reconstruct the merkle tree from the file data, this should overwrite the stored Merkle Root with a new one
deserializedRewardsFile.generateMerkleTree()
// Reconstruct the merkle tree from the file data, this should overwrite the stored Merkle Root with a new one
deserializedRewardsFile.generateMerkleTree()

// Get the resulting merkle root
calculatedRoot := deserializedRewardsFile.GetHeader().MerkleRoot
// Get the resulting merkle root
calculatedRoot := deserializedRewardsFile.GetHeader().MerkleRoot

// Compare the merkle roots to see if the original is correct
if !strings.EqualFold(downloadedRoot, calculatedRoot) {
return fmt.Errorf("the merkle root from %s does not match the root generated by its tree data (had %s, but generated %s)", url, downloadedRoot, calculatedRoot)
}
// Compare the merkle roots to see if the original is correct
if !strings.EqualFold(downloadedRoot, calculatedRoot) {
return fmt.Errorf("the merkle root from %s does not match the root generated by its tree data (had %s, but generated %s)", url, downloadedRoot, calculatedRoot)
}

// Make sure the calculated root matches the canonical one
if !strings.EqualFold(calculatedRoot, expectedRoot.Hex()) {
return fmt.Errorf("the merkle root from %s does not match the canonical one (had %s, but generated %s)", url, calculatedRoot, expectedRoot.Hex())
}
// Make sure the calculated root matches the canonical one
if !strings.EqualFold(calculatedRoot, expectedRoot.Hex()) {
return fmt.Errorf("the merkle root from %s does not match the canonical one (had %s, but generated %s)", url, calculatedRoot, expectedRoot.Hex())
}

// Serialize again so we're sure to have all the correct proofs that we've generated (instead of verifying every proof on the file)
localRewardsFile := NewLocalFile[IRewardsFile](
deserializedRewardsFile,
rewardsTreePath,
)
err = localRewardsFile.Write()
if err != nil {
return fmt.Errorf("error saving interval %d file to %s: %w", interval, rewardsTreePath, err)
}
// Serialize again so we're sure to have all the correct proofs that we've generated (instead of verifying every proof on the file)
localRewardsFile := NewLocalFile[IRewardsFile](
deserializedRewardsFile,
rewardsTreePath,
)
err = localRewardsFile.Write()
if err != nil {
return fmt.Errorf("error saving interval %d file to %s: %w", interval, rewardsTreePath, err)
}

return nil

return nil
}

errBuilder.WriteString(fmt.Sprintf("Downloading files with timeout %v failed.\n", timeout))
}

return fmt.Errorf(errBuilder.String())
Expand Down