Skip to content

Commit

Permalink
Merge pull request #255 from hrntsm/feature/clean-visualization
Browse files Browse the repository at this point in the history
Feature/clean visualization
  • Loading branch information
hrntsm authored Feb 24, 2024
2 parents 0cbf3e1 + 985cf1c commit 26b4f04
Show file tree
Hide file tree
Showing 7 changed files with 158 additions and 177 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ for soon-to-be removed features.

### Removed

for now removed features.
- Show hypervolume ratio while optimization running

### Fixed

Expand Down
5 changes: 5 additions & 0 deletions Optuna/Study/Study.cs
Original file line number Diff line number Diff line change
Expand Up @@ -146,5 +146,10 @@ public static dynamic CreateStudy(dynamic optuna, string studyName, dynamic samp
load_if_exists: loadIfExists
);
}

public static dynamic LoadStudy(dynamic optuna, dynamic storage, string studyName)
{
return optuna.load_study(storage: storage, study_name: studyName);
}
}
}
115 changes: 115 additions & 0 deletions Optuna/Visualization/Visualization.cs
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,121 @@ public void Hypervolume()
_fig = visualize(_study);
}

public void Clustering(int nClusters, string targetType, int targetIndex)
{
PyModule ps = Py.CreateScope();
ps.Exec(
"def visualize(study, n_clusters, target_type, target_index):\n" +
" import numpy as np\n" +
" import optuna\n" +
" from sklearn.cluster import KMeans\n" +
" import plotly.graph_objects as go\n" +
" from optuna.visualization._utils import _make_hovertext\n" +

" trials = study.get_trials(deepcopy=False, states=[optuna.trial.TrialState.COMPLETE])\n" +
" feasible_trials = []\n" +
" infeasible_trials = []\n" +
" for trial in trials:\n" +
" constraints = trial.system_attrs.get('constraints')\n" +
" if constraints is None or all([x <= 0.0 for x in constraints]):\n" +
" feasible_trials.append(trial)\n" +
" else:\n" +
" infeasible_trials.append(trial)\n" +

" target = []\n" +
" if target_type == 'objective':\n" +
" target = [trial.values[target_index] for trial in feasible_trials]\n" +
" else:\n" +
" target = [\n" +
" list(trial.params.values())[target_index] for trial in feasible_trials\n" +
" ]\n" +
" np_array = np.array(target).reshape(-1, 1)\n" +
" kmeans = KMeans(n_clusters=n_clusters).fit(np_array)\n" +

" feasible_marker = dict(\n" +
" color=kmeans.labels_,\n" +
" showscale=True,\n" +
" colorscale='RdYlBu_r',\n" +
" colorbar=dict(title='Cluster'),\n" +
" size=12,\n" +
" )\n" +
" infeasible_marker = dict(\n" +
" color='#cccccc',\n" +
" showscale=False,\n" +
" size=12,\n" +
" )\n" +
" fig = go.Figure()\n" +
" if len(study.directions) == 2:\n" +
" fig.add_trace(\n" +
" go.Scatter(\n" +
" x=[trial.values[0] for trial in feasible_trials],\n" +
" y=[trial.values[1] for trial in feasible_trials],\n" +
" mode='markers',\n" +
" marker=feasible_marker,\n" +
" showlegend=False,\n" +
" text=[_make_hovertext(trial) for trial in feasible_trials],\n" +
" hovertemplate='%{text}<extra>Trial</extra>',\n" +
" )\n" +
" )\n" +
" fig.add_trace(\n" +
" go.Scatter(\n" +
" x=[trial.values[0] for trial in infeasible_trials],\n" +
" y=[trial.values[1] for trial in infeasible_trials],\n" +
" mode='markers',\n" +
" marker=infeasible_marker,\n" +
" showlegend=False,\n" +
" text=[_make_hovertext(trial) for trial in feasible_trials],\n" +
" hovertemplate='%{text}<extra>Infeasible Trial</extra>',\n" +
" )\n" +
" )\n" +
" else:\n" +
" fig.add_trace(\n" +
" go.Scatter3d(\n" +
" x=[trial.values[0] for trial in feasible_trials],\n" +
" y=[trial.values[1] for trial in feasible_trials],\n" +
" z=[trial.values[2] for trial in feasible_trials],\n" +
" mode='markers',\n" +
" marker=feasible_marker,\n" +
" showlegend=False,\n" +
" text=[_make_hovertext(trial) for trial in feasible_trials],\n" +
" hovertemplate='%{text}<extra>Trial</extra>',\n" +
" )\n" +
" )\n" +
" fig.add_trace(\n" +
" go.Scatter3d(\n" +
" x=[trial.values[0] for trial in infeasible_trials],\n" +
" y=[trial.values[1] for trial in infeasible_trials],\n" +
" z=[trial.values[2] for trial in infeasible_trials],\n" +
" mode='markers',\n" +
" marker=infeasible_marker,\n" +
" showlegend=False,\n" +
" text=[_make_hovertext(trial) for trial in feasible_trials],\n" +
" hovertemplate='%{text}<extra>Infeasible Trial</extra>',\n" +
" )\n" +
" )\n" +
" metric_names = study.metric_names\n" +
" if metric_names is not None:\n" +
" if len(metric_names) == 3:\n" +
" fig.update_layout(\n" +
" title=f'Clustering of Trials',\n" +
" scene=dict(\n" +
" xaxis_title=metric_names[0],\n" +
" yaxis_title=metric_names[1],\n" +
" zaxis_title=metric_names[2],\n" +
" ),\n" +
" )\n" +
" else:\n" +
" fig.update_layout(\n" +
" title=f'Clustering of Trials',\n" +
" xaxis=dict(title=metric_names[0]),\n" +
" yaxis=dict(title=metric_names[1]),\n" +
" )\n" +
" return go.Figure(fig)\n"
);
dynamic visualize = ps.Get("visualize");
_fig = visualize(_study, nClusters, targetType, targetIndex);
}

public void TruncateParetoFrontPlotHover()
{
CheckPlotCreated();
Expand Down
12 changes: 5 additions & 7 deletions Tunny/Solver/Algorithm.cs
Original file line number Diff line number Diff line change
Expand Up @@ -456,33 +456,31 @@ private bool CheckOptimizeComplete(int nTrials, double timeout, int trialNum, Da
private ProgressState SetProgressState(OptimizationHandlingInfo optSet, Parameter[] parameter, int trialNum, DateTime startTime)
{
TLog.MethodStart();
ComputeBestValues(optSet.Study, trialNum, out double[][] bestValues, out double hypervolumeRatio);
double[][] bestValues = ComputeBestValues(optSet.Study);
return new ProgressState
{
TrialNumber = trialNum,
ObjectiveNum = Objective.Length,
BestValues = bestValues,
Parameter = parameter,
HypervolumeRatio = hypervolumeRatio,
HypervolumeRatio = 0,
EstimatedTimeRemaining = optSet.Timeout <= 0
? TimeSpan.FromSeconds((DateTime.Now - startTime).TotalSeconds * (optSet.NTrials - trialNum) / (trialNum + 1))
: TimeSpan.FromSeconds(optSet.Timeout - (DateTime.Now - startTime).TotalSeconds)
};
}

private void ComputeBestValues(dynamic study, int trialNum, out double[][] bestValues, out double hypervolumeRatio)
private double[][] ComputeBestValues(dynamic study)
{
TLog.MethodStart();
if (Settings.Optimize.ShowRealtimeResult)
{
dynamic[] bestTrials = study.best_trials;
bestValues = bestTrials.Select(t => (double[])t.values).ToArray();
hypervolumeRatio = trialNum == 0 ? 0 : trialNum == 1 || Objective.Length == 1 ? 1 : Hypervolume.Compute2dHypervolumeRatio(study);
return bestTrials.Select(t => (double[])t.values).ToArray();
}
else
{
bestValues = null;
hypervolumeRatio = 0;
return null;
}
}

Expand Down
152 changes: 0 additions & 152 deletions Tunny/Solver/Hypervolume.cs

This file was deleted.

29 changes: 13 additions & 16 deletions Tunny/Solver/Visualize.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using System;
using System.Windows.Forms;

using Optuna.Study;
using Optuna.Visualization;

using Python.Runtime;
Expand All @@ -24,20 +25,6 @@ public Visualize(TSettings settings, bool hasConstraint)
_hasConstraint = hasConstraint;
}

private static dynamic LoadStudy(dynamic optuna, dynamic storage, string studyName)
{
TLog.MethodStart();
try
{
return optuna.load_study(storage: storage, study_name: studyName);
}
catch (Exception e)
{
TunnyMessageBox.Show(e.Message, "Tunny", MessageBoxButtons.OK, MessageBoxIcon.Error);
return null;
}
}

public void Plot(Plot pSettings)
{
TLog.MethodStart();
Expand All @@ -46,7 +33,7 @@ public void Plot(Plot pSettings)
{
dynamic optuna = Py.Import("optuna");
dynamic storage = _settings.Storage.CreateNewOptunaStorage(false);
dynamic study = LoadStudy(optuna, storage, pSettings.TargetStudyName);
dynamic study = Study.LoadStudy(optuna, storage, pSettings.TargetStudyName);
if (study == null)
{
return;
Expand Down Expand Up @@ -85,7 +72,7 @@ private Visualization CreateFigure(dynamic study, Plot pSettings)
visualize.ParallelCoordinate(pSettings.TargetObjectiveName[0], pSettings.TargetObjectiveIndex[0], pSettings.TargetVariableName);
break;
case "param importances":
visualize.ParamImportances(pSettings.TargetObjectiveName[0], pSettings.TargetObjectiveIndex[0]);
visualize.ParamImportances(pSettings.TargetObjectiveName[0], pSettings.TargetObjectiveIndex[0]);
break;
case "pareto front":
visualize.ParetoFront(pSettings.TargetObjectiveName, pSettings.TargetObjectiveIndex, _hasConstraint, pSettings.IncludeDominatedTrials);
Expand All @@ -97,6 +84,16 @@ private Visualization CreateFigure(dynamic study, Plot pSettings)
case "hypervolume":
visualize.Hypervolume();
break;
case "clustering":
if (pSettings.TargetObjectiveIndex.Length > 0)
{
visualize.Clustering(pSettings.ClusterCount, "objective", pSettings.TargetObjectiveIndex[0]);
}
else
{
visualize.Clustering(pSettings.ClusterCount, "variable", pSettings.TargetVariableIndex[0]);
}
break;
default:
TunnyMessageBox.Show("This visualization type is not supported in this study case.", "Tunny");
break;
Expand Down
Loading

0 comments on commit 26b4f04

Please sign in to comment.