diff --git a/src/sas/qtgui/Perspectives/Fitting/ComplexConstraint.py b/src/sas/qtgui/Perspectives/Fitting/ComplexConstraint.py index ac973f5bf6..9842245562 100644 --- a/src/sas/qtgui/Perspectives/Fitting/ComplexConstraint.py +++ b/src/sas/qtgui/Perspectives/Fitting/ComplexConstraint.py @@ -88,7 +88,6 @@ def setupWidgets(self): self.setupParamWidgets() - self.setupMenu() def setupMenu(self): @@ -110,10 +109,10 @@ def setupParamWidgets(self): # Populate the left combobox parameter arbitrarily with the parameters # from the first tab if `All` option is selected if self.cbModel1.currentText() == "All": - items1 = self.tabs[1].main_params_to_fit + items1 = self.tabs[1].main_params_to_fit + self.tabs[1].poly_params_to_fit else: tab_index1 = self.cbModel1.currentIndex() - items1 = self.tabs[tab_index1].main_params_to_fit + items1 = self.tabs[tab_index1].main_params_to_fit + self.tabs[tab_index1].poly_params_to_fit self.cbParam1.addItems(items1) # Show the previously selected parameter if available if previous_param1 in items1: @@ -122,10 +121,13 @@ def setupParamWidgets(self): # Store previously select parameter previous_param2 = self.cbParam2.currentText() - # M2 has to be non-constrained self.cbParam2.clear() tab_index2 = self.cbModel2.currentIndex() - items2 = [param for param in self.params[tab_index2] if not self.tabs[tab_index2].paramHasConstraint(param)] + items2 = [param for param in self.params[tab_index2]] + # The following can be used if it is judged preferable that constrained + # parameters are not used in the definition of a new constraint + #items2 = [param for param in self.params[tab_index2] if not self.tabs[tab_index2].paramHasConstraint(param)] + self.cbParam2.addItems(items2) # Show the previously selected parameter if available if previous_param2 in items2: @@ -210,9 +212,9 @@ def validateFormula(self): """ Add visual cues when formula is incorrect """ - # temporarily disable validation + # temporarily disable validation, as not yet fully operational return - # + formula_is_valid = self.validateConstraint(self.txtConstraint.text()) if not formula_is_valid: self.cmdOK.setEnabled(False) @@ -340,7 +342,9 @@ def applyAcrossTabs(self, tabs, param, expr): """ for tab in tabs: if hasattr(tab, "kernel_module"): - if param in tab.kernel_module.params: + if (param in tab.kernel_module.params or + param in tab.poly_params or + param in tab.magnet_params): value_ex = tab.kernel_module.name + "." +param constraint = Constraint(param=param, value=param, diff --git a/src/sas/qtgui/Perspectives/Fitting/Constraint.py b/src/sas/qtgui/Perspectives/Fitting/Constraint.py index b686363827..3fec14c89f 100644 --- a/src/sas/qtgui/Perspectives/Fitting/Constraint.py +++ b/src/sas/qtgui/Perspectives/Fitting/Constraint.py @@ -1,7 +1,7 @@ class Constraint(object): """ Internal representation of a single parameter constraint - Currently just a data structure, might get expaned with more functionality, + Currently just a data structure, might get expanded with more functionality, hence made into a class. """ def __init__(self, parent=None, param=None, value=0.0, @@ -14,6 +14,7 @@ def __init__(self, parent=None, param=None, value=0.0, self._min = min self._max = max self._operator = operator + self._model = None self.validate = True self.active = True @@ -81,3 +82,11 @@ def operator(self): def operator(self, val): self._operator = val + @property + def model(self): + # model this constraint originates from + return self._model + + @model.setter + def model(self, val): + self._model = val diff --git a/src/sas/qtgui/Perspectives/Fitting/ConstraintWidget.py b/src/sas/qtgui/Perspectives/Fitting/ConstraintWidget.py index c8f629ca48..4a477e95bc 100644 --- a/src/sas/qtgui/Perspectives/Fitting/ConstraintWidget.py +++ b/src/sas/qtgui/Perspectives/Fitting/ConstraintWidget.py @@ -190,7 +190,7 @@ def initializeWidgets(self): # Single Fit is the default, so disable chainfit self.chkChain.setVisible(False) - # disabled constraint + # disabled constraint labels = ['Constraint'] self.tblConstraints.setColumnCount(len(labels)) self.tblConstraints.setHorizontalHeaderLabels(labels) @@ -400,6 +400,10 @@ def onHelp(self): help_location = tree_location + helpfile # OMG, really? Crawling up the object hierarchy... + # + # It's the top level that needs to do the show help. + # Perhaps better to address directly, but it does need to + # be that object. I don't like that the type is hidden. :LW self.parent.parent.showHelp(help_location) def onTabCellEdit(self, row, column): @@ -500,16 +504,16 @@ def onTabCellEdit(self, row, column): def onConstraintChange(self, row, column): """ - Modify the constraint when the user edits the constraint list. If the - user changes the constrained parameter, the constraint is erased and a - new one is created. - Checking is performed on the constrained entered by the user, showing - message box warning him the constraint is not valid and cancelling - his changes by reloading the view. View is reloaded - when the user is finished for consistency. + Modify the constraint when the user edits the constraint list. + If the user changes the constrained parameter, the constraint is erased + and a new one is created. + Checking is performed on the constrained entered by the user. + In case of an error during checking, a warning message box is shown + and the constraint is cancelled by reloading the view. + View is also reloaded when the user is finished for consistency. """ item = self.tblConstraints.item(row, column) - # extract information from the constraint object + # Extract information from the constraint object constraint = self.available_constraints[row] model = constraint.value_ex[:constraint.value_ex.index(".")] param = constraint.param @@ -529,6 +533,7 @@ def onConstraintChange(self, row, column): QtWidgets.QMessageBox.Ok) self.initializeFitList() return + # Then check if the parameter is correctly defined with colons # separating model and parameter name lhs, rhs = re.split(" *= *", item.data(0).strip(), 1) @@ -546,7 +551,7 @@ def onConstraintChange(self, row, column): # We can parse the string new_param = lhs.split(":", 1)[1].strip() new_model = lhs.split(":", 1)[0].strip() - # Check that the symbol is known so we dont get an unknown tab + # Check that the symbol is known so we don't get an unknown tab # All the conditional statements could be grouped in one or # alternatively we could check with expression.py, but we would still # need to do some checks to parse the string @@ -566,6 +571,7 @@ def onConstraintChange(self, row, column): return new_function = rhs new_tab = self.available_tabs[new_model] + model_key = tab.getModelKeyFromName(param) # Make sure we are dealing with fit tabs assert isinstance(tab, FittingWidget) assert isinstance(new_tab, FittingWidget) @@ -574,8 +580,9 @@ def onConstraintChange(self, row, column): # Apply the new constraint constraint = Constraint(param=new_param, func=new_function, value_ex=new_model + "." + new_param) + model_key = tab.getModelKeyFromName(new_param) new_tab.addConstraintToRow(constraint=constraint, - row=tab.getRowFromName(new_param)) + row=tab.getRowFromName(new_param), model_key=model_key) # If the constraint is valid and we are changing model or # parameter, delete the old constraint if (self.constraint_accepted and new_model != model or @@ -594,9 +601,9 @@ def onConstraintChange(self, row, column): font.setItalic(True) brush = QtGui.QBrush(QtGui.QColor('blue')) tab.modifyViewOnRow(tab.getRowFromName(new_param), font=font, - brush=brush) + brush=brush, model_key=model_key) else: - tab.modifyViewOnRow(tab.getRowFromName(new_param)) + tab.modifyViewOnRow(tab.getRowFromName(new_param), model_key=model_key) # reload the view so the user gets a consistent feedback on the # constraints self.initializeFitList() @@ -891,7 +898,8 @@ def deleteConstraint(self):#, row): moniker = constraint[:constraint.index(':')] param = constraint[constraint.index(':')+1:constraint.index('=')].strip() tab = self.available_tabs[moniker] - tab.deleteConstraintOnParameter(param) + model_key = tab.getModelKeyFromName(param) + tab.deleteConstraintOnParameter(param, model_key=model_key) # Constraints removed - refresh the table widget self.initializeFitList() @@ -904,14 +912,17 @@ def uneditableItem(self, data=""): item.setFlags( QtCore.Qt.ItemIsSelectable | QtCore.Qt.ItemIsEnabled ) return item - def updateFitLine(self, tab): + def updateFitLine(self, tab, model_key="standard"): """ Update a single line of the table widget with tab info """ fit_page = ObjectLibrary.getObject(tab) model = fit_page.kernel_module + if model is None: + logging.warning("No model selected") return + tab_name = tab model_name = model.id moniker = model.name @@ -948,19 +959,29 @@ def updateFitLine(self, tab): self.tblTabList.blockSignals(False) # Check if any constraints present in tab - active_constraint_names = fit_page.getComplexConstraintsForModel() - constraint_names = fit_page.getFullConstraintNameListForModel() - constraints = fit_page.getConstraintObjectsForModel() + constraint_names = fit_page.getComplexConstraintsForAllModels() + constraints = fit_page.getConstraintObjectsForAllModels() + + active_constraint_names = [] + constraint_names = [] + constraints = [] + for model_key in fit_page.model_dict.keys(): + active_constraint_names += fit_page.getComplexConstraintsForModel(model_key=model_key) + constraint_names += fit_page.getFullConstraintNameListForModel(model_key=model_key) + constraints += fit_page.getConstraintObjectsForModel(model_key=model_key) + if not constraints: return + self.tblConstraints.setEnabled(True) self.tblConstraints.blockSignals(True) for constraint, constraint_name in zip(constraints, constraint_names): - # Ignore constraints that have no *func* attribute defined - if constraint.func is None: + if not constraint_name and len(constraint_name) < 2: + continue + if constraint_name[0] is None or constraint_name[1] is None: continue # Create the text for widget item - label = moniker + ":"+ constraint_name[0] + " = " + constraint_name[1] + label = moniker + ":" + constraint_name[0] + " = " + constraint_name[1] pos = self.tblConstraints.rowCount() self.available_constraints[pos] = constraint @@ -979,7 +1000,7 @@ def updateFitLine(self, tab): self.tblConstraints.setItem(pos, 0, item) self.tblConstraints.blockSignals(False) - def initializeFitList(self): + def initializeFitList(self, row=0, model_key="standard"): """ Fill the list of model/data sets for fitting/constraining """ @@ -1018,7 +1039,7 @@ def initializeFitList(self): self._row_order = tabs for tab in tabs: - self.updateFitLine(tab) + self.updateFitLine(tab, model_key=model_key) self.updateSignalsFromTab(tab) # We have at least 1 fit page, allow fitting self.cmdFit.setEnabled(True) @@ -1085,14 +1106,15 @@ def onAcceptConstraint(self, con_tuple): # Find the constrained parameter row constrained_row = constrained_tab.getRowFromName(constraint.param) + model_key = constrained_tab.getModelKeyFromName(constraint.param) # Update the tab - constrained_tab.addConstraintToRow(constraint, constrained_row) + constrained_tab.addConstraintToRow(constraint, constrained_row, model_key=model_key) if not self.constraint_accepted: return # Select this parameter for adjusting/fitting - constrained_tab.changeCheckboxStatus(constrained_row, True) + # constrained_tab.selectCheckbox(constrained_row, model=model) def showMultiConstraint(self): """ @@ -1212,4 +1234,5 @@ def uncheckConstraint(self, name): # deactivate the constraint tab = self.parent.getTabByName(name[:name.index(":")]) row = tab.getRowFromName(name[name.index(":") + 1:]) - tab.getConstraintForRow(row).active = False + model_key = tab.getModelKey(constraint) + tab.getConstraintForRow(row, model_key=model_key).active = False diff --git a/src/sas/qtgui/Perspectives/Fitting/FittingPerspective.py b/src/sas/qtgui/Perspectives/Fitting/FittingPerspective.py index 9208469f45..6ea8572c67 100644 --- a/src/sas/qtgui/Perspectives/Fitting/FittingPerspective.py +++ b/src/sas/qtgui/Perspectives/Fitting/FittingPerspective.py @@ -217,9 +217,11 @@ def updateFromConstraints(self, constraint_dict): constraint.param = constraint_param[1] constraint.value_ex = constraint_param[2] constraint.validate = constraint_param[3] + model_key = tab.getModelKey(constraint) tab.addConstraintToRow(constraint=constraint, row=tab.getRowFromName( - constraint_param[1])) + constraint_param[1]), + model_key=model_key) def closeEvent(self, event): """ @@ -535,9 +537,9 @@ def getActiveConstraintList(self): constraints = [] for tab in self.getFitTabs(): tab_name = tab.modelName() - tab_constraints = tab.getConstraintsForModel() - constraints.extend((tab_name + "." + par, expr) - for par, expr in tab_constraints) + tab_constraints = tab.getConstraintsForAllModels() + constraints.extend((tab_name + "." + par, expr) for par, expr in tab_constraints) + return constraints def getSymbolDictForConstraints(self): diff --git a/src/sas/qtgui/Perspectives/Fitting/FittingUtilities.py b/src/sas/qtgui/Perspectives/Fitting/FittingUtilities.py index a412dda933..79aae914c2 100644 --- a/src/sas/qtgui/Perspectives/Fitting/FittingUtilities.py +++ b/src/sas/qtgui/Perspectives/Fitting/FittingUtilities.py @@ -187,8 +187,8 @@ def addSimpleParametersToModel(parameters, is2D, parameters_original=None, model Actually appends to model, if model and view params are not None. Always returns list of lists of QStandardItems. - parameters_original: list of parameters before any tagging on their IDs, e.g. for product model (so that those are - the display names; see below) + parameters_original: list of parameters before any tagging on their IDs, + e.g. for product model (so that those are the display names; see below) """ if is2D: params = [p for p in parameters.kernel_parameters if p.type != 'magnetic'] @@ -196,8 +196,9 @@ def addSimpleParametersToModel(parameters, is2D, parameters_original=None, model params = parameters.iq_parameters if parameters_original: - # 'parameters_original' contains the parameters as they are to be DISPLAYED, while 'parameters' - # contains the parameters as they were renamed; this is for handling name collisions in product model. + # 'parameters_original' contains the parameters as they are to be DISPLAYED, + # while 'parameters' contains the parameters as they were renamed; + # this is for handling name collisions in product model. # The 'real name' of the parameter will be stored in the item's user data. if is2D: params_orig = [p for p in parameters_original.kernel_parameters if p.type != 'magnetic'] @@ -710,7 +711,6 @@ def getRelativeError(data, is2d, flag=None): return weight - def calcWeightIncrease(weights, ratios, flag=False): """ Calculate the weights to be passed to bumps in order to ensure that each data set contributes to the total residual with a @@ -767,7 +767,6 @@ def calcWeightIncrease(weights, ratios, flag=False): return weight_increase - def updateKernelWithResults(kernel, results): """ Takes model kernel and applies results dict to its parameters, @@ -782,7 +781,6 @@ def updateKernelWithResults(kernel, results): return local_kernel - def getStandardParam(model=None): """ Returns a list with standard parameters for the current model @@ -965,9 +963,15 @@ def isParamPolydisperse(param_name, kernel_params, is2D=False): """ Simple lookup for polydispersity for the given param name """ + # First, check if this is a polydisperse parameter directly + if '.width' in param_name: + return True + parameters = kernel_params.form_volume_parameters if is2D: parameters += kernel_params.orientation_parameters + + # Next, check if the parameter is included in para.polydisperse has_poly = False for param in parameters: if param.name==param_name and param.polydisperse: diff --git a/src/sas/qtgui/Perspectives/Fitting/FittingWidget.py b/src/sas/qtgui/Perspectives/Fitting/FittingWidget.py index b8178d714e..3e23185bbd 100644 --- a/src/sas/qtgui/Perspectives/Fitting/FittingWidget.py +++ b/src/sas/qtgui/Perspectives/Fitting/FittingWidget.py @@ -17,6 +17,7 @@ from sasmodels import generate from sasmodels import modelinfo +from sasmodels.sasview_model import SasviewModel from sasmodels.sasview_model import load_standard_models from sasmodels.sasview_model import MultiplicationModel from sasmodels.weights import MODELS as POLYDISPERSITY_MODELS @@ -63,7 +64,6 @@ # CRUFT: remove when new release of sasmodels is available # https://github.com/SasView/sasview/pull/181#discussion_r218135162 -from sasmodels.sasview_model import SasviewModel if not hasattr(SasviewModel, 'get_weights'): def get_weights(self: Any, name: str) -> Tuple[np.ndarray, np.ndarray]: """ @@ -102,7 +102,7 @@ class FittingWidget(QtWidgets.QWidget, Ui_FittingWidgetUI): """ Main widget for selecting form and structure factor models """ - constraintAddedSignal = QtCore.pyqtSignal(list) + constraintAddedSignal = QtCore.pyqtSignal(list, str) newModelSignal = QtCore.pyqtSignal() fittingFinishedSignal = QtCore.pyqtSignal(tuple) batchFittingFinishedSignal = QtCore.pyqtSignal(tuple) @@ -216,7 +216,6 @@ def data(self, value): self.order_widget.updateData(self.all_data) # Overwrite data type descriptor - self.is2D = True if isinstance(self.logic.data, Data2D) else False # Let others know we're full of data now @@ -254,6 +253,12 @@ def initializeGlobals(self): self._num_shell_params = 0 # Dictionary of {model name: model class} for the current category self.models = {} + # Dictionary of QModels + self.model_dict = {} + self.lst_dict = {} + self.tabToList = {} # tab_id -> list widget + self.tabToKey = {} # tab_id -> model key + # Parameters to fit self.main_params_to_fit = [] self.poly_params_to_fit = [] @@ -355,7 +360,7 @@ def initializeWidgets(self): # Magnetic angles explained in one picture self.magneticAnglesWidget = QtWidgets.QWidget() labl = QtWidgets.QLabel(self.magneticAnglesWidget) - pixmap = QtGui.QPixmap(GuiUtils.IMAGES_DIRECTORY_LOCATION + '/mag_vector.png') + pixmap = QtGui.QPixmap(GuiUtils.IMAGES_DIRECTORY_LOCATION + '/M_angles_pic.png') labl.setPixmap(pixmap) self.magneticAnglesWidget.setFixedSize(pixmap.width(), pixmap.height()) @@ -370,6 +375,22 @@ def initializeModels(self): self._poly_model = ToolTippedItemModel() self._magnet_model = ToolTippedItemModel() + self.model_dict["standard"] = self._model_model + self.model_dict["poly"] = self._poly_model + self.model_dict["magnet"] = self._magnet_model + + self.lst_dict["standard"] = self.lstParams + self.lst_dict["poly"] = self.lstPoly + self.lst_dict["magnet"] = self.lstMagnetic + + self.tabToList[0] = self.lstParams + self.tabToList[3] = self.lstPoly + self.tabToList[4] = self.lstMagnetic + + self.tabToKey[0] = "standard" + self.tabToKey[3] = "poly" + self.tabToKey[4] = "magnet" + # Param model displayed in param list self.lstParams.setModel(self._model_model) self.readCategoryInfo() @@ -425,6 +446,10 @@ def initializeModels(self): self.lstPoly.itemDelegate().combo_updated.connect(self.onPolyComboIndexChange) self.lstPoly.itemDelegate().filename_updated.connect(self.onPolyFilenameChange) + self.lstPoly.setContextMenuPolicy(QtCore.Qt.CustomContextMenu) + self.lstPoly.customContextMenuRequested.connect(self.showModelContextMenu) + self.lstPoly.setAttribute(QtCore.Qt.WA_MacShowFocusRect, False) + # Magnetism model displayed in magnetism list self.lstMagnetic.setModel(self._magnet_model) self.setMagneticModel() @@ -517,7 +542,6 @@ def togglePoly(self, isChecked): if key[-6:] == '.width': self.kernel_module.setParam(key, (value if isChecked else 0)) - def toggleMagnetism(self, isChecked): """ Enable/disable the magnetism tab """ self.tabFitting.setTabEnabled(TAB_MAGNETISM, isChecked) @@ -602,6 +626,7 @@ def initializeSignals(self): self.lstParams.installEventFilter(self) self.lstPoly.installEventFilter(self) self.lstMagnetic.installEventFilter(self) + self.lstPoly.selectionModel().selectionChanged.connect(self.onSelectionChanged) # Local signals self.batchFittingFinishedSignal.connect(self.batchFitComplete) @@ -658,11 +683,13 @@ def showModelContextMenu(self, position): When clicked on parameter(s): fitting/constraints options When clicked on white space: model description """ - rows = [s.row() for s in self.lstParams.selectionModel().selectedRows() + # See which model we're dealing with by looking at the tab id + current_list = self.tabToList[self.tabFitting.currentIndex()] + rows = [s.row() for s in current_list.selectionModel().selectedRows() if self.isCheckable(s.row())] menu = self.showModelDescription() if not rows else self.modelContextMenu(rows) try: - menu.exec_(self.lstParams.viewport().mapToGlobal(position)) + menu.exec_(current_list.viewport().mapToGlobal(position)) except AttributeError as ex: logger.error("Error generating context menu: %s" % ex) return @@ -675,11 +702,13 @@ def modelContextMenu(self, rows): num_rows = len(rows) if num_rows < 1: return menu + current_list = self.tabToList[self.tabFitting.currentIndex()] + model_key = self.tabToKey[self.tabFitting.currentIndex()] # Select for fitting param_string = "parameter " if num_rows == 1 else "parameters " to_string = "to its current value" if num_rows == 1 else "to their current values" - has_constraints = any([self.rowHasConstraint(i) for i in rows]) - has_real_constraints = any([self.rowHasActiveConstraint(i) for i in rows]) + has_constraints = any([self.rowHasConstraint(i, model_key=model_key) for i in rows]) + has_real_constraints = any([self.rowHasActiveConstraint(i, model_key=model_key) for i in rows]) self.actionSelect = QtWidgets.QAction(self) self.actionSelect.setObjectName("actionSelect") @@ -717,32 +746,38 @@ def modelContextMenu(self, rows): menu.addAction(self.actionRemoveConstraint) if num_rows == 1 and has_real_constraints: menu.addAction(self.actionEditConstraint) - #if num_rows == 1: - # menu.addAction(self.actionEditConstraint) else: - menu.addAction(self.actionConstrain) if num_rows == 2: menu.addAction(self.actionMutualMultiConstrain) + else: + menu.addAction(self.actionConstrain) # Define the callbacks self.actionConstrain.triggered.connect(self.addSimpleConstraint) self.actionRemoveConstraint.triggered.connect(self.deleteConstraint) self.actionEditConstraint.triggered.connect(self.editConstraint) - self.actionMutualMultiConstrain.triggered.connect(self.showMultiConstraint) + self.actionMutualMultiConstrain.triggered.connect(lambda: self.showMultiConstraint(current_list=current_list)) self.actionSelect.triggered.connect(self.selectParameters) self.actionDeselect.triggered.connect(self.deselectParameters) return menu - def showMultiConstraint(self): + def showMultiConstraint(self, current_list=None): """ Show the constraint widget and receive the expression """ - selected_rows = self.lstParams.selectionModel().selectedRows() + if current_list is None: + current_list = self.lstParams + model = current_list.model() + for key, val in self.model_dict.items(): + if val == model: + model_key = key + + selected_rows = current_list.selectionModel().selectedRows() # There have to be only two rows selected. The caller takes care of that # but let's check the correctness. assert len(selected_rows) == 2 - params_list = [s.data() for s in selected_rows] + params_list = [s.data(role=QtCore.Qt.UserRole) for s in selected_rows] # Create and display the widget for param1 and param2 mc_widget = MultiConstraint(self, params=params_list) # Check if any of the parameters are polydisperse @@ -776,30 +811,99 @@ def showMultiConstraint(self): constraint.validate = mc_widget.validate # Create a new item and add the Constraint object as a child - self.addConstraintToRow(constraint=constraint, row=row) + self.addConstraintToRow(constraint=constraint, row=row, model_key=model_key) + + def getModelKeyFromName(self, name): + """ + Given parameter name, get the model index. + """ + if name in self.getParamNamesMain(): + return "standard" + elif name in self.getParamNamesPoly(): + return "poly" + elif name in self.getParamNamesMagnet(): + return "magnet" + else: + return "standard" def getRowFromName(self, name): """ - Given parameter name get the row number in self._model_model + Given parameter name, get the row number in a model. + The model is the main _model_model by default """ - for row in range(self._model_model.rowCount()): - row_name = self._model_model.item(row).text() + model_key = self.getModelKeyFromName(name) + model = self.model_dict[model_key] + + for row in range(model.rowCount()): + row_name = model.item(row).text() + if model_key == 'poly': + row_name = self.polyNameToParam(row_name) if row_name == name: return row return None def getParamNames(self): """ - Return list of all parameters for the current model + Return list of all active parameters for the current model + """ + main_model_params = self.getParamNamesMain() + poly_model_params = self.getParamNamesPoly() + # magnet_model_params = self.getParamNamesMagnet() + return main_model_params + poly_model_params # + magnet_model_params + + def getParamNamesMain(self): + """ + Return list of main parameters for the current model """ - return [self._model_model.item(row).text() - for row in range(self._model_model.rowCount()) - if self.isCheckable(row)] + main_model_params = [self._model_model.item(row).text() + for row in range(self._model_model.rowCount()) + if self.isCheckable(row, model_key="standard")] + return main_model_params - def modifyViewOnRow(self, row, font=None, brush=None): + def getParamNamesPoly(self): + """ + Return list of polydisperse parameters for the current model + """ + if not self.chkPolydispersity.isChecked(): + return [] + poly_model_params = [self.polyNameToParam(self._poly_model.item(row).text()) + for row in range(self._poly_model.rowCount()) + if self.chkPolydispersity.isChecked() and + self.isCheckable(row, model_key="poly")] + return poly_model_params + + def getParamNamesMagnet(self): + """ + Return list of magnetic parameters for the current model + """ + if not self.chkMagnetism.isChecked(): + return [] + magnetic_model_params = [self._magnet_model.item(row).text() + for row in range(self._magnet_model.rowCount()) + if self.isCheckable(row, model_key="magnet")] + return magnetic_model_params + + def polyParamToName(self, param_name): + """ + Translate polydisperse parameter name into QTable representation + """ + param_name = param_name.replace('.width', '') + param_name = 'Distribution of ' + param_name + return param_name + + def polyNameToParam(self, param_name): + """ + Translate polydisperse QTable representation into parameter name + """ + param_name = param_name.replace('Distribution of ', '') + param_name += '.width' + return param_name + + def modifyViewOnRow(self, row, font=None, brush=None, model_key="standard"): """ Change how the given row of the main model is shown """ + model = self.model_dict[model_key] fields_enabled = False if font is None: font = QtGui.QFont() @@ -807,25 +911,38 @@ def modifyViewOnRow(self, row, font=None, brush=None): if brush is None: brush = QtGui.QBrush() fields_enabled = True - self._model_model.blockSignals(True) + model.blockSignals(True) # Modify font and foreground of affected rows - for column in range(0, self._model_model.columnCount()): - self._model_model.item(row, column).setForeground(brush) - self._model_model.item(row, column).setFont(font) + for column in range(0, model.columnCount()): + model.item(row, column).setForeground(brush) + model.item(row, column).setFont(font) # Allow the user to interact or not with the fields depending on # whether the parameter is constrained or not - self._model_model.item(row, column).setEditable(fields_enabled) + model.item(row, column).setEditable(fields_enabled) # Force checkbox selection when parameter is constrained and disable # checkbox interaction - if not fields_enabled and self._model_model.item(row, 0).isCheckable(): - self._model_model.item(row, 0).setCheckState(2) - self._model_model.item(row, 0).setEnabled(False) + if not fields_enabled and model.item(row, 0).isCheckable(): + model.item(row, 0).setCheckState(2) + model.item(row, 0).setEnabled(False) else: # Enable checkbox interaction - self._model_model.item(row, 0).setEnabled(True) - self._model_model.blockSignals(False) + model.item(row, 0).setEnabled(True) + model.blockSignals(False) + + def getModelKey(self, constraint): + """ + Given parameter name get the model index. + """ + if constraint.param in self.getParamNamesMain(): + return "standard" + elif constraint.param in self.getParamNamesPoly(): + return "poly" + elif constraint.param in self.getParamNamesMagnet(): + return "magnet" + else: + return None - def addConstraintToRow(self, constraint=None, row=0): + def addConstraintToRow(self, constraint=None, row=0, model_key="standard"): """ Adds the constraint object to requested row. The constraint is first checked for errors, and a message box interrupting flow is @@ -833,8 +950,9 @@ def addConstraintToRow(self, constraint=None, row=0): """ # Create a new item and add the Constraint object as a child assert isinstance(constraint, Constraint) - assert 0 <= row <= self._model_model.rowCount() - assert self.isCheckable(row) + model = self.model_dict[model_key] + assert 0 <= row <= model.rowCount() + assert self.isCheckable(row, model_key=model_key) # Error checking # First, get a list of constraints and symbols @@ -859,19 +977,20 @@ def addConstraintToRow(self, constraint=None, row=0): # constraint tab that the constraint was not accepted constraint_tab.constraint_accepted = False return + item = QtGui.QStandardItem() item.setData(constraint) - self._model_model.item(row, 1).setChild(0, item) + model.item(row, 1).setChild(0, item) # Set min/max to the value constrained - self.constraintAddedSignal.emit([row]) + self.constraintAddedSignal.emit([row], model_key) # Show visual hints for the constraint font = QtGui.QFont() font.setItalic(True) brush = QtGui.QBrush(QtGui.QColor('blue')) - self.modifyViewOnRow(row, font=font, brush=brush) + self.modifyViewOnRow(row, font=font, brush=brush, model_key=model_key) # update the main parameter list so the constrained parameter gets # updated when fitting - self.checkboxSelected(self._model_model.item(row, 0)) + self.checkboxSelected(model.item(row, 0), model_key=model_key) self.communicate.statusBarUpdateSignal.emit('Constraint added') if constraint_tab: # Set the constraint_accepted flag to True to inform the @@ -882,45 +1001,49 @@ def addSimpleConstraint(self): """ Adds a constraint on a single parameter. """ + model_key = self.tabToKey[self.tabFitting.currentIndex()] + model = self.model_dict[model_key] min_col = self.lstParams.itemDelegate().param_min max_col = self.lstParams.itemDelegate().param_max - for row in self.selectedParameters(): - assert(self.isCheckable(row)) - param = self._model_model.item(row, 0).text() - value = self._model_model.item(row, 1).text() - min_t = self._model_model.item(row, min_col).text() - max_t = self._model_model.item(row, max_col).text() + for row in self.selectedParameters(model_key=model_key): + param = model.item(row, 0).text() + value = model.item(row, 1).text() + min_t = model.item(row, min_col).text() + max_t = model.item(row, max_col).text() # Create a Constraint object constraint = Constraint(param=param, value=value, min=min_t, max=max_t) # Create a new item and add the Constraint object as a child item = QtGui.QStandardItem() item.setData(constraint) - self._model_model.item(row, 1).setChild(0, item) + model.item(row, 1).setChild(0, item) # Assumed correctness from the validator value = float(value) # BUMPS calculates log(max-min) without any checks, so let's assign minor range min_v = value - (value/10000.0) max_v = value + (value/10000.0) # Set min/max to the value constrained - self._model_model.item(row, min_col).setText(str(min_v)) - self._model_model.item(row, max_col).setText(str(max_v)) - self.constraintAddedSignal.emit([row]) + model.item(row, min_col).setText(str(min_v)) + model.item(row, max_col).setText(str(max_v)) + self.constraintAddedSignal.emit([row], model_key) # Show visual hints for the constraint font = QtGui.QFont() font.setItalic(True) brush = QtGui.QBrush(QtGui.QColor('blue')) - self.modifyViewOnRow(row, font=font, brush=brush) + self.modifyViewOnRow(row, font=font, brush=brush, model_key=model_key) self.communicate.statusBarUpdateSignal.emit('Constraint added') def editConstraint(self): """ - Delete constraints from selected parameters. + Edit constraints for selected parameters. """ - params_list = [s.data() for s in self.lstParams.selectionModel().selectedRows() - if self.isCheckable(s.row())] + current_list = self.tabToList[self.tabFitting.currentIndex()] + model_key = self.tabToKey[self.tabFitting.currentIndex()] + + params_list = [s.data(role=QtCore.Qt.UserRole) for s in current_list.selectionModel().selectedRows() + if self.isCheckable(s.row(), model_key=model_key)] assert len(params_list) == 1 - row = self.lstParams.selectionModel().selectedRows()[0].row() - constraint = self.getConstraintForRow(row) + row = current_list.selectionModel().selectedRows()[0].row() + constraint = self.getConstraintForRow(row, model_key=model_key) # Create and display the widget for param1 and param2 mc_widget = MultiConstraint(self, params=params_list, constraint=constraint) # Check if any of the parameters are polydisperse @@ -952,31 +1075,33 @@ def editConstraint(self): row = self.getRowFromName(constraint.param) # Create a new item and add the Constraint object as a child - self.addConstraintToRow(constraint=constraint, row=row) + self.addConstraintToRow(constraint=constraint, row=row, model_key=model_key) def deleteConstraint(self): """ Delete constraints from selected parameters. """ - params = [s.data() for s in self.lstParams.selectionModel().selectedRows() - if self.isCheckable(s.row())] + current_list = self.tabToList[self.tabFitting.currentIndex()] + model_key = self.tabToKey[self.tabFitting.currentIndex()] + params = [s.data(role=QtCore.Qt.UserRole) for s in current_list.selectionModel().selectedRows() + if self.isCheckable(s.row(), model_key=model_key)] for param in params: - self.deleteConstraintOnParameter(param=param) + self.deleteConstraintOnParameter(param=param, model_key=model_key) - def deleteConstraintOnParameter(self, param=None): + def deleteConstraintOnParameter(self, param=None, model_key="standard"): """ Delete the constraint on model parameter 'param' """ - min_col = self.lstParams.itemDelegate().param_min - max_col = self.lstParams.itemDelegate().param_max - for row in range(self._model_model.rowCount()): - if not self.isCheckable(row): + param_list = self.lst_dict[model_key] + model = self.model_dict[model_key] + for row in range(model.rowCount()): + if not self.isCheckable(row, model_key=model_key): continue - if not self.rowHasConstraint(row): + if not self.rowHasConstraint(row, model_key=model_key): continue # Get the Constraint object from of the model item - item = self._model_model.item(row, 1) - constraint = self.getConstraintForRow(row) + item = model.item(row, 1) + constraint = self.getConstraintForRow(row, model_key=model_key) if constraint is None: continue if not isinstance(constraint, Constraint): @@ -986,23 +1111,32 @@ def deleteConstraintOnParameter(self, param=None): # Now we got the right row. Delete the constraint and clean up # Retrieve old values and put them on the model if constraint.min is not None: - self._model_model.item(row, min_col).setText(constraint.min) + try: + min_col = param_list.itemDelegate().param_min + except AttributeError: + min_col = 2 + model.item(row, min_col).setText(constraint.min) if constraint.max is not None: - self._model_model.item(row, max_col).setText(constraint.max) + try: + max_col = param_list.itemDelegate().param_max + except AttributeError: + max_col = 3 + model.item(row, max_col).setText(constraint.max) # Remove constraint item item.removeRow(0) - self.constraintAddedSignal.emit([row]) - self.modifyViewOnRow(row) + self.constraintAddedSignal.emit([row], model_key) + self.modifyViewOnRow(row, model_key=model_key) self.communicate.statusBarUpdateSignal.emit('Constraint removed') - def getConstraintForRow(self, row): + def getConstraintForRow(self, row, model_key="standard"): """ For the given row, return its constraint, if any (otherwise None) """ - if not self.isCheckable(row): + model = self.model_dict[model_key] + if not self.isCheckable(row, model_key=model_key): return None - item = self._model_model.item(row, 1) + item = model.item(row, 1) try: return item.child(0).data() except AttributeError: @@ -1013,35 +1147,46 @@ def allParamNames(self): Returns a list of all parameter names defined on the current model """ all_params = self.kernel_module._model_info.parameters.kernel_parameters - all_param_names = [param.name for param in all_params] + all_params = list(self.kernel_module.details.keys()) + + # all_param_names = [param.name for param in all_params] # Assure scale and background are always included - if 'scale' not in all_param_names: - all_param_names.append('scale') - if 'background' not in all_param_names: - all_param_names.append('background') - return all_param_names + # if 'scale' not in all_param_names: + # all_param_names.append('scale') + # if 'background' not in all_param_names: + # all_param_names.append('background') + return all_params def paramHasConstraint(self, param=None): """ - Finds out if the given parameter in the main model has a constraint child + Finds out if the given parameter in all the models has a constraint child """ - if param is None: return False - if param not in self.allParamNames(): return False + if param is None: + return False + if param not in self.allParamNames(): + return False - for row in range(self._model_model.rowCount()): - if self._model_model.item(row,0).text() != param: continue - return self.rowHasConstraint(row) + for model_key in self.model_dict.keys(): + for row in range(self.model_dict[model_key].rowCount()): + param_name = self.model_dict[model_key].item(row,0).text() + if model_key == 'poly': + param_name = self.polyNameToParam(param_name) + if param_name != param: + continue + return self.rowHasConstraint(row, model_key=model_key) # nothing found return False - def rowHasConstraint(self, row): + def rowHasConstraint(self, row, model_key="standard"): """ Finds out if row of the main model has a constraint child """ - if not self.isCheckable(row): + model = self.model_dict[model_key] + + if not self.isCheckable(row, model_key=model_key): return False - item = self._model_model.item(row, 1) + item = model.item(row, 1) if not item.hasChildren(): return False c = item.child(0).data() @@ -1049,13 +1194,14 @@ def rowHasConstraint(self, row): return True return False - def rowHasActiveConstraint(self, row): + def rowHasActiveConstraint(self, row, model_key="standard"): """ Finds out if row of the main model has an active constraint child """ - if not self.isCheckable(row): + model = self.model_dict[model_key] + if not self.isCheckable(row, model_key=model_key): return False - item = self._model_model.item(row, 1) + item = model.item(row, 1) if not item.hasChildren(): return False c = item.child(0).data() @@ -1063,13 +1209,14 @@ def rowHasActiveConstraint(self, row): return True return False - def rowHasActiveComplexConstraint(self, row): + def rowHasActiveComplexConstraint(self, row, model_key="standard"): """ Finds out if row of the main model has an active, nontrivial constraint child """ - if not self.isCheckable(row): + model = self.model_dict[model_key] + if not self.isCheckable(row, model_key=model_key): return False - item = self._model_model.item(row, 1) + item = model.item(row, 1) if not item.hasChildren(): return False c = item.child(0).data() @@ -1082,23 +1229,27 @@ def selectParameters(self): Selected parameter is chosen for fitting """ status = QtCore.Qt.Checked - item = self._model_model.itemFromIndex(self.lstParams.currentIndex()) - self.setParameterSelection(status, item=item) + model_key = self.tabToKey[self.tabFitting.currentIndex()] + model = self.model_dict[model_key] + item = model.itemFromIndex(self.lstParams.currentIndex()) + self.setParameterSelection(status, item=item, model_key=model_key) def deselectParameters(self): """ Selected parameters are removed for fitting """ status = QtCore.Qt.Unchecked - item = self._model_model.itemFromIndex(self.lstParams.currentIndex()) - self.setParameterSelection(status, item=item) + model_key = self.tabToKey[self.tabFitting.currentIndex()] + model = self.model_dict[model_key] + item = model.itemFromIndex(self.lstParams.currentIndex()) + self.setParameterSelection(status, item=item, model_key=model_key) - def selectedParameters(self): + def selectedParameters(self, model_key="standard"): """ Returns list of selected (highlighted) parameters """ - return [s.row() for s in self.lstParams.selectionModel().selectedRows() - if self.isCheckable(s.row())] + return [s.row() for s in self.lst_dict[model_key].selectionModel().selectedRows() + if self.isCheckable(s.row(), model_key=model_key)] - def setParameterSelection(self, status=QtCore.Qt.Unchecked, item=None): + def setParameterSelection(self, status=QtCore.Qt.Unchecked, item=None, model_key="standard"): """ Selected parameters are chosen for fitting """ @@ -1109,57 +1260,95 @@ def setParameterSelection(self, status=QtCore.Qt.Unchecked, item=None): # `item` is also selected! # Otherwise things get confusing. # https://github.com/SasView/sasview/issues/1676 - if item.row() not in self.selectedParameters(): + if item.row() not in self.selectedParameters(model_key=model_key): return - for row in self.selectedParameters(): - self._model_model.item(row, 0).setCheckState(status) + for row in self.selectedParameters(model_key=model_key): + self.model_dict[model_key].item(row, 0).setCheckState(status) - def getConstraintsForModel(self): + def getConstraintsForAllModels(self): """ Return a list of tuples. Each tuple contains constraints mapped as ('constrained parameter', 'function to constrain') e.g. [('sld','5*sld_solvent')] """ - param_number = self._model_model.rowCount() - params = [(self._model_model.item(s, 0).text(), - self._model_model.item(s, 1).child(0).data().func) - for s in range(param_number) if self.rowHasActiveConstraint(s)] + params = [] + for model_key in self.model_dict.keys(): + model = self.model_dict[model_key] + param_number = model.rowCount() + if model_key == 'poly': + params += [(self.polyNameToParam(model.item(s, 0).text()), + model.item(s, 1).child(0).data().func) + for s in range(param_number) if self.rowHasActiveConstraint(s, model_key=model_key)] + else: + params += [(model.item(s, 0).text(), + model.item(s, 1).child(0).data().func) + for s in range(param_number) if self.rowHasActiveConstraint(s, model_key=model_key)] return params - def getComplexConstraintsForModel(self): + def getComplexConstraintsForAllModels(self): + """ + Returns a list of tuples containing all the constraints defined + for a given FitPage + """ + constraints = [] + for model_key in self.model_dict.keys(): + constraints += self.getComplexConstraintsForModel(model_key=model_key) + return constraints + + def getComplexConstraintsForModel(self, model_key): """ Return a list of tuples. Each tuple contains constraints mapped as ('constrained parameter', 'function to constrain') e.g. [('sld','5*M2.sld_solvent')]. Only for constraints with defined VALUE """ - param_number = self._model_model.rowCount() - params = [(self._model_model.item(s, 0).text(), - self._model_model.item(s, 1).child(0).data().func) - for s in range(param_number) if self.rowHasActiveComplexConstraint(s)] + model = self.model_dict[model_key] + params = [] + param_number = model.rowCount() + for s in range(param_number): + if self.rowHasActiveComplexConstraint(s, model_key): + if model.item(s, 0).data(role=QtCore.Qt.UserRole): + parameter_name = str(model.item(s, 0).data(role=QtCore.Qt.UserRole)) + else: + parameter_name = str(model.item(s, 0).data(0)) + params.append((parameter_name, model.item(s, 1).child(0).data().func)) return params - def getFullConstraintNameListForModel(self): + def getFullConstraintNameListForModel(self, model_key): """ Return a list of tuples. Each tuple contains constraints mapped as ('constrained parameter', 'function to constrain') e.g. [('sld','5*M2.sld_solvent')]. Returns a list of all constraints, not only active ones """ - param_number = self._model_model.rowCount() - params = [(self._model_model.item(s, 0).text(), - self._model_model.item(s, 1).child(0).data().func) - for s in range(param_number) if self.rowHasConstraint(s)] + model = self.model_dict[model_key] + param_number = model.rowCount() + params = list() + for s in range(param_number): + if self.rowHasConstraint(s, model_key=model_key): + param_name = model.item(s, 0).text() + if model_key == 'poly': + param_name = self.polyNameToParam(model.item(s, 0).text()) + params.append((param_name, model.item(s, 1).child(0).data().func)) return params - def getConstraintObjectsForModel(self): + def getConstraintObjectsForAllModels(self): """ - Returns Constraint objects present on the whole model + Returns a list of the constraint object for a given FitPage """ - param_number = self._model_model.rowCount() - constraints = [self._model_model.item(s, 1).child(0).data() - for s in range(param_number) if self.rowHasConstraint(s)] + constraints = [] + for model_key in self.model_dict.keys(): + constraints += self.getConstraintObjectsForModel(model_key=model_key) + return constraints + def getConstraintObjectsForModel(self, model_key): + """ + Returns Constraint objects present on the whole model + """ + model = self.model_dict[model_key] + param_number = model.rowCount() + constraints = [model.item(s, 1).child(0).data() + for s in range(param_number) if self.rowHasConstraint(s, model_key=model_key)] return constraints def getConstraintsForFitting(self): @@ -1167,7 +1356,9 @@ def getConstraintsForFitting(self): Return a list of constraints in format ready for use in fiting """ # Get constraints - constraints = self.getComplexConstraintsForModel() + constraints = [] + for model_key in self.model_dict.keys(): + constraints += self.getComplexConstraintsForModel(model_key=model_key) # See if there are any constraints across models multi_constraints = [cons for cons in constraints if self.isConstraintMultimodel(cons[1])] @@ -1196,19 +1387,20 @@ def getConstraintsForFitting(self): for cons in multi_constraints: # deactivate the constraint row = self.getRowFromName(cons[0]) - self.getConstraintForRow(row).active = False + model_key = self.getModelKeyFromName(cons[0]) + self.getConstraintForRow(row, model_key=model_key).active = False # uncheck in the constraint tab if constraint_tab: constraint_tab.uncheckConstraint( self.kernel_module.name + ':' + cons[0]) # re-read the constraints - constraints = self.getComplexConstraintsForModel() + constraints = self.getComplexConstraintsForModel(model_key=model_key) return constraints def showModelDescription(self): """ - Creates a window with model description, when right clicked in the treeview + Creates a window with model description, when right-clicked in the treeview """ msg = 'Model description:\n' if self.kernel_module is not None: @@ -1350,16 +1542,19 @@ def onSelectionChanged(self): """ React to parameter selection """ - rows = self.lstParams.selectionModel().selectedRows() + current_list = self.tabToList[self.tabFitting.currentIndex()] + model_key = self.tabToKey[self.tabFitting.currentIndex()] + + rows = current_list.selectionModel().selectedRows() # Clean previous messages self.communicate.statusBarUpdateSignal.emit("") if len(rows) == 1: # Show constraint, if present row = rows[0].row() - if not self.rowHasConstraint(row): + if not self.rowHasConstraint(row, model_key=model_key): return - constr = self.getConstraintForRow(row) - func = self.getConstraintForRow(row).func + constr = self.getConstraintForRow(row, model_key=model_key) + func = self.getConstraintForRow(row, model_key=model_key).func if constr.func is not None: # inter-parameter constraint update_text = "Active constraint: "+func @@ -1488,17 +1683,18 @@ def onPolyModelChange(self, top, bottom): model_row = item.row() name_index = self._poly_model.index(model_row, 0) parameter_name = str(name_index.data()) # "distribution of sld" etc. - if "istribution of" in parameter_name: - # just the last word - parameter_name = parameter_name.rsplit()[-1] + parameter_name_w = self.polyNameToParam(parameter_name) + # Needs to retrieve also name of main parameter in order to update + # corresponding values in FitPage + parameter_name = parameter_name.rsplit()[-1] delegate = self.lstPoly.itemDelegate() - parameter_name_w = parameter_name + '.width' # Extract changed value. if model_column == delegate.poly_parameter: # Is the parameter checked for fitting? value = item.checkState() + if value == QtCore.Qt.Checked: self.poly_params_to_fit.append(parameter_name_w) else: @@ -1537,17 +1733,17 @@ def onPolyModelChange(self, top, bottom): # PD[ratio] -> width, npts -> npts, nsigs -> nsigmas if model_column not in delegate.columnDict(): return - key = parameter_name + '.' + delegate.columnDict()[model_column] - self.poly_params[key] = value - self.kernel_module.setParam(key, value) + self.poly_params[parameter_name_w] = value + self.kernel_module.setParam(parameter_name_w, value) # Update plot self.updateData() # update in param model if model_column in [delegate.poly_pd, delegate.poly_error, delegate.poly_min, delegate.poly_max]: + model_key = self.getModelKeyFromName(parameter_name) row = self.getRowFromName(parameter_name) - param_item = self._model_model.item(row).child(0).child(0, model_column) + param_item = self.model_dict[model_key].item(row).child(0).child(0, model_column) if param_item is None: return self._model_model.blockSignals(True) @@ -1576,7 +1772,7 @@ def onMagnetModelChange(self, top, bottom): self.updateUndo() return - # Extract changed value. + # Extract changed value try: value = GuiUtils.toDouble(item.text()) except TypeError: @@ -1597,6 +1793,9 @@ def onMagnetModelChange(self, top, bottom): self.kernel_module.details[parameter_name][pos] = value else: self.magnet_params[parameter_name] = value + #self.kernel_module.setParam(parameter_name) = value + # Force the chart update when actual parameters changed + self.recalculatePlotData() # Update state stack self.updateUndo() @@ -1885,14 +2084,20 @@ def prepareFitters(self, fitter=None, fit_id=0, weight_increase=1): params_to_fit = copy.deepcopy(self.main_params_to_fit) if self.chkPolydispersity.isChecked(): - params_to_fit += self.poly_params_to_fit + for p in self.poly_params_to_fit: + if "Distribution of" in p: + params_to_fit += [self.polyNameToParam(p)] + else: + params_to_fit += [p] if self.chkMagnetism.isChecked() and self.canHaveMagnetism(): params_to_fit += self.magnet_params_to_fit if not params_to_fit: raise ValueError('Fitting requires at least one parameter to optimize.') # Get the constraints. - constraints = self.getComplexConstraintsForModel() + constraints = [] + for model_key in self.model_dict.keys(): + constraints += self.getComplexConstraintsForModel(model_key=model_key) if fitter is None: # For single fits - check for inter-model constraints constraints = self.getConstraintsForFitting() @@ -2349,6 +2554,8 @@ def addBackgroundToModel(self, model): last_row = model.rowCount()-1 model.item(last_row, 0).setEditable(False) model.item(last_row, 4).setEditable(False) + model.item(last_row,0).setData('background', role=QtCore.Qt.UserRole) + def addScaleToModel(self, model): """ @@ -2360,6 +2567,7 @@ def addScaleToModel(self, model): last_row = model.rowCount()-1 model.item(last_row, 0).setEditable(False) model.item(last_row, 4).setEditable(False) + model.item(last_row,0).setData('scale', role=QtCore.Qt.UserRole) def addWeightingToData(self, data): """ @@ -2625,7 +2833,7 @@ def onMainParamsChange(self, top, bottom): model_column = item.column() if model_column == 0: - self.checkboxSelected(item) + self.checkboxSelected(item, model_key="standard") self.cmdFit.setEnabled(self.haveParamsToFit()) # Update state stack self.updateUndo() @@ -2737,22 +2945,27 @@ def setParamEditableByRow(self, row, editable=True): item_name.setCheckState(QtCore.Qt.Unchecked) item_name.setCheckable(False) - def isCheckable(self, row): - return self._model_model.item(row, 0).isCheckable() + def isCheckable(self, row, model_key="standard"): + model = self.model_dict[model_key] + if model.item(row,0) is None: + return False + return model.item(row, 0).isCheckable() - def changeCheckboxStatus(self, row, checkbox_status): + def changeCheckboxStatus(self, row, checkbox_status, model_key="standard"): """ Checks/unchecks the checkbox at given row """ - assert 0<= row <= self._model_model.rowCount() - index = self._model_model.index(row, 0) - item = self._model_model.itemFromIndex(index) + model = self.model_dict[model_key] + + assert 0<= row <= model.rowCount() + index = model.index(row, 0) + item = model.itemFromIndex(index) if checkbox_status: item.setCheckState(QtCore.Qt.Checked) else: item.setCheckState(QtCore.Qt.Unchecked) - def checkboxSelected(self, item): + def checkboxSelected(self, item, model_key="standard"): # Assure we're dealing with checkboxes if not item.isCheckable(): return @@ -2762,22 +2975,31 @@ def checkboxSelected(self, item): # Convert to proper indices and set requested enablement # Careful with `item` NOT being selected. This means we only want to # select that one item. - self.setParameterSelection(status, item=item) + self.setParameterSelection(status, item=item, model_key=model_key) # update the list of parameters to fit - self.main_params_to_fit = self.checkedListFromModel(self._model_model) + self.main_params_to_fit = self.checkedListFromModel("standard") + self.poly_params_to_fit = self.checkedListFromModel("poly") + self.magnet_params_to_fit = self.checkedListFromModel("magnet") - def checkedListFromModel(self, model): + def checkedListFromModel(self, model_key): """ Returns list of checked parameters for given model """ def isChecked(row): + model = self.model_dict[model_key] return model.item(row, 0).checkState() == QtCore.Qt.Checked - return [str(model.item(row_index, 0).text()) - for row_index in range(model.rowCount()) - if isChecked(row_index)] + model = self.model_dict[model_key] + if model_key == "poly": + return [self.polyNameToParam(str(model.item(row_index, 0).text())) + for row_index in range(model.rowCount()) + if isChecked(row_index)] + else: + return [str(model.item(row_index, 0).text()) + for row_index in range(model.rowCount()) + if isChecked(row_index)] def createNewIndex(self, fitted_data): """ Create a model or theory index with passed Data1D/Data2D @@ -2937,7 +3159,7 @@ def completed2D(self, return_data): def _appendPlotsPolyDisp(self, new_plots, return_data, fitted_data): """ Internal helper for 1D and 2D for creating plots of the polydispersity distribution for - parameters which have a polydispersity enabled. + parameters which have a polydispersity enabled """ for plot in FittingUtilities.plotPolydispersities(return_data.get('model', None)): data_id = fitted_data.id.split() @@ -3186,7 +3408,7 @@ def setPolyModelParameters(self, i, param): Standard of multishell poly parameter driver """ param_name = param.name - # see it the parameter is multishell + # see it if the parameter is multishell if '[' in param.name: # Skip empty shells if self.current_shell_displayed == 0: @@ -3225,6 +3447,9 @@ def addNameToPolyModel(self, i, param_name): str(npts), str(nsigs), "gaussian ",''] FittingUtilities.addCheckedListToModel(self._poly_model, checked_list) + all_items = self._poly_model.rowCount() + self._poly_model.item(all_items-1,0).setData(param_wname, role=QtCore.Qt.UserRole) + # All possible polydisp. functions as strings in combobox func = QtWidgets.QComboBox() func.addItems([str(name_disp) for name_disp in POLYDISPERSITY_MODELS.keys()]) @@ -3433,6 +3658,8 @@ def addCheckedMagneticListToModel(self, param, value): self.magnet_params[param.name] = value FittingUtilities.addCheckedListToModel(self._magnet_model, checked_list) + all_items = self._magnet_model.rowCount() + self._magnet_model.item(all_items-1,0).setData(param.name, role=QtCore.Qt.UserRole) def enableStructureFactorControl(self, structure_factor): """ @@ -3765,7 +3992,7 @@ def saveToFitPage(self, fp): fp.smearing_options[fp.SMEARING_MIN] = smearing_min fp.smearing_options[fp.SMEARING_MAX] = smearing_max - # TODO: add polidyspersity and magnetism + # TODO: add polydispersity and magnetism def updateUndo(self): """ @@ -4048,13 +4275,16 @@ def gatherParams(row): Create list of main parameters based on _model_model """ param_name = str(self._model_model.item(row, 0).text()) - + current_list = self.tabToList[self.tabFitting.currentIndex()] + model = self._model_model + if model.item(row, 0) is None: + return # Assure this is a parameter - must contain a checkbox - if not self._model_model.item(row, 0).isCheckable(): + if not model.item(row, 0).isCheckable(): # maybe it is a combobox item (multiplicity) try: - index = self._model_model.index(row, 1) - widget = self.lstParams.indexWidget(index) + index = model.index(row, 1) + widget = current_list.indexWidget(index) if widget is None: return if isinstance(widget, QtWidgets.QComboBox): @@ -4065,23 +4295,23 @@ def gatherParams(row): pass return - param_checked = str(self._model_model.item(row, 0).checkState() == QtCore.Qt.Checked) + param_checked = str(model.item(row, 0).checkState() == QtCore.Qt.Checked) # Value of the parameter. In some cases this is the text of the combobox choice. - param_value = str(self._model_model.item(row, 1).text()) + param_value = str(model.item(row, 1).text()) param_error = None param_min = None param_max = None column_offset = 0 if self.has_error_column: column_offset = 1 - param_error = str(self._model_model.item(row, 1+column_offset).text()) + param_error = str(model.item(row, 1+column_offset).text()) try: - param_min = str(self._model_model.item(row, 2+column_offset).text()) - param_max = str(self._model_model.item(row, 3+column_offset).text()) + param_min = str(model.item(row, 2+column_offset).text()) + param_max = str(model.item(row, 3+column_offset).text()) except: pass # Do we have any constraints on this parameter? - constraint = self.getConstraintForRow(row) + constraint = self.getConstraintForRow(row, model_key="standard") cons = () if constraint is not None: value = constraint.value @@ -4453,6 +4683,7 @@ def getSymbolDict(self): return sym_dict model_name = self.kernel_module.name for param in self.getParamNames(): + model_key = self.getModelKeyFromName(param) sym_dict[f"{model_name}.{param}"] = GuiUtils.toDouble( - self._model_model.item(self.getRowFromName(param), 1).text()) + self.model_dict[model_key].item(self.getRowFromName(param), 1).text()) return sym_dict diff --git a/src/sas/qtgui/Perspectives/Inversion/InversionPerspective.py b/src/sas/qtgui/Perspectives/Inversion/InversionPerspective.py index d5247ae452..438d838b54 100644 --- a/src/sas/qtgui/Perspectives/Inversion/InversionPerspective.py +++ b/src/sas/qtgui/Perspectives/Inversion/InversionPerspective.py @@ -699,7 +699,7 @@ def removeData(self, data_list=None): self.nTermsSuggested)) self.regConstantSuggestionButton.setText("{:-3.2g}".format( REGULARIZATION)) - self.updateGuiValues() + # self.updateGuiValues() self.setupModel() else: self.dataList.setCurrentIndex(0) diff --git a/src/sas/sascalc/fit/BumpsFitting.py b/src/sas/sascalc/fit/BumpsFitting.py index b5d51d630b..44bd8c2758 100644 --- a/src/sas/sascalc/fit/BumpsFitting.py +++ b/src/sas/sascalc/fit/BumpsFitting.py @@ -282,6 +282,7 @@ def fit(self, msg_q=None, values, errs, cov = result['value'], result['stderr'], result[ 'covariance'] assert values is not None and errs is not None + assert len(values) == cov.shape[0] == cov.shape[1] # Propagate uncertainty through the parameter expressions # We are going to abuse bumps a little here and stuff uncertainty @@ -297,24 +298,24 @@ def fit(self, msg_q=None, param.value = uncertainties.ufloat(val, err) else: try: - uncertainties.correlated_values(values, cov) - except: - # No convergance - for param, val, err in zip(varying, values, errs): - # Convert all varying parameters to uncertainties objects - param.value = uncertainties.ufloat(val, err) - else: # Use the covariance matrix to calculate error in the parameter fitted = uncertainties.correlated_values(values, cov) for param, val in zip(varying, fitted): param.value = val + except Exception: + # No convergence. Convert all varying parameters to uncertainties objects + for param, val, err in zip(varying, values, errs): + param.value = uncertainties.ufloat(val, err) # Propagate correlated uncertainty through constraints. problem.setp_hook() - # collect the results + # Collect the results all_results = [] + # Check if uncertainty is missing for any parameter + uncertainty_warning = False + for fitting_module in problem.models: fitness = fitting_module.fitness pars = fitness.fitted_pars + fitness.computed_pars @@ -331,69 +332,31 @@ def fit(self, msg_q=None, if result['uncertainty'] is not None: fitting_result.uncertainty_state = result['uncertainty'] - if fitting_result.success: - pvec = list() - stderr = list() - for p in pars: - # If p is already defined as an uncertainties object it is not constrained based on another - # parameter - if isinstance(p.value, uncertainties.core.Variable) or \ - isinstance(p.value, uncertainties.core.AffineScalarFunc): - # value.n returns value p - pvec.append(p.value.n) - # value.n returns error in p - stderr.append(p.value.s) - # p constrained based on another parameter - else: - # Details of p - param_model, param_name = p.name.split(".")[0], p.name.split(".")[1] - # Constraints applied on p, list comprehension most efficient method, will always return a - # list with 1 entry - constraints = [model.constraints for model in models if model.name == param_model][0] - # Parameters p is constrained on. - reference_params = [v for v in varying if str(v.name) in str(constraints[param_name])] - err_exp = str(constraints[param_name]) - # Convert string entries into variable names within the code. - for i, index in enumerate(reference_params): - err_exp = err_exp.replace(reference_params[index].name, f"reference_params[{index}].value") - try: - # Evaluate a string containing constraints as if it where a line of code - pvec.append(eval(err_exp).n) - stderr.append(eval(err_exp).s) - except NameError as e: - pvec.append(p.value) - stderr.append(0) - # Get model causing error - name_error = e.args[0].split()[1].strip("'") - # Safety net if following code does not work - error_param = name_error - # Get parameter causing error - constraints_sections = constraints[param_name].split(".") - for i in range(len(constraints_sections)): - if name_error in constraints_sections[i]: - error_param = f"{name_error}.{constraints_sections[i+1]}" - logging.error(f"Constraints ordered incorrectly. Attempting to constrain {p}, based on " - f"{error_param}, however {error_param} is not defined itself. This is " - f"because {error_param} is also constrained.\n" - f"The fitting will continue, but {name_error} will be incorrect.") - logging.error(e) - except Exception as e: - logging.error(e) - pvec.append(p.value) - stderr.append(0) - - fitting_result.pvec = (np.array(pvec)) - fitting_result.stderr = (np.array(stderr)) - DOF = max(1, fitness.numpoints() - len(fitness.fitted_pars)) - fitting_result.fitness = np.sum(fitting_result.residuals ** 2) / DOF - else: - fitting_result.pvec = np.asarray([p.value for p in pars]) - fitting_result.stderr = np.NaN * np.ones(len(pars)) + fitting_result.pvec = np.array([getattr(p.value, 'n', p.value) for p in pars]) + fitting_result.stderr = np.array([getattr(p.value, 's', 0) for p in pars]) + DOF = max(1, fitness.numpoints() - len(fitness.fitted_pars)) + fitting_result.fitness = np.sum(fitting_result.residuals ** 2) / DOF + + # Warn user about any parameter that is not an uncertainty object + miss_uncertainty = [p for p in pars if not isinstance(p.value, + (uncertainties.core.Variable, uncertainties.core.AffineScalarFunc))] + if miss_uncertainty: + uncertainty_warning = True + for p in miss_uncertainty: + logging.warn(p.name + " uncertainty could not be calculated.") + + # TODO: Let the GUI decided how to handle success/failure. + if not fitting_result.success: + fitting_result.stderr[:] = np.NaN fitting_result.fitness = np.NaN all_results.append(fitting_result) + all_results[0].mesg = result['errors'] + if uncertainty_warning: + logging.warn("Consider checking related constraint definitions and status of parameters used there.") + if q is not None: q.put(all_results) return q diff --git a/src/sas/sascalc/fit/expression.py b/src/sas/sascalc/fit/expression.py index 08129fdc8a..9c3c18e6bd 100644 --- a/src/sas/sascalc/fit/expression.py +++ b/src/sas/sascalc/fit/expression.py @@ -306,6 +306,7 @@ def _compile_constraints(symtab, exprs, context={}, html=False): if errors: return None, errors + #print(f"{symtab=}\n {deps=}\n {order=}\n") # Rather than using the full path to the parameters in the parameter # expressions, instead use Pn, and substitute Pn.value for each occurrence @@ -356,7 +357,6 @@ def order_dependencies(pairs): Order elements from pairs so that b comes before a in the ordered list for all pairs (a, b). """ - #print("order_dependencies", pairs) emptyset = set() order = [] @@ -364,7 +364,6 @@ def order_dependencies(pairs): # Note: pairs is array or list, so use "len(pairs) > 0" to check for empty. left, right = [set(s) for s in zip(*pairs)] if len(pairs) > 0 else ([], []) while len(pairs) > 0: - #print "within", pairs # Find which items only occur on the right independent = right - left if independent == emptyset: @@ -379,10 +378,9 @@ def order_dependencies(pairs): else: left, right = [set(s) for s in zip(*pairs)] resolved = dependent - left - #print "independent", independent, "dependent", dependent, "resolvable", resolved + order += resolved - #print "new order", order - order.reverse() + return order # ========= Test code ======== @@ -391,20 +389,21 @@ def _check(msg, pairs): Verify that the list n contains the given items, and that the list satisfies the partial ordering given by the pairs in partial order. """ + # pairs are a list of (lhs, rhs) + # lhs may be repeated e.g., x = a+b has pairs (x, a) and (x, b) + # lhs may be in rhs e.g., x = a+b; b = 2*c has pairs (x, a) (x, b), (b, c) + # find lhs eval order; since (x, b) is a pair then b must come before x # Note: pairs is array or list, so use "len(pairs) > 0" to check for empty. - left, right = zip(*pairs) if len(pairs) > 0 else ([], []) - items = set(left) - n = order_dependencies(pairs) - if set(n) != items or len(n) != len(items): - n.sort() - items = list(items) - items.sort() + lhs_list, rhs_list = zip(*pairs) if len(pairs) > 0 else ([], []) + items = set(lhs_list) # items = all LHS, removing duplicates + order = order_dependencies(pairs) + if set(order) != set(items) or len(order) != len(items): raise ValueError("%s expect %s to contain %s for %s" - % (msg, n, items, pairs)) - for lo, hi in pairs: - if lo in n and hi in n and n.index(lo) >= n.index(hi): + % (msg, order, sorted(items), pairs)) + for lhs, rhs in pairs: + if lhs in order and rhs in order and order.index(rhs) >= order.index(lhs): raise ValueError("%s expect %s before %s in %s for %s" - % (msg, lo, hi, n, pairs)) + % (msg, lhs, rhs, order, pairs)) def test_deps(): import numpy as np @@ -480,14 +479,21 @@ def world(*pars): p3_circular = TestParameter('M1.G1', expression='other + 6') p3_self = TestParameter('M1.G1', expression='M1.G1') p4 = TestParameter('constant', expression='2*pi*35') - # Simple chain + p5 = TestParameter('chain', expression='2+other') + # Simple pairs assert (set(_find_dependencies(*world(p1, p2, p3))) == set([(p2.path, p1.path), (p2.path, p3.path)])) + # Chain + assert (set(_find_dependencies(*world(p1, p2, p3, p5))) + == set([(p2.path, p1.path), (p2.path, p3.path), (p5.path, p2.path)])) # Constant expression assert set(_find_dependencies(*world(p1, p4))) == set([(p4.path, None)]) # No dependencies assert not set(_find_dependencies(*world(p1, p3))) + # Make sure 'other' is evaluated before 'chain' + assert (order_dependencies(_find_dependencies(*world(p1, p2, p3, p5))) + == [p2.path, p5.path]) # Check function builder fn = compile_constraints(*world(p1, p2, p3))