-
Notifications
You must be signed in to change notification settings - Fork 31
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Big update #14
Big update #14
Conversation
presto/eval/eval.py
Outdated
name: str | ||
num_outputs: int | ||
regression: bool | ||
multilabel: bool | ||
|
||
def __init__(self, seed: int = DEFAULT_SEED): | ||
self.seed = seed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't really like that this class has a seed: int
argument and its subclass EvalTaskWithAggregatedOutput has a seeds: List[int]
argument, so you basically have to use isinstance(EvalTaskWithAggregatedOutput,...)
to know what field to use. I think it would be nicer to have this class use a field seeds: List[int]
too and include a check this its length equals 1 (or something along these lines)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated!
"encoder": model.encoder.__class__, | ||
"decoder": model.decoder.__class__, | ||
"device": device, | ||
"model_parameters": "random" if fully_supervised else path_to_state_dict, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I should have added "logging_dir": logging_dir,
here but I forgot, same in train.py. Because now output_dir
gets logged but it's basically always something like /network/scratch/<u>/<user>/presto/
and then you don't know what the actual subdir for this run was
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added!
@@ -194,7 +244,7 @@ def load_dataset(url, shuffle_on_load): | |||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add "logging_dir": logging_dir,
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated
argparser.add_argument("--fully_supervised", dest="fully_supervised", action="store_true") | ||
argparser.add_argument("--wandb", dest="wandb", action="store_true") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add an argument argparser.add_argument("--nb_seeds", type=int, default=1)
so it's actually also still possible to run the eval script with just 1 seed?
Or do we not want to support that?
Then below: seeds = list(range(0, DEFAULT_SEED*nb_seeds, DEFAULT_SEED) if nb_seeds > 1 else [DEFAULT_SEED]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added an eval_seeds
arguments
eval.py
Outdated
CropHarvestEval("Brazil", seed=0), | ||
CropHarvestEval("Kenya", seed=84), | ||
CropHarvestEval("Togo", seed=84), | ||
CropHarvestEval("Brazil", seed=84), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Then here
eval_task_list: List[EvalTask] = [
*[CropHarvestEval("Kenya", seed=seed) for seed in seeds],
*[CropHarvestEval("Kenya", ignore_dynamic_world=True, seed=seed) for seed in seeds],
*[CropHarvestEval("Brazil", seed=seed) for seed in seeds],
*[CropHarvestEval("Brazil", ignore_dynamic_world=True, seed=seed) for seed in seeds],
*[CropHarvestEval("Togo", seed=seed) for seed in seeds],
*[CropHarvestEval("Togo", ignore_dynamic_world=True, seed=seed) for seed in seeds],
*[FuelMoistureEval(seed=seed) for seed in seeds],
*[AlgaeBloomsEval(seed=seed) for seed in seeds],
*[TreeSatEval("S1", input_patch_size=ps, seeds=seeds) for ps in [1, 2, 3, 6]],
*[TreeSatEval("S2", input_patch_size=ps, seeds=seeds) for ps in [1, 2, 3, 6]],
... # EuroSat
]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice this is great. updated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great, looks good to me!
This PR introduces a number of significant updates:
num_timesteps
SRTM tokens were passed through the model), even though it was only collected at a single timestep. A single SRTM token is now passed through the model. In addition, a bug in the decoder which shuffled some bands during reconstruction has been fixed.The
default_model.pt
weights have now been updated with a model trained after these changes.A manuscript with updated results will be posted on arxivThe latest arxiv version of the paper reflects these changes.data_dir
configurable intrain.py
, (vi) dump the evaluation results in a json file locally in addition to (optionally) storing them on wandb.The repository before this PR is tagged
v0.1
.