Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

260 remove classif #263

Merged
merged 9 commits into from
Feb 7, 2022
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/github-actions-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,4 @@ jobs:
unzip ./data/massachusetts_buildings.zip -d ./data
python GDL.py mode=sampling
python GDL.py mode=train
python GDL.py mode=inference
python GDL.py mode=inference
8 changes: 5 additions & 3 deletions GDL.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
import time
import hydra
import logging

from hydra.utils import get_method
from omegaconf import DictConfig, OmegaConf, open_dict
from utils.utils import load_obj, print_config, get_git_hash, getpath
from utils.utils import print_config, get_git_hash


@hydra.main(config_path="config", config_name="gdl_config_template")
Expand Down Expand Up @@ -45,7 +47,7 @@ def run_gdl(cfg: DictConfig) -> None:
logging.info('\nOverwritten parameters in the config: \n' + cfg.general.config_override_dirname)

# Start -----------------------------------
msg = "Let's start {} for {} !!!".format(cfg.mode, cfg.task.task_name)
msg = "Let's start {} for {} !!!".format(cfg.mode, cfg.general.task)
logging.info(
"\n" + "-" * len(msg) + "\n" + msg +
"\n" + "-" * len(msg)
Expand All @@ -55,7 +57,7 @@ def run_gdl(cfg: DictConfig) -> None:
# Start the timer
start_time = time.time()
# Read the task and execute it
task = load_obj(cfg.task.path_task_function)
task = get_method(f"{cfg.mode}_{cfg.general.task}.main")
task(cfg)

# Add git hash from current commit to parameters.
Expand Down
2 changes: 1 addition & 1 deletion config/gdl_config_template.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
defaults:
- task: segmentation
- model: unet
- training: default_training
- optimizer: adamw
Expand All @@ -20,6 +19,7 @@ general:
# hydra hijacks working directory by changing it to the current log directory,
# so it's useful to have this path as a special variable
# learn more here: https://hydra.cc/docs/next/tutorials/basic/running_your_app/working_directory
task: segmentation
work_dir: ${hydra:runtime.cwd} # where the code is executed
config_name: ${hydra:job.config_name}
config_override_dirname: ${hydra:job.override_dirname}
Expand Down
2 changes: 0 additions & 2 deletions config/task/segmentation.yaml

This file was deleted.

4 changes: 2 additions & 2 deletions data/images_to_samples_ci_csv.csv
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
./data/22978945_15.tif,,./data/massachusetts_buildings.gpkg,,trn
./data/23429155_15.tif,,./data/massachusetts_buildings.gpkg,,tst
./22978945_15.tif,,./massachusetts_buildings.gpkg,,trn
./23429155_15.tif,,./massachusetts_buildings.gpkg,,tst
Comment on lines +1 to +2
Copy link
Collaborator Author

@remtav remtav Feb 3, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is temporary. Will refactor the in_case_of_path parameter. I'd rather use hydra.utils.get_original_cwd().

287 changes: 104 additions & 183 deletions inference_segmentation.py

Large diffs are not rendered by default.

1 change: 0 additions & 1 deletion sampling_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,6 @@ def main(cfg: DictConfig) -> None:
num_bands = len(cfg.dataset.modalities)
modalities = read_modalities(cfg.dataset.modalities) # TODO add the Victor module to manage the modalities
debug = cfg.debug
task = cfg.task.task_name

# RAW DATA PARAMETERS
# Data folder
Expand Down
4 changes: 1 addition & 3 deletions train_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,14 +508,12 @@ def train(cfg: DictConfig) -> None:

# OPTIONAL PARAMETERS
debug = get_key_def('debug', cfg)
task = get_key_def('task_name', cfg['task'], default='segmentation')
task = get_key_def('task', cfg['general'], default='segmentation')
dontcare_val = get_key_def("ignore_index", cfg['dataset'], default=-1)
bucket_name = get_key_def('bucket_name', cfg['AWS'])
scale = get_key_def('scale_data', cfg['augmentation'], default=[0, 1])
batch_metrics = get_key_def('batch_metrics', cfg['training'], default=None)
crop_size = get_key_def('target_size', cfg['training'], default=None)
if task != 'segmentation':
raise logging.critical(ValueError(f"\nThe task should be segmentation. The provided value is {task}"))

# MODEL PARAMETERS
class_weights = get_key_def('class_weights', cfg['dataset'], default=None)
Expand Down
2 changes: 1 addition & 1 deletion utils/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def report_classification(pred, label, batch_size, metrics_dict, ignore_index=-1
"""Computes precision, recall and f-score for each class and average of all classes.
http://scikit-learn.org/stable/modules/generated/sklearn.metrics.classification_report.html
"""
class_report = classification_report(label.cpu(), pred.cpu(), output_dict=True)
class_report = classification_report(label.cpu(), pred.cpu(), output_dict=True, zero_division=1)

class_score = {}
for key, value in class_report.items():
Expand Down
42 changes: 4 additions & 38 deletions utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import os
import csv
import logging
import numbers
import subprocess
import importlib as imp
from functools import reduce
from pathlib import Path
from typing import Sequence, List
Expand All @@ -26,7 +24,6 @@
from rasterio.crs import CRS
from affine import Affine

from utils.readers import read_parameters
from urllib.parse import urlparse

try:
Expand Down Expand Up @@ -427,7 +424,7 @@ def try2read_csv(path_file, in_case_of_path, msg):
Path(path_file).resolve(strict=True)
except FileNotFoundError:
if in_case_of_path:
path_file = os.path.join(in_case_of_path, os.path.basename(path_file))
path_file = str(Path(in_case_of_path) / (path_file.split('./')[-1]))
try:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll get back to this in_case_of_path story in a future PR. Would like to suggest a different, lighter approach with hydra.utils.get_original_cwd()

Path(path_file).resolve(strict=True)
except FileNotFoundError:
Expand Down Expand Up @@ -458,8 +455,8 @@ def read_csv(csv_file_name, data_path=None):
if row[2]:
row[2] = try2read_csv(row[2], data_path, 'Gpkg not found:')
if not isinstance(row[3], str):
raise ValueError(f"Attribute name should be a string")
if row[3] is not "":
logging.error(f"Attribute name should be a string")
if row[3] != "":
logging.error(f"Deprecation notice:\nFiltering ground truth features by attribute name and values should"
f" be done through the dataset parameters in config/dataset. The attribute name value in "
f"csv will be ignored. Got: {row[3]}")
Expand Down Expand Up @@ -630,26 +627,6 @@ def compare_config_yamls(yaml1: dict, yaml2: dict, update_yaml1: bool = False) -
log.info(f'Value in yaml1 updated')


def load_obj(obj_path: str, default_obj_path: str = '') -> any:
"""
Extract an object from a given path.

:param obj_path: (str) Path to an object to be extracted, including the object name.
:param default_obj_path: (str) Default path object.

:return: Extract object. Can be a function or a class or ...

:raise AttributeError: When the object does not have the given named attribute.
"""
obj_path_list = obj_path.rsplit('.', 1)
obj_path = obj_path_list.pop(0) if len(obj_path_list) > 1 else default_obj_path
obj_name = obj_path_list[0]
module_obj = imp.import_module(obj_path)
if not hasattr(module_obj, obj_name):
raise AttributeError(f"Object `{obj_name}` cannot be loaded from from `{obj_path}`.")
return getattr(module_obj, obj_name)


Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@CharlesAuthier same here. Let me know if you had anytning in my mind for further developments with this function. If the purpose was just to use in GDL.py to reach the right method, I think hydra.utils.get_method() does the job.

def read_modalities(modalities: str) -> list:
"""
Function that read the modalities from the yaml and convert it to a list
Expand Down Expand Up @@ -696,7 +673,7 @@ def getpath(d, path):
def print_config(
config: DictConfig,
fields: Sequence[str] = (
"task",
"general.task",
"mode",
"dataset",
"general.work_dir",
Expand Down Expand Up @@ -769,14 +746,3 @@ def print_config(

with open("run_config.config", "w") as fp:
rich.print(tree, file=fp)


# def save_useful_info():
# shutil.copytree(
# os.path.join(hydra.utils.get_original_cwd(), 'src'),
# os.path.join(os.getcwd(), 'code/src')
# )
# shutil.copy2(
# os.path.join(hydra.utils.get_original_cwd(), 'hydra_run.py'),
# os.path.join(os.getcwd(), 'code')
# )