Skip to content

Commit

Permalink
Remove power_limit_optimizer and bring back the original code for ima…
Browse files Browse the repository at this point in the history
…ge processing
  • Loading branch information
sharonsyh committed Dec 10, 2024
1 parent f18ecb9 commit 2276ac2
Showing 1 changed file with 27 additions and 37 deletions.
64 changes: 27 additions & 37 deletions examples/prometheus/train_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,13 @@
import torch.utils.data
from torch.utils.data import DataLoader
import torch.utils.data.distributed
from torch.utils.data import Subset
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models
from multiprocessing import set_start_method
from PIL import Image, ImageFile, UnidentifiedImageError
#set_start_method("fork", force=True)

# ZEUS
from zeus.monitor import ZeusMonitor
from zeus.monitor import PowerMonitor
from zeus.optimizer.power_limit import MaxSlowdownConstraint, GlobalPowerLimitOptimizer
from zeus.utils.env import get_env
from zeus.metric import EnergyHistogram, EnergyCumulativeCounter, PowerGauge

Expand Down Expand Up @@ -112,21 +107,6 @@ def parse_args() -> argparse.Namespace:

return parser.parse_args()

ImageFile.LOAD_TRUNCATED_IMAGES = True # Optionally allow truncated images

def remove_corrupted_images(dataset_dir):
"""Remove corrupted or truncated image files from the dataset directory."""
for root, _, files in os.walk(dataset_dir):
for file in files:
img_path = os.path.join(root, file)
try:
with Image.open(img_path) as img:
img.verify() # Verify if the image is valid
img.convert("RGB") # Ensure it's in a proper format
except (UnidentifiedImageError, OSError):
print(f"Removing corrupted or truncated file: {img_path}")
os.remove(img_path)

def main():
"""Main function that prepares values and spawns/calls the worker function."""
args = parse_args()
Expand All @@ -152,19 +132,16 @@ def main():
scheduler = StepLR(optimizer, step_size=30, gamma=0.1)

traindir = os.path.join(args.data, "train")
#remove_corrupted_images(traindir)

valdir = os.path.join(args.data, "val")
#remove_corrupted_images(valdir)

normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
)
train_dataset = datasets.ImageFolder(
traindir,
transforms.Compose(
[
transforms.Resize((224, 224)),
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]
Expand All @@ -174,7 +151,8 @@ def main():
valdir,
transforms.Compose(
[
transforms.Resize((224, 224)),
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
]
Expand All @@ -196,27 +174,41 @@ def main():
pin_memory=True,
)

train_dataset = Subset(train_dataset, range(5))
val_dataset = Subset(val_dataset, range(2))

################################## The important part #####################################
# Histogram to track energy consumption over time
energy_histogram = EnergyHistogram(cpu_indices=[0,1], gpu_indices=[0], prometheus_url='http://localhost:9091', job='training_energy_histogram')
energy_histogram = EnergyHistogram(
cpu_indices=[0,1],
gpu_indices=[0],
prometheus_url='http://localhost:9091',
job='training_energy_histogram'
)
# Gauge to track power consumption over time
power_gauge = PowerGauge(gpu_indices=[0], update_period=2, prometheus_url='http://localhost:9091', job='training_power_gauge')
power_gauge = PowerGauge(
gpu_indices=[0],
update_period=2,
prometheus_url='http://localhost:9091',
job='training_power_gauge'
)
# Counter to track energy consumption over time
energy_counter = EnergyCumulativeCounter(cpu_indices=[0,1], gpu_indices=[0], update_period=2, prometheus_url='http://localhost:9091', job='training_energy_counter')
energy_counter = EnergyCumulativeCounter(
cpu_indices=[0,1],
gpu_indices=[0],
update_period=2,
prometheus_url='http://localhost:9091',
job='training_energy_counter'
)

power_gauge.begin_window("epoch_power")
energy_counter.begin_window("epoch_energy")

for epoch in range(args.epochs):
acc1 = validate(val_loader, model, criterion, args)
energy_histogram.begin_window("training_energy")
energy_histogram.end_window("training_energy")
train(train_loader, model, criterion, optimizer, epoch, args)
energy_histogram.end_window("training_energy")
print(f"Top-1 accuracy: {acc1}")


# Allow metrics to capture remaining data before shutting down monitoring.
time.sleep(10)

energy_counter.end_window("epoch_energy")
Expand Down Expand Up @@ -244,7 +236,6 @@ def train(

end = time.time()
for i, (images, target) in enumerate(train_loader):
#power_limit_optimizer.on_step_begin() # Mark the beginning of one training step.

# Load data to GPU
images = images.cuda(args.gpu, non_blocking=True)
Expand All @@ -268,8 +259,6 @@ def train(
loss.backward()
optimizer.step()

#power_limit_optimizer.on_step_end() # Mark the end of one training step.

# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
Expand Down Expand Up @@ -418,4 +407,5 @@ def accuracy(output, target, topk=(1,)):

if __name__ == "__main__":
main()


0 comments on commit 2276ac2

Please sign in to comment.