Skip to content

Commit

Permalink
remove unused lines
Browse files Browse the repository at this point in the history
  • Loading branch information
Yanchao Sun committed Nov 23, 2023
1 parent a75de0a commit 0ae4006
Showing 1 changed file with 38 additions and 86 deletions.
124 changes: 38 additions & 86 deletions src/models/components/mingpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,9 +182,6 @@ def __init__(self, config):
self.model_type = config.model_type # act based on rtgs ('reward_conditioned') or not ('naive')
self.ct = 0

# input embedding stem
# self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)

# pos embedding
self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size + 1, config.n_embd))
self.global_pos_emb = nn.Parameter(torch.zeros(1, config.max_timestep + 1, config.n_embd))
Expand All @@ -196,23 +193,6 @@ def __init__(self, config):
# normalization
self.ln_f = nn.LayerNorm(config.n_embd)

# action prediction head
# if config.linear_rtg:
# self.reward_conditioned_head = nn.Linear(config.n_embd * 2, config.vocab_size, bias=False) # predict action conditioned on rtg
# else:
# self.reward_conditioned_head = nn.Sequential(
# nn.Linear(config.n_embd * 2, 512),
# nn.ReLU(),
# nn.Linear(512, config.vocab_size),
# )
# self.naive_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) # predict action with state embedding
# # forward prediction head
# self.forward_pred_head = nn.Linear(config.n_embd * 2, config.n_embd, bias=True)
# # inverse prediction head
# self.inverse_pred_head = nn.Linear(config.n_embd * 2, config.vocab_size, bias=False)
# # reward prediction head
# self.reward_pred_head = nn.Linear(config.n_embd * 2, 1, bias=False)

# rtg-based action prediction head
self.reward_conditioned_head = build_mlp(config.n_embd * 2, config.vocab_size, config.rtg_layers, bias=False)
# naive action prediction head (for behavior cloning)
Expand All @@ -228,33 +208,29 @@ def __init__(self, config):

self.apply(self._init_weights)

if hasattr(config, "vector_obs") and config.vector_obs:
self.state_encoder = nn.Sequential(nn.Linear(config.obs_dim, config.n_embd), nn.Tanh())
self.target_state_encoder = nn.Sequential(nn.Linear(config.obs_dim, config.n_embd), nn.Tanh())
else:
self.state_encoder = nn.Sequential(
nn.Conv2d(self.config.channels, 32, 8, stride=4, padding=0),
nn.ReLU(),
nn.Conv2d(32, 64, 4, stride=2, padding=0),
nn.ReLU(),
nn.Conv2d(64, 64, 3, stride=1, padding=0),
nn.ReLU(),
nn.Flatten(),
nn.Linear(3136, config.n_embd),
nn.Tanh(),
)
self.state_encoder = nn.Sequential(
nn.Conv2d(self.config.channels, 32, 8, stride=4, padding=0),
nn.ReLU(),
nn.Conv2d(32, 64, 4, stride=2, padding=0),
nn.ReLU(),
nn.Conv2d(64, 64, 3, stride=1, padding=0),
nn.ReLU(),
nn.Flatten(),
nn.Linear(3136, config.n_embd),
nn.Tanh(),
)

self.target_state_encoder = nn.Sequential(
nn.Conv2d(self.config.channels, 32, 8, stride=4, padding=0),
nn.ReLU(),
nn.Conv2d(32, 64, 4, stride=2, padding=0),
nn.ReLU(),
nn.Conv2d(64, 64, 3, stride=1, padding=0),
nn.ReLU(),
nn.Flatten(),
nn.Linear(3136, config.n_embd),
nn.Tanh(),
)
self.target_state_encoder = nn.Sequential(
nn.Conv2d(self.config.channels, 32, 8, stride=4, padding=0),
nn.ReLU(),
nn.Conv2d(32, 64, 4, stride=2, padding=0),
nn.ReLU(),
nn.Conv2d(64, 64, 3, stride=1, padding=0),
nn.ReLU(),
nn.Flatten(),
nn.Linear(3136, config.n_embd),
nn.Tanh(),
)
self.target_state_encoder.load_state_dict(self.state_encoder.state_dict())

# rtg encoder
Expand Down Expand Up @@ -320,15 +296,11 @@ def forward(

is_testing = (actions is None) or (actions.shape[1] != states.shape[1])

# (batch * context_length, n_embd)
if hasattr(self.config, "vector_obs") and self.config.vector_obs:
state_embeddings = self.state_encoder(states)
else:
state_embeddings = self.state_encoder(
states.reshape(-1, self.config.channels, 84, 84).type(torch.float32).contiguous()
)
# (batch, context_length, n_embd)
state_embeddings = state_embeddings.reshape(states.shape[0], states.shape[1], self.config.n_embd)
state_embeddings = self.state_encoder(
states.reshape(-1, self.config.channels, 84, 84).type(torch.float32).contiguous()
)
# (batch, context_length, n_embd)
state_embeddings = state_embeddings.reshape(states.shape[0], states.shape[1], self.config.n_embd)

if actions is not None:
if self.config.cont_action:
Expand Down Expand Up @@ -369,11 +341,6 @@ def forward(
state_output = x # for completeness
action_output = None

# print("token_embeddings", token_embeddings.size())
# print("final_embeddings", final_embeddings.size())
# print("state_output", state_output.size())
# print("action_output", action_output.size())

## act
rtg_action_logits, naive_action_logits = None, None
## compute losses
Expand Down Expand Up @@ -407,15 +374,12 @@ def forward(
raise NotImplementedError()

if pred_forward:
if hasattr(self.config, "vector_obs") and self.config.vector_obs:
next_state_embeddings = self.state_encoder(states).detach()
else:
next_state_embeddings = self.target_state_encoder(
states.reshape(-1, self.config.channels, 84, 84).type(torch.float32).contiguous()
).detach() # (batch, context_length, n_embd)
next_state_embeddings = next_state_embeddings.reshape(
states.shape[0], states.shape[1], self.config.n_embd
)
next_state_embeddings = self.target_state_encoder(
states.reshape(-1, self.config.channels, 84, 84).type(torch.float32).contiguous()
).detach() # (batch, context_length, n_embd)
next_state_embeddings = next_state_embeddings.reshape(
states.shape[0], states.shape[1], self.config.n_embd
)
next_state_embeddings = next_state_embeddings[:, 1:, :] # (batch, context_length-1, n_embd)
forward_pred = self.forward_pred_head(
torch.cat((state_output[:, :-1, :], action_output[:, : -1 + int(is_testing), :]), dim=2)
Expand Down Expand Up @@ -459,14 +423,6 @@ def forward(
rand_mask_obs_idx = np.random.choice(list(range(1, actions.shape[1] - 1)), mask_obs_size, replace=False)
for j in range(mask_obs_size):
masked_token[:, 2 * rand_mask_obs_idx[j], :] = -1
# batch_size = states.shape[0]
# all_global_pos_emb = torch.repeat_interleave(
# self.global_pos_emb, batch_size, dim=0
# ) # batch_size, traj_length, n_embd
# position_embeddings = (
# torch.gather(all_global_pos_emb, 1, torch.repeat_interleave(timesteps, self.config.n_embd, dim=-1))
# + self.pos_emb[:, : token_embeddings.shape[1], :]
# )

final_masked_embeddings = self.drop(masked_token + position_embeddings)

Expand Down Expand Up @@ -494,15 +450,11 @@ def get_embeddings(self, states, actions, timesteps):
actions = None
is_testing = (actions is None) or (actions.shape[1] != states.shape[1])

# (batch * context_length, n_embd)
if hasattr(self.config, "vector_obs") and self.config.vector_obs:
state_embeddings = self.state_encoder(states)
else:
state_embeddings = self.state_encoder(
states.reshape(-1, self.config.channels, 84, 84).type(torch.float32).contiguous()
)
# (batch, context_length, n_embd)
state_embeddings = state_embeddings.reshape(states.shape[0], states.shape[1], self.config.n_embd)
state_embeddings = self.state_encoder(
states.reshape(-1, self.config.channels, 84, 84).type(torch.float32).contiguous()
)
# (batch, context_length, n_embd)
state_embeddings = state_embeddings.reshape(states.shape[0], states.shape[1], self.config.n_embd)

if actions is not None:
if self.config.cont_action:
Expand Down

0 comments on commit 0ae4006

Please sign in to comment.