Skip to content

Commit

Permalink
feat: extra protocol information (#60)
Browse files Browse the repository at this point in the history
  • Loading branch information
jesusprubio authored Nov 12, 2024
1 parent 11f0022 commit 3c410a2
Show file tree
Hide file tree
Showing 6 changed files with 71 additions and 34 deletions.
3 changes: 2 additions & 1 deletion internal/probe.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,13 @@ func (p Probe) Run(ctx context.Context) error {
p.Logger.Debug(
"New protocol", "count", count, "protocol", proto,
)
rhost, err := proto.Probe("")
rhost, extra, err := proto.Probe("")
report := Report{
ProtocolID: proto.String(),
Time: time.Since(start),
Error: err,
RHost: rhost,
Extra: extra,
}
p.Logger.Debug(
"Sending report back",
Expand Down
5 changes: 3 additions & 2 deletions internal/probe_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@ import (
)

const testHostPort = "127.0.0.1:3355"
const testExtra = "test-extra"

type testProtocol struct{}

func (p *testProtocol) String() string { return "test-proto" }

func (p *testProtocol) Probe(target string) (string, error) {
return testHostPort, nil
func (p *testProtocol) Probe(target string) (string, string, error) {
return testHostPort, testExtra, nil
}

func TestProbeValidate(t *testing.T) {
Expand Down
42 changes: 25 additions & 17 deletions internal/protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@ type Protocol interface {
String() string
// Attempt to check the connectivity to the target.
// The target depends on the protocol. For example, for HTTP it's a URL.
// Returns the used target or error if the attempt failed.
Probe(target string) (string, error)
// Returns the used target or error if the attempt failed. Some protocols
// include an additional string with extra information. For example, the
// HTTP protocol returns the status code.
Probe(target string) (string, string, error)
}

// HTTP protocol implementation.
Expand All @@ -33,26 +35,28 @@ func (h *HTTP) String() string {
}

// Probe makes an HTTP request to a random captive portal.
//
// The target is a URL.
func (h *HTTP) Probe(target string) (string, error) {
// The extra data is the status code.
func (h *HTTP) Probe(target string) (string, string, error) {
cli := &http.Client{Timeout: h.Timeout}
url := target
if url == "" {
var err error
url, err = RandomCaptivePortal()
if err != nil {
return "", fmt.Errorf("selecting captive portal: %w", err)
return "", "", fmt.Errorf("selecting captive portal: %w", err)
}
}
resp, err := cli.Get(url)
if err != nil {
return "", err
return "", "", err
}
err = resp.Body.Close()
if err != nil {
return "", fmt.Errorf("closing response body: %w", err)
return "", "", fmt.Errorf("closing response body: %w", err)
}
return url, nil
return url, resp.Status, nil
}

// TCP protocol implementation.
Expand All @@ -66,25 +70,27 @@ func (t *TCP) String() string {
}

// Probe makes a TCP request to a random server.
//
// The target is a host:port.
func (t *TCP) Probe(target string) (string, error) {
// The extra data is the local interface.
func (t *TCP) Probe(target string) (string, string, error) {
hostPort := target
if hostPort == "" {
var err error
hostPort, err = RandomTCPServer()
if err != nil {
return "", fmt.Errorf("selecting TCP server: %w", err)
return "", "", fmt.Errorf("selecting TCP server: %w", err)
}
}
conn, err := net.DialTimeout("tcp", hostPort, t.Timeout)
if err != nil {
return "", err
return "", "", err
}
err = conn.Close()
if err != nil {
return "", fmt.Errorf("closing connection: %w", err)
return "", "", fmt.Errorf("closing connection: %w", err)
}
return hostPort, nil
return hostPort, conn.LocalAddr().String(), nil
}

// DNS protocol implementation.
Expand All @@ -100,8 +106,10 @@ func (d *DNS) String() string {
}

// Probe resolves a random domain name.
//
// The target is a domain name.
func (d *DNS) Probe(target string) (string, error) {
// The extra data is the first resolved IP address.
func (d *DNS) Probe(target string) (string, string, error) {
var r net.Resolver
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
Expand All @@ -121,12 +129,12 @@ func (d *DNS) Probe(target string) (string, error) {
var err error
domain, err = RandomDomain()
if err != nil {
return "", fmt.Errorf("selecting domain: %w", err)
return "", "", fmt.Errorf("selecting domain: %w", err)
}
}
_, err := r.LookupHost(ctx, domain)
addrs, err := r.LookupHost(ctx, domain)
if err != nil {
return "", err
return "", "", err
}
return domain, nil
return domain, addrs[0], nil
}
37 changes: 31 additions & 6 deletions internal/protocol_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,23 @@ func TestHTTPProbe(t *testing.T) {
func(t *testing.T) {
u := url.URL{Scheme: "http", Host: server.Addr}
proto := HTTP{Timeout: tout}
got, err := proto.Probe(u.String())
got, extra, err := proto.Probe(u.String())
if err != nil {
t.Fatal(err)
}
want := "http://127.0.0.1:8080"
if got != want {
t.Fatalf("got %q, want %q", got, want)
}
if extra != "200 OK" {
t.Fatalf("got %q, want %q", extra, "200 OK")
}
},
)
t.Run("returns an error if the request fails", func(t *testing.T) {
u := url.URL{Scheme: "http", Host: "localhost"}
proto := HTTP{Timeout: 1}
got, err := proto.Probe(u.String())
got, extra, err := proto.Probe(u.String())
if got != "" {
t.Fatalf("got %q should be zero", got)
}
Expand All @@ -42,6 +45,9 @@ func TestHTTPProbe(t *testing.T) {
if got != want {
t.Fatalf("got %q, want %q", got, want)
}
if extra != "" {
t.Fatalf("got %q should be zero", extra)
}
})
}

Expand Down Expand Up @@ -85,18 +91,28 @@ func TestTCPProbe(t *testing.T) {
"returns the remote host/port if the request is successful",
func(t *testing.T) {
proto := &TCP{Timeout: tout}
got, err := proto.Probe(hostPort)
got, extra, err := proto.Probe(hostPort)
if err != nil {
t.Fatal(err)
}
if got != hostPort {
t.Fatalf("got %q, want %q", got, hostPort)
}
host, port, err := net.SplitHostPort(extra)
if err != nil {
t.Fatal(err)
}
if host != "127.0.0.1" {
t.Fatalf("got %q, want %q", host, "127.0.0.1")
}
if port == "" {
t.Fatalf("got %q, want a valid port", port)
}
},
)
t.Run("returns an error if the request fails", func(t *testing.T) {
proto := &TCP{Timeout: 1}
got, err := proto.Probe("localhost:80")
got, extra, err := proto.Probe("localhost:80")
if err == nil {
t.Fatal("got nil, want an error")
}
Expand All @@ -108,6 +124,9 @@ func TestTCPProbe(t *testing.T) {
if got != want {
t.Fatalf("got %q, want %q", got, want)
}
if extra != "" {
t.Fatalf("got %q should be zero", extra)
}
})
}

Expand All @@ -130,18 +149,21 @@ func TestDNSProbe(t *testing.T) {
func(t *testing.T) {
proto := &DNS{Timeout: tout}
domain := "google.com"
got, err := proto.Probe(domain)
got, extra, err := proto.Probe(domain)
if err != nil {
t.Fatal(err)
}
if got != domain {
t.Fatalf("got %q, want %q", got, domain)
}
if !net.ParseIP(extra).IsGlobalUnicast() {
t.Fatalf("got %q, want a valid IP address", extra)
}
},
)
t.Run("returns an error if the request fails", func(t *testing.T) {
proto := &DNS{Timeout: 1}
got, err := proto.Probe("invalid.aa")
got, extra, err := proto.Probe("invalid.aa")
if err == nil {
t.Fatal("got nil, want an error")
}
Expand All @@ -156,5 +178,8 @@ func TestDNSProbe(t *testing.T) {
if got != want {
t.Fatalf("got %q, want %q", got, want)
}
if extra != "" {
t.Fatalf("got %q should be zero", extra)
}
})
}
2 changes: 2 additions & 0 deletions internal/report.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,6 @@ type Report struct {
Time time.Duration `json:"time"`
// Network error.
Error error `json:"error,omitempty"`
// Extra information. Depends on the protocol.
Extra string `json:"extra,omitempty"`
}
16 changes: 8 additions & 8 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,23 +132,23 @@ func main() {
}
}

// Fatal logs the error to the standard output and exits with status 1.
// Logs the error to the standard output and exits with status 1.
func fatal(err error) {
fmt.Fprintf(os.Stderr, "%s: %s\n", appName, err)
os.Exit(1)
}

// ReportToLine returns a human-readable representation of the report.
// Returns a human-readable representation of the report.
func reportToLine(r *internal.Report) string {
symbol := green("✔")
suffix := r.RHost
line := fmt.Sprintf("%-15s %-14s %s", bold(r.ProtocolID), r.Time, r.RHost)
suffix := r.Extra
prefix := green("✔")
if r.Error != nil {
symbol = red("✘")
prefix = red("✘")
suffix = r.Error.Error()
}
return fmt.Sprintf("%s %s", symbol, fmt.Sprintf(
"%-15s %-14s %-15s", bold(r.ProtocolID), r.Time, faint(suffix),
))
suffix = fmt.Sprintf("(%s)", suffix)
return fmt.Sprintf("%s %s %s", prefix, line, faint(suffix))
}

var (
Expand Down

0 comments on commit 3c410a2

Please sign in to comment.