Skip to content

Commit

Permalink
Refactor, add tests and coverage badge
Browse files Browse the repository at this point in the history
  • Loading branch information
ei-grad committed Jan 3, 2024
1 parent 178dc46 commit f348b74
Show file tree
Hide file tree
Showing 5 changed files with 233 additions and 13 deletions.
26 changes: 26 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
name: test

on:
push:

jobs:
test:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Set up Go
uses: actions/setup-go@v4
with:
go-version: '>=1.21.5'
- name: Test
run: go test -v ./...
- name: Update coverage report
uses: ncruces/go-coverage-report@v0
with:
report: true
chart: true
amend: true
continue-on-error: true
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
/check-expiring-certs
coverage.txt
dist/
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
[![License](https://img.shields.io/github/license/ei-grad/check-expiring-certs)](LICENSE)
[![Workflow Status](https://github.com/ei-grad/check-expiring-certs/actions/workflows/release.yml/badge.svg)](https://github.com/ei-grad/check-expiring-certs/actions/workflows/release.yml)
[![Go Coverage](https://github.com/ei-grad/check-expiring-certs/wiki/coverage.svg)](https://raw.githack.com/wiki/ei-grad/check-expiring-certs/coverage.html)
[![Go Report](https://goreportcard.com/badge/github.com/ei-grad/check-expiring-certs)](https://goreportcard.com/report/github.com/ei-grad/check-expiring-certs)
[![Latest Release](https://img.shields.io/github/v/release/ei-grad/check-expiring-certs)](https://github.com/ei-grad/check-expiring-certs/releases/latest)
[![Downloads](https://img.shields.io/github/downloads/ei-grad/check-expiring-certs/total)](https://github.com/ei-grad/check-expiring-certs/graphs/traffic)
Expand Down
64 changes: 51 additions & 13 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@ import (

func main() {

exitcode := 0

warning_period := flag.Int("warn", 7, "warning period in days")
timeout := flag.Duration("timeout", 2*time.Second, "timeout for connection")
concurrency := flag.Int("c", 128, "number of concurrent checks")
Expand All @@ -26,18 +24,33 @@ func main() {
// endpoints to check
endpoints := flag.Args()

// semaphore to limit concurrency to a reasonable number
semaphore := make(chan struct{}, *concurrency)

dialer := &net.Dialer{
Timeout: *timeout,
}

warn_if_expired_at := time.Now().AddDate(0, 0, *warning_period)

checker := NewSimpleHostChecker(dialer, warn_if_expired_at)

os.Exit(RunChecks(checker, endpoints, *concurrency))
}

type HostChecker interface {
CheckHost(host string) (bool, error)
}

func RunChecks(
checker HostChecker,
endpoints []string,
concurrency int,
) (exitcode int) {

// semaphore to limit concurrency to a reasonable number
semaphore := make(chan struct{}, concurrency)

wg := new(sync.WaitGroup)
wg.Add(len(endpoints))

warn_if_expired_at := time.Now().AddDate(0, 0, *warning_period)

for _, i := range endpoints {

// sleep 1ms to avoid hitting DNS resolver limits
Expand All @@ -54,7 +67,7 @@ func main() {
// release semaphore
defer func() { <-semaphore }()

is_expired, err := checkHost(dialer, i, warn_if_expired_at)
is_expired, err := checker.CheckHost(i)
if err != nil {
fmt.Printf("can't check %s: %s\n", i, err)
exitcode = 1
Expand All @@ -67,38 +80,63 @@ func main() {

wg.Wait()

os.Exit(exitcode)
return exitcode

}

var addrOverride = regexp.MustCompile(`^([^:]+):(((\[[0-9a-f:]+\])|([^:]+)):\d+)$`)

func checkHost(
type SimpleHostChecker struct {
dialer *net.Dialer
warn_if_expired_at time.Time
}

func NewSimpleHostChecker(
dialer *net.Dialer,
host string,
warn_if_expired_at time.Time,
) *SimpleHostChecker {
return &SimpleHostChecker{
dialer: dialer,
warn_if_expired_at: warn_if_expired_at,
}
}

func (c *SimpleHostChecker) CheckHost(
host string,
) (is_expired bool, err error) {

config := tls.Config{
// we still want to get connection even if the cert is expired, or if
// the hostname doesn't match
InsecureSkipVerify: true,
}

// custom address parsing to allow default port and address override
if !strings.Contains(host, ":") {
host = host + ":443"
} else if match := addrOverride.FindStringSubmatch(host); match != nil {
config.ServerName = match[1]
host = match[2]
}

conn, err := tls.DialWithDialer(dialer, "tcp", host, &config)
// make a connection to get the certificate
conn, err := tls.DialWithDialer(c.dialer, "tcp", host, &config)
if err != nil {
return
}
conn.Close()

// check all certificates in the chain for expiration
for _, cert := range conn.ConnectionState().PeerCertificates {
if warn_if_expired_at.After(cert.NotAfter) {
if c.warn_if_expired_at.After(cert.NotAfter) {
is_expired = true
fmt.Printf("Certificate for %s (%s) expires in %s\n",
host, cert.Subject.CommonName,
humanize.Time(cert.NotAfter))
}
}

// TODO: validate hostname and chain of trust

return
}
154 changes: 154 additions & 0 deletions main_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
package main

import (
"bytes"
"net"
"os"
"strings"
"testing"
"time"
)

var googleTimeout = 5 * time.Second

func TestCheckHostOneOneOneOne(t *testing.T) {
checker := NewSimpleHostChecker(
&net.Dialer{Timeout: googleTimeout},
time.Now().AddDate(0, 0, 7),
)
expired, err := checker.CheckHost("one.one.one.one:1.1.1.1:443")
if err != nil {
t.Errorf("checkHost() with valid host returned an error: %v", err)
}
if expired {
t.Errorf("checkHost() with valid host reported expired certificate")
}
}

func TestCheckHostGoogleValid(t *testing.T) {
checker := NewSimpleHostChecker(
&net.Dialer{Timeout: googleTimeout},
time.Now().AddDate(0, 0, 7),
)
expired, err := checker.CheckHost("google.com")
if err != nil {
t.Errorf("checkHost() with valid host returned an error: %v", err)
}
if expired {
t.Errorf("checkHost() with valid host reported expired certificate")
}
}

func captureOutput(f func()) string {
// Keep backup of the real stdout
old := os.Stdout
// Create a pipe for capturing output
r, w, _ := os.Pipe()
os.Stdout = w

// Execute the function
f()

// Close the pipe
w.Close()
// Restore the original stdout
os.Stdout = old

var buf bytes.Buffer
// Read the captured output
buf.ReadFrom(r)
return buf.String()
}

func TestGoogleCertInFiveYears(t *testing.T) {
checker := NewSimpleHostChecker(
&net.Dialer{Timeout: googleTimeout},
// Set the warning threshold to 5 years from now
time.Now().AddDate(5, 0, 0),
)
var isExpired bool
var err error
output := captureOutput(func() {
isExpired, err = checker.CheckHost("google.com")
})
if err != nil {
t.Fatalf("checkHost() returned an error: %v", err)
}
if !strings.Contains(output, "Certificate for google.com:443 (*.google.com) expires in") {
t.Errorf("Expected output to contain certificate expiration warning, got %q", output)
}
if !isExpired {
t.Errorf("Google's certificate is not marked as expiring within 5 years")
}
}

func TestCheckHostUnreachable(t *testing.T) {
checker := NewSimpleHostChecker(
&net.Dialer{Timeout: googleTimeout},
time.Now().AddDate(0, 0, 7),
)
isExpired, err := checker.CheckHost("some-unreachable-domain")
if err == nil {
t.Errorf("checkHost() with unreachable host did not return an error")
}
if isExpired {
t.Errorf("checkHost() with unreachable host reported expired certificate")
}
}

type MockHostChecker struct {
MockCheckHost func(host string) (bool, error)
}

func (m *MockHostChecker) CheckHost(host string) (bool, error) {
return m.MockCheckHost(host)
}

func TestRunChecks(t *testing.T) {

// Mock checkHost function
mockCheckHost := func(host string) (bool, error) {
if host == "test.com" {
return false, nil // test.com has a valid certificate
}
// unreachable domains error out
if strings.HasPrefix(host, "some-unreachable-domain") {
return false, &net.OpError{
Op: "dial",
Net: "tcp",
Err: &net.DNSError{Err: "no such host"},
}
}
return true, nil // other domains are considered expired
}

checker := &MockHostChecker{MockCheckHost: mockCheckHost}

// Define test cases
testCases := []struct {
name string
endpoints []string
expectedCode int
expectedOutput string
}{
{"All Valid", []string{"test.com", "test.com"}, 0, ""},
{"One Expired", []string{"test.com", "expired.com"}, 1, ""},
{"Unreachable", []string{"some-unreachable-domain"}, 1, "can't check"},
// Add more test cases as needed
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
var exitCode int
output := captureOutput(func() {
exitCode = RunChecks(checker, tc.endpoints, 2)
})
if exitCode != tc.expectedCode {
t.Errorf("Expected exit code %d, got %d", tc.expectedCode, exitCode)
}
if tc.expectedOutput != "" && !strings.Contains(output, tc.expectedOutput) {
t.Errorf("Expected output %q, got %q", tc.expectedOutput, output)
}
})
}
}

0 comments on commit f348b74

Please sign in to comment.