Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GTPの互換性向上(Lizzie用) #80

Merged
merged 2 commits into from
Dec 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion board/go_board.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""碁盤のデータ定義と操作処理。
"""
from typing import List, NoReturn
from typing import List, Tuple, NoReturn
from collections import deque
import numpy as np

Expand Down Expand Up @@ -469,6 +469,26 @@ def get_komi(self) -> float:
"""
return self.komi

def get_to_move(self) -> Stone:
"""手番の色を取得する。

Returns:
Stone: 手番の色。
"""
if self.moves == 1:
return Stone.BLACK
else:
last_move_color, _, _ = self.record.get(self.moves - 1)
return Stone.get_opponent_color(last_move_color)

def get_move_history(self) -> List[Tuple[Stone, int, np.array]]:
"""着手の履歴を取得する。

Returns:
[(Stone, int, np.array), ...]: (着手の色、座標、ハッシュ値) のリスト。
"""
return [self.record.get(m) for m in range(1, self.moves)]

def count_score(self) -> int: # pylint: disable=R0912
"""領地を簡易的にカウントする。

Expand Down
109 changes: 81 additions & 28 deletions gtp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from sgf.reader import SGFReader


gtp_command_id = ""

class GtpClient: # pylint: disable=R0902,R0903
"""_Go Text Protocolクライアントの実装クラス
Expand Down Expand Up @@ -163,6 +164,19 @@ def _play(self, color: str, pos: str) -> NoReturn:

respond_success("")

def _undo(self) -> NoReturn:
"""undoコマンドを処理する。
"""
# 一旦クリアして初手から直前手まで打ち直す非効率実装
history = self.board.get_move_history()
if not history:
respond_failure("cannot undo")
return
self._clear_board()
for (color, pos, _) in history[:-1]:
self.board.put_stone(pos, color)
respond_success("")

def _genmove(self, color: str) -> NoReturn:
"""genmoveコマンドを処理する。
入力された手番で思考し、着手を生成する。
Expand Down Expand Up @@ -295,25 +309,59 @@ def _load_sgf(self, arg_list: List[str]) -> NoReturn:

respond_success("")

def _decode_analyze_arg(self, arg_list: List[str]) -> (Stone, float):
"""analyzeコマンド(lz-analyze, cgos-analyze)の引数を解釈する。
不正な引数の場合は更新間隔として負値を返す。

Args:
arg_list (List[str]): コマンドの引数リスト。

Returns:
(Stone, float): 手番の色、更新間隔(秒)
"""
to_move = self.board.get_to_move()
interval = 0
error_value = (to_move, -1.0)
# 受けつける形式の例
# lz-analyze B 10
# lz-analyze B
# lz-analyze 10
# lz-analyze B interval 10
# lz-analyze interval 10
try:
if arg_list[0][0] in ['B', 'b']:
to_move = Stone.BLACK
arg_list.pop(0)
elif arg_list[0][0] in ['W', 'w']:
to_move = Stone.WHITE
arg_list.pop(0)
if arg_list[0] == "interval":
if len(arg_list) == 1:
return error_value
arg_list.pop(0)
if arg_list[0].isdigit():
interval = int(arg_list[0])/100
arg_list.pop(0)
except IndexError as e:
pass
if arg_list:
return error_value
return (to_move, interval)

def _analyze(self, mode: str, arg_list: List[str]) -> NoReturn:
"""analyzeコマンド(lz-analyze, cgos-analyze)を実行する。

Args:
mode (str): 解析モード。値は"lz"か"cgos"。
arg_list (List[str]): コマンドの引数リスト (手番の色, 更新間隔)。
"""
interval = 0
if len(arg_list) >= 2:
interval = int(arg_list[1])/100

if arg_list[0][0] in ['B', 'b']:
to_move = Stone.BLACK
elif arg_list[0][0] in ['W', 'w']:
to_move = Stone.WHITE
else:
respond_failure(f"{mode}-analyze color")
to_move, interval = self._decode_analyze_arg(arg_list)
if interval < 0:
respond_failure(f"{mode}-analyze [color] [interval]")
return

respond_success("", ongoing=True)

analysis_query = {
"mode" : mode,
"interval" : interval,
Expand All @@ -328,19 +376,13 @@ def _genmove_analyze(self, mode: str, arg_list: List[str]) -> NoReturn:
mode (str): 解析モード。値は"lz"か"cgos"。
arg_list (List[str]): コマンドの引数リスト(手番の色, 更新間隔)。
"""
color = arg_list[0]
interval = 0
if len(arg_list) >= 2:
interval = int(arg_list[1])/100

if color.lower()[0] == 'b':
genmove_color = Stone.BLACK
elif color.lower()[0] == 'w':
genmove_color = Stone.WHITE
else:
respond_failure(f"{mode}-genmove_analyze color")
genmove_color, interval = self._decode_analyze_arg(arg_list)
if interval < 0:
respond_failure(f"{mode}-analyze [color] [interval]")
return

respond_success("", ongoing=True)

if self.use_network:
# モンテカルロ木探索で着手生成
analysis_query = {
Expand Down Expand Up @@ -369,13 +411,24 @@ def run(self) -> NoReturn: # pylint: disable=R0912,R0915
"""Go Text Protocolのクライアントの実行処理。
入力されたコマンドに対応する処理を実行し、応答メッセージを表示する。
"""
global gtp_command_id
while True:
command = input()

command_list = command.rstrip().split(' ')

gtp_command_id = ""
input_gtp_command = command_list[0]

# 入力されたコマンドの冒頭が数字なら、それを id とみなす。
# (参照)
# Specification of the Go Text Protocol, version 2, draft 2
# の「2.5 Command Structure」
# http://www.lysator.liu.se/~gunnar/gtp/gtp2-spec-draft2/gtp2-spec.html#SECTION00035000000000000000
if input_gtp_command.isdigit():
gtp_command_id = command_list.pop(0)
input_gtp_command = command_list[0]

if input_gtp_command == "version":
_version()
elif input_gtp_command == "protocol_version":
Expand All @@ -392,6 +445,8 @@ def run(self) -> NoReturn: # pylint: disable=R0912,R0915
self._komi(command_list[1])
elif input_gtp_command == "play":
self._play(command_list[1], command_list[2])
elif input_gtp_command == "undo":
self._undo()
elif input_gtp_command == "genmove":
self._genmove(command_list[1])
elif input_gtp_command == "boardsize":
Expand Down Expand Up @@ -445,40 +500,38 @@ def run(self) -> NoReturn: # pylint: disable=R0912,R0915
self.board.display_self_atari(Stone.WHITE)
respond_success("")
elif input_gtp_command == "lz-analyze":
print_out("= ")
self._analyze("lz", command_list[1:])
print("")
elif input_gtp_command == "lz-genmove_analyze":
print_out("= ")
self._genmove_analyze("lz", command_list[1:])
elif input_gtp_command == "cgos-analyze":
print_out("= ")
self._analyze("cgos", command_list[1:])
print("")
elif input_gtp_command == "cgos-genmove_analyze":
print_out("= ")
self._genmove_analyze("cgos", command_list[1:])
elif input_gtp_command == "hash_record":
print_err(self.board.record.get_hash_history())
respond_success("")
else:
respond_failure("unknown_command")

def respond_success(response: str) -> NoReturn:
def respond_success(response: str, ongoing: bool = False) -> NoReturn:
"""コマンド処理成功時の応答メッセージを表示する。

Args:
response (str): 表示する応答メッセージ。
ongoing (bool): 追加の応答メッセージが後に続くかどうか。
"""
print("= " + response + '\n')
terminator = "" if ongoing else '\n'
print(f"={gtp_command_id} " + response + terminator)

def respond_failure(response: str) -> NoReturn:
"""コマンド処理失敗時の応答メッセージを表示する。

Args:
response (str): 表示する応答メッセージ。
"""
print("= ? " + response + '\n')
print(f"?{gtp_command_id} " + response + '\n')

def _version() -> NoReturn:
"""versionコマンドを処理する。
Expand Down