diff --git a/README.md b/README.md index 16d5ee3..a4fe6f2 100644 --- a/README.md +++ b/README.md @@ -3,20 +3,20 @@ A simple training platform for PyTorch. The models are implemented in [mitorch-models](https://github.com/shonohs/mitorch_models). -There are two ways to use this platform. Use as a simple training command on local machine, or use as a service. - -# Usage -To install, +# Install +Use python 3.8+. ```bash pip install mitorch ``` +# Usage + ## Training Command ```bash mitrain [-w ] [-o ] [-d] ``` - config_filepath - - Json-serialized training config. For the detail, please see the sample configs in samples/ directory. + - Json-serialized training config. - train_filepath / val_filepath - Filepath to the training / validation dataset - weights_filepath @@ -28,99 +28,76 @@ mitrain [-w ] [- The validatoin results will be printed on stdout. -## Validation Command -```bash -mitest -w [-d] +## Training config +The definition of training config is in mitorch/common/training_config.py. + +Here is an example configuration. ``` +{ + "model": { + "input_size": 224, + "name": "MobileNetV3" + }, + "optimizer": { + "name": "sgd", + "momentum": 0.9, + "weight_decay": 0.0001 + }, + "lr_scheduler": { + "name": "cosine_annealing", + "base_lr": 0.01 + }, + "augmentation": { # The names are defined in mitorch/datasets/factory.py. + "train": "random_resize", + "val": "center_crop" + }, + "dataset": { # This setting is used by mitorch-agent. + "train": "mnist/train_images.txt", + "val": "mnist/test_images.txt" + }, + "batch_size": 2, + "max_epochs": 5, + "task_type": "multiclass_classification", + "num_processes": 1 # For mitorch-agent. Specify the number of GPU/CPU for the training. +} +``` + +## Dataset format +See [simpledataset](https://github.com/shonohs/simpledataset). -# Usage as a service -This library can be used as a service using AzureML and MongoDB. Queue training jobs to MongoDB, and the trainings will be run on AzureML instances. The training results will be stored on MongoDB after the trainings. +# Advanced usage: experiment management +You can manage experiments on remote machines using this framework. ## Setup -First, you need to createa AzureML and MongoDB resource on Azure portal. For the detail of this step, please read the Azure official documents. Once you set up the resources, collect the following informations. -- Subscription ID -- AzureML workspace name -- AzureML compute cluster name -- Username/password to access the AzureML resource. (Service Principal is recommended) -- MongoDB Endpoing with the access token +First, please set up an Azure Blob Storage and a mongo DB account. + +- Blob Storage URL with SAS token +- MongoDB URL with the access token Set those information to the following environment variables. ```bash -export MITORCH_AZURE_SUBSCRIPTION_ID= -export MITORCH_AML_WORKSPACE= -export MITORCH_AML_COMPUTE= -export MITORCH_AML_AUTH=: -export MITORCH_DB_URI= +export MITORCH_STORAGE_URL= +export MITORCH_DATABASE_URL= ``` -Second, upload the datasets to Azure Blob Storage. Our dataset format is described in the later section. Register the dataset infomation to the MongoDB using midataset command. +## Queue a job ```bash -midataset register -``` -The format of the dataset definition is: -```javascript -[{"name": "dataset_name", - "version": 0, - "train": {"path": "path", "support_files": ["path"]}, // "path" is a Azure Blob storage URL with a sas token - "val": {"path": "path", "support_files": ["path"]}}] +misubmit ``` +This command will send a config file to the Mongo DB. -Third, run micontrol every 5 minutes. You can use any method to achive this step as long as the environment varialbes are correctly provided. You can manually execute them every 5 minutes, you can set up a cron job, or deploy to Azure Functions (recommended). - -That's it. Now you are ready to use the service. +## Run an agent +On a powerful machine you want to use, run the follwing command. +```bash +miagent --data +``` +It will get a job from the Mongo DB, train it, and save the results to the MongoDB and the Blob storage. ## Commands ```bash # Queue a new training misubmit [--priority ] -# Queue a hyper-parameter search job -misubmit [--priority ] - -# Get status of a training -miquery - -# Launch a web UI for managing the service -miviewer -``` - -## Data structures -### Training job config -```javascript -{"_id": "", - "status": "", // "new", "running", "failed", or "completed". - "prority": 100, // lower has more priority. - "created_at": "", - "dataset": "", - "config": {} // Training configs. - } -``` -### Training config -```javascript -{"base": "", // Existing training id - "augmentation": {}, - "lr_scheduler": {}, - "model": {}, - "optimizer": {} -} -``` -### Job config structure -```javascript -{"job_type": "search", // Only "search" job is supported now. -} +# Get status of a training. If a job_id is not provided, it shows a list of jobs. +miquery [--job_id JOB_ID] ``` -### Dataset format -TBD - -### MongoDB database structure -This library will create one database on the given MongoDB endpoint. The database name is "mitorch" by default. - -The database has the following collections. -- trainings - - Each record represents one training. New record will be added when a new training job is queued. The record will track the status of the job. Final evaluation results will be added to this record. -- training_results - - Training loss/validation loss for each training epochs. Those records will be updated real-time during the trainings. -- jobs - - Hyper-parameter search jobs will be stored in this collection. One job can create multiple trainings. -- datasets - - Information of registered datasets. This collection needs to be created manually before the trainings. \ No newline at end of file diff --git a/mitorch/azureml/__init__.py b/mitorch/azureml/__init__.py deleted file mode 100644 index e7750fc..0000000 --- a/mitorch/azureml/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .manager import AzureMLManager - -__all__ = ['AzureMLManager'] diff --git a/mitorch/azureml/manager.py b/mitorch/azureml/manager.py deleted file mode 100644 index 4ea8dd3..0000000 --- a/mitorch/azureml/manager.py +++ /dev/null @@ -1,106 +0,0 @@ -import os -import tempfile -import azureml.core -from azureml.core.authentication import ServicePrincipalAuthentication -from ..settings import AzureMLSetting - -EXPERIMENT_NAME = 'mitorch' - - -class AzureMLManager: - """Manage AzureML runs. Submit a new run and query the status of a run. - """ - def __init__(self, settings): - assert isinstance(settings, list) - self.managers = [AzureMLSingleResourceManager(s) for s in settings] - - def get_num_available_nodes(self): - return sum(m.get_num_available_nodes() for m in self.managers) - - def submit(self, *args): - for manager in self.managers: - if manager.get_num_available_nodes() > 0: - aml_run_id, aml_run_url = manager.submit(*args) - return aml_run_id, aml_run_url, manager.region - return None - - def query(self, run_id, region): - for manager in self.managers: - if manager.region == region: - return manager.query(run_id) - return None - - -class AzureMLSingleResourceManager: - def __init__(self, setting): - assert isinstance(setting, AzureMLSetting) - self.setting = setting - if setting.sp_tenant_id and setting.sp_username and setting.sp_password: - auth = ServicePrincipalAuthentication(tenant_id=setting.sp_tenant_id, - service_principal_id=setting.sp_username, - service_principal_password=setting.sp_password) - else: - print("Use interactive authentication...") - auth = None - - self.workspace = azureml.core.Workspace.get(name=setting.workspace_name, subscription_id=setting.subscription_id, auth=auth) - if not self.workspace: - raise RuntimeError(f"Workspace {setting.workspace_name} not found") - self.experiment = azureml.core.Experiment(workspace=self.workspace, name=EXPERIMENT_NAME) - self.cluster = azureml.core.compute.ComputeTarget(workspace=self.workspace, name=setting.cluster_name) - if not self.cluster: - raise RuntimeError(f"Cluster {setting.cluster_name} doesn't exist in workspace {setting.workspace_name}") - - @property - def region(self): - return self.setting.region_name - - def submit(self, db_url, job_id): - run_config = azureml.core.runconfig.RunConfiguration() - run_config.target = self.cluster - dependencies = azureml.core.conda_dependencies.CondaDependencies() - dependencies.set_python_version('3.7') - dependencies.add_pip_package('torch~=1.5') - dependencies.add_pip_package('torchvision~=0.6') - run_config.environment.python.conda_dependencies = dependencies - # Specify a docker base image since the default one is ubuntu 16.04. - run_config.environment.docker.enabled = True - run_config.environment.docker.base_image = 'mcr.microsoft.com/azureml/base-gpu:openmpi3.1.2-cuda10.1-cudnn7-ubuntu18.04' - run_config.environment.docker.shm_size = '16g' - - with tempfile.TemporaryDirectory() as temp_dir: - self._generate_bootstrap(temp_dir) - args = [str(job_id), '"' + db_url + '"'] - script_run_config = azureml.core.ScriptRunConfig(source_directory=temp_dir, script='boot.py', arguments=args, run_config=run_config) - run = self.experiment.submit(config=script_run_config) - run_id = run.get_details()['runId'] - run_url = run.get_portal_url() - return run_id, run_url - - def query(self, run_id): - """Get the status of the specified run. - Returns: - (str) running, failed, or completed. - """ - try: - run = azureml.core.run.Run(self.experiment, run_id) - return run.get_status().lower() - except Exception as e: - print(e) - return None - - def get_num_available_nodes(self): - """Get the number of available nodes""" - status = self.cluster.get_status() - s = status.serialize() - return s['scaleSettings']['maxNodeCount'] - s['currentNodeCount'] + s['nodeStateCounts']['idleNodeCount'] - - @staticmethod - def _generate_bootstrap(directory): - filepath = os.path.join(directory, 'boot.py') - with open(filepath, 'w') as f: - f.write('import os\n') - f.write('import sys\n') - f.write('os.system("pip install https://github.com/shonohs/mitorch_models/archive/dev.zip")\n') - f.write('os.system("pip install https://github.com/shonohs/mitorch/archive/dev.zip")\n') - f.write('os.system("miamlrun " + " ".join(sys.argv[1:]))') diff --git a/mitorch/azureml/runner.py b/mitorch/azureml/runner.py deleted file mode 100644 index b04d340..0000000 --- a/mitorch/azureml/runner.py +++ /dev/null @@ -1,169 +0,0 @@ -"""Runner script on AzureML instance - -The job of this script is: -- Download a training config from MongoDB. - - ID will be given as a commandline argument. -- Download training datasets from Azure Storage. -- Download a base weights from Azure Storage if it is specified. -- Run Train command -- Upload the trained model to Azure Storage -- Run Test command -- Upload the standard outputs to Azure Storage - -""" -import argparse -import json -import logging -import os -import pathlib -import shutil -import subprocess -import tempfile -import time -import urllib -import uuid -import requests -import tenacity -import torch -from mitorch.service import DatabaseClient - -_logger = logging.getLogger(__name__) - - -class AzureMLRunner: - def __init__(self, db_url, job_id): - assert isinstance(job_id, uuid.UUID) - self.db_url = db_url - self.job_id = job_id - self.client = DatabaseClient(self.db_url) - - def run(self): - """First method to be run on AzureML instance""" - _logger.info("Started") - # Get the job description from the database. - job = self.client.find_training_by_id(self.job_id) - if not job: - raise RuntimeError(f"Unknown job id {self.job_id}.") - - config = job['config'] - dataset_name = config['dataset'] - region = job['region'] - _logger.info(f"Training config: {config}") - - settings = self.client.get_settings() - self.dataset_base_url = settings.dataset_url[region] - self.blob_storage_url = settings.storage_url - - # Record machine setup. - os.system('df -h') - num_gpus = torch.cuda.device_count() - self.client.start_training(self.job_id, num_gpus) - - with tempfile.TemporaryDirectory() as work_dir: - work_dir = pathlib.Path(work_dir) - os.mkdir(work_dir / 'outputs') - output_filepath = work_dir / 'outputs' / 'model.pth' - config_filepath = work_dir / 'config.json' - config_filepath.write_text(json.dumps(config)) - train_filepath, val_filepath = self.download_dataset(dataset_name, work_dir) - weights_filepath = self.download_weights(uuid.UUID(config['base']), work_dir) if 'base' in config else None - - command = ['mitrain', str(config_filepath), str(train_filepath), str(val_filepath), '--output_filepath', str(output_filepath), '--job_id', str(self.job_id), '--db_url', self.db_url] - if weights_filepath: - command.extend(['--weights_filepath', str(weights_filepath)]) - - _logger.info(f"Starting the training. command: {command}") - proc = subprocess.run(command) - - if proc.returncode != 0: - _logger.error(f"Training failed. return code is {proc.returncode}") - self.client.fail_training(self.job_id) - return - - if not output_filepath.exists(): - _logger.error(f"Training failed to generate the output file {output_filepath}") - self.client.fail_training(self.job_id) - return - - _logger.info("Training completed.") - - self.upload_files([output_filepath]) - - command = ['mitest', str(config_filepath), str(train_filepath), str(val_filepath), '--weights_filepath', str(output_filepath), - '--job_id', str(self.job_id), '--db_url', self.db_url] - _logger.info(f"Starting test. command: {command}") - proc = subprocess.run(command) - _logger.info(f"Test completed. returncode: {proc.returncode}") - - if proc.returncode == 0: - self.client.complete_training(self.job_id) - else: - self.client.fail_training(self.job_id) - - _logger.info("All completed.") - - def download_dataset(self, dataset_name, directory): - dataset = self.client.find_dataset_by_name(dataset_name) - train_filepath = self._download_blob_file(self.dataset_base_url, dataset['train']['path'], directory) - val_filepath = self._download_blob_file(self.dataset_base_url, dataset['val']['path'], directory) - - files = set(dataset['train']['support_files'] + dataset['val']['support_files']) - for uri in files: - self._download_blob_file(self.dataset_base_url, uri, directory) - - return train_filepath, val_filepath - - def download_weights(self, base_job_id, directory): - blob_path = os.path.join(base_job_id.hex, 'model.pth') - return self._download_blob_file(self.blob_storage_url, blob_path, directory) - - def upload_files(self, files): - for filepath in files: - blob_path = os.path.join(self.job_id.hex, filepath.name) - self._upload_blob_file(self.blob_storage_url, filepath, blob_path) - - @staticmethod - def _upload_blob_file(base_blob_uri, local_filepath, blob_path): - parts = urllib.parse.urlparse(base_blob_uri) - path = os.path.join(parts[2], blob_path) - url = urllib.parse.urlunparse((parts[0], parts[1], path, parts[3], parts[4], parts[5])) - _logger.info(f"Uploading {local_filepath} to {url}") - requests.put(url=url, data=local_filepath.read_bytes(), headers={'Content-Type': 'application/octet-stream', 'x-ms-blob-type': 'BlockBlob'}) - - @staticmethod - def _download_blob_file(base_blob_uri, blob_path, directory): - parts = urllib.parse.urlparse(base_blob_uri) - path = os.path.join(parts[2], blob_path) - url = urllib.parse.urlunparse((parts[0], parts[1], path, parts[3], parts[4], parts[5])) - return AzureMLRunner._download_file(url, directory) - - @staticmethod - @tenacity.retry(stop=tenacity.stop_after_attempt(3)) - def _download_file(url, directory): - filename = os.path.basename(urllib.parse.urlparse(url).path) - filepath = directory / filename - _logger.info(f"Downloading {url} to {filepath}") - start = time.time() - with requests.get(url, stream=True, allow_redirects=True) as r: - with open(filepath, 'wb') as f: - shutil.copyfileobj(r.raw, f, length=4194304) # 4MB - _logger.debug(f"Downloaded. {time.time() - start}s.") - return filepath - - -def main(): - logging.getLogger().setLevel(logging.INFO) - logging.getLogger('mitorch').setLevel(logging.DEBUG) - - parser = argparse.ArgumentParser("Run training on AzureML") - parser.add_argument('job_id', type=uuid.UUID, help="Guid for the target run") - parser.add_argument('db_url', help="MongoDB URI for training management") - - args = parser.parse_args() - - runner = AzureMLRunner(args.db_url, args.job_id) - runner.run() - - -if __name__ == '__main__': - main() diff --git a/mitorch/builders/__init__.py b/mitorch/builders/__init__.py index c0ef489..8df94a3 100644 --- a/mitorch/builders/__init__.py +++ b/mitorch/builders/__init__.py @@ -1,6 +1,7 @@ from .dataloader_builder import DataLoaderBuilder +from .evaluator_builder import EvaluatorBuilder from .lr_scheduler_builder import LrSchedulerBuilder from .model_builder import ModelBuilder from .optimizer_builder import OptimizerBuilder -__all__ = ['DataLoaderBuilder', 'LrSchedulerBuilder', 'ModelBuilder', 'OptimizerBuilder'] +__all__ = ['DataLoaderBuilder', 'EvaluatorBuilder', 'LrSchedulerBuilder', 'ModelBuilder', 'OptimizerBuilder'] diff --git a/mitorch/builders/dataloader_builder.py b/mitorch/builders/dataloader_builder.py index a23210c..e5920d8 100644 --- a/mitorch/builders/dataloader_builder.py +++ b/mitorch/builders/dataloader_builder.py @@ -1,9 +1,9 @@ import functools import logging import torch -from ..datasets import (ImageDataset, CenterCropTransform, CenterCropTransformV2, CenterCropTransformV3, ResizeTransform, ResizeFlipTransform, - RandomResizedCropTransform, RandomResizedCropTransformV2, RandomResizedCropTransformV3, RandomResizedCropTransformV4, - RandomSizedBBoxSafeCropTransform, InceptionTransform, DevTransform, Dev2Transform) +from mitorch.datasets import ImageDataset, TransformFactory + +NUM_WORKERS = 4 def _default_collate(task_type, batch): @@ -18,45 +18,34 @@ def _default_collate(task_type, batch): class DataLoaderBuilder: def __init__(self, config): - self.augmentation_config = config['augmentation'] - self.task_type = config['task_type'] - self.input_size = config['input_size'] - self.batch_size = config['batch_size'] + self.augmentation_config = config.augmentation + self.task_type = config.task_type + self.input_size = config.model.input_size + self.batch_size = config.batch_size def build(self, train_dataset_filepath, val_dataset_filepath): logging.info(f"Building a data_loader. train: {train_dataset_filepath}, val: {val_dataset_filepath}, augmentation: {self.augmentation_config}, " f"task: {self.task_type}, input_size: {self.input_size}, batch_size: {self.batch_size}") is_object_detection = self.task_type == 'object_detection' - train_augmentation = self.build_augmentation(self.augmentation_config['train'], self.input_size, is_object_detection) - train_dataset = ImageDataset.from_file(train_dataset_filepath, train_augmentation) + collate_fn = functools.partial(_default_collate, self.task_type) - val_augmentation = self.build_augmentation(self.augmentation_config['val'], self.input_size, is_object_detection) - val_dataset = ImageDataset.from_file(val_dataset_filepath, val_augmentation) + train_augmentation = self.build_augmentation(self.augmentation_config.train, self.input_size, is_object_detection) + train_dataset = ImageDataset.from_file(train_dataset_filepath, train_augmentation) + train_dataloader = torch.utils.data.DataLoader(train_dataset, self.batch_size, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True, collate_fn=collate_fn) - collate_fn = functools.partial(_default_collate, self.task_type) - train_dataloader = torch.utils.data.DataLoader(train_dataset, self.batch_size, shuffle=True, num_workers=8, pin_memory=True, collate_fn=collate_fn) - val_dataloader = torch.utils.data.DataLoader(val_dataset, self.batch_size, shuffle=False, num_workers=8, pin_memory=True, collate_fn=collate_fn) + if val_dataset_filepath: + val_augmentation = self.build_augmentation(self.augmentation_config.val, self.input_size, is_object_detection) + val_dataset = ImageDataset.from_file(val_dataset_filepath, val_augmentation) + val_dataloader = torch.utils.data.DataLoader(val_dataset, self.batch_size, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True, collate_fn=collate_fn) + else: + val_dataloader = None return train_dataloader, val_dataloader @staticmethod def build_augmentation(name, input_size, is_object_detection): - augmentation_class = {'center_crop': CenterCropTransform, - 'center_crop_v2': CenterCropTransformV2, - 'center_crop_v3': CenterCropTransformV3, - 'dev': DevTransform, - 'dev2': Dev2Transform, - 'inception': InceptionTransform, - 'resize': ResizeTransform, - 'resize_flip': ResizeFlipTransform, - 'random_resize': RandomResizedCropTransform, - 'random_resize_v2': RandomResizedCropTransformV2, - 'random_resize_v3': RandomResizedCropTransformV3, - 'random_resize_v4': RandomResizedCropTransformV4, - 'random_resize_bbox': RandomSizedBBoxSafeCropTransform}.get(name) - - if not augmentation_class: + transform = TransformFactory(is_object_detection).create(name, input_size) + if not transform: raise NotImplementedError(f"Non supported augmentation: {name}") - - return augmentation_class(input_size, is_object_detection) + return transform diff --git a/mitorch/builders/evaluator_builder.py b/mitorch/builders/evaluator_builder.py new file mode 100644 index 0000000..4e7e0dd --- /dev/null +++ b/mitorch/builders/evaluator_builder.py @@ -0,0 +1,13 @@ +from mitorch.evaluators import MulticlassClassificationEvaluator, MultilabelClassificationEvaluator, ObjectDetectionEvaluator + + +class EvaluatorBuilder: + def __init__(self, config): + self._task_type = config.task_type + + def build(self): + mappings = {'multiclass_classification': MulticlassClassificationEvaluator, + 'multilabel_classification': MultilabelClassificationEvaluator, + 'object_detection': ObjectDetectionEvaluator} + assert self._task_type in mappings + return mappings[self._task_type]() diff --git a/mitorch/builders/lr_scheduler_builder.py b/mitorch/builders/lr_scheduler_builder.py index 7448e91..3709e43 100644 --- a/mitorch/builders/lr_scheduler_builder.py +++ b/mitorch/builders/lr_scheduler_builder.py @@ -49,28 +49,28 @@ def get_lr(self): class LrSchedulerBuilder: def __init__(self, config): - self.config = config['lr_scheduler'] - self.max_epochs = config['max_epochs'] + self.config = config.lr_scheduler + self.max_epochs = config.max_epochs def build(self, optimizer, num_epoch_iters): total_iters = num_epoch_iters * self.max_epochs logging.info(f"Building a lr_scheduler. total_iters: {total_iters}, config: {self.config}") - if self.config['name'] == 'cosine_annealing': + if self.config.name == 'cosine_annealing': lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, total_iters) - elif self.config['name'] == 'linear_decreasing': + elif self.config.name == 'linear_decreasing': lr_scheduler = LinearDecreasingLR(optimizer, total_iters) - elif self.config['name'] == 'step': - step_size = self.config['step_size'] * num_epoch_iters - lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size, self.config['step_gamma']) + elif self.config.name == 'step': + step_size = self.config.step_size * num_epoch_iters + lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size, self.config.step_gamma) else: raise NotImplementedError(f"Unsupported LR scheduler: {self.config['name']}") - warmup_scheduler = self.config.get('warmup') + warmup_scheduler = self.config.warmup if warmup_scheduler: - warmup_epochs = self.config.get('warmup_epochs') + warmup_epochs = self.config.warmup_epochs warmup_iters = warmup_epochs * num_epoch_iters - warmup_factor = self.config.get('warmup_factor', 0.01) + warmup_factor = self.config.warmup_factor if warmup_scheduler == 'const': lr_scheduler = WarmupLR(lr_scheduler, warmup_iters, warmup_factor) elif warmup_scheduler == 'linear': diff --git a/mitorch/builders/model_builder.py b/mitorch/builders/model_builder.py index 983e3a9..d248217 100644 --- a/mitorch/builders/model_builder.py +++ b/mitorch/builders/model_builder.py @@ -7,19 +7,19 @@ class ModelBuilder: def __init__(self, config): - self.config = config['model'] - self.task_type = config['task_type'] + self.config = config.model + self.task_type = config.task_type def build(self, num_classes, weights_filepath=None): logging.info(f"Building a model. weights: {weights_filepath}, config: {self.config}") - model_options = set(self.config['options']) + model_options = set(self.config.options) if self.task_type == 'multiclass_classification' and 'multilabel' in model_options: model_options.remove('multilabel') elif self.task_type == 'multilabel_classification': model_options.add('multilabel') - model = ModelFactory.create(self.config['name'], num_classes, list(model_options)) + model = ModelFactory.create(self.config.name, num_classes, list(model_options)) if weights_filepath: self._load_weights(model, weights_filepath) diff --git a/mitorch/builders/optimizer_builder.py b/mitorch/builders/optimizer_builder.py index 1e60fc0..46ea70f 100644 --- a/mitorch/builders/optimizer_builder.py +++ b/mitorch/builders/optimizer_builder.py @@ -4,13 +4,13 @@ class OptimizerBuilder: def __init__(self, config): - self.config = config['optimizer'] - self.base_lr = config['lr_scheduler']['base_lr'] + self.config = config.optimizer + self.base_lr = config.lr_scheduler.base_lr def build(self, model): logging.info(f"Building a optimizer. base_lr: {self.base_lr}, config: {self.config}") - momentum = self.config['momentum'] - weight_decay = self.config['weight_decay'] + momentum = self.config.momentum + weight_decay = self.config.weight_decay assert isinstance(self.base_lr, float) and self.base_lr > 0 assert isinstance(momentum, float) and momentum > 0 @@ -27,9 +27,9 @@ def build(self, model): params = [{'params': params_with_decay, 'weight_decay': weight_decay}, {'params': params_no_decay, 'weight_decay': 0}] - if self.config['name'] == 'adam': + if self.config.name == 'adam': return torch.optim.Adam(params, lr=self.base_lr, weight_decay=0) - if self.config['name'] == 'sgd': + if self.config.name == 'sgd': return torch.optim.SGD(params, lr=self.base_lr, momentum=momentum, weight_decay=0) else: raise NotImplementedError(f"Non-supported optimizer: {self.config['name']}") diff --git a/mitorch/commands/__init__.py b/mitorch/commands/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/mitorch/commands/agent.py b/mitorch/commands/agent.py new file mode 100644 index 0000000..b7e581c --- /dev/null +++ b/mitorch/commands/agent.py @@ -0,0 +1,104 @@ +"""Get training configs from database and run mitrain with it.""" +import argparse +import logging +import pathlib +import tempfile +import torch +from mitorch.common import Environment, JobRepository, ModelRepository +from mitorch.commands.train import train +from mitorch.commands.common import init_logging + +logger = logging.getLogger(__name__) + + +def process_one_job(job, db_url, model_repository, data_dir): + # Get the next training config. + train_dataset_filepath = data_dir / job.config.dataset.train + val_dataset_filepath = data_dir / job.config.dataset.val + + with tempfile.TemporaryDirectory() as temp_dir: + temp_dir = pathlib.Path(temp_dir) + output_filepath = temp_dir / 'trained_weights.pth' + + if job.base_job_id: + pretrained_weights_filepath = temp_dir / 'pretrained_weights.pth' + model_repository.download_weights(job.base_job_id, pretrained_weights_filepath) + else: + pretrained_weights_filepath = None + + log_filepath = temp_dir / 'training.log' + tb_log_dir = temp_dir / 'tensorboard/' + log_handler = logging.FileHandler(log_filepath) + logging.getLogger().addHandler(log_handler) + train(job.config, train_dataset_filepath, val_dataset_filepath, pretrained_weights_filepath, output_filepath, job.job_id, db_url, tb_log_dir) + logging.getLogger().removeHandler(log_handler) + + model_repository.upload_weights(job.job_id, output_filepath) + + # Optional file uploads. + try: + model_repository.upload_config(job.job_id, job.config) + except Exception: + logger.exception("Failed to upload a job config.") + + try: + model_repository.upload_file(job.job_id, log_filepath) + except Exception: + logger.exception("Failed to upload a log file.") + + try: + model_repository.upload_dir(job.job_id, tb_log_dir) + except Exception: + logger.exception("Failed to upload a tensorboard log.") + + +def run_agent(db_url, storage_url, data_dir, num_runs): + logger.info("Starting an agent.") + + job_repository = JobRepository(db_url) + model_repository = ModelRepository(storage_url) + + num_processes = torch.cuda.device_count() if torch.cuda.is_available() else 1 + + for _ in range(num_runs): + job = job_repository.get_next_job(num_processes=num_processes) + if not job: + logger.info("The job queue is empty. Exiting...") + break + + logger.info(f"Got a new job! {job.job_id}") + + try: + job_repository.update_job_status(job.job_id, 'running') + process_one_job(job, db_url, model_repository, data_dir) + job_repository.update_job_status(job.job_id, 'completed') + except Exception: + logger.exception(f"Failed to process a job {job.job_id}.") + job_repository.update_job_status(job.job_id, 'failed') + + logger.info(f"Job {job.job_id} completed!") + + logger.info("All done!") + + +def main(): + init_logging() + env = Environment() + parser = argparse.ArgumentParser() + parser.add_argument('--data', required=True, type=pathlib.Path) + parser.add_argument('--num_runs', '-n', type=int, default=10000) + parser.add_argument('--db_url', default=env.db_url, help="URL for a mongo db that stores training configs.") + parser.add_argument('--storage_url', default=env.storage_url, help="Blob container URL with SAS token. Trained weights will be stored here.") + + args = parser.parse_args() + + if not args.db_url: + parser.error("A database url must be specified.") + if not args.storage_url: + parser.error("A storage url must be specified.") + + run_agent(args.db_url, args.storage_url, args.data, args.num_runs) + + +if __name__ == '__main__': + main() diff --git a/mitorch/commands/common.py b/mitorch/commands/common.py new file mode 100644 index 0000000..abb3e98 --- /dev/null +++ b/mitorch/commands/common.py @@ -0,0 +1,5 @@ +import logging + + +def init_logging(): + logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(module)s %(process)d: %(message)s') diff --git a/mitorch/commands/query.py b/mitorch/commands/query.py new file mode 100644 index 0000000..992a7e8 --- /dev/null +++ b/mitorch/commands/query.py @@ -0,0 +1,52 @@ +"""A command to submit a new job.""" +import argparse +import uuid +from mitorch.common import Environment, JobRepository, MetricsRepository +from mitorch.commands.common import init_logging + + +def query_job(db_url, job_id): + job_repository = JobRepository(db_url) + metrics_repository = MetricsRepository(db_url) + job = job_repository.get_job(job_id) + if not job: + print("Not found.") + return + + metrics = metrics_repository.get_metrics(job_id) + + print(job) + for m in metrics: + print(m) + + +def query_job_list(db_url): + job_repository = JobRepository(db_url) + jobs = job_repository.query_jobs() + if not jobs: + print("No job found.") + + for job_record in jobs: + print(job_record) + + +def main(): + init_logging() + + env = Environment() + parser = argparse.ArgumentParser() + parser.add_argument('--db_url', default=env.db_url) + parser.add_argument('--job_id', type=uuid.UUID) + + args = parser.parse_args() + if not args.db_url: + parser.error("A database url must be specified via commandline argument or environment variable.") + + if args.job_id: + query_job(args.db_url, args.job_id) + else: + query_job_list(args.db_url) + + +if __name__ == '__main__': + main() diff --git a/mitorch/commands/submit.py b/mitorch/commands/submit.py new file mode 100644 index 0000000..a616371 --- /dev/null +++ b/mitorch/commands/submit.py @@ -0,0 +1,35 @@ +"""A command to submit a new job.""" +import argparse +import pathlib +import uuid +import jsons +from mitorch.common import Environment, TrainingConfig, JobRepository +from mitorch.commands.common import init_logging + + +def submit_job(training_config, priority, base_job_id, job_repository): + print(f"Adding {training_config}") + job_id = job_repository.add_new_job(training_config, priority, base_job_id) + print(f"Successfully added. Job id is {job_id}") + + +def main(): + init_logging() + + env = Environment() + parser = argparse.ArgumentParser() + parser.add_argument('config_filepath', type=pathlib.Path) + parser.add_argument('--db_url', default=env.db_url) + parser.add_argument('--priority', type=int, default=2, help="Lower value has higher priority.") + parser.add_argument('--base_job_id', type=uuid.UUID) + + args = parser.parse_args() + + training_config = jsons.loads(args.config_filepath.read_text(), TrainingConfig) + job_repository = JobRepository(args.db_url) + + submit_job(training_config, args.priority, args.base_job_id, job_repository) + + +if __name__ == '__main__': + main() diff --git a/mitorch/commands/train.py b/mitorch/commands/train.py new file mode 100644 index 0000000..1ddd927 --- /dev/null +++ b/mitorch/commands/train.py @@ -0,0 +1,77 @@ +"""Train a model based on the given config.""" +import argparse +import importlib.metadata +import logging +import pathlib +import uuid +import jsons +import pytorch_lightning as pl +import torch +from mitorch.builders import DataLoaderBuilder +from mitorch.common import MiModel, TrainingConfig, StandardLogger, MongoDBLogger +from mitorch.commands.common import init_logging + +_logger = logging.getLogger(__name__) + + +def train(config, train_dataset_filepath, val_dataset_filepath, weights_filepath, output_filepath, job_id, db_url, tensorboard_log_dir, fast_dev_run=False): + try: + _logger.info(f"Training started. mitorch version is {importlib.metadata.version('mitorch')}. model version is {importlib.metadata.version('mitorch-models')}") + except Exception: + _logger.info("Training started.") + + logger = [StandardLogger()] + if job_id and db_url: + logger.append(MongoDBLogger(db_url, job_id)) + if tensorboard_log_dir: + logger.append(pl.loggers.TensorBoardLogger(str(tensorboard_log_dir))) + + pl.seed_everything(0) + + _logger.debug(f"Loaded config: {config}") + + gpus = config.num_processes if torch.cuda.is_available() else None + if config.num_processes > 1 and not gpus: + _logger.warning("Multiple processes are requested, but only 1 CPU is available on this node.") + + train_dataloader, val_dataloader = DataLoaderBuilder(config).build(train_dataset_filepath, val_dataset_filepath) + num_classes = len(train_dataloader.dataset.labels) if train_dataloader else len(val_dataloader.dataset.labels) + + trainer = pl.Trainer(max_epochs=config.max_epochs, fast_dev_run=fast_dev_run, gpus=gpus, distributed_backend='ddp', terminate_on_nan=True, + logger=logger, progress_bar_refresh_rate=0, check_val_every_n_epoch=10, num_sanity_val_steps=0, deterministic=False, + accumulate_grad_batches=config.accumulate_grad_batches, checkpoint_callback=False) + + model = MiModel(config, num_classes, weights_filepath) + trainer.fit(model, train_dataloader, val_dataloader) + _logger.info("Training completed.") + + trainer.validate(model, val_dataloader) + _logger.info("Validation completed.") + + if output_filepath: + model.save(output_filepath) + + +def main(): + init_logging() + + parser = argparse.ArgumentParser(description="Train a model") + parser.add_argument('config_filepath', type=pathlib.Path) + parser.add_argument('train_dataset_filepath', type=pathlib.Path) + parser.add_argument('val_dataset_filepath', nargs='?', type=pathlib.Path) + parser.add_argument('--weights_filepath', '-w', type=pathlib.Path) + parser.add_argument('--output_filepath', '-o', type=pathlib.Path) + parser.add_argument('--fast_dev_run', '-d', action='store_true') + parser.add_argument('--job_id', type=uuid.UUID) + parser.add_argument('--db_url') + parser.add_argument('--tensorboard_log', type=pathlib.Path) + + args = parser.parse_args() + config = jsons.loads(args.config_filepath.read_text(), TrainingConfig) + + train(config, args.train_dataset_filepath, args.val_dataset_filepath, args.weights_filepath, + args.output_filepath, args.job_id, args.db_url, args.tensorboard_log, args.fast_dev_run) + + +if __name__ == '__main__': + main() diff --git a/mitorch/common/__init__.py b/mitorch/common/__init__.py new file mode 100644 index 0000000..ecba409 --- /dev/null +++ b/mitorch/common/__init__.py @@ -0,0 +1,9 @@ +from .environment import Environment +from .job_repository import JobRepository +from .logger import StandardLogger, StdoutLogger, MongoDBLogger +from .metrics_repository import MetricsRepository +from .mimodel import MiModel +from .model_repository import ModelRepository +from .training_config import TrainingConfig + +__all__ = ['TrainingConfig', 'Environment', 'JobRepository', 'StandardLogger', 'StdoutLogger', 'MongoDBLogger', 'MetricsRepository', 'MiModel', 'ModelRepository'] diff --git a/mitorch/environment.py b/mitorch/common/environment.py similarity index 56% rename from mitorch/environment.py rename to mitorch/common/environment.py index f987442..df9c191 100644 --- a/mitorch/environment.py +++ b/mitorch/common/environment.py @@ -4,9 +4,12 @@ class Environment: def __init__(self): self._db_url = os.getenv('MITORCH_DATABASE_URL') + self._storage_url = os.getenv('MITORCH_STORAGE_URL') @property def db_url(self): - if not self._db_url: - raise RuntimeError("MITORCH_DATABASE_URL is not set") return self._db_url + + @property + def storage_url(self): + return self._storage_url diff --git a/mitorch/common/job_repository.py b/mitorch/common/job_repository.py new file mode 100644 index 0000000..d24e95b --- /dev/null +++ b/mitorch/common/job_repository.py @@ -0,0 +1,78 @@ +"""Repository for training configs. Used by mitorch-agent.""" +import dataclasses +import datetime +from typing import Optional +import uuid +import jsons +import pymongo +import tenacity +from mitorch.common.training_config import TrainingConfig + + +@dataclasses.dataclass(frozen=True) +class JobRecord: + job_id: uuid.UUID + status: str + config: TrainingConfig + base_job_id: Optional[uuid.UUID] = None + priority: int = 2 + created_at: datetime.datetime = None + updated_at: datetime.datetime = None + + +class JobRepository: + def __init__(self, mongodb_url): + client = pymongo.MongoClient(mongodb_url, uuidRepresentation='standard') + db = client.mitorch + self._job_collection = db.jobs + + @tenacity.retry(retry=tenacity.retry_if_exception_type(pymongo.errors.PyMongoError), stop=tenacity.stop_after_attempt(2), reraise=True) + def get_next_job(self, num_processes=-1): + assert isinstance(num_processes, int) + raw_data = self._job_collection.find_one({'status': 'queued', 'config.num_processes': {'$in': [-1, num_processes]}}, + sort=[('priority', pymongo.ASCENDING), ('created_at', pymongo.ASCENDING)]) + if not raw_data: + return None + return self._job_document_to_record(raw_data) + + @tenacity.retry(retry=tenacity.retry_if_exception_type(pymongo.errors.PyMongoError), stop=tenacity.stop_after_attempt(2), reraise=True) + def update_job_status(self, job_id: uuid.UUID, new_status): + assert isinstance(job_id, uuid.UUID) + assert new_status in ['queued', 'running', 'failed', 'cancelled', 'completed'] + + # TODO: Transaction + result = self._job_collection.update_one({'_id': job_id}, {'$set': {'status': new_status, + 'updated_at': datetime.datetime.utcnow()}}) + if result.modified_count == 0: + raise RuntimeError(f"Job not found: {job_id}") + + @tenacity.retry(retry=tenacity.retry_if_exception_type(pymongo.errors.PyMongoError), stop=tenacity.stop_after_attempt(2), reraise=True) + def add_new_job(self, training_config: TrainingConfig, priority=2, base_job_id=None): + assert base_job_id is None or isinstance(base_job_id, uuid.UUID) + job_id = uuid.uuid4() + job_dict = {'_id': job_id, + 'status': 'queued', + 'config': dataclasses.asdict(training_config), + 'priority': priority, + 'base_job_id': base_job_id, + 'created_at': datetime.datetime.utcnow(), + 'updated_at': datetime.datetime.utcnow()} + + self._job_collection.insert_one(job_dict) + return job_id + + def query_jobs(self): + jobs = self._job_collection.find(sort=[('updated_at', pymongo.DESCENDING)]) + return [self._job_document_to_record(d) for d in jobs] + + def get_job(self, job_id): + assert isinstance(job_id, uuid.UUID) + raw_data = self._job_collection.find_one({'_id': job_id}) + return self._job_document_to_record(raw_data) + + @staticmethod + def _job_document_to_record(raw_data): + raw_data['job_id'] = raw_data['_id'] + raw_data['config'] = jsons.load(raw_data['config'], TrainingConfig) + del raw_data['_id'] + return JobRecord(**raw_data) diff --git a/mitorch/common/logger.py b/mitorch/common/logger.py new file mode 100644 index 0000000..148746b --- /dev/null +++ b/mitorch/common/logger.py @@ -0,0 +1,83 @@ +import datetime +import json +import jsons +import logging +import uuid +import pymongo +from pytorch_lightning.loggers import LightningLoggerBase +from pytorch_lightning.utilities import rank_zero_only +import torch + + +logger = logging.getLogger(__name__) + + +class SerializableMongoClient: + def __init__(self, url): + self._url = url + # w=0: Disable write achknowledgement. + self._client = pymongo.MongoClient(url, uuidRepresentation='standard', w=0) + + def __getattr__(self, name): + return getattr(self._client, name) + + def __getstate__(self): + return {'url': self._url} + + def __setstate__(self, state): + self._url = state['url'] + self._client = pymongo.MongoClient(self._url, uuidRepresentation='standard', w=0) + + +class LoggerBase(LightningLoggerBase): + @property + def experiment(self): + return self + + @property + def name(self): + return 'experiment' + + @property + def version(self): + return 0 + + @rank_zero_only + def log_hyperparams(self, params): + pass + + +class StdoutLogger(LoggerBase): + @rank_zero_only + def log_metrics(self, metrics, step): + print(f"{datetime.datetime.now()}: {step}: {metrics}") + + @rank_zero_only + def log_hyperparams(self, params): + print(f"hyperparams: {jsons.dumps(params)}") + + +class StandardLogger(LoggerBase): + @rank_zero_only + def log_metrics(self, metrics, step): + logger.info(f"Step {step}: {json.dumps(metrics)}") + + @rank_zero_only + def log_hyperparams(self, params): + logger.info(f"hyperparams: {jsons.dumps(params)}") + + +class MongoDBLogger(LoggerBase): + def __init__(self, db_url, job_id): + super().__init__() + assert isinstance(job_id, uuid.UUID) + self._db_url = db_url + self._job_id = job_id + self._client = SerializableMongoClient(db_url) + + @rank_zero_only + def log_metrics(self, metrics, step): + m = {key: value.tolist() if torch.is_tensor(value) else value for key, value in metrics.items() if not key.endswith('_step') and key != 'epoch'} + if m and 'epoch' in metrics: + epoch = metrics['epoch'] + self._client.mitorch.training_metrics.insert_one({'job_id': self._job_id, 'e': epoch, 'm': m}) diff --git a/mitorch/common/metrics_repository.py b/mitorch/common/metrics_repository.py new file mode 100644 index 0000000..6373e14 --- /dev/null +++ b/mitorch/common/metrics_repository.py @@ -0,0 +1,28 @@ +import dataclasses +import typing +import uuid +import pymongo +import tenacity + + +@dataclasses.dataclass(frozen=True) +class Metrics: + job_id: uuid.UUID + epoch: int + metrics: typing.Any + + +class MetricsRepository: + def __init__(self, mongodb_url): + client = pymongo.MongoClient(mongodb_url, uuidRepresentation='standard') + db = client.mitorch + self._metrics_collection = db.training_metrics + + @tenacity.retry(retry=tenacity.retry_if_exception_type(pymongo.errors.PyMongoError), stop=tenacity.stop_after_attempt(2), reraise=True) + def get_metrics(self, job_id): + results = self._metrics_collection.find({'job_id': job_id}) + return [self._to_metrics(r) for r in results] + + @staticmethod + def _to_metrics(raw_data): + return Metrics(job_id=raw_data['job_id'], epoch=raw_data['e'], metrics=raw_data['m']) diff --git a/mitorch/common/mimodel.py b/mitorch/common/mimodel.py new file mode 100644 index 0000000..80dd6c1 --- /dev/null +++ b/mitorch/common/mimodel.py @@ -0,0 +1,66 @@ +"""Lightning Module class for all trainings in mitorch.""" +import logging +from pytorch_lightning import LightningModule +from pytorch_lightning.utilities import rank_zero_only +import torch +from mitorch.builders import EvaluatorBuilder, LrSchedulerBuilder, ModelBuilder, OptimizerBuilder + + +class MiModel(LightningModule): + def __init__(self, config, num_classes, weights_filepath=None): + super().__init__() + self.save_hyperparameters('config') + self.model = ModelBuilder(config).build(num_classes, weights_filepath) + # TODO: Leverage torchmetrics + self.evaluator = EvaluatorBuilder(config).build() + + def configure_optimizers(self): + # lr_scheduler.step() is called after every training steps. + optimizer = OptimizerBuilder(self.hparams['config']).build(self.model) + num_samples = len(self.train_dataloader.dataloader.dataset) + lr_scheduler = LrSchedulerBuilder(self.hparams['config']).build(optimizer, num_samples) + return {'optimizer': optimizer, 'lr_scheduler': {'scheduler': lr_scheduler, 'interval': 'step'}} + + def training_step(self, batch, batch_index): + image, target = batch + output = self.forward(image) + loss = self.model.loss(output, target) + self.log('train_loss', loss, on_epoch=True) + return loss + + def validation_step(self, batch, batch_index): + image, target = batch + output = self.forward(image) + loss = self.model.loss(output, target) + predictions = self.model.predictor(output) + self.evaluator.add_predictions(predictions, target) + self.log('val_loss', loss, sync_dist=True) + + def validation_epoch_end(self, outputs): + results = self.evaluator.get_report() + self.evaluator.reset() + results = {'val_' + key: torch.tensor(value).to(self.device) for key, value in results.items()} + self.log_dict(results, sync_dist=True) + + def test_step(self, batch, batch_index): + image, target = batch + output = self.forward(image) + loss = self.model.loss(output, target) + predictions = self.model.predictor(output) + self.evaluator.add_predictions(predictions, target) + self.log('test_loss', loss, sync_dist=True) + + def test_epoch_end(self, outputs): + results = self.evaluator.get_report() + self.evaluator.reset() + results = {'test_' + key: torch.tensor(value).to(self.device) for key, value in results.items()} + self.log_dict(results, sync_dist=True) + + def forward(self, x): + return self.model(x) + + @rank_zero_only + def save(self, filepath): + logging.info(f"Saving a model to {filepath}") + state_dict = self.model.state_dict() + torch.save(state_dict, filepath) diff --git a/mitorch/common/model_repository.py b/mitorch/common/model_repository.py new file mode 100644 index 0000000..bc1f374 --- /dev/null +++ b/mitorch/common/model_repository.py @@ -0,0 +1,60 @@ +import dataclasses +import json +import logging +import shutil +import urllib.parse +import requests +import tenacity + +logger = logging.getLogger(__name__) + + +class ModelRepository: + def __init__(self, base_url): + self._base_url = base_url + + def upload_weights(self, job_id, filepath): + url = self._get_model_url(job_id) + logger.info(f"Uploading to {url}.") + self._put_blob(url, filepath.read_bytes()) + logger.info("Upload completed.") + + @tenacity.retry(retry=tenacity.retry_if_exception_type(IOError), stop=tenacity.stop_after_attempt(2), reraise=True) + def download_weights(self, job_id, filepath): + url = self._get_model_url(job_id) + logger.info(f"Downloading from {url}.") + with requests.get(url, stream=True) as r: + with open(filepath, 'wb') as f: + shutil.copyfileobj(r.raw, f, length=4 * 1024 * 1024) + + def upload_config(self, job_id, training_config): + config = dataclasses.asdict(training_config) + config_binary = json.dumps(config, indent=4).encode('utf-8') + url = self._get_config_url(job_id) + self._put_blob(url, config_binary) + + def upload_file(self, job_id, filepath): + url = self._get_file_url(job_id, filepath.name) + self._put_blob(url, filepath.read_bytes()) + + def upload_dir(self, job_id, directory): + all_files = [p for p in directory.rglob('*') if p.is_file()] + for filepath in all_files: + name = filepath.relative_to(directory.parent) + url = self._get_file_url(job_id, name) + self._put_blob(url, filepath.read_bytes()) + + def _get_model_url(self, job_id): + return self._get_file_url(job_id, 'model.pth') + + def _get_config_url(self, job_id): + return self._get_file_url(job_id, 'config.json') + + def _get_file_url(self, job_id, relative_path): + parsed = urllib.parse.urlparse(self._base_url) + path = parsed.path + ('/' if parsed.path[-1] != '/' else '') + str(job_id) + '/' + str(relative_path) + return urllib.parse.urlunparse((*parsed[:2], path, *parsed[3:])) + + @tenacity.retry(retry=tenacity.retry_if_exception_type(IOError), stop=tenacity.stop_after_attempt(2), reraise=True) + def _put_blob(self, url, data): + return requests.put(url=url, data=data, headers={'Content-Type': 'application/octet-stream', 'x-ms-blob-type': 'BlockBlob'}) diff --git a/mitorch/common/training_config.py b/mitorch/common/training_config.py new file mode 100644 index 0000000..1d3fd55 --- /dev/null +++ b/mitorch/common/training_config.py @@ -0,0 +1,54 @@ +import dataclasses +from typing import Optional, List + + +@dataclasses.dataclass(frozen=True) +class AugmentationConfig: + train: Optional[str] + val: Optional[str] + + +@dataclasses.dataclass(frozen=True) +class LrSchedulerConfig: + name: str + base_lr: float + step_size: Optional[int] + step_gamma: Optional[float] + warmup: Optional[str] + warmup_epochs: Optional[int] + warmup_factor: Optional[float] = 0.01 + + +@dataclasses.dataclass(frozen=True) +class ModelConfig: + name: str + input_size: int + options: List[str] = dataclasses.field(default_factory=list) + + +@dataclasses.dataclass(frozen=True) +class OptimizerConfig: + name: str = 'sgd' + momentum: float = 0.9 + weight_decay: float = 1e-5 + + +@dataclasses.dataclass(frozen=True) +class DatasetConfig: + """Used by mitorch-agent to prepare a training environment.""" + train: str + val: Optional[str] + + +@dataclasses.dataclass(frozen=True) +class TrainingConfig: + task_type: str + batch_size: int + max_epochs: int + model: ModelConfig = None + augmentation: AugmentationConfig = None + lr_scheduler: LrSchedulerConfig = None + optimizer: OptimizerConfig = None + dataset: Optional[DatasetConfig] = None + num_processes: int = -1 + accumulate_grad_batches: int = 1 diff --git a/mitorch/datasets/__init__.py b/mitorch/datasets/__init__.py index ca5d2cd..c76d4f5 100644 --- a/mitorch/datasets/__init__.py +++ b/mitorch/datasets/__init__.py @@ -1,9 +1,4 @@ -from .image_dataset import ImageDataset -from .albumentations_transforms import (CenterCropTransform, CenterCropTransformV2, CenterCropTransformV3, - RandomResizedCropTransform, RandomResizedCropTransformV2, RandomResizedCropTransformV3, RandomResizedCropTransformV4, - ResizeTransform, ResizeFlipTransform, RandomSizedBBoxSafeCropTransform) -from .transforms import InceptionTransform, DevTransform, Dev2Transform +from .image_dataset import ImageDataset, ObjectDetectionDataset +from .factory import TransformFactory -__all__ = ['ImageDataset', 'CenterCropTransform', 'CenterCropTransformV2', 'CenterCropTransformV3', 'ResizeFlipTransform', 'ResizeTransform', - 'RandomResizedCropTransform', 'RandomResizedCropTransformV2', 'RandomResizedCropTransformV3', 'RandomResizedCropTransformV4', - 'RandomSizedBBoxSafeCropTransform', 'InceptionTransform', 'DevTransform', 'Dev2Transform'] +__all__ = ['ImageDataset', 'ObjectDetectionDataset', 'TransformFactory'] diff --git a/mitorch/datasets/albumentations_transforms.py b/mitorch/datasets/albumentations_transforms.py index 0f795e3..8bc1693 100644 --- a/mitorch/datasets/albumentations_transforms.py +++ b/mitorch/datasets/albumentations_transforms.py @@ -1,108 +1,72 @@ import albumentations +import albumentations.pytorch import cv2 import numpy as np -import torch -import torchvision - -MIN_BOX_SIZE = 0.001 class AlbumentationsTransform: - def __init__(self, transforms, is_object_detection): - bbox_params = albumentations.BboxParams(format='pascal_voc', label_fields=['category_id']) if is_object_detection else None + def __init__(self, input_size, is_object_detection): + bbox_params = albumentations.BboxParams(format='albumentations', label_fields=['category_id'], min_area=16, min_visibility=0.1) if is_object_detection else None + normalize = albumentations.Normalize(mean=[0.5, 0.5, 0.5], std=[1, 1, 1]) + to_tensor = albumentations.pytorch.ToTensorV2() + transforms = self.get_transforms(input_size) + self.transforms = albumentations.Compose(transforms + [normalize, to_tensor], bbox_params=bbox_params) self.is_object_detection = is_object_detection - self.transforms = albumentations.Compose(transforms, bbox_params=bbox_params) - self.to_tensor = torchvision.transforms.ToTensor() - self.mean_value = torch.tensor([0.5, 0.5, 0.5]).reshape(3, 1, 1) def __call__(self, image, target): + w, h = image.width, image.height image = np.array(image) if self.is_object_detection: - bboxes = [t[1:] for t in target] + bboxes = [[t[1] / w, t[2] / h, t[3] / w, t[4] / h] for t in target] category_id = [t[0] for t in target] - augmented = self.transforms(image=np.array(image), bboxes=bboxes, category_id=category_id) - image = augmented['image'] - w, h = image.shape[0:2] - target = [[label, bbox[0] / w, bbox[1] / h, bbox[2] / w, bbox[3] / h] for label, bbox in zip(augmented['category_id'], augmented['bboxes']) - if bbox[0] + MIN_BOX_SIZE < bbox[2] and bbox[1] + MIN_BOX_SIZE < bbox[3]] + augmented = self.transforms(image=image, bboxes=bboxes, category_id=category_id) + target = [[label, *bbox] for label, bbox in zip(augmented['category_id'], augmented['bboxes'])] else: - image = self.transforms(image=image)['image'] - - image = self.to_tensor(image) - self.mean_value - return image, target + augmented = self.transforms(image=image) + return augmented['image'], target -class RandomResizedCropTransformV2(AlbumentationsTransform): - def __init__(self, input_size, is_object_detection): - transforms = [albumentations.augmentations.transforms.RandomResizedCrop(input_size, input_size), - albumentations.augmentations.transforms.HorizontalFlip()] - super().__init__(transforms, is_object_detection) + def get_transforms(self, input_size): + raise NotImplementedError -class RandomResizedCropTransform(AlbumentationsTransform): - def __init__(self, input_size, is_object_detection): - transforms = [albumentations.augmentations.transforms.RandomResizedCrop(input_size, input_size), - albumentations.augmentations.transforms.Flip(), - albumentations.augmentations.transforms.RandomBrightnessContrast()] - super().__init__(transforms, is_object_detection) +class SurveillanceCameraTransform(AlbumentationsTransform): + def get_transforms(self, input_size): + return [albumentations.RandomSizedBBoxSafeCrop(input_size, input_size, interpolation=cv2.INTER_CUBIC), + albumentations.HorizontalFlip(), + albumentations.ImageCompression(quality_lower=20, quality_upper=100), + albumentations.RandomBrightnessContrast(), + albumentations.ToGray(p=0.1)] -class RandomResizedCropTransformV3(AlbumentationsTransform): - def __init__(self, input_size, is_object_detection): - transforms = [albumentations.augmentations.transforms.RandomResizedCrop(input_size, input_size), - albumentations.augmentations.transforms.HorizontalFlip(), - albumentations.augmentations.transforms.RandomBrightnessContrast()] - super().__init__(transforms, is_object_detection) - - -class RandomResizedCropTransformV4(AlbumentationsTransform): - def __init__(self, input_size, is_object_detection): - transforms = [albumentations.augmentations.transforms.RandomResizedCrop(input_size, input_size, interpolation=cv2.INTER_CUBIC), - albumentations.augmentations.transforms.HorizontalFlip(), - albumentations.augmentations.transforms.RandomBrightnessContrast()] - super().__init__(transforms, is_object_detection) +class RandomResizedCropTransform(AlbumentationsTransform): + def get_transforms(self, input_size): + return [albumentations.RandomResizedCrop(input_size, input_size, interpolation=cv2.INTER_CUBIC), + albumentations.HorizontalFlip(), + albumentations.RandomBrightnessContrast()] class ResizeTransform(AlbumentationsTransform): - def __init__(self, input_size, is_object_detection): - transforms = [albumentations.augmentations.transforms.Resize(input_size, input_size)] - super().__init__(transforms, is_object_detection) + def get_transforms(self, input_size): + return [albumentations.Resize(input_size, input_size)] class ResizeFlipTransform(AlbumentationsTransform): - def __init__(self, input_size, is_object_detection): - transforms = [albumentations.augmentations.transforms.Resize(input_size, input_size), - albumentations.augmentations.transforms.Flip()] - super().__init__(transforms, is_object_detection) + def get_transforms(self, input_size): + return [albumentations.Resize(input_size, input_size), + albumentations.Flip()] class RandomSizedBBoxSafeCropTransform(AlbumentationsTransform): - def __init__(self, input_size, is_object_detection): - transforms = [albumentations.augmentations.transforms.RandomSizedBBoxSafeCrop(input_size, input_size, erosion_rate=0.2), - albumentations.augmentations.transforms.Flip(), - albumentations.augmentations.transforms.RandomBrightnessContrast()] - super().__init__(transforms, is_object_detection) + def get_transforms(self, input_size): + return [albumentations.RandomSizedBBoxSafeCrop(input_size, input_size, erosion_rate=0.2), + albumentations.Flip(), + albumentations.RandomBrightnessContrast()] class CenterCropTransform(AlbumentationsTransform): - def __init__(self, input_size, is_object_detection): - transforms = [albumentations.augmentations.transforms.SmallestMaxSize(input_size), - albumentations.augmentations.transforms.CenterCrop(input_size, input_size)] - super().__init__(transforms, is_object_detection) - - -class CenterCropTransformV2(AlbumentationsTransform): """This method was found in pytorch's imagenet training example.""" - def __init__(self, input_size, is_object_detection): - transforms = [albumentations.augmentations.transforms.SmallestMaxSize(int(input_size / 224 * 256)), - albumentations.augmentations.transforms.CenterCrop(input_size, input_size)] - super().__init__(transforms, is_object_detection) - - -class CenterCropTransformV3(AlbumentationsTransform): - """This method was found in pytorch's imagenet training example.""" - def __init__(self, input_size, is_object_detection): - transforms = [albumentations.augmentations.transforms.SmallestMaxSize(int(input_size / 224 * 256), interpolation=cv2.INTER_CUBIC), - albumentations.augmentations.transforms.CenterCrop(input_size, input_size)] - super().__init__(transforms, is_object_detection) + def get_transforms(self, input_size): + return [albumentations.SmallestMaxSize(int(input_size / 224 * 256), interpolation=cv2.INTER_CUBIC), + albumentations.CenterCrop(input_size, input_size)] diff --git a/mitorch/datasets/factory.py b/mitorch/datasets/factory.py new file mode 100644 index 0000000..683dbef --- /dev/null +++ b/mitorch/datasets/factory.py @@ -0,0 +1,23 @@ +from mitorch.datasets.albumentations_transforms import (CenterCropTransform, ResizeTransform, ResizeFlipTransform, RandomResizedCropTransform, + RandomSizedBBoxSafeCropTransform, SurveillanceCameraTransform) +from mitorch.datasets.transforms import InceptionTransform + + +class TransformFactory: + def __init__(self, is_object_detection): + self._is_object_detection = is_object_detection + + def create(self, name, input_size): + augmentation_class = {'center_crop': CenterCropTransform, + 'inception': InceptionTransform, + 'resize': ResizeTransform, + 'resize_flip': ResizeFlipTransform, + 'random_resize': RandomResizedCropTransform, + 'random_resize_bbox': RandomSizedBBoxSafeCropTransform, + 'surveillance': SurveillanceCameraTransform + }.get(name) + + if not augmentation_class: + return None + + return augmentation_class(input_size, self._is_object_detection) diff --git a/mitorch/datasets/image_dataset.py b/mitorch/datasets/image_dataset.py index f98feb7..b145641 100644 --- a/mitorch/datasets/image_dataset.py +++ b/mitorch/datasets/image_dataset.py @@ -27,7 +27,7 @@ def _load_labels(self, max_label): labels_filepath = self.base_dir / 'labels.txt' if labels_filepath.exists(): with open(labels_filepath) as f: - labels = [l.strip() for l in f.readlines()] + labels = [line.strip() for line in f.readlines()] assert len(labels) > max_label return labels else: diff --git a/mitorch/logger.py b/mitorch/logger.py deleted file mode 100644 index 0436065..0000000 --- a/mitorch/logger.py +++ /dev/null @@ -1,87 +0,0 @@ -import datetime -import uuid -import pymongo -from pytorch_lightning.loggers import LightningLoggerBase -from pytorch_lightning.utilities import rank_zero_only -import torch - - -class StdoutLogger(LightningLoggerBase): - @rank_zero_only - def log_metrics(self, metrics, step): - if metrics: - print(f"{datetime.datetime.now()}: {step}: {metrics}") - - @rank_zero_only - def log_hyperparams(self, params): - print(str(params)) - - @property - def experiment(self): - return self - - @rank_zero_only - def log_epoch_metrics(self, metrics, epoch): - metrics = {key: value.tolist() if isinstance(value, torch.Tensor) else value for key, value in metrics.items()} - print(f"{datetime.datetime.now()}: Epoch {epoch}: {metrics}") - - @property - def name(self): - return "experiment" - - @property - def version(self): - return 0 - - -class SerializableMongoClient: - def __init__(self, url): - self._url = url - # w=0: Disable write achknowledgement. - self._client = pymongo.MongoClient(url, uuidRepresentation='standard', w=0) - - def __getattr__(self, name): - return getattr(self._client, name) - - def __getstate__(self): - return {'url': self._url} - - def __setstate__(self, state): - self._url = state['url'] - self._client = pymongo.MongoClient(self._url, uuidRepresentation='standard', w=0) - - -class MongoDBLogger(LightningLoggerBase): - def __init__(self, db_url, training_id): - super().__init__() - assert isinstance(training_id, uuid.UUID) - self._db_url = db_url - self.training_id = training_id - self.client = SerializableMongoClient(db_url) - - @rank_zero_only - def log_metrics(self, metrics, step): - pass - - @rank_zero_only - def log_hyperparams(self, params): - if params and 'model_versions' in params: - model_versions = params['model_versions'] - self.client.mitorch.trainings.update_one({'_id': self.training_id}, {'$set': {'model_versions': model_versions}}) - - @property - def experiment(self): - return self - - @rank_zero_only - def log_epoch_metrics(self, metrics, epoch): - m = {key: value.tolist() if isinstance(value, torch.Tensor) else value for key, value in metrics.items()} - self.client.mitorch.training_metrics.insert_one({'tid': self.training_id, 'e': epoch, 'm': m}) - - @property - def name(self): - return "experiment" - - @property - def version(self): - return 0 diff --git a/mitorch/mimodel.py b/mitorch/mimodel.py deleted file mode 100644 index 59295ce..0000000 --- a/mitorch/mimodel.py +++ /dev/null @@ -1,103 +0,0 @@ -import logging -import pytorch_lightning as pl -from pytorch_lightning.utilities import rank_zero_only -import torch -from .builders import DataLoaderBuilder, LrSchedulerBuilder, ModelBuilder, OptimizerBuilder -from .evaluators import MulticlassClassificationEvaluator, MultilabelClassificationEvaluator, ObjectDetectionEvaluator - - -class MiModel(pl.LightningModule): - def __init__(self, hparams): - super().__init__() - # Save arguments so that checkpoint can load it. Remove after updating the pytorch lightning - self.hparams = hparams - config = hparams['config'] - train_dataset_filepath = hparams['train_dataset_filepath'] - val_dataset_filepath = hparams['val_dataset_filepath'] - weights_filepath = hparams['weights_filepath'] - - self._train_dataloader, self._val_dataloader = DataLoaderBuilder(config).build(train_dataset_filepath, val_dataset_filepath) - num_classes = len(self._train_dataloader.dataset.labels) - self.model = ModelBuilder(config).build(num_classes, weights_filepath) - self.optimizer = OptimizerBuilder(config).build(self.model) - self.lr_scheduler = LrSchedulerBuilder(config).build(self.optimizer, len(self._train_dataloader)) - self.evaluator = self._get_evaluator(config['task_type']) - self.train_epoch = 0 - - @property - def model_version(self): - return self.model.version - - @staticmethod - def _get_evaluator(task_type): - mappings = {'multiclass_classification': MulticlassClassificationEvaluator, - 'multilabel_classification': MultilabelClassificationEvaluator, - 'object_detection': ObjectDetectionEvaluator} - assert task_type in mappings - return mappings[task_type]() - - def configure_optimizers(self): - # lr_scheduler.step() is called after every training steps. - return {'optimizer': self.optimizer, 'lr_scheduler': {'scheduler': self.lr_scheduler, 'interval': 'step'}} - - def train_dataloader(self): - return self._train_dataloader - - def val_dataloader(self): - return self._val_dataloader - - def test_dataloader(self): - return self.val_dataloader() - - def training_step(self, batch, batch_index): - image, target = batch - output = self.forward(image) - loss = self.model.loss(output, target) - return {'loss': loss, 'log': {'train_loss': float(loss)}} - - def training_epoch_end(self, outputs): - train_loss = torch.cat([o['loss'] if o['loss'].shape else o['loss'].unsqueeze(0) for o in outputs], dim=0).mean() - self._log_epoch_metrics({'train_loss': train_loss}, self.current_epoch) - return {} - - def validation_step(self, batch, batch_index): - image, target = batch - output = self.forward(image) - loss = self.model.loss(output, target) - predictions = self.model.predictor(output) - self.evaluator.add_predictions(predictions, target) - return {'val_loss': loss} - - def validation_epoch_end(self, outputs): - results = self.evaluator.get_report() - self.evaluator.reset() - results = {key: torch.tensor(value).to(self.device) for key, value in results.items()} - results['val_loss'] = torch.cat([o['val_loss'] if o['val_loss'].shape else o['val_loss'].unsqueeze(0) for o in outputs], dim=0).to(self.device).mean() - self._log_epoch_metrics(results, self.current_epoch) - return {'log': results} - - def test_step(self, batch, batch_index): - val_loss = self.validation_step(batch, batch_index) - return {'test_loss': val_loss['val_loss']} - - def test_epoch_end(self, outputs): - results = self.evaluator.get_report() - self.evaluator.reset() - results = {key: torch.tensor(value).to(self.device) for key, value in results.items()} - results['test_loss'] = torch.cat([o['test_loss'] if o['test_loss'].shape else o['test_loss'].unsqueeze(0) for o in outputs], dim=0).to(self.device).mean() - self._log_epoch_metrics(results, self.current_epoch) - return {'log': results} - - def forward(self, x): - return self.model(x) - - @rank_zero_only - def save(self, filepath): - logging.info(f"Saving a model to {filepath}") - state_dict = self.model.state_dict() - torch.save(state_dict, filepath) - - def _log_epoch_metrics(self, metrics, epoch): - loggers = self.logger.experiment if isinstance(self.logger.experiment, list) else [self.logger.experiment] - for l in loggers: - l.log_epoch_metrics(metrics, epoch) diff --git a/mitorch/service/__init__.py b/mitorch/service/__init__.py deleted file mode 100644 index f0d4bf0..0000000 --- a/mitorch/service/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .database_client import DatabaseClient -from .task import Task - -__all__ = ['DatabaseClient', 'Task'] diff --git a/mitorch/service/control.py b/mitorch/service/control.py deleted file mode 100644 index 40ac811..0000000 --- a/mitorch/service/control.py +++ /dev/null @@ -1,98 +0,0 @@ -import argparse -import datetime -import logging -import time - -from ..azureml import AzureMLManager -from ..environment import Environment -from .database_client import DatabaseClient -from .task import Task - - -logger = logging.getLogger(__name__) - -def control_loop(env): - while True: - try: - control(env) - except KeyboardInterrupt: - return - except Exception as e: - print(f"Exception happened: {e}") - # Ignore the exception - - time.sleep(600) # Sleep for 10 minutes - - -def control(env): - client = DatabaseClient(env.db_url) - aml_manager = AzureMLManager(client.get_settings().azureml_settings) - - # If there is available resource, pick a training job and submit to AzureML. - process_trainings(client, aml_manager, env.db_url) - - # Check the status of jobs. Queue new trainings if needed. - process_tasks(client) - - print(f"{datetime.datetime.now()}: Completed") - - -def process_trainings(client, aml_manager, db_url): - # If there are available resources, submit new AML jobs - num_available_nodes = aml_manager.get_num_available_nodes() - logger.debug(f"Found {num_available_nodes} available nodes.") - if num_available_nodes > 0: - pending_jobs = client.get_new_trainings(num_available_nodes) - for job in pending_jobs: - submit_result = aml_manager.submit(db_url, job['_id']) - if submit_result: - aml_run_id, aml_run_url, region = submit_result - updated = client.update_training(job['_id'], {'status': 'queued', 'run_id': aml_run_id, 'run_url': aml_run_url, 'region': region}) - if not updated: - raise RuntimeError(f"Failed to update {job['_id']}") - print(f"Queued a new AML run: id: {job['_id']}, run_id: {aml_run_id}, url: {aml_run_url}, region: {region}") - - # Check the status of ongoing tasks. If there is a task which is dead silently, update its record. - jobs = list(client.get_running_trainings()) + list(client.get_queued_trainings()) - for job in jobs: - status = aml_manager.query(job['run_id'], job['region']) - if status in ['completed', 'failed', 'canceled']: - print(f"Unexpected run status: id: {job['_id']}, run_id: {job['run_id']}, status: {status}") - updated = client.update_training(job['_id'], {'status': 'failed'}) - if not updated: - raise RuntimeError(f"Failed to update {job['_id']}") - - -def process_tasks(client): - # Get all active jobs - tasks = client.get_active_tasks() - for task_data in tasks: - task = Task.from_dict(task_data) - candidate = task.fetch_next() - assert candidate - training = client.find_training_by_config(candidate) - if not training: - client.add_training(candidate) - print(f"Submitted new training: {candidate}") - else: - task.update_training_status(training) - new_task_data = task.to_dict() - if new_task_data != task_data: - client.update_task(new_task_data) - - -def main(): - logger.setLevel(logging.DEBUG) - parser = argparse.ArgumentParser(description="Manage training jobs.") - parser.add_argument('--loop', '-l', action='store_true', help="Keep running until interuppted. The jobs will be processed every 5 minutes.") - args = parser.parse_args() - - env = Environment() - if args.loop: - control_loop(env) - else: - control(env) - - -if __name__ == '__main__': - main() diff --git a/mitorch/service/database_client.py b/mitorch/service/database_client.py deleted file mode 100644 index 3576acc..0000000 --- a/mitorch/service/database_client.py +++ /dev/null @@ -1,132 +0,0 @@ -import copy -import dataclasses -import datetime -import uuid -import pymongo -from ..settings import Settings - - -class DatabaseClient: - def __init__(self, mongodb_url): - client = pymongo.MongoClient(mongodb_url, uuidRepresentation='standard') - self.db = client.mitorch - - def add_training(self, config, priority=100): - record = {'config': config} - record['_id'] = uuid.uuid4() - record['created_at'] = datetime.datetime.utcnow() - record['priority'] = priority - record['status'] = 'new' - self.db.trainings.insert_one(record) - - return record['_id'] - - def find_job_by_id(self, job_id): - record = self.db.trainings.find_one({'_id': job_id}) - if not record: - record = self.db.jobs.find_one({'_id': job_id}) - return record - - def find_training_by_id(self, training_id): - return self.db.trainings.find_one({'_id': training_id}) - - def find_training_by_config(self, config): - return self.db.trainings.find_one({'config': config}) - - def get_new_trainings(self, max_num=100): - return self.db.trainings.find({'status': 'new'}).sort('priority').limit(max_num) - - def get_running_trainings(self): - return self.db.trainings.find({'status': 'running'}) - - def get_queued_trainings(self): - return self.db.trainings.find({'status': 'queued'}) - - def get_failed_trainings(self): - return self.db.trainings.find({'status': 'failed'}) - - def delete_training(self, training_id): - result = self.db.trainings.delete_one({'_id': training_id}) - return result.deleted_count == 1 - - def update_training(self, training_id, set_data): - assert isinstance(set_data, dict) - result = self.db.trainings.update_one({'_id': training_id}, {'$set': set_data}) - return result.modified_count == 1 - - def start_training(self, training_id, num_gpus): - result = self.db.trainings.update_one({'_id': training_id}, {'$set': {'status': 'running', - 'started_at': datetime.datetime.utcnow(), - 'machine': {'num_gpus': num_gpus}}}) - return result.modified_count == 1 - - def complete_training(self, training_id): - # Get the test metrics - result = self.db.training_metrics.find_one({'tid': training_id, 'm.test_loss': {'$exists': True}}) - metrics = result['m'] - - result = self.db.trainings.update_one({'_id': training_id}, {'$set': {'status': 'completed', - 'completed_at': datetime.datetime.utcnow(), - 'evaluation': metrics}}) - return result.modified_count == 1 - - def fail_training(self, training_id): - result = self.db.trainings.update_one({'_id': training_id}, {'$set': {'status': 'failed', - 'completed_at': datetime.datetime.utcnow()}}) - return result.modified_count == 1 - - # Datasets - def find_dataset_by_name(self, dataset_name, version=None): - # TODO: Find the latest version. - return self.db.datasets.find_one({'name': dataset_name}) - - def add_dataset(self, dataset): - return self.db.datasets.insert_one(dataset) - - def delete_dataset(self, dataset_id): - self.db.datasets.delete_one({'_id': dataset_id}) - - # Common Settings - def get_settings(self): - record = self.db.settings.find_one({'key': 'settings'}) - return record and Settings.from_dict(record['value']) - - def put_settings(self, settings): - settings = dataclasses.asdict(settings) - if self.get_settings(): - result = self.db.settings.update_one({'key': 'settings'}, {'$set': {'value': settings}}) - assert result.modified_count == 1 - else: - self.db.settings.insert_one({'key': 'settings', 'value': settings}) - - # Tasks - def add_task(self, task): - assert 'config' in task - assert 'max_trainings' in task - record = copy.deepcopy(task) - record['_id'] = uuid.uuid4() - record['created_at'] = datetime.datetime.utcnow() - record['status'] = 'active' - self.db.tasks.insert_one(record) - return record['_id'] - - def get_tasks(self): - return self.db.tasks.find() - - def get_task_by_id(self, task_id): - return self.db.tasks.find_one({'_id': task_id}) - - def get_active_tasks(self): - return self.db.tasks.find({'status': 'active'}) - - def update_task(self, task_description): - assert task_description['_id'] - self.db.tasks.update_one({'_id': task_description['_id']}, {'$set': task_description}) - - def cancel_task(self, task_id): - result = self.db.tasks.update_one({'_id': task_id}, {'$set': {'status': 'cancelled'}}) - return result.modified_count == 1 - - def delete_task(self, task_id): - result = self.db.tasks.delete_one({'_id': task_id}) - return result.deleted_count == 1 diff --git a/mitorch/service/dataset.py b/mitorch/service/dataset.py deleted file mode 100644 index 8ed3f6f..0000000 --- a/mitorch/service/dataset.py +++ /dev/null @@ -1,31 +0,0 @@ -import argparse -import json -from ..environment import Environment -from .database_client import DatabaseClient - - -def add_dataset(db_url, json_filepath): - with open(json_filepath) as f: - datasets = json.load(f) - - client = DatabaseClient(db_url) - for dataset in datasets: - if not client.find_dataset_by_name(dataset['name'], dataset['version']): - print(f"Adding dataset: {dataset}") - client.add_dataset(dataset) - - -def main(): - parser = argparse.ArgumentParser("Manage the regisgered datasets") - subparsers = parser.add_subparsers(dest='command') - parser_add = subparsers.add_parser('add') - parser_add.add_argument('json_filepath', help="JSON file which has the datasets") - - args = parser.parse_args() - env = Environment() - if args.command == 'add': - add_dataset(env.db_url, args.json_filepath) - - -if __name__ == '__main__': - main() diff --git a/mitorch/service/random_search_task.py b/mitorch/service/random_search_task.py deleted file mode 100644 index 5e278b3..0000000 --- a/mitorch/service/random_search_task.py +++ /dev/null @@ -1,70 +0,0 @@ -import random -from .task import Task - - -class RandomElement: - pass - - -class RandomChoice(RandomElement): - def __init__(self, choice_list): - assert all(isinstance(c, (int, float, str)) for c in choice_list), f"Invalid choices: {choice_list}" - self.choice_list = choice_list - self.current_choice = None - - def update(self): - self.current_choice = random.choice(self.choice_list) - - def get(self): - return self.current_choice - - -def _create_element(data): - if '_choice' in data: - return RandomChoice(data['_choice']) - return None - - -class RandomSearchTask(Task): - def __init__(self, task_description): - super().__init__(task_description) - self.config = task_description['config'] - self._elements = [] - self._parsed_config = self._parse_config(self.config) - - def _parse_config(self, config): - if isinstance(config, list): - return [self._parse_config(c) for c in config] - elif isinstance(config, dict): - element = _create_element(config) - if element: - self._elements.append(element) - return element - - return {key: self._parse_config(config[key]) for key in config} - else: - return config - - def _get_config(self, config): - if isinstance(config, RandomElement): - return config.get() - if isinstance(config, list): - return [self._get_config(c) for c in config] - elif isinstance(config, dict): - return {key: self._get_config(config[key]) for key in config} - else: - return config - - def fetch_next(self): - if not self.has_next(): - return None - - random.seed(0) - for i in range(self.num_trainings + 1): - self._update() # Fast forward to the current index. - - return self._get_config(self._parsed_config) - - def _update(self): - for element in self._elements: - element.update() diff --git a/mitorch/service/task.py b/mitorch/service/task.py deleted file mode 100644 index b912b32..0000000 --- a/mitorch/service/task.py +++ /dev/null @@ -1,39 +0,0 @@ -import copy - - -class Task: - def __init__(self, task_description): - self._task_description = task_description - self.state = task_description.get('state', {}) - self.status = task_description.get('status', 'new') - self.num_trainings = task_description.get('num_trainings', 0) - self.max_trainings = task_description['max_trainings'] - - @staticmethod - def from_dict(data): - from .random_search_task import RandomSearchTask - name = data['name'] - task_class = {'random_search': RandomSearchTask}[name] - return task_class(data) - - def fetch_next(self): - raise NotImplementedError - - def update_training_status(self, training): - if training['status'] in ['completed', 'failed']: - self.num_trainings += 1 - if self.num_trainings >= self.max_trainings: - self.status = 'completed' - - def has_next(self): - return self.status == 'active' and self.num_trainings < self.max_trainings - - def to_dict(self): - task = copy.deepcopy(self._task_description) - task['state'] = self.state - task['status'] = self.status - task['num_trainings'] = self.num_trainings - return task - - def __str__(self): - return str(self.to_dict()) diff --git a/mitorch/settings.py b/mitorch/settings.py deleted file mode 100644 index d1c31b1..0000000 --- a/mitorch/settings.py +++ /dev/null @@ -1,32 +0,0 @@ -import dataclasses -from typing import Dict, List - - -@dataclasses.dataclass -class AzureMLSetting: - region_name: str # Region name for this resource. Must be unique. - subscription_id: str - workspace_name: str - cluster_name: str - sp_tenant_id: str = None # Service Principal tenant id (optional) - sp_username: str = None # Service Principal username (optional) - sp_password: str = None # Service Principal password (optional) - - -@dataclasses.dataclass -class Settings: - # Azure Blob url with SAS token to store trained models. - storage_url: str - - # Dictionary of dataset for each regions. - dataset_url: Dict[str, str] - - # AzureML settings - azureml_settings: List[AzureMLSetting] - - readonly_storage_url: str = None - - @classmethod - def from_dict(cls, data): - azureml_settings = [AzureMLSetting(**s) for s in data['azureml_settings']] - return cls(storage_url=data['storage_url'], dataset_url=data['dataset_url'], azureml_settings=azureml_settings, readonly_storage_url=data.get('readonly_storage_url')) diff --git a/mitorch/test.py b/mitorch/test.py deleted file mode 100644 index 6340a90..0000000 --- a/mitorch/test.py +++ /dev/null @@ -1,56 +0,0 @@ -"""Train a model based on the given config.""" -import argparse -import json -import logging -import pathlib -import uuid -import pytorch_lightning as pl -import torch -from .logger import StdoutLogger, MongoDBLogger -from .mimodel import MiModel - -_logger = logging.getLogger(__name__) - - -def test(config_filepath, train_dataset_filepath, val_dataset_filepath, weights_filepath, fast_dev_run, job_id, db_url): - _logger.info("started") - logger = [StdoutLogger()] - if job_id and db_url: - logger.append(MongoDBLogger(db_url, job_id)) - - pl.seed_everything(0) - - config = json.loads(config_filepath.read_text()) - - if fast_dev_run: - config['batch_size'] = 2 - - gpus = 1 if torch.cuda.is_available() else None - hparams = {'config': config, 'train_dataset_filepath': train_dataset_filepath, 'val_dataset_filepath': val_dataset_filepath, 'weights_filepath': weights_filepath} - model = MiModel(hparams) - - trainer = pl.Trainer(max_epochs=config['max_epochs'], fast_dev_run=fast_dev_run, gpus=gpus, - logger=logger, progress_bar_refresh_rate=0, check_val_every_n_epoch=10, num_sanity_val_steps=0, deterministic=True) - - trainer.test(model) - - -def main(): - logging.getLogger().setLevel(logging.INFO) - logging.getLogger('mitorch').setLevel(logging.DEBUG) - - parser = argparse.ArgumentParser(description="Train a model") - parser.add_argument('config_filepath', type=pathlib.Path) - parser.add_argument('train_dataset_filepath', type=pathlib.Path) - parser.add_argument('val_dataset_filepath', type=pathlib.Path) - parser.add_argument('--weights_filepath', '-w', type=pathlib.Path) - parser.add_argument('--fast_dev_run', '-d', action='store_true') - parser.add_argument('--job_id', type=uuid.UUID) - parser.add_argument('--db_url') - - args = parser.parse_args() - test(args.config_filepath, args.train_dataset_filepath, args.val_dataset_filepath, args.weights_filepath, args.fast_dev_run, args.job_id, args.db_url) - - -if __name__ == '__main__': - main() diff --git a/mitorch/train.py b/mitorch/train.py deleted file mode 100644 index 5cd46e4..0000000 --- a/mitorch/train.py +++ /dev/null @@ -1,69 +0,0 @@ -"""Train a model based on the given config.""" -import argparse -import json -import logging -import pathlib -import uuid -import pytorch_lightning as pl -import torch -from .logger import StdoutLogger, MongoDBLogger -from .mimodel import MiModel - -_logger = logging.getLogger(__name__) - - -def train(config_filepath, train_dataset_filepath, val_dataset_filepath, weights_filepath, output_filepath, fast_dev_run, job_id, db_url): - _logger.info("started") - - logger = [StdoutLogger()] - if job_id and db_url: - logger.append(MongoDBLogger(db_url, job_id)) - - pl.seed_everything(0) - - config = json.loads(config_filepath.read_text()) - - if fast_dev_run: - config['batch_size'] = 2 - - num_processes = config.get('num_processes', -1) - accumulate_grad_batches = config.get('accumulate_grad_batches', 1) - gpus = num_processes if torch.cuda.is_available() else None - if num_processes > 1 and not gpus: - _logger.warning(f"Multiple processes are requested, but only 1 CPU is available on this node.") - - hparams = {'config': config, 'train_dataset_filepath': train_dataset_filepath, 'val_dataset_filepath': val_dataset_filepath, 'weights_filepath': weights_filepath} - model = MiModel(hparams) - for lo in logger: - lo.log_hyperparams({'model_versions': model.model_version}) - - trainer = pl.Trainer(max_epochs=config['max_epochs'], fast_dev_run=fast_dev_run, gpus=gpus, distributed_backend='ddp', - logger=logger, progress_bar_refresh_rate=0, check_val_every_n_epoch=10, num_sanity_val_steps=0, deterministic=True, - accumulate_grad_batches=accumulate_grad_batches) - - trainer.fit(model) - if output_filepath: - model.save(output_filepath) - - -def main(): - logging.getLogger().setLevel(logging.INFO) - logging.getLogger('mitorch').setLevel(logging.DEBUG) - - parser = argparse.ArgumentParser(description="Train a model") - parser.add_argument('config_filepath', type=pathlib.Path) - parser.add_argument('train_dataset_filepath', type=pathlib.Path) - parser.add_argument('val_dataset_filepath', type=pathlib.Path) - parser.add_argument('--weights_filepath', '-w', type=pathlib.Path) - parser.add_argument('--output_filepath', '-o', type=pathlib.Path) - parser.add_argument('--fast_dev_run', '-d', action='store_true') - parser.add_argument('--job_id', type=uuid.UUID) - parser.add_argument('--db_url') - - args = parser.parse_args() - train(args.config_filepath, args.train_dataset_filepath, args.val_dataset_filepath, args.weights_filepath, - args.output_filepath, args.fast_dev_run, args.job_id, args.db_url) - - -if __name__ == '__main__': - main() diff --git a/setup.py b/setup.py index c0a0b84..4259e92 100644 --- a/setup.py +++ b/setup.py @@ -7,17 +7,21 @@ description="MiTorch training framework", url='https://github.com/shonohs/mitorch', packages=setuptools.find_namespace_packages(include=['mitorch', 'mitorch.*']), - install_requires=['mitorch-models', 'pymongo', 'pytorch_lightning~=0.8', 'requests', 'tenacity', 'torch>=1.4.0', 'torchvision>=0.5.0', 'sklearn', 'azureml-sdk', 'albumentations'], + install_requires=['mitorch-models', + 'pymongo', + 'pytorch_lightning~=1.3.0', + 'requests', + 'tenacity', + 'torch~=1.9.0', + 'torchvision>=0.5.0', + 'sklearn', + 'albumentations'], entry_points={ 'console_scripts': [ - 'micontrol=mitorch.service.control:main', - 'mitest=mitorch.test:main', - 'mitrain=mitorch.train:main', - 'misubmit=mitorch.service.submit:main', - 'miquery=mitorch.service.query:main', - 'miamlmanager=mitorch.azureml.manager:main', - 'miamlrun=mitorch.azureml.runner:main', - 'midataset=mitorch.service.dataset:main' + 'miagent=mitorch.commands.agent:main', + 'misubmit=mitorch.commands.submit:main', + 'mitrain=mitorch.commands.train:main', + 'miquery=mitorch.commands.query:main', ] }, classifiers=[ @@ -26,4 +30,5 @@ 'Programming Language :: Python :: 3 :: Only', 'Topic :: Scientific/Engineering :: Artificial Intelligence' ], + python_requires='>=3.8', license='MIT') diff --git a/test/test_augmentations.py b/test/test_augmentations.py index 5a8f5ee..8c0a9cf 100644 --- a/test/test_augmentations.py +++ b/test/test_augmentations.py @@ -1,7 +1,7 @@ import unittest import PIL.Image import torch -from mitorch.datasets import ResizeTransform, ResizeFlipTransform, RandomResizedCropTransform +from mitorch.datasets.albumentations_transforms import ResizeTransform, ResizeFlipTransform, RandomResizedCropTransform TRANSFORMS = [ResizeTransform, ResizeFlipTransform, RandomResizedCropTransform] diff --git a/test/test_logger.py b/test/test_logger.py index 4b33ac7..3f44ddd 100644 --- a/test/test_logger.py +++ b/test/test_logger.py @@ -1,7 +1,7 @@ import pickle import unittest import uuid -from mitorch.logger import MongoDBLogger +from mitorch.common.logger import MongoDBLogger class TestLogger(unittest.TestCase): @@ -10,8 +10,8 @@ def test_pickle_mongodb_logger(self): training_id = uuid.uuid4() logger = MongoDBLogger(url, training_id) new_logger = pickle.loads(pickle.dumps(logger)) - self.assertEqual(new_logger.training_id, training_id) - self.assertIsNotNone(new_logger.client) + self.assertEqual(new_logger._job_id, training_id) + self.assertIsNotNone(new_logger._client) if __name__ == '__main__': diff --git a/test/test_mimodel.py b/test/test_mimodel.py index a4ff150..21fef9c 100644 --- a/test/test_mimodel.py +++ b/test/test_mimodel.py @@ -1,28 +1,17 @@ import unittest from unittest.mock import MagicMock, patch -from mitorch.mimodel import MiModel +from mitorch.common.mimodel import MiModel class TestMiModel(unittest.TestCase): - def test_version(self): - mock_builder = MagicMock() - mock_builder.build.return_value = [MagicMock(), [None]] + def test_init(self): mock_model_builder = MagicMock() mock_model_builder.build.return_value = [None, None, None] - mock_optimizer_builder = MagicMock() - mock_lrscheduler_builder = MagicMock() - MiModel._get_evaluator = MagicMock() - with patch('mitorch.mimodel.DataLoaderBuilder', return_value=mock_builder): - with patch('mitorch.mimodel.ModelBuilder', return_value=mock_model_builder): - with patch('mitorch.mimodel.OptimizerBuilder', return_value=mock_optimizer_builder): - with patch('mitorch.mimodel.LrSchedulerBuilder', return_value=mock_lrscheduler_builder): - model = MiModel({'config': MagicMock(), 'train_dataset_filepath': None, - 'val_dataset_filepath': None, 'weights_filepath': None}) - - model.model = MagicMock() - model.model.version = 42 - self.assertEqual(model.model_version, 42) + config = MagicMock() + config.task_type = 'multiclass_classification' + with patch('mitorch.common.mimodel.ModelBuilder', return_value=mock_model_builder): + MiModel(config, None) if __name__ == '__main__': diff --git a/test/test_model_builder.py b/test/test_model_builder.py index 555b22b..01d6481 100644 --- a/test/test_model_builder.py +++ b/test/test_model_builder.py @@ -1,13 +1,13 @@ import unittest from unittest.mock import patch from mitorch.builders import ModelBuilder +from mitorch.common.training_config import ModelConfig, TrainingConfig class TestModelBuilder(unittest.TestCase): def test_build(self): - config = {'task_type': 'multiclass_classification', - 'model': {'name': 'MobileNetV2', - 'options': []}} + model_config = ModelConfig(name='MobileNetV2', input_size=224) + config = TrainingConfig(task_type='multiclass_classification', model=model_config, batch_size=4, max_epochs=10) builder = ModelBuilder(config) with patch('mitorch.builders.model_builder.ModelFactory') as mock_factory: model = builder.build(3) @@ -15,9 +15,8 @@ def test_build(self): mock_factory.create.assert_called_once_with('MobileNetV2', 3, []) def test_build_multilabel(self): - config = {'task_type': 'multilabel_classification', - 'model': {'name': 'MobileNetV2', - 'options': []}} + model_config = ModelConfig(name='MobileNetV2', input_size=224) + config = TrainingConfig(task_type='multilabel_classification', model=model_config, batch_size=4, max_epochs=10) builder = ModelBuilder(config) with patch('mitorch.builders.model_builder.ModelFactory') as mock_factory: model = builder.build(3) diff --git a/test/test_random_search_task.py b/test/test_random_search_task.py deleted file mode 100644 index 198da79..0000000 --- a/test/test_random_search_task.py +++ /dev/null @@ -1,39 +0,0 @@ -import unittest -from mitorch.service.random_search_task import RandomSearchTask - - -class TestRandomSearchTask(unittest.TestCase): - BASE_CONFIG = {'name': 'random_search', - 'status': 'active', - 'num_trainings': 0, - 'max_trainings': 100, - 'config': {'test0': 'value0', - 'test1': {'_choice': [1, 2, 3]}, - 'test2': {'_choice': [100, 200, 300]}}} - - def test_get_random_same_choice(self): - task = RandomSearchTask(self.BASE_CONFIG) - results = [task.fetch_next() for i in range(50)] - results_set = set([str(r) for r in results]) - self.assertEqual(len(results_set), 1) - - def test_update_random_choice(self): - task = RandomSearchTask(self.BASE_CONFIG) - results = [] - for i in range(100): - results.append(task.fetch_next()) - task.update_training_status({'status': 'completed'}) - results_set = set([str(r) for r in results]) - self.assertEqual(len(results_set), 9) # Most likely it covers all combinations. - - def test_complete_task(self): - task = RandomSearchTask(self.BASE_CONFIG) - for i in range(100): - task.fetch_next() - task.update_training_status({'status': 'completed'}) - self.assertEqual(task.status, 'completed') - self.assertIsNone(task.fetch_next()) - - -if __name__ == '__main__': - unittest.main() diff --git a/test/test_settings.py b/test/test_settings.py deleted file mode 100644 index 803c697..0000000 --- a/test/test_settings.py +++ /dev/null @@ -1,21 +0,0 @@ -import dataclasses -import unittest -from mitorch.settings import Settings - - -class TestSettings(unittest.TestCase): - def test_from_dict(self): - obj = {'storage_url': 'storage_url', - 'readonly_storage_url': None, - 'dataset_url': {'region': 'dataset_url'}, - 'azureml_settings': [{'subscription_id': 'subscription_id', - 'workspace_name': 'workspace_name', - 'cluster_name': 'cluster', - 'region_name': 'region'}]} - - settings = Settings(**obj) - self.assertEqual(dataclasses.asdict(settings), obj) - - -if __name__ == '__main__': - unittest.main() diff --git a/utils/cancel_task.py b/utils/cancel_task.py deleted file mode 100644 index 8f65cc1..0000000 --- a/utils/cancel_task.py +++ /dev/null @@ -1,23 +0,0 @@ -import argparse -import uuid -from mitorch.environment import Environment -from mitorch.service import DatabaseClient - - -def cancel_task(task_id): - env = Environment() - client = DatabaseClient(env.db_url) - client.cancel_task(task_id) - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument('task_id', type=uuid.UUID) - - args = parser.parse_args() - - cancel_task(args.task_id) - - -if __name__ == '__main__': - main() diff --git a/utils/delete_task.py b/utils/delete_task.py deleted file mode 100644 index bb0b3d5..0000000 --- a/utils/delete_task.py +++ /dev/null @@ -1,30 +0,0 @@ -import argparse -import uuid -from mitorch.environment import Environment -from mitorch.service import DatabaseClient - - -def delete_task(task_id): - assert isinstance(task_id, uuid.UUID) - env = Environment() - client = DatabaseClient(env.db_url) - task = client.get_task_by_id(task_id) - print(task) - response = input("Delete this task? [y/N]") - if response == 'y': - result = client.delete_task(task_id) - assert result - print("Deleted") - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument('task_id', type=uuid.UUID) - - args = parser.parse_args() - - delete_task(args.task_id) - - -if __name__ == '__main__': - main() diff --git a/utils/delete_training.py b/utils/delete_training.py deleted file mode 100644 index 54fd8e2..0000000 --- a/utils/delete_training.py +++ /dev/null @@ -1,53 +0,0 @@ -import argparse -import uuid -from mitorch.environment import Environment -from mitorch.service import DatabaseClient - - -def remove_training(training_id): - assert isinstance(training_id, uuid.UUID) - env = Environment() - client = DatabaseClient(env.db_url) - - training = client.find_training_by_id(training_id) - if not training: - print(f"Training {training_id} not found") - return - - print(training) - - response = input("Remove? [y/N]: ") - if response.lower() == 'y': - result = client.delete_training(training_id) - assert result - print("Removed successfully") - - -def remove_failed_trainings(): - env = Environment() - client = DatabaseClient(env.db_url) - trainings = client.get_failed_trainings() - for t in trainings: - remove_training(t['_id']) - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument('training_id', nargs='?', type=uuid.UUID) - parser.add_argument('--delete_failed_trainings', action='store_true') - - args = parser.parse_args() - if args.delete_failed_trainings and args.training_id: - parser.error("training_id and --delete_failed_trainings cannot be specified at the same time.") - - if not (args.delete_failed_trainings or args.training_id): - parser.error("Please specify training_id or --delete_failed_trainings.") - - if args.delete_failed_trainings: - remove_failed_trainings() - else: - remove_training(args.training_id) - - -if __name__ == '__main__': - main() diff --git a/utils/init_database.py b/utils/init_database.py new file mode 100644 index 0000000..2cd185a --- /dev/null +++ b/utils/init_database.py @@ -0,0 +1,28 @@ +import argparse +import logging +import pymongo +from mitorch.common import Environment + + +def init_database(db_url): + client = pymongo.MongoClient(db_url) + response = client.mitorch.jobs.create_index([('priority', pymongo.ASCENDING), ('created_at', pymongo.ASCENDING)]) + print(response) + + +def main(): + logging.basicConfig(level=logging.INFO) + + env = Environment() + parser = argparse.ArgumentParser() + parser.add_argument('--db_url', default=env.db_url) + + args = parser.parse_args() + if not args.db_url: + parser.error("A database url must be specified via commandline argument or environment variable.") + + init_database(args.db_url) + + +if __name__ == '__main__': + main() diff --git a/utils/submit_task.py b/utils/submit_task.py deleted file mode 100644 index a1ef7a3..0000000 --- a/utils/submit_task.py +++ /dev/null @@ -1,31 +0,0 @@ -import argparse -import json -import pathlib -from mitorch.environment import Environment -from mitorch.service import DatabaseClient, Task - - -def submit_task(json_filepath): - task_dict = json.loads(json_filepath.read_text()) - task = Task.from_dict(task_dict) - - print(task) - response = input("Add the task? [y/N]") - if response == 'y': - env = Environment() - client = DatabaseClient(env.db_url) - task_id = client.add_task(task.to_dict()) - print(f"Added task: {task_id}") - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument('json_filepath', type=pathlib.Path) - - args = parser.parse_args() - - submit_task(args.json_filepath) - - -if __name__ == '__main__': - main() diff --git a/utils/update_settings.py b/utils/update_settings.py deleted file mode 100644 index f04ab63..0000000 --- a/utils/update_settings.py +++ /dev/null @@ -1,35 +0,0 @@ -import argparse -import json -import pathlib -from mitorch.environment import Environment -from mitorch.service import DatabaseClient -from mitorch.settings import Settings - - -def update_settings(json_filepath): - env = Environment() - client = DatabaseClient(env.db_url) - current_settings = client.get_settings() - print(f"Current settings: {current_settings}") - - if json_filepath: - settings_data = json.loads(json_filepath.read_text()) - settings = Settings.from_dict(settings_data) - print(f"New settings: {settings}") - response = input("Continue? [y/N]") - if response == 'y': - client.put_settings(settings) - print("Updated") - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument('settings_filepath', nargs='?', type=pathlib.Path) - - args = parser.parse_args() - - update_settings(args.settings_filepath) - - -if __name__ == '__main__': - main() diff --git a/utils/visualize_augmentation.py b/utils/visualize_augmentation.py new file mode 100644 index 0000000..8f8e2ef --- /dev/null +++ b/utils/visualize_augmentation.py @@ -0,0 +1,60 @@ +import argparse +import pathlib +import numpy as np +import PIL.Image +import PIL.ImageDraw +from mitorch.datasets import TransformFactory, ImageDataset, ObjectDetectionDataset + + +COLOR_CODES = ["black", "brown", "red", "orange", "yellow", "green", "blue", "violet", "grey", "white"] + + +def _draw_od_labels(image, annotations, label_names): + draw = PIL.ImageDraw.Draw(image) + w, h = image.width, image.height + for class_id, x, y, x2, y2 in annotations: + x, x2 = x * w, x2 * w + y, y2 = y * h, y2 * h + color = COLOR_CODES[class_id % len(COLOR_CODES)] + draw.rectangle(((x, y), (x2, y2)), outline=color) + draw.text((x, y), label_names[class_id]) + + +def visualize_augmentation(dataset_filepath, output_dir, name, input_size, num_images, num_tries): + dataset = ImageDataset.from_file(dataset_filepath, lambda x: x) + is_object_detection = isinstance(dataset, ObjectDetectionDataset) + transform = TransformFactory(is_object_detection).create(name, input_size) + if not transform: + raise RuntimeError(f"Unknown transform: {name}") + dataset.transform = transform + + for i in range(num_images): + for j in range(num_tries): + image, target = dataset[i] + image = (image + 0.5) * 255 + image = image.permute((1, 2, 0)) + filepath = output_dir / f'{i}_{j}.jpg' + image = PIL.Image.fromarray(np.array(image, dtype=np.uint8)) + if is_object_detection: + _draw_od_labels(image, target, dataset.labels) + image.save(filepath) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('transform_name') + parser.add_argument('dataset_filepath', type=pathlib.Path) + parser.add_argument('output_dir', type=pathlib.Path) + parser.add_argument('--input_size', type=int, default=224) + parser.add_argument('--num_images', '-n', type=int, default=10) + parser.add_argument('--num_tries', '-t', type=int, default=10) + + args = parser.parse_args() + + args.output_dir.mkdir(parents=True, exist_ok=True) + + visualize_augmentation(args.dataset_filepath, args.output_dir, args.transform_name, args.input_size, args.num_images, args.num_tries) + + +if __name__ == '__main__': + main()