From ee7ac9cf49d48e72714e9e4258870e6781c1fbc4 Mon Sep 17 00:00:00 2001 From: WGH Date: Wed, 19 Oct 2022 20:59:34 +0300 Subject: [PATCH] journal: add StderrIsJournalStream function This function can be used for automatic protocol upgrade described in [1]. Both unit tests and runnable example are included. Only the latter requires systemd, as unit tests are self-sufficient, and only test that JOURNAL_STREAM environment variable is checked properly. [1] https://systemd.io/JOURNAL_NATIVE_PROTOCOL/#automatic-protocol-upgrading --- examples/journal/main.go | 37 +++++++++ examples/journal/run.sh | 13 ++++ journal/journal_unix.go | 38 ++++++++++ journal/journal_unix_test.go | 142 +++++++++++++++++++++++++++++++++++ 4 files changed, 230 insertions(+) create mode 100644 examples/journal/main.go create mode 100755 examples/journal/run.sh create mode 100644 journal/journal_unix_test.go diff --git a/examples/journal/main.go b/examples/journal/main.go new file mode 100644 index 00000000..86dadc96 --- /dev/null +++ b/examples/journal/main.go @@ -0,0 +1,37 @@ +// Copyright 2022 CoreOS, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License + +package main + +import ( + "fmt" + "os" + + "github.com/coreos/go-systemd/v22/journal" +) + +func main() { + ok, err := journal.StderrIsJournalStream() + if err != nil { + panic(err) + } + + if ok { + // use journal native protocol + journal.Send("this is a message logged through the native protocol", journal.PriInfo, nil) + } else { + // use stderr + fmt.Fprintln(os.Stderr, "this is a message logged through stderr") + } +} diff --git a/examples/journal/run.sh b/examples/journal/run.sh new file mode 100755 index 00000000..2909b8f0 --- /dev/null +++ b/examples/journal/run.sh @@ -0,0 +1,13 @@ +#!/bin/bash + +set -e + +go build + +echo "Running directly" +./journal + +echo "Running through systemd" +unit_name="run-$(systemd-id128 new)" +systemd-run -u "$unit_name" --user --wait --quiet ./journal +journalctl --user -u "$unit_name" diff --git a/journal/journal_unix.go b/journal/journal_unix.go index 439ad287..cd8c4c30 100644 --- a/journal/journal_unix.go +++ b/journal/journal_unix.go @@ -69,6 +69,44 @@ func Enabled() bool { return true } +// StderrIsJournalStream returns whether the process stderr is connected +// to the Journal's stream transport. +// +// This can be used for automatic protocol upgrading described in [Journal Native Protocol]. +// +// Returns true if JOURNAL_STREAM environment variable is present, +// and stderr's device and inode numbers match it. +// +// Error is returned if unexpected error occurs: e.g. if JOURNAL_STREAM environment variable +// is present, but malformed, fstat syscall fails, etc. +// +// [Journal Native Protocol]: https://systemd.io/JOURNAL_NATIVE_PROTOCOL/#automatic-protocol-upgrading +func StderrIsJournalStream() (res bool, err error) { + journalStream := os.Getenv("JOURNAL_STREAM") + if journalStream == "" { + return + } + + var expectedStat syscall.Stat_t + _, err = fmt.Sscanf(journalStream, "%d:%d", &expectedStat.Dev, &expectedStat.Ino) + if err != nil { + err = fmt.Errorf("failed to parse JOURNAL_STREAM=%q: %w", journalStream, err) + return + } + + var stat syscall.Stat_t + err = syscall.Fstat(syscall.Stderr, &stat) + if err != nil { + return + } + + if stat.Dev != expectedStat.Dev || stat.Ino != expectedStat.Ino { + return + } + + return true, nil +} + // Send a message to the local systemd journal. vars is a map of journald // fields to values. Fields must be composed of uppercase letters, numbers, // and underscores, but must not start with an underscore. Within these diff --git a/journal/journal_unix_test.go b/journal/journal_unix_test.go new file mode 100644 index 00000000..6a915522 --- /dev/null +++ b/journal/journal_unix_test.go @@ -0,0 +1,142 @@ +// Copyright 2022 CoreOS, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build unix +// +build unix + +package journal_test + +import ( + "fmt" + "os" + "syscall" + "testing" + + "github.com/coreos/go-systemd/v22/journal" +) + +func TestStderrIsJournalStream(t *testing.T) { + if _, ok := os.LookupEnv("JOURNAL_STREAM"); ok { + t.Fatal("unset JOURNAL_STREAM before running this test") + } + + t.Run("Missing", func(t *testing.T) { + ok, err := journal.StderrIsJournalStream() + if err != nil { + t.Fatal(err) + } + if ok { + t.Error("stderr shouldn't be connected to journal stream") + } + }) + t.Run("Present", func(t *testing.T) { + f, stat := getUnixStreamSocket(t) + defer f.Close() + os.Setenv("JOURNAL_STREAM", fmt.Sprintf("%d:%d", stat.Dev, stat.Ino)) + defer os.Unsetenv("JOURNAL_STREAM") + replaceStderr(int(f.Fd()), func() { + ok, err := journal.StderrIsJournalStream() + if err != nil { + t.Fatal(err) + } + if !ok { + t.Error("stderr should've been connected to journal stream") + } + }) + }) + t.Run("NotMatching", func(t *testing.T) { + f, stat := getUnixStreamSocket(t) + defer f.Close() + os.Setenv("JOURNAL_STREAM", fmt.Sprintf("%d:%d", stat.Dev+1, stat.Ino)) + defer os.Unsetenv("JOURNAL_STREAM") + replaceStderr(int(f.Fd()), func() { + ok, err := journal.StderrIsJournalStream() + if err != nil { + t.Fatal(err) + } + if ok { + t.Error("stderr shouldn't be connected to journal stream") + } + }) + }) + t.Run("Malformed", func(t *testing.T) { + f, stat := getUnixStreamSocket(t) + defer f.Close() + os.Setenv("JOURNAL_STREAM", fmt.Sprintf("%d-%d", stat.Dev, stat.Ino)) + defer os.Unsetenv("JOURNAL_STREAM") + replaceStderr(int(f.Fd()), func() { + _, err := journal.StderrIsJournalStream() + if err == nil { + t.Fatal("JOURNAL_STREAM is malformed, but no error returned") + } + }) + }) +} + +func ExampleStderrIsJournalStream() { + // NOTE: this is just an example. Production code + // will likely use this to setup a logging library + // to write messages to either journal or stderr. + ok, err := journal.StderrIsJournalStream() + if err != nil { + panic(err) + } + + if ok { + // use journal native protocol + journal.Send("this is a message logged through the native protocol", journal.PriInfo, nil) + } else { + // use stderr + fmt.Fprintln(os.Stderr, "this is a message logged through stderr") + } +} + +func replaceStderr(fd int, cb func()) { + savedStderr, err := syscall.Dup(syscall.Stderr) + if err != nil { + panic(err) + } + defer syscall.Close(savedStderr) + err = syscall.Dup2(fd, syscall.Stderr) + if err != nil { + panic(err) + } + defer func() { + err := syscall.Dup2(savedStderr, syscall.Stderr) + if err != nil { + panic(err) + } + }() + cb() +} + +// getUnixStreamSocket returns a unix stream socket obtained with +// socketpair(2), and its fstat result. Only one end of the socket pair +// is returned, and the other end is closed immediately: we don't need +// it for our purposes. +func getUnixStreamSocket(t *testing.T) (*os.File, *syscall.Stat_t) { + fds, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_STREAM, 0) + if err != nil { + t.Fatal(os.NewSyscallError("socketpair", err)) + } + // we don't need the remote end for our tests + syscall.Close(fds[1]) + + file := os.NewFile(uintptr(fds[0]), "unix-stream") + stat, err := file.Stat() + if err != nil { + t.Fatal(err) + } + return file, stat.Sys().(*syscall.Stat_t) +}