Skip to content

Commit

Permalink
feat: add GITHUB_TOKEN as a Bearer token when calling the GitHub API …
Browse files Browse the repository at this point in the history
…in order to increase the rate limit
  • Loading branch information
brianheineman committed Feb 6, 2024
1 parent 154be3e commit 9e77f29
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 29 deletions.
75 changes: 46 additions & 29 deletions postgresql_archive/src/archive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ use crate::version::Version;
use bytes::Bytes;
use flate2::bufread::GzDecoder;
use regex::Regex;
use reqwest::header;
use reqwest::header::HeaderMap;
use reqwest::{header, RequestBuilder};
use std::fs::{create_dir_all, File};
use std::io::{copy, BufReader, Cursor};
use std::path::Path;
Expand All @@ -17,6 +18,13 @@ use tar::Archive;
const GITHUB_API_VERSION_HEADER: &str = "X-GitHub-Api-Version";
const GITHUB_API_VERSION: &str = "2022-11-28";

lazy_static! {
static ref GITHUB_TOKEN: Option<String> = match std::env::var("GITHUB_TOKEN") {
Ok(token) => Some(token),
Err(_) => None,
};
}

lazy_static! {
static ref USER_AGENT: String = format!(
"{PACKAGE}/{VERSION}",
Expand All @@ -25,20 +33,42 @@ lazy_static! {
);
}

/// Adds GitHub headers to the request builder.
trait GitHubHeaders {
/// Adds GitHub headers to the request builder. If a GitHub token is set, then it is added as a
/// bearer token. This is used to authenticate with the GitHub API to increase the rate limit.
fn add_github_headers(self) -> anyhow::Result<RequestBuilder>;
}

/// Implementation that adds GitHub headers to a request builder.
impl GitHubHeaders for RequestBuilder {
/// Adds GitHub headers to the request builder. If a GitHub token is set, then it is added as a
/// bearer token. This is used to authenticate with the GitHub API to increase the rate limit.
fn add_github_headers(self) -> anyhow::Result<RequestBuilder> {
let mut headers = HeaderMap::new();

headers.append(GITHUB_API_VERSION_HEADER, GITHUB_API_VERSION.parse()?);
headers.append(header::USER_AGENT, USER_AGENT.parse()?);

if let Some(token) = &*GITHUB_TOKEN {
headers.append(header::AUTHORIZATION, format!("Bearer {token}").parse()?);
}

Ok(self.headers(headers))
}
}

/// Gets a release from GitHub for a given [`version`](Version) of PostgreSQL. If a release for the
/// [`version`](Version) is not found, then a [`ReleaseNotFound`] error is returned.
async fn get_release(version: &Version) -> Result<Release> {
let url = "https://api.github.com/repos/theseus-rs/postgresql-binaries/releases";
let client = reqwest::Client::new();

if version.minor.is_some() && version.release.is_some() {
let response = client
let request = client
.get(format!("{url}/tags/{version}"))
.header(GITHUB_API_VERSION_HEADER, GITHUB_API_VERSION)
.header(header::USER_AGENT, USER_AGENT.as_str())
.send()
.await?
.error_for_status()?;
.add_github_headers()?;
let response = request.send().await?.error_for_status()?;
let release = response.json::<Release>().await?;

return Ok(release);
Expand All @@ -48,15 +78,11 @@ async fn get_release(version: &Version) -> Result<Release> {
let mut page = 1;

loop {
let response = client
let request = client
.get(url)
.header(GITHUB_API_VERSION_HEADER, GITHUB_API_VERSION)
.header(header::USER_AGENT, USER_AGENT.as_str())
.query(&[("page", page.to_string().as_str()), ("per_page", "100")])
.send()
.await?
.error_for_status()?;

.add_github_headers()?
.query(&[("page", page.to_string().as_str()), ("per_page", "100")]);
let response = request.send().await?.error_for_status()?;
let response_releases = response.json::<Vec<Release>>().await?;
if response_releases.is_empty() {
break;
Expand Down Expand Up @@ -145,14 +171,10 @@ pub async fn get_archive_for_target<S: AsRef<str>>(
) -> Result<(Version, Bytes, String)> {
let (asset_version, asset, asset_hash) = get_asset(version, target).await?;
let client = reqwest::Client::new();

let response = client
let request = client
.get(asset_hash.browser_download_url)
.header(GITHUB_API_VERSION_HEADER, GITHUB_API_VERSION)
.header(header::USER_AGENT, USER_AGENT.as_str())
.send()
.await?
.error_for_status()?;
.add_github_headers()?;
let response = request.send().await?.error_for_status()?;
let text = response.text().await?;
let re = Regex::new(r"[0-9a-f]{64}")?;
let hash = match re.find(&text) {
Expand All @@ -161,13 +183,8 @@ pub async fn get_archive_for_target<S: AsRef<str>>(
};

let asset_url = asset.browser_download_url;
let response = client
.get(asset_url)
.header(GITHUB_API_VERSION_HEADER, GITHUB_API_VERSION)
.header(header::USER_AGENT, USER_AGENT.as_str())
.send()
.await?
.error_for_status()?;
let request = client.get(asset_url).add_github_headers()?;
let response = request.send().await?.error_for_status()?;
let archive: Bytes = response.bytes().await?;

Ok((asset_version, archive, hash))
Expand Down
7 changes: 7 additions & 0 deletions postgresql_archive/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,10 @@ impl From<std::path::StripPrefixError> for ArchiveError {
ArchiveError::ParseError(error.into())
}
}

/// Converts a [`anyhow::Error`] into an [`Unexpected`](ArchiveError::Unexpected)
impl From<anyhow::Error> for ArchiveError {
fn from(error: anyhow::Error) -> Self {
ArchiveError::Unexpected(error.to_string())
}
}

0 comments on commit 9e77f29

Please sign in to comment.