Skip to content

Commit

Permalink
Update SEG-Y import for cloud native support and add test
Browse files Browse the repository at this point in the history
  • Loading branch information
tasansal committed Nov 21, 2024
1 parent 343e33d commit 04cce83
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 9 deletions.
2 changes: 1 addition & 1 deletion src/mdio/commands/segy.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@


@cli.command(name="import")
@argument("segy-path", type=Path(exists=True))
@argument("segy-path", type=STRING)
@argument("mdio-path", type=STRING)
@option(
"-loc",
Expand Down
10 changes: 9 additions & 1 deletion src/mdio/segy/_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import os
from typing import TYPE_CHECKING
from typing import Any

Expand Down Expand Up @@ -37,7 +38,14 @@ def header_scan_worker(
Returns:
HeaderArray parsed from SEG-Y library.
"""
return segy_file.header[slice(*trace_range)]
slice_ = slice(*trace_range)

cloud_native_mode = os.getenv("MDIO__IMPORT__CLOUD_NATIVE", default="False")

if cloud_native_mode.lower() in {"true", "1"}:
return segy_file.trace[slice_].header

return segy_file.header[slice_]


def trace_worker(
Expand Down
11 changes: 8 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,17 @@ def fake_segy_tmp(tmp_path_factory):


@pytest.fixture(scope="session")
def segy_input(tmp_path_factory):
def segy_input_uri():
"""Path to dome dataset for cloud testing."""
return "http://s3.amazonaws.com/teapot/filt_mig.sgy"


@pytest.fixture(scope="session")
def segy_input(segy_input_uri, tmp_path_factory):
"""Download teapot dome dataset for testing."""
url = "http://s3.amazonaws.com/teapot/filt_mig.sgy"
tmp_dir = tmp_path_factory.mktemp("segy")
tmp_file = path.join(tmp_dir, "teapot.segy")
urlretrieve(url, tmp_file) # noqa: S310
urlretrieve(segy_input_uri, tmp_file) # noqa: S310

return tmp_file

Expand Down
22 changes: 18 additions & 4 deletions tests/test_main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Test cases for the __main__ module."""

import os
from pathlib import Path

import pytest
Expand All @@ -8,18 +9,31 @@
from mdio import __main__


@pytest.fixture()
@pytest.fixture
def runner() -> CliRunner:
"""Fixture for invoking command-line interfaces."""
return CliRunner()


@pytest.mark.dependency()
@pytest.mark.dependency
def test_main_succeeds(runner: CliRunner, segy_input: str, zarr_tmp: Path) -> None:
"""It exits with a status code of zero."""
cli_args = ["segy", "import", segy_input, str(zarr_tmp)]
cli_args.extend(["-loc", "181,185"])
cli_args.extend(["-names", "inline,crossline"])
cli_args.extend(["--header-locations", "181,185"])
cli_args.extend(["--header-names", "inline,crossline"])

result = runner.invoke(__main__.main, args=cli_args)
assert result.exit_code == 0


@pytest.mark.dependency(depends=["test_main_succeeds"])
def test_main_cloud(runner: CliRunner, segy_input_uri: str, zarr_tmp: Path) -> None:
"""It exits with a status code of zero."""
os.environ["MDIO__IMPORT__CLOUD_NATIVE"] = "true"
cli_args = ["segy", "import", str(segy_input_uri), str(zarr_tmp)]
cli_args.extend(["--header-locations", "181,185"])
cli_args.extend(["--header-names", "inline,crossline"])
cli_args.extend(["--overwrite"])

result = runner.invoke(__main__.main, args=cli_args)
assert result.exit_code == 0
Expand Down

0 comments on commit 04cce83

Please sign in to comment.