-
Notifications
You must be signed in to change notification settings - Fork 130
/
Copy pathdemo-torch.config
103 lines (83 loc) · 3.1 KB
/
demo-torch.config
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
#!returnn.py
# kate: syntax python;
# -*- mode: python -*-
# sublime: syntax 'Packages/Python Improved/PythonImproved.tmLanguage'
# vim:set expandtab tabstop=4 fenc=utf-8 ff=unix ft=python:
import torch
from torch import nn
import os
import returnn.frontend as rf
from returnn.tensor import TensorDict
from returnn.util.basic import get_login_username
demo_name, _ = os.path.splitext(__file__)
print("Hello, experiment: %s" % demo_name)
backend = "torch"
task = "train"
train = {"class": "Task12AXDataset", "num_seqs": 1000}
dev = {"class": "Task12AXDataset", "num_seqs": 100, "fixed_random_seed": 1}
num_inputs = 9
num_outputs = 2
extern_data = {
"data": {"dim": num_inputs},
"classes": {"dim": num_outputs, "sparse": True, "available_for_inference": False},
}
model_outputs = {
"output": {"dim": num_outputs},
}
batching = "random"
batch_size = 5000
max_seqs = 10
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
# Note that padding="same" is not exportable to ONNX as of 2023/05/19.
self.layers = nn.Sequential(
nn.Conv1d(num_inputs, 50, 5, padding=2),
nn.ReLU(),
nn.Conv1d(50, 100, 5, padding=2),
nn.ReLU(),
nn.Conv1d(100, num_outputs, 5, padding=2),
)
def forward(self, x): # (B, T, F)
x = x.permute(0, 2, 1) # (B, F, T)
x = self.layers(x)
x = x.permute(0, 2, 1) # (B, T, F)
return x # logits
def get_model(*, epoch, step, **_kwargs):
return Model()
def forward_step(*, model: Model, extern_data: TensorDict):
"""
Function used in inference.
"""
data = extern_data["data"]
out = model(data.raw_tensor)
rf.get_run_ctx().expected_outputs["output"].dims[1].dyn_size_ext.raw_tensor = data.dims[1].dyn_size_ext.raw_tensor
rf.get_run_ctx().mark_as_default_output(tensor=out)
def train_step(*, model: Model, extern_data: TensorDict, **_kwargs):
"""
Function used in training/dev.
"""
data = extern_data["data"]
logits = model(data.raw_tensor)
logits_packed = torch.nn.utils.rnn.pack_padded_sequence(
logits, data.dims[1].dyn_size_ext.raw_tensor, batch_first=True, enforce_sorted=False)
targets = extern_data["classes"]
targets_packed = torch.nn.utils.rnn.pack_padded_sequence(
targets.raw_tensor, data.dims[1].dyn_size_ext.raw_tensor, batch_first=True, enforce_sorted=False)
loss = nn.CrossEntropyLoss(reduction='none')(logits_packed.data, targets_packed.data.long())
rf.get_run_ctx().mark_as_loss(name="ce", loss=loss)
frame_error = torch.argmax(logits_packed.data, dim=-1).not_equal(targets_packed.data)
rf.get_run_ctx().mark_as_loss(name="fer", loss=frame_error, as_error=True)
# training
optimizer = {"class": "adam"}
learning_rate = 0.01
learning_rate_control = "newbob"
learning_rate_decay = 0.9
newbob_relative_error_threshold = 0.0
learning_rate_file = "/tmp/%s/returnn/%s/learning_rates" % (get_login_username(), demo_name)
model = "/tmp/%s/returnn/%s/model" % (get_login_username(), demo_name)
num_epochs = 5
# log
#log_verbosity = 3
log_verbosity = 5
torch_log_memory_usage = True