forked from tinkerbell/smee
-
Notifications
You must be signed in to change notification settings - Fork 0
/
tftp.go
138 lines (117 loc) · 2.88 KB
/
tftp.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
package main
import (
"flag"
"io"
"net"
"os"
"path"
"github.com/avast/retry-go"
tftp "github.com/betawaffle/tftp-go"
"github.com/pkg/errors"
"github.com/prometheus/client_golang/prometheus"
"github.com/tinkerbell/boots/conf"
"github.com/tinkerbell/boots/job"
"github.com/tinkerbell/boots/metrics"
)
var (
tftpAddr = conf.TFTPBind
)
func init() {
flag.StringVar(&tftpAddr, "tftp-addr", tftpAddr, "IP and port to listen on for TFTP.")
}
// ServeTFTP is a useless comment
func ServeTFTP() {
err := retry.Do(
func() error {
return errors.Wrap(tftp.ListenAndServe(tftpAddr, tftpHandler{}), "serving tftp")
},
)
if err != nil {
mainlog.Fatal(errors.Wrap(err, "retry tftp serve"))
}
}
type tftpHandler struct {
}
func (tftpHandler) ReadFile(c tftp.Conn, filename string) (tftp.ReadCloser, error) {
labels := prometheus.Labels{"from": "tftp", "op": "read"}
metrics.JobsTotal.With(labels).Inc()
metrics.JobsInProgress.With(labels).Inc()
timer := prometheus.NewTimer(metrics.JobDuration.With(labels))
defer timer.ObserveDuration()
defer metrics.JobsInProgress.With(labels).Dec()
ip := tftpClientIP(c.RemoteAddr())
j, err := job.CreateFromIP(ip)
if err == nil {
return j.ServeTFTP(filename, ip.String())
}
err = errors.WithMessage(err, "retrieved job is empty")
filename = path.Base(filename)
l := mainlog.With("client", ip, "event", "open", "filename", filename)
l.With("error", err).Info()
switch filename {
case "test.1mb":
l.With("tftp_fake_read", true).Info()
return &fakeReader{1 * 1024 * 1024}, nil
case "test.8mb":
l.With("tftp_fake_read", true).Info()
return &fakeReader{8 * 1024 * 1024}, nil
}
l.With("error", errors.Wrap(os.ErrPermission, "access_violation")).Info()
return nil, os.ErrPermission
}
func (tftpHandler) WriteFile(c tftp.Conn, filename string) (tftp.WriteCloser, error) {
ip := tftpClientIP(c.RemoteAddr())
err := errors.Wrap(os.ErrPermission, "access_violation")
mainlog.With("client", ip, "event", "create", "filename", filename, "error", err).Info()
return nil, os.ErrPermission
}
func tftpClientIP(addr net.Addr) net.IP {
switch a := addr.(type) {
case *net.IPAddr:
return a.IP
case *net.UDPAddr:
return a.IP
case *net.TCPAddr:
return a.IP
}
host, _, err := net.SplitHostPort(addr.String())
if err != nil {
err = errors.Wrap(err, "parse host:port")
mainlog.Error(err)
return nil
}
l := mainlog.With("host", host)
if ip := net.ParseIP(host); ip != nil {
l.With("ip", ip).Info()
if v4 := ip.To4(); v4 != nil {
ip = v4
}
return ip
}
l.Info("returning nil")
return nil
}
var zeros = make([]byte, 1456)
type fakeReader struct {
N int
}
func (r *fakeReader) Close() error {
return nil
}
func (r *fakeReader) Read(p []byte) (n int, err error) {
if len(p) == 0 {
return
}
if len(p) > r.N {
p = p[:r.N]
}
for len(p) > 0 {
n = copy(p, zeros)
r.N -= n
p = p[n:]
}
if r.N == 0 {
err = io.EOF
}
return
}