Skip to content

Commit

Permalink
update dataverse module
Browse files Browse the repository at this point in the history
transfer codes from pachterlab#124
  • Loading branch information
abearab committed Dec 24, 2024
1 parent 9b3bffa commit 84f0e64
Show file tree
Hide file tree
Showing 5 changed files with 144 additions and 1 deletion.
1 change: 1 addition & 0 deletions gget/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .gget_opentargets import opentargets
from .gget_cbio import cbio_plot, cbio_search
from .gget_bgee import bgee
from .gget_dataverse import dataverse

import logging
# Mute numexpr threads info
Expand Down
3 changes: 3 additions & 0 deletions gget/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@
# OpenTargets API endpoint
OPENTARGETS_GRAPHQL_API = "https://api.platform.opentargets.org/api/v4/graphql"

# Harvard dataverse API server
DATAVERSE_GET_URL = "https://dataverse.harvard.edu/api/access/datafile/"

# CBIO data
CBIO_CANCER_TYPE_TO_TISSUE_DICTIONARY = {
"Acute Leukemias of Ambiguous Lineage": "leukemia",
Expand Down
89 changes: 89 additions & 0 deletions gget/gget_dataverse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import os
import requests
from tqdm import tqdm
import pandas as pd
import pandas as pd
from .utils import print_sys
from .constants import DATAVERSE_GET_URL

def dataverse_downloader(url, path, file_name):
"""dataverse download helper with progress bar
Args:
url (str): the url of the dataset to download
path (str): the path to save the dataset locally
file_name (str): the name of the file to save locally
"""
save_path = os.path.join(path, file_name)
response = requests.get(url, stream=True)
total_size_in_bytes = int(response.headers.get("content-length", 0))
block_size = 1024
progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
with open(save_path, "wb") as file:
for data in response.iter_content(block_size):
progress_bar.update(len(data))
file.write(data)
progress_bar.close()


def download_wrapper(entry, path, return_type=None):
"""wrapper for downloading a dataset given the name and path, for csv,pkl,tsv or similar files
Args:
entry (dict): the entry of the dataset to download. Must include 'id', 'name', 'type' keys
path (str): the path to save the dataset locally
return_type (str, optional): the return type. Defaults to None. Can be "url", "filename", or ["url", "filename"]
Returns:
str: the exact dataset query name
"""
url = DATAVERSE_GET_URL + str(entry['id'])

if not os.path.exists(path):
os.mkdir(path)

filename = f"{entry['name']}.{entry['type']}"

if os.path.exists(os.path.join(path, filename)):
print_sys(f"Found local copy for {entry['id']} datafile as {filename} ...")
os.path.join(path, filename)
else:
print_sys(f"Downloading {entry['id']} datafile as {filename} ...")
dataverse_downloader(url, path, filename)

if return_type == "url":
return url
elif return_type == "filename":
return filename
elif return_type == ["url", "filename"]:
return url, filename


def dataverse(df, path, sep=","):
"""download datasets from dataverse for a given dataframe
Input dataframe must have 'name', 'id', 'type' columns.
- 'name' is the dataset name for single file
- 'id' is the unique identifier for the file
- 'type' is the file type (e.g. csv, tsv, pkl)
Args:
df (pd.DataFrame or str): the dataframe or path to the csv/tsv file
path (str): the path to save the dataset locally
"""
if type(df) == str:
if os.path.exists(df):
df = pd.read_csv(df, sep=sep)
else:
raise FileNotFoundError(f"File {df} not found")
elif type(df) == pd.DataFrame:
pass
else:
raise ValueError("Input must be a pandas dataframe or a path to a csv / tsv file")

print_sys(f"Searching for {len(df)} datafiles in dataverse ...")

# run the download wrapper for each entry in the dataframe
for _, entry in df.iterrows():
download_wrapper(entry, path)

print_sys(f"Download completed, saved to `{path}`.")
43 changes: 42 additions & 1 deletion gget/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from .gget_opentargets import opentargets, OPENTARGETS_RESOURCES
from .gget_cbio import cbio_plot, cbio_search
from .gget_bgee import bgee

from .gget_dataverse import dataverse

# Custom formatter for help messages that preserved the text formatting and adds the default value to the end of the help message
class CustomHelpFormatter(argparse.RawTextHelpFormatter):
Expand Down Expand Up @@ -2335,6 +2335,32 @@ def main():
help="Does not print progress information.",
)

## dataverse parser arguments
dataverse_desc = "Download datasets from the Dataverse repositories."
parser_dataverse = parent_subparsers.add_parser(
"dataverse",
parents=[parent],
description=dataverse_desc,
help=dataverse_desc,
add_help=True,
formatter_class=CustomHelpFormatter,
)
parser_dataverse.add_argument(
"-o",
"--path",
type=str,
required=True,
help="Path to the directory the datasets will be saved in, e.g. 'path/to/directory'.",
)
parser_dataverse.add_argument(
"-t",
"--table",
type=str,
default=None,
required=False,
help="File containing the dataset IDs to download, e.g. 'datasets.tsv'.",
)

### Define return values
args = parent_parser.parse_args()

Expand Down Expand Up @@ -2386,6 +2412,7 @@ def main():
"opentargets": parser_opentargets,
"cbio": parser_cbio,
"bgee": parser_bgee,
"dataverse": parser_dataverse,
}

if len(sys.argv) == 2:
Expand Down Expand Up @@ -3295,3 +3322,17 @@ def main():
print(
bgee_results.to_json(orient="records", force_ascii=False, indent=4)
)

## dataverse return
if args.command == "dataverse":
# Define separator based on file extension
if '.csv' in args.table:
sep = ','
elif '.tsv' in args.table:
sep = '\t'
# Run gget dataverse function
dataverse(
df = args.table,
path = args.out,
sep = sep,
)
9 changes: 9 additions & 0 deletions gget/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# import time
import re
import os
import sys
import uuid
import pandas as pd
import numpy as np
Expand Down Expand Up @@ -66,6 +67,14 @@ def flatten(xss):
return [x for xs in xss for x in xs]


def print_sys(s):
"""system print
Args:
s (str): the string to print
"""
print(s, flush = True, file = sys.stderr)


def get_latest_cosmic():
html = requests.get(COSMIC_RELEASE_URL)
if html.status_code != 200:
Expand Down

0 comments on commit 84f0e64

Please sign in to comment.