-
Notifications
You must be signed in to change notification settings - Fork 0
/
training.py
105 lines (90 loc) · 3.39 KB
/
training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import torch
import time
from data.dataset import GameplayActionPairVideoDataset
from torch import optim
from torch.utils.data import DataLoader
from model.agent import Agent, AgentConfig, device
from model.action_loss import ActionLoss
from model.cvivit import CvivitConfig
from model.encoder import MultiModelEncoderConfig
from model.decoder import MultiModelDecoderConfig
from tools.utils import custom_collate_fn
from torch.utils.tensorboard import SummaryWriter
# Set up TensorBoard writer
writer = SummaryWriter(log_dir='runs/behavior_cloning')
# Initialize your model, loss, optimizer
config = AgentConfig(
encoder_config=MultiModelEncoderConfig(
vit_model_name='google/vit-base-patch16-224-in21k',
language_model_name='bert-base-uncased',
cvivit_config=CvivitConfig(
image_size=224,
color_channel=3,
emb_size=768,
d_model=768,
patch_size=(2, 8, 8),
num_layers_spatial=2,
num_heads_spatial=4,
dim_feedforward_spatial=512,
dropout_spatial=0.1,
num_layers_temporal=2,
num_heads_temporal=4,
dim_feedforward_temporal=512,
dropout_temporal=0.1
)
),
decoder_config=MultiModelDecoderConfig(
d_model=768,
dim_feedforward=512,
nhead=4,
num_layers=2
)
)
agent = Agent(config=config, debug=False).to(device)
criterion = ActionLoss()
optimizer = optim.Adam(agent.parameters(), lr=0.001)
# Load data
root_dir = "output_logs"
dataset = GameplayActionPairVideoDataset(root_dir=root_dir, image_size=(224, 224))
dataloader = DataLoader(dataset, batch_size=1, shuffle=True, collate_fn=custom_collate_fn)
epochs = 20000
def train():
for epoch in range(epochs):
running_loss = 0.0
for batch, (instruction, frames, action) in enumerate(dataloader):
frames = frames.to(device)
action = action.to(device)
_, _, channel, height, width = frames.shape
images = frames.reshape(-1, channel, height, width).to(device)
logits = agent(images, frames, instruction)
loss = criterion(logits, action)
loss.backward()
optimizer.step()
optimizer.zero_grad()
# Accumulate the loss
running_loss += loss.item()
# Log gradients and weights
# for name, param in agent.named_parameters():
# if param.grad is not None:
# writer.add_histogram(f'{name}.grad', param.grad, epoch)
# writer.add_histogram(name, param, epoch)
# Log average loss per epoch
avg_loss = running_loss / len(dataloader)
print(f"Epoch [{epoch+1}/{epochs}], Loss: {avg_loss:.4f}")
writer.add_scalar('training_loss_per_epoch', avg_loss, epoch)
# Log learning rate
for param_group in optimizer.param_groups:
writer.add_scalar('learning_rate', param_group['lr'], epoch)
# Log the time taken per epoch
writer.add_scalar('time_per_epoch', time.time(), epoch)
# Save the model weights every 2000 epochs
if (epoch + 1) % 2000 == 0:
save()
def save():
torch.save(agent.state_dict(), 'model_weights.pth')
print("Model weights saved.")
def close_writer():
writer.close()
if __name__ == "__main__":
train()
close_writer()