Skip to content

Commit

Permalink
Feat: python -m zeus.show_env (#119)
Browse files Browse the repository at this point in the history
  • Loading branch information
jaywonchung authored Sep 9, 2024
1 parent 78106db commit 941591d
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 11 deletions.
2 changes: 1 addition & 1 deletion zeus/device/cpu/rapl.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def __init__(self) -> None:
self._init_cpus()

@property
def cpus(self) -> Sequence[cpu_common.CPU]:
def cpus(self) -> Sequence[RAPLCPU]:
"""Returns a list of CPU objects being tracked."""
return self._cpus

Expand Down
2 changes: 1 addition & 1 deletion zeus/device/gpu/amd.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ def __init__(self, ensure_homogeneous: bool = False) -> None:
raise exception_class(e.msg) from e

@property
def gpus(self) -> Sequence[gpu_common.GPU]:
def gpus(self) -> Sequence[AMDGPU]:
"""Return a list of AMDGPU objects being tracked."""
return self._gpus

Expand Down
2 changes: 1 addition & 1 deletion zeus/device/gpu/nvidia.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ def __init__(self, ensure_homogeneous: bool = False) -> None:
) from e

@property
def gpus(self) -> Sequence[gpu_common.GPU]:
def gpus(self) -> Sequence[NVIDIAGPU]:
"""Return a list of NVIDIAGPU objects being tracked."""
return self._gpus

Expand Down
107 changes: 107 additions & 0 deletions zeus/show_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
"""Collect information about the environment and display it.
- Python version
- Package availablility and versions: Zeus, PyTorch, JAX.
- NVIDIA GPU availability: Number of GPUs and moels.
- AMD GPU availability: Number of GPUs and models.
- Intel RAPL availability: Number of CPUs and whether DRAM measurements are available.
"""

from __future__ import annotations

import platform

import zeus
from zeus.utils import framework
from zeus.device import get_gpus, get_cpus
from zeus.device.cpu import RAPLCPUs
from zeus.device.gpu.common import ZeusGPUInitError, EmptyGPUs
from zeus.device.cpu.common import ZeusCPUInitError, EmptyCPUs


SECTION_SEPARATOR = "=" * 80 + "\n"


def show_env():
"""Collect information about the environment and display it."""
print(SECTION_SEPARATOR)
print(f"Python version: {platform.python_version()}\n")

print(SECTION_SEPARATOR)
package_availability = "\nPackage availability and versions:\n"
package_availability += f" Zeus: {zeus.__version__}\n"

try:
torch_available = framework.torch_is_available()
torch_cuda_available = True
except RuntimeError:
torch_available = framework.torch_is_available(ensure_cuda=False)
torch_cuda_available = False

if torch_available and torch_cuda_available:
torch = framework.MODULE_CACHE["torch"]
package_availability += f" PyTorch: {torch.__version__} (with CUDA support)\n"
elif torch_available and not torch_cuda_available:
torch = framework.MODULE_CACHE["torch"]
package_availability += (
f" PyTorch: {torch.__version__} (without CUDA support)\n"
)
else:
package_availability += " PyTorch: not available\n"

try:
jax_available = framework.jax_is_available()
jax_cuda_available = True
except RuntimeError:
jax_available = framework.jax_is_available(ensure_cuda=False)
jax_cuda_available = False

if jax_available and jax_cuda_available:
jax = framework.MODULE_CACHE["jax"]
package_availability += f" JAX: {jax.__version__} (with CUDA support)\n"
elif jax_available and not jax_cuda_available:
jax = framework.MODULE_CACHE["jax"]
package_availability += f" JAX: {jax.__version__} (without CUDA support)\n"
else:
package_availability += " JAX: not available\n"

print(package_availability)

print(SECTION_SEPARATOR)
gpu_availability = "\nGPU availability:\n"
try:
gpus = get_gpus()
except ZeusGPUInitError:
gpus = EmptyGPUs()
if len(gpus) > 0:
for i in range(len(gpus)):
gpu_availability += f" GPU {i}: {gpus.getName(i)}\n"
else:
gpu_availability += " No GPUs available.\n"
print(gpu_availability)

print(SECTION_SEPARATOR)
cpu_availability = "\nCPU availability:\n"
try:
cpus = get_cpus()
except ZeusCPUInitError:
cpus = EmptyCPUs()
if len(cpus) > 0:
assert isinstance(cpus, RAPLCPUs)
for i in range(len(cpus)):
cpu_availability += f" CPU {i}:\n CPU measurements available ({cpus.cpus[i].rapl_file.path})\n"
if cpus.supportsGetDramEnergyConsumption(i):
dram = cpus.cpus[i].dram
assert dram is not None
cpu_availability += f" DRAM measurements available ({dram.path})\n"
else:
cpu_availability += " DRAM measurements unavailable\n"
else:
cpu_availability += " No CPUs available.\n"
print(cpu_availability)

print(SECTION_SEPARATOR)


if __name__ == "__main__":
show_env()
23 changes: 15 additions & 8 deletions zeus/utils/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,19 @@


@lru_cache(maxsize=1)
def torch_is_available(ensure_available: bool = False):
def torch_is_available(ensure_available: bool = False, ensure_cuda: bool = True):
"""Check if PyTorch is available."""
try:
import torch

assert (
torch.cuda.is_available()
), "PyTorch is available but does not have CUDA support."
cuda_available = torch.cuda.is_available()
if ensure_cuda and not cuda_available:
raise RuntimeError("PyTorch is available but does not have CUDA support.")
MODULE_CACHE["torch"] = torch
logger.info("PyTorch with CUDA support is available.")
logger.info(
"PyTorch %s CUDA support is available.",
"with" if cuda_available else "without",
)
return True
except ImportError as e:
logger.info("PyTorch is not available.")
Expand All @@ -32,14 +35,18 @@ def torch_is_available(ensure_available: bool = False):


@lru_cache(maxsize=1)
def jax_is_available(ensure_available: bool = False):
def jax_is_available(ensure_available: bool = False, ensure_cuda: bool = True):
"""Check if JAX is available."""
try:
import jax # type: ignore

assert jax.devices("gpu"), "JAX is available but does not have CUDA support."
cuda_available = jax.devices("gpu")
if ensure_cuda and not cuda_available:
raise RuntimeError("JAX is available but does not have CUDA support.")
MODULE_CACHE["jax"] = jax
logger.info("JAX with CUDA support is available.")
logger.info(
"JAX %s CUDA support is available.", "with" if cuda_available else "without"
)
return True
except ImportError as e:
logger.info("JAX is not available")
Expand Down

0 comments on commit 941591d

Please sign in to comment.