Skip to content

Commit

Permalink
Reintroduce tpu- prefix; add tpu vendor alias (#1587)
Browse files Browse the repository at this point in the history
Closes: #1586
  • Loading branch information
un-def authored Aug 21, 2024
1 parent 74f7f39 commit 75d5f55
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 18 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def get_long_description():
"cachetools",
"dnspython",
"grpcio>=1.50", # indirect
"gpuhunt>=0.0.13",
"gpuhunt>=0.0.15,<0.1.0",
"sentry-sdk[fastapi]",
"httpx",
"aiorwlock",
Expand Down
10 changes: 6 additions & 4 deletions src/dstack/_internal/cli/services/configurators/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@

_KNOWN_AMD_GPUS = {gpu.name.lower() for gpu in gpuhunt.KNOWN_AMD_GPUS}
_KNOWN_NVIDIA_GPUS = {gpu.name.lower() for gpu in gpuhunt.KNOWN_NVIDIA_GPUS}
_KNOWN_TPUS = {gpu.name.lower() for gpu in gpuhunt.KNOWN_TPUS}
_KNOWN_TPU_VERSIONS = {gpu.name.lower() for gpu in gpuhunt.KNOWN_TPUS}

_BIND_ADDRESS_ARG = "bind_address"

Expand Down Expand Up @@ -316,10 +316,12 @@ def validate_gpu_vendor_and_image(self, conf: BaseRunConfiguration) -> None:
vendors.add(gpuhunt.AcceleratorVendor.NVIDIA)
elif name in _KNOWN_AMD_GPUS:
vendors.add(gpuhunt.AcceleratorVendor.AMD)
elif name in _KNOWN_TPUS:
vendors.add(gpuhunt.AcceleratorVendor.GOOGLE)
else:
vendors.add(None)
maybe_tpu_version, _, maybe_tpu_cores = name.partition("-")
if maybe_tpu_version in _KNOWN_TPU_VERSIONS and maybe_tpu_cores.isdigit():
vendors.add(gpuhunt.AcceleratorVendor.GOOGLE)
else:
vendors.add(None)
if len(vendors) == 1:
# Only one vendor or all names are not known.
vendor = next(iter(vendors))
Expand Down
34 changes: 30 additions & 4 deletions src/dstack/_internal/core/models/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
from typing_extensions import Annotated

from dstack._internal.core.models.common import CoreModel
from dstack._internal.utils.logging import get_logger

logger = get_logger(__name__)


T = TypeVar("T", bound=Union[int, float])

Expand Down Expand Up @@ -163,7 +167,7 @@ def parse(cls, v: Any) -> Any:
if not token:
raise ValueError(f"GPU spec contains empty token: {v}")
try:
vendor = gpuhunt.AcceleratorVendor.cast(token)
vendor = cls._vendor_from_string(token)
except ValueError:
vendor = None
if vendor:
Expand All @@ -189,16 +193,38 @@ def parse(cls, v: Any) -> Any:

@validator("name", pre=True)
def _validate_name(cls, v: Any) -> Any:
if v is not None and not isinstance(v, list):
return [v]
return v
if v is None:
return None
if not isinstance(v, list):
v = [v]
validated: List[Any] = []
has_tpu_prefix = False
for name in v:
if isinstance(name, str) and name.startswith("tpu-"):
name = name[4:]
has_tpu_prefix = True
validated.append(name)
if has_tpu_prefix:
logger.warning("`tpu-` prefix is deprecated, specify gpu_vendor instead")
return validated

@validator("vendor", pre=True)
def _validate_vendor(
cls, v: Union[str, gpuhunt.AcceleratorVendor, None]
) -> Optional[gpuhunt.AcceleratorVendor]:
if v is None:
return None
if isinstance(v, gpuhunt.AcceleratorVendor):
return v
if isinstance(v, str):
return cls._vendor_from_string(v)
raise TypeError(f"Unsupported type: {v!r}")

@classmethod
def _vendor_from_string(cls, v: str) -> gpuhunt.AcceleratorVendor:
v = v.lower()
if v == "tpu":
return gpuhunt.AcceleratorVendor.GOOGLE
return gpuhunt.AcceleratorVendor.cast(v)


Expand Down
21 changes: 14 additions & 7 deletions src/tests/_internal/cli/services/configurators/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,11 +125,18 @@ def test_zero_gpu(self):
self.validate(conf)
assert conf.resources.gpu.vendor is None

@pytest.mark.parametrize("gpu_spec", ["nvidia", "google"])
def test_non_amd_vendor_declared(self, gpu_spec):
@pytest.mark.parametrize(
["gpu_spec", "expected_vendor"],
[
["nvidia", AcceleratorVendor.NVIDIA],
["tpu", AcceleratorVendor.GOOGLE],
["google", AcceleratorVendor.GOOGLE],
],
)
def test_non_amd_vendor_declared(self, gpu_spec, expected_vendor):
conf = self.prepare_conf(gpu_spec=gpu_spec)
self.validate(conf)
assert conf.resources.gpu.vendor == AcceleratorVendor.cast(gpu_spec)
assert conf.resources.gpu.vendor == expected_vendor

def test_amd_vendor_declared_with_image(self):
conf = self.prepare_conf(image="tgi:rocm", gpu_spec="AMD")
Expand All @@ -140,7 +147,7 @@ def test_amd_vendor_declared_with_image(self):
["gpu_spec", "expected_vendor"],
[
["a40,l40", AcceleratorVendor.NVIDIA], # lowercase
["V4", AcceleratorVendor.GOOGLE], # uppercase
["V3-64", AcceleratorVendor.GOOGLE], # uppercase
],
)
def test_one_non_amd_vendor_inferred(self, gpu_spec, expected_vendor):
Expand All @@ -164,7 +171,7 @@ def test_one_unknown_vendor_inferred(self, gpu_spec):
"gpu_spec",
[
"A1000,v4", # Nvidia and Google
"v3,foo", # Google and unknown
"v3-64,foo", # Google and unknown
],
)
def test_two_non_amd_vendors_inferred(self, gpu_spec):
Expand All @@ -176,7 +183,7 @@ def test_two_non_amd_vendors_inferred(self, gpu_spec):
"gpu_spec",
[
"A1000,mi300x", # Nvidia and AMD (lowercase)
"MI300x,v3", # AMD (mixedcase) and Google
"MI300x,v3-64", # AMD (mixedcase) and Google
"foo,MI300X", # unknown and AMD (uppercase)
],
)
Expand All @@ -200,7 +207,7 @@ def test_amd_vendor_inferred_no_image(self, gpu_spec):
"gpu_spec",
[
"A1000,mi300x", # Nvidia and AMD (lowercase)
"MI300x,v3", # AMD (mixedcase) and Google
"MI300x,v3-64", # AMD (mixedcase) and Google
"foo,MI300X", # unknown and AMD (uppercase)
],
)
Expand Down
20 changes: 18 additions & 2 deletions src/tests/_internal/core/models/test_resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,20 @@ def test_count(self):
"Nvidia", {"vendor": AcceleratorVendor.NVIDIA}, id="vendor-only-mixedcase"
),
pytest.param(
"google:v3",
{"vendor": AcceleratorVendor.GOOGLE, "name": ["v3"]},
"google:v3-64",
{"vendor": AcceleratorVendor.GOOGLE, "name": ["v3-64"]},
id="vendor-lowercase-and-name",
),
pytest.param(
"tpu:v5p-1024",
{"vendor": AcceleratorVendor.GOOGLE, "name": ["v5p-1024"]},
id="tpu-lowercase-and-name",
),
pytest.param(
"v5litepod-64:TPU",
{"vendor": AcceleratorVendor.GOOGLE, "name": ["v5litepod-64"]},
id="name-and-tpu-uppercase",
),
pytest.param(
"MI300X:AMD",
{"vendor": AcceleratorVendor.AMD, "name": ["MI300X"]},
Expand All @@ -132,6 +142,8 @@ def test_vendor_in_string_form(self, value, expected):
pytest.param("NVIDIA", AcceleratorVendor.NVIDIA, id="uppercase"),
pytest.param("amd", AcceleratorVendor.AMD, id="lowercase"),
pytest.param("Google", AcceleratorVendor.GOOGLE, id="mixedcase"),
pytest.param("tpu", AcceleratorVendor.GOOGLE, id="tpu-lowercase"),
pytest.param("TPU", AcceleratorVendor.GOOGLE, id="tpu-uppercase"),
pytest.param(AcceleratorVendor.GOOGLE, AcceleratorVendor.GOOGLE, id="enum-value"),
],
)
Expand All @@ -143,6 +155,10 @@ def test_vendor_in_object_form(self, value, expected):
def test_name(self):
assert parse_obj_as(GPUSpec, "A100") == parse_obj_as(GPUSpec, {"name": ["A100"]})

def test_name_with_tpu_prefix(self):
spec = parse_obj_as(GPUSpec, "tpu-v3-2048")
assert spec.name == ["v3-2048"]

def test_memory(self):
assert parse_obj_as(GPUSpec, "16GB") == parse_obj_as(GPUSpec, {"memory": "16GB"})

Expand Down

0 comments on commit 75d5f55

Please sign in to comment.