diff --git a/README.md b/README.md index ef541046..18084bbb 100644 --- a/README.md +++ b/README.md @@ -130,8 +130,9 @@ True ### Get info on the top n moves ```python -stockfish.get_top_moves(3) +stockfish.get_top_moves(3, include_info=False) ``` +Optional parameter `include_info` specifies whether to include the full info from the engine in the returned dictionary, including seldepth, multipv, time, nodes, nps, and wdl if available. Boolean. Default is `False`. ```text [ {'Move': 'f5h7', 'Centipawn': None, 'Mate': 1}, diff --git a/stockfish/models.py b/stockfish/models.py index b90a3d55..f2e4e438 100644 --- a/stockfish/models.py +++ b/stockfish/models.py @@ -514,7 +514,9 @@ def get_evaluation(self) -> dict: elif splitted_text[0] == "bestmove": return evaluation - def get_top_moves(self, num_top_moves: int = 5) -> List[dict]: + def get_top_moves( + self, num_top_moves: int = 5, include_info: bool = False + ) -> List[dict]: """Returns info on the top moves in the position. Args: @@ -522,10 +524,18 @@ def get_top_moves(self, num_top_moves: int = 5) -> List[dict]: The number of moves to return info on, assuming there are at least those many legal moves. + include_info: + Option to include the full info from the engine in the returned dictionary, + including seldepth, multipv, time, nodes, nps, and wdl if available. + Boolean. Default is False. + Returns: A list of dictionaries. In each dictionary, there are keys for Move, Centipawn, and Mate; the corresponding value for either the Centipawn or Mate key will be None. If there are no moves in the position, an empty list is returned. + + If include_info is True, the dictionary will also include the keys SelectiveDepth, Time, + Nodes, N/s, MultiPVLine, and WDL (if available). WDL is set from the White player's perspective. """ if num_top_moves <= 0: @@ -562,20 +572,44 @@ def get_top_moves(self, num_top_moves: int = 5) -> List[dict]: raise RuntimeError( "Having a centipawn value and mate value should be mutually exclusive." ) - top_moves.insert( - 0, - { - "Move": current_line[current_line.index("pv") + 1], - "Centipawn": int(current_line[current_line.index("cp") + 1]) - * multiplier - if has_centipawn_value - else None, - "Mate": int(current_line[current_line.index("mate") + 1]) - * multiplier - if has_mate_value - else None, - }, - ) + move_evaluation = { + "Move": current_line[current_line.index("pv") + 1], + "Centipawn": int(current_line[current_line.index("cp") + 1]) + * multiplier + if has_centipawn_value + else None, + "Mate": int(current_line[current_line.index("mate") + 1]) + * multiplier + if has_mate_value + else None, + } + if include_info: + move_evaluation.update( + { + "Nodes": current_line[current_line.index("nodes") + 1], + "N/s": current_line[current_line.index("nps") + 1], + "Time": current_line[current_line.index("time") + 1], + "SelectiveDepth": current_line[ + current_line.index("seldepth") + 1 + ], + "MultiPVLine": current_line[ + current_line.index("multipv") + 1 + ], + } + ) + if self.does_current_engine_version_have_wdl_option(): + move_evaluation.update( + { + "WDL": " ".join( + [ + current_line[current_line.index("wdl") + 1], + current_line[current_line.index("wdl") + 2], + current_line[current_line.index("wdl") + 3], + ][::multiplier] + ) + } + ) + top_moves.insert(0, move_evaluation) else: break if old_MultiPV_value != self._parameters["MultiPV"]: diff --git a/tests/stockfish/test_models.py b/tests/stockfish/test_models.py index 45269a44..a77082fc 100644 --- a/tests/stockfish/test_models.py +++ b/tests/stockfish/test_models.py @@ -571,6 +571,31 @@ def test_get_top_moves_mate(self, stockfish): assert stockfish.get_top_moves() == [] assert stockfish.get_parameters()["MultiPV"] == 3 + def test_get_top_moves_with_info(self, stockfish): + stockfish.set_depth(15) + stockfish._set_option("MultiPV", 4) + stockfish.set_fen_position("1rQ1r1k1/5ppp/8/8/1R6/8/2r2PPP/4R1K1 w - - 0 1") + assert stockfish.get_top_moves(2, include_info=False) == [ + {"Move": "e1e8", "Centipawn": None, "Mate": 1}, + {"Move": "c8e8", "Centipawn": None, "Mate": 2}, + ] + moves = stockfish.get_top_moves(2, include_info=True) + assert all( + k in moves[0] + for k in ( + "Move", + "Centipawn", + "Mate", + "MultiPVLine", + "N/s", + "Nodes", + "SelectiveDepth", + "Time", + ) + ) + if stockfish.does_current_engine_version_have_wdl_option(): + assert "WDL" in moves[0] + def test_get_top_moves_raising_error(self, stockfish): stockfish.set_fen_position( "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"