Skip to content

Commit

Permalink
Provide correct pty/tty file paths on OpenBSD (#148)
Browse files Browse the repository at this point in the history
While here, add test coverage for opening the TTY from the given filename.
  • Loading branch information
4a6f656c authored Apr 21, 2022
1 parent 2e47437 commit 0d412c9
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 2 deletions.
39 changes: 39 additions & 0 deletions doc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package pty
import (
"bytes"
"io"
"os"
"testing"
)

Expand Down Expand Up @@ -65,6 +66,44 @@ func TestName(t *testing.T) {
}
}

// TestOpenByName ensures that the name associated with the tty is valid
// and can be opened and used if passed by file name (rather than passing
// the existing open file descriptor).
func TestOpenByName(t *testing.T) {
t.Parallel()

pty, tty, err := Open()
if err != nil {
t.Fatal(err)
}
defer pty.Close()
defer tty.Close()

ttyFile, err := os.OpenFile(tty.Name(), os.O_RDWR, 0600)
if err != nil {
t.Fatalf("Failed to open tty file: %v", err)
}
defer ttyFile.Close()

// Ensure we can write to the newly opened tty file and read on the pty.
text := []byte("ping")
n, err := ttyFile.Write(text)
if err != nil {
t.Errorf("Unexpected error from Write: %s", err)
}
if n != len(text) {
t.Errorf("Unexpected count returned from Write, got %d expected %d", n, len(text))
}

buffer := make([]byte, len(text))
if err := readBytes(pty, buffer); err != nil {
t.Errorf("Unexpected error from readBytes: %s", err)
}
if !bytes.Equal(text, buffer) {
t.Errorf("Unexpected result returned from Read, got %v expected %v", buffer, text)
}
}

func TestGetsize(t *testing.T) {
t.Parallel()

Expand Down
15 changes: 13 additions & 2 deletions pty_openbsd.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,17 @@ import (
"unsafe"
)

func cInt8ToString(in []int8) string {
var s []byte
for _, v := range in {
if v == 0 {
break
}
s = append(s, byte(v))
}
return string(s)
}

func open() (pty, tty *os.File, err error) {
/*
* from ptm(4):
Expand All @@ -29,8 +40,8 @@ func open() (pty, tty *os.File, err error) {
return nil, nil, err
}

pty = os.NewFile(uintptr(ptm.Cfd), "/dev/ptm")
tty = os.NewFile(uintptr(ptm.Sfd), "/dev/ptm")
pty = os.NewFile(uintptr(ptm.Cfd), cInt8ToString(ptm.Cn[:]))
tty = os.NewFile(uintptr(ptm.Sfd), cInt8ToString(ptm.Sn[:]))

return pty, tty, nil
}

0 comments on commit 0d412c9

Please sign in to comment.