Skip to content

Commit

Permalink
add Python API for sample deletion (#759)
Browse files Browse the repository at this point in the history
  • Loading branch information
awenocur authored Aug 28, 2024
1 parent 867f130 commit b9fd438
Show file tree
Hide file tree
Showing 10 changed files with 121 additions and 22 deletions.
4 changes: 4 additions & 0 deletions apis/python/src/tiledbvcf/binding/libtiledbvcf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,10 @@ PYBIND11_MODULE(libtiledbvcf, m) {
"ingest_samples",
&Writer::ingest_samples,
py::call_guard<py::gil_scoped_release>())
.def(
"delete_samples",
&Writer::delete_samples,
py::call_guard<py::gil_scoped_release>())
.def("get_schema_version", &Writer::get_schema_version)
.def("set_tiledb_config", &Writer::set_tiledb_config)
.def("set_sample_batch_size", &Writer::set_sample_batch_size)
Expand Down
12 changes: 12 additions & 0 deletions apis/python/src/tiledbvcf/binding/writer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,18 @@ void Writer::ingest_samples() {
check_error(writer, tiledb_vcf_writer_store(writer));
}

void Writer::delete_samples(std::vector<std::string> samples_to_delete) {
std::vector<const char*> samples;
for (std::string& sample : samples_to_delete) {
samples.emplace_back(sample.c_str());
}

auto writer = ptr.get();
check_error(
writer,
tiledb_vcf_writer_delete_samples(writer, samples.data(), samples.size()));
}

void Writer::deleter(tiledb_vcf_writer_t* w) {
tiledb_vcf_writer_free(&w);
}
Expand Down
2 changes: 2 additions & 0 deletions apis/python/src/tiledbvcf/binding/writer.h
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,8 @@ class Writer {

void ingest_samples();

void delete_samples(std::vector<std::string> samples);

/** Returns schema version number of the TileDB VCF dataset */
int32_t get_schema_version();

Expand Down
8 changes: 8 additions & 0 deletions apis/python/src/tiledbvcf/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -851,6 +851,14 @@ def ingest_samples(
self.writer.register_samples()
self.writer.ingest_samples()

def delete_samples(
self,
sample_uris: List[str] = None,
):
if self.mode != "w":
raise Exception("Dataset not open in write mode")
self.writer.delete_samples(sample_uris)

def tiledb_stats(self) -> str:
"""
Get TileDB stats as a string.
Expand Down
78 changes: 57 additions & 21 deletions apis/python/tests/test_tiledbvcf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1197,15 +1197,8 @@ def test_ingest_mode_merged(tmp_path):
assert ds.count(regions=["chrX:9032893-9032893"]) == 0


# Ok to skip is missing bcftools in Windows CI job
@pytest.mark.skipif(
os.environ.get("CI") == "true"
and platform.system() == "Windows"
and shutil.which("bcftools") is None,
reason="no bcftools",
)
def test_ingest_with_stats_v3(tmp_path):
# tiledbvcf.config_logging("debug")
@pytest.fixture
def test_stats_bgzipped_inputs(tmp_path):
tmp_path_contents = os.listdir(tmp_path)
if "stats" in tmp_path_contents:
shutil.rmtree(os.path.join(tmp_path, "stats"))
Expand All @@ -1221,23 +1214,46 @@ def test_ingest_with_stats_v3(tmp_path):
check=True,
)
bgzipped_inputs = glob.glob(os.path.join(tmp_path, "stats", "*.gz"))
# print(f"bgzipped inputs: {bgzipped_inputs}")
for vcf_file in bgzipped_inputs:
assert subprocess.run("bcftools index " + vcf_file, shell=True).returncode == 0
if "outputs" in tmp_path_contents:
shutil.rmtree(os.path.join(tmp_path, "outputs"))
if "stats_test" in tmp_path_contents:
shutil.rmtree(os.path.join(tmp_path, "stats_test"))
# tiledbvcf.config_logging("trace")
return bgzipped_inputs


@pytest.fixture
def test_stats_sample_names(test_stats_bgzipped_inputs):
assert len(test_stats_bgzipped_inputs) == 8
return [os.path.basename(file).split(".")[0] for file in test_stats_bgzipped_inputs]


@pytest.fixture
def test_stats_v3_ingestion(tmp_path, test_stats_bgzipped_inputs):
assert len(test_stats_bgzipped_inputs) == 8
# print(f"bgzipped inputs: {test_stats_bgzipped_inputs}")
ds = tiledbvcf.Dataset(uri=os.path.join(tmp_path, "stats_test"), mode="w")
ds.create_dataset(
enable_variant_stats=True, enable_allele_count=True, variant_stats_version=3
)
ds.ingest_samples(bgzipped_inputs)
ds.ingest_samples(test_stats_bgzipped_inputs)
ds = tiledbvcf.Dataset(uri=os.path.join(tmp_path, "stats_test"), mode="r")
sample_names = [os.path.basename(file).split(".")[0] for file in bgzipped_inputs]
data_frame = ds.read(
samples=sample_names,
return ds


# Ok to skip is missing bcftools in Windows CI job
@pytest.mark.skipif(
os.environ.get("CI") == "true"
and platform.system() == "Windows"
and shutil.which("bcftools") is None,
reason="no bcftools",
)
def test_ingest_with_stats_v3(
tmp_path, test_stats_v3_ingestion, test_stats_sample_names
):
data_frame = test_stats_v3_ingestion.read(
samples=test_stats_sample_names,
attrs=["contig", "pos_start", "id", "qual", "info_TILEDB_IAF", "sample_name"],
set_af_filter="<0.2",
)
Expand All @@ -1249,8 +1265,8 @@ def test_ingest_with_stats_v3(tmp_path):
data_frame[data_frame["sample_name"] == "second"]["info_TILEDB_IAF"].iloc[0][0]
== 0.9375
)
data_frame = ds.read(
samples=sample_names,
data_frame = test_stats_v3_ingestion.read(
samples=test_stats_sample_names,
attrs=["contig", "pos_start", "id", "qual", "info_TILEDB_IAF", "sample_name"],
scan_all_samples=True,
)
Expand All @@ -1260,25 +1276,45 @@ def test_ingest_with_stats_v3(tmp_path):
]["info_TILEDB_IAF"].iloc[0][0]
== 0.9375
)
ds = tiledbvcf.Dataset(uri=os.path.join(tmp_path, "stats_test"), mode="r")
df = ds.read_variant_stats("chr1:1-10000")
df = test_stats_v3_ingestion.read_variant_stats("chr1:1-10000")
assert df.shape == (13, 5)
df = tiledbvcf.allele_frequency.read_allele_frequency(
os.path.join(tmp_path, "stats_test"), "chr1:1-10000"
)
assert df.pos.is_monotonic_increasing
df["an_check"] = (df.ac / df.af).round(0).astype("int32")
assert df.an_check.equals(df.an)
df = ds.read_variant_stats("chr1:1-10000")
df = test_stats_v3_ingestion.read_variant_stats("chr1:1-10000")
assert df.shape == (13, 5)
df = df.to_pandas()
df = ds.read_allele_count("chr1:1-10000")
df = test_stats_v3_ingestion.read_allele_count("chr1:1-10000")
assert df.shape == (7, 6)
df = df.to_pandas()
assert sum(df["pos"] == (0, 1, 1, 2, 2, 2, 3)) == 7
assert sum(df["count"] == (8, 5, 3, 4, 2, 2, 1)) == 7


@pytest.mark.skipif(
os.environ.get("CI") == "true"
and platform.system() == "Windows"
and shutil.which("bcftools") is None,
reason="no bcftools",
)
def test_delete_samples(tmp_path, test_stats_v3_ingestion, test_stats_sample_names):
# assert test_stats_v3_ingestion.samples() == test_stats_sample_names
assert "second" in test_stats_sample_names
assert "fifth" in test_stats_sample_names
assert "third" in test_stats_sample_names
ds = tiledbvcf.Dataset(uri=os.path.join(tmp_path, "stats_test"), mode="w")
# tiledbvcf.config_logging("trace")
ds.delete_samples(["second", "fifth"])
ds = tiledbvcf.Dataset(uri=os.path.join(tmp_path, "stats_test"), mode="r")
sample_names = ds.samples()
assert "second" not in sample_names
assert "fifth" not in sample_names
assert "third" in sample_names


# Ok to skip is missing bcftools in Windows CI job
@pytest.mark.skipif(
os.environ.get("CI") == "true"
Expand Down
15 changes: 15 additions & 0 deletions libtiledbvcf/src/c_api/tiledbvcf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1823,6 +1823,21 @@ int32_t tiledb_vcf_writer_set_variant_stats_version(
return TILEDB_VCF_OK;
}

int32_t tiledb_vcf_writer_delete_samples(
tiledb_vcf_writer_t* writer, const char** samples, size_t nsamples) {
std::vector<std::string> encoded_samples;
for (size_t i = 0; i < nsamples; i++)
encoded_samples.emplace_back(samples[i]);
if (sanity_check(writer) == TILEDB_VCF_ERR)
return TILEDB_VCF_ERR;

if (SAVE_ERROR_CATCH(
writer, writer->writer_->delete_samples(encoded_samples)))
return TILEDB_VCF_ERR;

return TILEDB_VCF_OK;
}

/* ********************************* */
/* ERROR */
/* ********************************* */
Expand Down
10 changes: 10 additions & 0 deletions libtiledbvcf/src/c_api/tiledbvcf.h
Original file line number Diff line number Diff line change
Expand Up @@ -1706,6 +1706,16 @@ tiledb_vcf_writer_set_compression_level(tiledb_vcf_writer_t* writer, int level);
TILEDBVCF_EXPORT int32_t tiledb_vcf_writer_set_variant_stats_version(
tiledb_vcf_writer_t* writer, uint8_t version);

/**
* Deletes samples from dataset
* @param writer VCF writer object
* @param samples samples to delete
* @param nsamples number of samples to delete
*/
TILEDBVCF_EXPORT int32_t tiledb_vcf_writer_delete_samples(

tiledb_vcf_writer_t* writer, const char** samples, size_t nsamples);

/* ********************************* */
/* ERROR */
/* ********************************* */
Expand Down
4 changes: 3 additions & 1 deletion libtiledbvcf/src/dataset/tiledbvcfdataset.cc
Original file line number Diff line number Diff line change
Expand Up @@ -938,7 +938,9 @@ void TileDBVCFDataset::delete_samples(
const std::vector<std::string>& sample_names,
const std::vector<std::string>& tiledb_config) {
// Open dataset in read mode, required before calling `sample_exists`.
open(uri);
if (!open_) {
open(uri, tiledb_config);
}

// Define a function that deletes a sample from an array
auto delete_sample = [&](Array& array, const std::string& sample) {
Expand Down
5 changes: 5 additions & 0 deletions libtiledbvcf/src/write/writer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1484,5 +1484,10 @@ void Writer::set_variant_stats_array_version(uint8_t version) {
creation_params_.variant_stats_array_version = version;
}

void Writer::delete_samples(std::vector<std::string> samples) {
dataset_->delete_samples(
ingestion_params_.uri, samples, ingestion_params_.tiledb_config);
}

} // namespace vcf
} // namespace tiledb
5 changes: 5 additions & 0 deletions libtiledbvcf/src/write/writer.h
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,11 @@ class Writer {
/** Set variant stats array version */
void set_variant_stats_array_version(uint8_t version);

/**
* @brief Delete samples from the writer's dataset.
*/
void delete_samples(std::vector<std::string> samples);

private:
/* ********************************* */
/* PRIVATE ATTRIBUTES */
Expand Down

0 comments on commit b9fd438

Please sign in to comment.