Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Jan 26, 2025
2 parents 19d45a8 + 72ddbac commit 848c251
Showing 1 changed file with 58 additions and 37 deletions.
95 changes: 58 additions & 37 deletions torchrl/envs/custom/chess.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,19 +76,28 @@ class ChessEnv(EnvBase, metaclass=_ChessMeta):
being a subset of this space. The environment uses a mask to ensure only legal moves are selected.
Examples:
>>> import torch
>>> from torchrl.envs import ChessEnv
>>> _ = torch.manual_seed(0)
>>> env = ChessEnv(include_fen=True, include_san=True, include_pgn=True, include_legal_moves=True)
>>> print(env)
TransformedEnv(
env=ChessEnv(),
transform=ActionMask(keys=['action', 'action_mask']))
>>> r = env.reset()
>>> env.rand_step(r)
>>> print(env.rand_step(r))
TensorDict(
fields={
action: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
action_mask: Tensor(shape=torch.Size([29275]), device=cpu, dtype=torch.bool, is_shared=False),
done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
fen: NonTensorData(data=rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1, batch_size=torch.Size([]), device=None),
legal_moves: Tensor(shape=torch.Size([219]), device=cpu, dtype=torch.int64, is_shared=False),
next: TensorDict(
fields={
action_mask: Tensor(shape=torch.Size([29275]), device=cpu, dtype=torch.bool, is_shared=False),
done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
fen: NonTensorData(data=rnbqkbnr/pppppppp/8/8/8/1P6/P1PPPPPP/RNBQKBNR b KQkq - 0 1, batch_size=torch.Size([]), device=None),
fen: NonTensorData(data=rnbqkbnr/pppppppp/8/8/5P2/8/PPPPP1PP/RNBQKBNR b KQkq - 0 1, batch_size=torch.Size([]), device=None),
legal_moves: Tensor(shape=torch.Size([219]), device=cpu, dtype=torch.int64, is_shared=False),
pgn: NonTensorData(data=[Event "?"]
[Site "?"]
Expand All @@ -97,9 +106,10 @@ class ChessEnv(EnvBase, metaclass=_ChessMeta):
[White "?"]
[Black "?"]
[Result "*"]
1. b3 *, batch_size=torch.Size([]), device=None),
1. f4 *, batch_size=torch.Size([]), device=None),
reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
san: NonTensorData(data=b3, batch_size=torch.Size([]), device=None),
san: NonTensorData(data=f4, batch_size=torch.Size([]), device=None),
terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
turn: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.bool, is_shared=False)},
batch_size=torch.Size([]),
Expand All @@ -112,56 +122,59 @@ class ChessEnv(EnvBase, metaclass=_ChessMeta):
[White "?"]
[Black "?"]
[Result "*"]
*, batch_size=torch.Size([]), device=None),
san: NonTensorData(data=[SAN][START], batch_size=torch.Size([]), device=None),
san: NonTensorData(data=<start>, batch_size=torch.Size([]), device=None),
terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
turn: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.bool, is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False)
>>> env.rollout(1000)
>>> print(env.rollout(1000))
TensorDict(
fields={
action: Tensor(shape=torch.Size([352]), device=cpu, dtype=torch.int64, is_shared=False),
done: Tensor(shape=torch.Size([352, 1]), device=cpu, dtype=torch.bool, is_shared=False),
action: Tensor(shape=torch.Size([96]), device=cpu, dtype=torch.int64, is_shared=False),
action_mask: Tensor(shape=torch.Size([96, 29275]), device=cpu, dtype=torch.bool, is_shared=False),
done: Tensor(shape=torch.Size([96, 1]), device=cpu, dtype=torch.bool, is_shared=False),
fen: NonTensorStack(
['rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQ...,
batch_size=torch.Size([352]),
batch_size=torch.Size([96]),
device=None),
legal_moves: Tensor(shape=torch.Size([352, 219]), device=cpu, dtype=torch.int64, is_shared=False),
legal_moves: Tensor(shape=torch.Size([96, 219]), device=cpu, dtype=torch.int64, is_shared=False),
next: TensorDict(
fields={
done: Tensor(shape=torch.Size([352, 1]), device=cpu, dtype=torch.bool, is_shared=False),
action_mask: Tensor(shape=torch.Size([96, 29275]), device=cpu, dtype=torch.bool, is_shared=False),
done: Tensor(shape=torch.Size([96, 1]), device=cpu, dtype=torch.bool, is_shared=False),
fen: NonTensorStack(
['rnbqkbnr/pppppppp/8/8/8/N7/PPPPPPPP/R1BQKBNR b K...,
batch_size=torch.Size([352]),
['rnbqkbnr/pppppppp/8/8/8/5N2/PPPPPPPP/RNBQKB1R b ...,
batch_size=torch.Size([96]),
device=None),
legal_moves: Tensor(shape=torch.Size([352, 219]), device=cpu, dtype=torch.int64, is_shared=False),
legal_moves: Tensor(shape=torch.Size([96, 219]), device=cpu, dtype=torch.int64, is_shared=False),
pgn: NonTensorStack(
['[Event "?"]\n[Site "?"]\n[Date "????.??.??"]\n[R...,
batch_size=torch.Size([352]),
batch_size=torch.Size([96]),
device=None),
reward: Tensor(shape=torch.Size([352, 1]), device=cpu, dtype=torch.float32, is_shared=False),
reward: Tensor(shape=torch.Size([96, 1]), device=cpu, dtype=torch.float32, is_shared=False),
san: NonTensorStack(
['Na3', 'a5', 'Nb1', 'Nc6', 'a3', 'g6', 'd4', 'd6'...,
batch_size=torch.Size([352]),
['Nf3', 'Na6', 'c4', 'f6', 'h4', 'Rb8', 'Na3', 'Ra...,
batch_size=torch.Size([96]),
device=None),
terminated: Tensor(shape=torch.Size([352, 1]), device=cpu, dtype=torch.bool, is_shared=False),
turn: Tensor(shape=torch.Size([352]), device=cpu, dtype=torch.bool, is_shared=False)},
batch_size=torch.Size([352]),
terminated: Tensor(shape=torch.Size([96, 1]), device=cpu, dtype=torch.bool, is_shared=False),
turn: Tensor(shape=torch.Size([96]), device=cpu, dtype=torch.bool, is_shared=False)},
batch_size=torch.Size([96]),
device=None,
is_shared=False),
pgn: NonTensorStack(
['[Event "?"]\n[Site "?"]\n[Date "????.??.??"]\n[R...,
batch_size=torch.Size([352]),
batch_size=torch.Size([96]),
device=None),
san: NonTensorStack(
['[SAN][START]', 'Na3', 'a5', 'Nb1', 'Nc6', 'a3', ...,
batch_size=torch.Size([352]),
['<start>', 'Nf3', 'Na6', 'c4', 'f6', 'h4', 'Rb8',...,
batch_size=torch.Size([96]),
device=None),
terminated: Tensor(shape=torch.Size([352, 1]), device=cpu, dtype=torch.bool, is_shared=False),
turn: Tensor(shape=torch.Size([352]), device=cpu, dtype=torch.bool, is_shared=False)},
batch_size=torch.Size([352]),
terminated: Tensor(shape=torch.Size([96, 1]), device=cpu, dtype=torch.bool, is_shared=False),
turn: Tensor(shape=torch.Size([96]), device=cpu, dtype=torch.bool, is_shared=False)},
batch_size=torch.Size([96]),
device=None,
is_shared=False)
Expand Down Expand Up @@ -227,13 +240,15 @@ def _legal_moves_to_index(
[self._san_moves.index(board.san(m)) for m in board.legal_moves],
dtype=torch.int64,
)

mask = None
if return_mask:
return self._move_index_to_mask(indices)
mask = self._move_index_to_mask(indices)
if pad:
indices = torch.nn.functional.pad(
indices, [0, 218 - indices.numel() + 1], value=len(self.san_moves)
)
if return_mask:
return indices, mask
return indices

@classmethod
Expand Down Expand Up @@ -371,16 +386,19 @@ def _reset(self, tensordict=None):
dest.set("pgn", pgn)
dest.set("turn", turn)
if self.include_legal_moves:
moves_idx = self._legal_moves_to_index(board=self.board, pad=True)
dest.set("legal_moves", moves_idx)
moves_idx = self._legal_moves_to_index(
board=self.board, pad=True, return_mask=self.mask_actions
)
if self.mask_actions:
dest.set("action_mask", self._move_index_to_mask(moves_idx))
moves_idx, mask = moves_idx
dest.set("action_mask", mask)
dest.set("legal_moves", moves_idx)
elif self.mask_actions:
dest.set(
"action_mask",
self._legal_moves_to_index(
board=self.board, pad=True, return_mask=True
),
)[1],
)

if self.pixels:
Expand Down Expand Up @@ -527,16 +545,19 @@ def _step(self, tensordict):
dest.set("san", san)

if self.include_legal_moves:
moves_idx = self._legal_moves_to_index(board=board, pad=True)
dest.set("legal_moves", moves_idx)
moves_idx = self._legal_moves_to_index(
board=board, pad=True, return_mask=self.mask_actions
)
if self.mask_actions:
dest.set("action_mask", self._move_index_to_mask(moves_idx))
moves_idx, mask = moves_idx
dest.set("action_mask", mask)
dest.set("legal_moves", moves_idx)
elif self.mask_actions:
dest.set(
"action_mask",
self._legal_moves_to_index(
board=self.board, pad=True, return_mask=True
),
)[1],
)

turn = torch.tensor(board.turn)
Expand Down

0 comments on commit 848c251

Please sign in to comment.