Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT]: huggingface integration #2701

Merged
merged 16 commits into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion daft/daft.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,9 @@ class HTTPConfig:
I/O configuration for accessing HTTP systems
"""

user_agent: str | None
bearer_token: str | None

def __init__(self, bearer_token: str | None = None): ...

class S3Config:
"""
Expand Down
1 change: 1 addition & 0 deletions docs/source/user_guide/integrations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ Integrations
integrations/microsoft-azure
integrations/aws
integrations/sql
integrations/huggingface
64 changes: 64 additions & 0 deletions docs/source/user_guide/integrations/huggingface.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
Huggingface Datasets
===========

Daft is able to read datasets directly from Huggingface via the ``hf://`` protocol.

Since huggingface will `automatically convert <https://huggingface.co/docs/dataset-viewer/en/parquet>`_ all public datasets to parquet format,
we can read these datasets using the ``read_parquet`` method.

.. NOTE::
This is limited to either public datasets, or PRO/ENTERPRISE datasets.

For other file formats, you will need to manually specify the path or glob pattern to the files you want to read, similar to how you would read from a local file system.


Reading Public Datasets
-----------------------

.. code:: python

import daft

df = daft.read_parquet("hf://username/dataset_name")

This will read the entire dataset into a daft DataFrame.

Not only can you read entire datasets, but you can also read individual files from a dataset.

.. code:: python

import daft

df = daft.read_parquet("hf://username/dataset_name/file_name.parquet")
# or a csv file
df = daft.read_csv("hf://username/dataset_name/file_name.csv")

# or a glob pattern
df = daft.read_parquet("hf://username/dataset_name/**/*.parquet")


Authorization
-------------

For authenticated datasets:

.. code:: python

from daft.io import IOConfig, HTTPConfig

io_config = IoConfig(http=HTTPConfig(bearer_token="your_token"))
df = daft.read_parquet("hf://username/dataset_name", io_config=io_config)


It's important to note that this will not work with standard tier private datasets.
Huggingface does not auto convert private datasets to parquet format, so you will need to specify the path to the files you want to read.

.. code:: python

df = daft.read_parquet("hf://username/my_private_dataset", io_config=io_config) # Errors

to get around this, you can read all files using a glob pattern *(assuming they are in parquet format)*

.. code:: python

df = daft.read_parquet("hf://username/my_private_dataset/**/*.parquet", io_config=io_config) # Works
33 changes: 31 additions & 2 deletions src/common/io-config/src/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,40 @@ use std::fmt::Formatter;
use serde::Deserialize;
use serde::Serialize;

use crate::ObfuscatedString;

#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub struct HTTPConfig {
pub user_agent: String,
pub bearer_token: Option<ObfuscatedString>,
}

impl Default for HTTPConfig {
fn default() -> Self {
HTTPConfig {
user_agent: "daft/0.0.1".to_string(), // NOTE: Ideally we grab the version of Daft, but that requires a dependency on daft-core
bearer_token: None,
}
}
}

impl HTTPConfig {
pub fn new<S: Into<ObfuscatedString>>(bearer_token: Option<S>) -> Self {
HTTPConfig {
bearer_token: bearer_token.map(|t| t.into()),
..Default::default()
}
}
}

impl HTTPConfig {
pub fn multiline_display(&self) -> Vec<String> {
vec![format!("user_agent = {}", self.user_agent)]
let mut v = vec![format!("user_agent = {}", self.user_agent)];
if let Some(bearer_token) = &self.bearer_token {
v.push(format!("bearer_token = {}", bearer_token));
}

v
}
}

Expand All @@ -30,6 +48,17 @@ impl Display for HTTPConfig {
"HTTPConfig
user_agent: {}",
self.user_agent,
)
)?;

if let Some(bearer_token) = &self.bearer_token {
write!(
f,
"
bearer_token: {}",
bearer_token
)
} else {
Ok(())
}
}
}
17 changes: 16 additions & 1 deletion src/common/io-config/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,10 @@ pub struct IOConfig {
///
/// Args:
/// user_agent (str, optional): The value for the user-agent header, defaults to "daft/{__version__}" if not provided
universalmind303 marked this conversation as resolved.
Show resolved Hide resolved
/// bearer_token (str, optional): Bearer token to use for authentication. This will be used as the value for the `Authorization` header. such as "Authorization: Bearer xxx"
///
/// Example:
/// >>> io_config = IOConfig(http=HTTPConfig(user_agent="my_application/0.0.1"))
/// >>> io_config = IOConfig(http=HTTPConfig(user_agent="my_application/0.0.1", bearer_token="xxx"))
/// >>> daft.read_parquet("http://some-path", io_config=io_config)
#[derive(Clone, Default)]
#[pyclass]
Expand Down Expand Up @@ -901,6 +902,20 @@ impl From<config::IOConfig> for IOConfig {
}
}

#[pymethods]
impl HTTPConfig {
#[new]
pub fn new(bearer_token: Option<String>) -> Self {
HTTPConfig {
config: crate::HTTPConfig::new(bearer_token),
}
}

pub fn __repr__(&self) -> PyResult<String> {
Ok(format!("{}", self.config))
}
}

pub fn register_modules(_py: Python, parent: &PyModule) -> PyResult<()> {
parent.add_class::<AzureConfig>()?;
parent.add_class::<GCSConfig>()?;
Expand Down
11 changes: 9 additions & 2 deletions src/daft-io/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ azure_storage_blobs = {version = "0.17.0", features = ["enable_reqwest"], defaul
bytes = {workspace = true}
common-error = {path = "../common/error", default-features = false}
common-io-config = {path = "../common/io-config", default-features = false}
common-py-serde = {path = "../common/py-serde", default-features = false}
futures = {workspace = true}
globset = "0.4"
google-cloud-storage = {version = "0.15.0", default-features = false, features = ["default-tls", "auth"]}
Expand All @@ -29,22 +30,28 @@ openssl-sys = {version = "0.9.102", features = ["vendored"]}
pyo3 = {workspace = true, optional = true}
rand = "0.8.5"
regex = {version = "1.10.4"}
serde = {workspace = true}
snafu = {workspace = true}
tokio = {workspace = true}
tokio-stream = {workspace = true}
url = {workspace = true}

[dependencies.reqwest]
default-features = false
features = ["stream", "native-tls"]
features = ["stream", "native-tls", "json"]
version = "0.11.18"

[dev-dependencies]
md5 = "0.7.0"
tempfile = "3.8.1"

[features]
python = ["dep:pyo3", "common-error/python", "common-io-config/python"]
python = [
"dep:pyo3",
"common-error/python",
"common-io-config/python",
"common-py-serde/python"
]

[package]
edition = {workspace = true}
Expand Down
3 changes: 2 additions & 1 deletion src/daft-io/src/azure_blob.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use crate::{
object_io::{FileMetadata, FileType, LSResult, ObjectSource},
stats::IOStatsRef,
stream_utils::io_stats_on_bytestream,
GetResult,
FileFormat, GetResult,
};
use common_io_config::AzureConfig;

Expand Down Expand Up @@ -577,6 +577,7 @@ impl ObjectSource for AzureBlobSource {
page_size: Option<i32>,
limit: Option<usize>,
io_stats: Option<IOStatsRef>,
_file_format: Option<FileFormat>,
) -> super::Result<BoxStream<'static, super::Result<FileMetadata>>> {
use crate::object_store_glob::glob;

Expand Down
58 changes: 58 additions & 0 deletions src/daft-io/src/file_format.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
use std::str::FromStr;

use common_error::{DaftError, DaftResult};
use common_py_serde::impl_bincode_py_state_serialization;
#[cfg(feature = "python")]
use pyo3::prelude::*;

use serde::{Deserialize, Serialize};

/// Format of a file, e.g. Parquet, CSV, JSON.
#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, Copy)]
#[cfg_attr(feature = "python", pyclass(module = "daft.daft"))]
pub enum FileFormat {
Parquet,
Csv,
Json,
Database,
Python,
}

#[cfg(feature = "python")]
#[pymethods]
impl FileFormat {
fn ext(&self) -> &'static str {
match self {
Self::Parquet => "parquet",
Self::Csv => "csv",
Self::Json => "json",
Self::Database => "db",
Self::Python => "py",
}
}
}

impl FromStr for FileFormat {
type Err = DaftError;

fn from_str(file_format: &str) -> DaftResult<Self> {
use FileFormat::*;

if file_format.trim().eq_ignore_ascii_case("parquet") {
Ok(Parquet)
} else if file_format.trim().eq_ignore_ascii_case("csv") {
Ok(Csv)
} else if file_format.trim().eq_ignore_ascii_case("json") {
Ok(Json)
} else if file_format.trim().eq_ignore_ascii_case("database") {
Ok(Database)
} else {
Err(DaftError::TypeError(format!(
"FileFormat {} not supported!",
file_format
)))
}
}
}

impl_bincode_py_state_serialization!(FileFormat);
2 changes: 2 additions & 0 deletions src/daft-io/src/google_cloud.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use crate::object_io::LSResult;
use crate::object_io::ObjectSource;
use crate::stats::IOStatsRef;
use crate::stream_utils::io_stats_on_bytestream;
use crate::FileFormat;
use crate::GetResult;
use common_io_config::GCSConfig;

Expand Down Expand Up @@ -436,6 +437,7 @@ impl ObjectSource for GCSSource {
page_size: Option<i32>,
limit: Option<usize>,
io_stats: Option<IOStatsRef>,
_file_format: Option<FileFormat>,
) -> super::Result<BoxStream<'static, super::Result<FileMetadata>>> {
use crate::object_store_glob::glob;

Expand Down
4 changes: 3 additions & 1 deletion src/daft-io/src/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use crate::{
object_io::{FileMetadata, FileType, LSResult},
stats::IOStatsRef,
stream_utils::io_stats_on_bytestream,
FileFormat,
};

use super::object_io::{GetResult, ObjectSource};
Expand Down Expand Up @@ -140,7 +141,7 @@ fn _get_file_metadata_from_html(path: &str, text: &str) -> super::Result<Vec<Fil
}

pub(crate) struct HttpSource {
client: reqwest::Client,
pub(crate) client: reqwest::Client,
}

impl From<Error> for super::Error {
Expand Down Expand Up @@ -276,6 +277,7 @@ impl ObjectSource for HttpSource {
_page_size: Option<i32>,
limit: Option<usize>,
io_stats: Option<IOStatsRef>,
_file_format: Option<FileFormat>,
) -> super::Result<BoxStream<'static, super::Result<FileMetadata>>> {
use crate::object_store_glob::glob;

Expand Down
Loading
Loading