Skip to content

Commit

Permalink
improve binary escaping
Browse files Browse the repository at this point in the history
  • Loading branch information
lonnywong committed Sep 2, 2023
1 parent 6c37e8d commit 2d4299a
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 99 deletions.
86 changes: 42 additions & 44 deletions trzsz/escape.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,13 @@ import (

type unicode string

type escapeArray [][]byte
type escapeTable struct {
totalCount int
escapeCodes []*byte
unescapeCodes []*byte
}

const escapeLeaderByte = '\xee'
const escapeLeaderByte = byte('\xee')

func (s unicode) MarshalJSON() ([]byte, error) {
b := new(bytes.Buffer)
Expand All @@ -59,18 +63,24 @@ func getEscapeChars(escapeAll bool) [][]unicode {
{"\u007e", "\u00ee\u0031"},
}
if escapeAll {
const chars = unicode("\x02\x10\x1b\x1d\u009d")
for i, c := range chars {
escapeChars = append(escapeChars, []unicode{unicode(c), "\u00ee" + unicode(byte(i+0x41))})
const chars = unicode("\x02\x0d\x10\x11\x13\x18\x1b\x1d\u008d\u0090\u0091\u0093\u009d")
e := byte('A')
for _, c := range chars {
escapeChars = append(escapeChars, []unicode{unicode(c), unicode(escapeLeaderByte) + unicode(e)})
e += 1
}
}
return escapeChars
}

func escapeCharsToCodes(escapeChars []interface{}) ([][]byte, error) {
escapeCodes := make([][]byte, len(escapeChars))
func escapeCharsToTable(escapeChars []interface{}) (*escapeTable, error) {
table := &escapeTable{
totalCount: len(escapeChars),
escapeCodes: make([]*byte, 256),
unescapeCodes: make([]*byte, 256),
}
encoder := charmap.ISO8859_1.NewEncoder()
for i, v := range escapeChars {
for _, v := range escapeChars {
a, ok := v.([]interface{})
if !ok {
return nil, simpleTrzszError("Escape chars invalid: %v", v)
Expand Down Expand Up @@ -103,54 +113,49 @@ func escapeCharsToCodes(escapeChars []interface{}) ([][]byte, error) {
if cc[0] != escapeLeaderByte {
return nil, simpleTrzszError("Escape chars invalid: %v", v)
}
escapeCodes[i] = make([]byte, 3)
escapeCodes[i][0] = bb[0]
escapeCodes[i][1] = cc[0]
escapeCodes[i][2] = cc[1]
table.escapeCodes[bb[0]] = &cc[1]
table.unescapeCodes[cc[1]] = &bb[0]
}
return escapeCodes, nil
return table, nil
}

func (c *escapeArray) UnmarshalJSON(data []byte) error {
func (c *escapeTable) UnmarshalJSON(data []byte) error {
var codes []interface{}
if err := json.Unmarshal(data, &codes); err != nil {
return err
}
var err error
*c, err = escapeCharsToCodes(codes)
return err
table, err := escapeCharsToTable(codes)
if err != nil {
return err
}
*c = *table
return nil
}

func escapeData(data []byte, escapeCodes [][]byte) []byte {
if len(escapeCodes) == 0 {
func escapeData(data []byte, table *escapeTable) []byte {
if table == nil || table.totalCount == 0 {
return data
}

buf := make([]byte, len(data)*2)
idx := 0
for _, d := range data {
escapeIdx := -1
for j, e := range escapeCodes {
if d == e[0] {
escapeIdx = j
break
}
}
if escapeIdx < 0 {
buf[idx] = d
for _, bdata := range data {
ecode := table.escapeCodes[bdata]
if ecode == nil {
buf[idx] = bdata
idx++
} else {
buf[idx] = escapeCodes[escapeIdx][1]
buf[idx] = escapeLeaderByte
idx++
buf[idx] = escapeCodes[escapeIdx][2]
buf[idx] = *ecode
idx++
}
}
return buf[:idx]
}

func unescapeData(data []byte, escapeCodes [][]byte, dst []byte) ([]byte, []byte, error) {
if len(escapeCodes) == 0 {
func unescapeData(data []byte, table *escapeTable, dst []byte) ([]byte, []byte, error) {
if table == nil || table.totalCount == 0 {
return data, nil, nil
}

Expand All @@ -166,18 +171,11 @@ func unescapeData(data []byte, escapeCodes [][]byte, dst []byte) ([]byte, []byte
return buf[:idx], data[i:], nil
}
i++
b := data[i]
escaped := false
for _, e := range escapeCodes {
if b == e[2] {
buf[idx] = e[0]
escaped = true
break
}
}
if !escaped {
return nil, nil, simpleTrzszError("Unknown escape code: %v", b)
ecode := table.unescapeCodes[data[i]]
if ecode == nil {
return nil, nil, simpleTrzszError("Unknown escape code: %v", data[i])
}
buf[idx] = *ecode
} else {
buf[idx] = data[i]
}
Expand Down
32 changes: 16 additions & 16 deletions trzsz/pipeline.go
Original file line number Diff line number Diff line change
Expand Up @@ -222,19 +222,19 @@ func newBase64Writer(writer io.WriteCloser) writeCloseFlusher {
}

type escapeReader struct {
transfer *trzszTransfer
reader io.Reader
buffer []byte
table *escapeTable
reader io.Reader
buffer []byte
}

func (e *escapeReader) Read(p []byte) (n int, err error) {
if len(e.transfer.transferConfig.EscapeCodes) == 0 {
if e.table == nil || e.table.totalCount == 0 {
return e.reader.Read(p)
}
for {
var idx int
if len(e.buffer) > 0 {
buf, remaining, err := unescapeData(e.buffer, e.transfer.transferConfig.EscapeCodes, p)
buf, remaining, err := unescapeData(e.buffer, e.table, p)
if err != nil {
return 0, err
}
Expand All @@ -260,17 +260,17 @@ func (e *escapeReader) Read(p []byte) (n int, err error) {
func (e *escapeReader) Close() {
}

func newEscapeReader(transfer *trzszTransfer, reader io.Reader) readCloser {
return &escapeReader{transfer, reader, nil}
func newEscapeReader(table *escapeTable, reader io.Reader) readCloser {
return &escapeReader{table, reader, nil}
}

type escapeWriter struct {
transfer *trzszTransfer
writer io.WriteCloser
table *escapeTable
writer io.WriteCloser
}

func (e *escapeWriter) Write(p []byte) (int, error) {
buf := escapeData(p, e.transfer.transferConfig.EscapeCodes)
buf := escapeData(p, e.table)
if err := writeAll(e.writer, buf); err != nil {
return 0, err
}
Expand All @@ -285,8 +285,8 @@ func (e *escapeWriter) Flush() error {
return nil
}

func newEscapeWriter(transfer *trzszTransfer, writer io.WriteCloser) writeCloseFlusher {
return &escapeWriter{transfer, writer}
func newEscapeWriter(table *escapeTable, writer io.WriteCloser) writeCloseFlusher {
return &escapeWriter{table, writer}
}

type zstdReader struct {
Expand Down Expand Up @@ -553,9 +553,9 @@ func (t *trzszTransfer) pipelineEncodeData(ctx *pipelineContext, fileDataChan <-
var writer writeCloseFlusher
if t.transferConfig.Binary {
if compress {
writer, err = newZstdWriter(newEscapeWriter(t, newSendDataWriter(t, ctx, sendDataChan)))
writer, err = newZstdWriter(newEscapeWriter(t.transferConfig.EscapeTable, newSendDataWriter(t, ctx, sendDataChan)))
} else {
writer = newEscapeWriter(t, newSendDataWriter(t, ctx, sendDataChan))
writer = newEscapeWriter(t.transferConfig.EscapeTable, newSendDataWriter(t, ctx, sendDataChan))
}
} else {
if compress {
Expand Down Expand Up @@ -964,9 +964,9 @@ func (t *trzszTransfer) pipelineDecodeData(ctx *pipelineContext, recvDataChan <-
var reader readCloser
if t.transferConfig.Binary {
if compress {
reader, err = newZstdReader(newEscapeReader(t, newRecvDataReader(ctx, recvDataChan)))
reader, err = newZstdReader(newEscapeReader(t.transferConfig.EscapeTable, newRecvDataReader(ctx, recvDataChan)))
} else {
reader = newEscapeReader(t, newRecvDataReader(ctx, recvDataChan))
reader = newEscapeReader(t.transferConfig.EscapeTable, newRecvDataReader(ctx, recvDataChan))
}
} else {
if compress {
Expand Down
62 changes: 36 additions & 26 deletions trzsz/pipeline_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ func TestBase64Writer(t *testing.T) {
func TestEscapeReader(t *testing.T) {
assert := assert.New(t)
transfer := newTransferWithEscape(t, nil, true)
reader := newEscapeReader(transfer, newMockReader([][]byte{
reader := newEscapeReader(transfer.transferConfig.EscapeTable, newMockReader([][]byte{
[]byte("ABC"),
[]byte("\xeeA\xeeB\xeeC\xeeDE"),
[]byte("ABCDEFGXYX\xee"),
Expand All @@ -381,8 +381,8 @@ func TestEscapeReader(t *testing.T) {
}
assertReadEqual("ABC", 100)
assertReadEqual("\x02", 1)
assertReadEqual("\x10", 1)
assertReadEqual("\x1b\x1d", 2)
assertReadEqual("\x0d", 1)
assertReadEqual("\x10\x11", 2)
assertReadEqual("E", 100)
assertReadEqual("A", 1)
assertReadEqual("BC", 2)
Expand All @@ -395,19 +395,25 @@ func TestEscapeWriter(t *testing.T) {
assert := assert.New(t)
transfer := newTransferWithEscape(t, nil, true)
var out mockWriter
writer := newEscapeWriter(transfer, &out)
writer := newEscapeWriter(transfer.transferConfig.EscapeTable, &out)
assertWriteSucc := func(data string) {
t.Helper()
n, err := writer.Write([]byte(data))
assert.Nil(err)
assert.Equal(n, len(data))
}
assertWriteSucc("AB\xeeC\x7e")
assertWriteSucc("\x02\x10\x1b\x1d")
assertWriteSucc("\x02\x0d\x10\x11")
assertWriteSucc("\x13\x18\x1b\x1d")
assertWriteSucc("\x8d\x90\x91\x93")
assertWriteSucc("\x9dXZ")
writer.Close()
assert.True(out.closed)
assert.Equal("AB\xee\xeeC\xee1\xeeA\xeeB\xeeC\xeeD\xeeEXZ", out.buf.String())
assert.Equal("AB\xee\xeeC\xee1"+
"\xeeA\xeeB\xeeC\xeeD"+
"\xeeE\xeeF\xeeG\xeeH"+
"\xeeI\xeeJ\xeeK\xeeL"+
"\xeeMXZ", out.buf.String())
}

func TestZstdReaderWriter(t *testing.T) {
Expand Down Expand Up @@ -638,16 +644,16 @@ func TestPipelineUnescapeData(t *testing.T) {

// escape at the end
recvDataChan <- []byte("ABC\xee\x41\xee\x42")
assertChannelFrontEqual(t, []byte("ABC\x02\x10"), fileDataChan)
assertChannelFrontEqual(t, []byte("ABC\x02\x10"), md5SourceChan)
assertChannelFrontEqual(t, []byte("ABC\x02\x0d"), fileDataChan)
assertChannelFrontEqual(t, []byte("ABC\x02\x0d"), md5SourceChan)

// escaping across buffers
recvDataChan <- []byte("ABC\xee\x41\xee")
recvDataChan <- []byte("\x42DEF")
assertChannelFrontEqual(t, []byte("ABC\x02"), fileDataChan)
assertChannelFrontEqual(t, []byte("ABC\x02"), md5SourceChan)
assertChannelFrontEqual(t, []byte("\x10DEF"), fileDataChan)
assertChannelFrontEqual(t, []byte("\x10DEF"), md5SourceChan)
assertChannelFrontEqual(t, []byte("\x0dDEF"), fileDataChan)
assertChannelFrontEqual(t, []byte("\x0dDEF"), md5SourceChan)

// complex escaping
recvDataChan <- []byte("ABC\xee\xee\xee\xee\xee")
Expand All @@ -658,23 +664,27 @@ func TestPipelineUnescapeData(t *testing.T) {
recvDataChan <- []byte("\xee\xee\xee\xee\xee")
recvDataChan <- []byte("G\xee\xee\xee\xee\xee")
recvDataChan <- []byte("\x31\xee\xee\xee\xee")
recvDataChan <- []byte("\xeeA\xeeB\xeeC\xeeD")
recvDataChan <- []byte("\xeeE\xeeF\xeeG\xeeH")
recvDataChan <- []byte("\xeeI\xeeJ\xeeK\xeeL")
recvDataChan <- []byte("\xeeM")
close(recvDataChan)
assertChannelFrontEqual(t, []byte("ABC\xee\xee"), fileDataChan)
assertChannelFrontEqual(t, []byte("ABC\xee\xee"), md5SourceChan)
assertChannelFrontEqual(t, []byte("\xee\xee\xee"), fileDataChan)
assertChannelFrontEqual(t, []byte("\xee\xee\xee"), md5SourceChan)
assertChannelFrontEqual(t, []byte("\xee\xee"), fileDataChan)
assertChannelFrontEqual(t, []byte("\xee\xee"), md5SourceChan)
assertChannelFrontEqual(t, []byte("\xee\xeeDEF\xee"), fileDataChan)
assertChannelFrontEqual(t, []byte("\xee\xeeDEF\xee"), md5SourceChan)
assertChannelFrontEqual(t, []byte("\xee\xee"), fileDataChan)
assertChannelFrontEqual(t, []byte("\xee\xee"), md5SourceChan)
assertChannelFrontEqual(t, []byte("\xee\xee\xee"), fileDataChan)
assertChannelFrontEqual(t, []byte("\xee\xee\xee"), md5SourceChan)
assertChannelFrontEqual(t, []byte("G\xee\xee"), fileDataChan)
assertChannelFrontEqual(t, []byte("G\xee\xee"), md5SourceChan)
assertChannelFrontEqual(t, []byte("\x7e\xee\xee"), fileDataChan)
assertChannelFrontEqual(t, []byte("\x7e\xee\xee"), md5SourceChan)
assertUnescapeEqual := func(dataChan <-chan []byte) {
assertChannelFrontEqual(t, []byte("ABC\xee\xee"), dataChan)
assertChannelFrontEqual(t, []byte("\xee\xee\xee"), dataChan)
assertChannelFrontEqual(t, []byte("\xee\xee"), dataChan)
assertChannelFrontEqual(t, []byte("\xee\xeeDEF\xee"), dataChan)
assertChannelFrontEqual(t, []byte("\xee\xee"), dataChan)
assertChannelFrontEqual(t, []byte("\xee\xee\xee"), dataChan)
assertChannelFrontEqual(t, []byte("G\xee\xee"), dataChan)
assertChannelFrontEqual(t, []byte("\x7e\xee\xee"), dataChan)
assertChannelFrontEqual(t, []byte("\x02\x0d\x10\x11"), dataChan)
assertChannelFrontEqual(t, []byte("\x13\x18\x1b\x1d"), dataChan)
assertChannelFrontEqual(t, []byte("\x8d\x90\x91\x93"), dataChan)
assertChannelFrontEqual(t, []byte("\x9d"), dataChan)
}
assertUnescapeEqual(fileDataChan)
assertUnescapeEqual(md5SourceChan)

// cancel
recvDataChan = make(chan []byte, 100)
Expand Down
6 changes: 3 additions & 3 deletions trzsz/transfer.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ type transferConfig struct {
Newline string `json:"newline"`
Protocol int `json:"protocol"`
MaxBufSize int64 `json:"bufsize"`
EscapeCodes escapeArray `json:"escape_chars"`
EscapeTable *escapeTable `json:"escape_chars"`
TmuxPaneColumns int32 `json:"tmux_pane_width"`
TmuxOutputJunk bool `json:"tmux_output_junk"`
CompressType compressType `json:"compress"`
Expand Down Expand Up @@ -404,7 +404,7 @@ func (t *trzszTransfer) sendData(data []byte) error {
if !t.transferConfig.Binary {
return t.sendBinary("DATA", data)
}
buf := escapeData(data, t.transferConfig.EscapeCodes)
buf := escapeData(data, t.transferConfig.EscapeTable)
if err := t.writeAll([]byte(fmt.Sprintf("#DATA:%d\n", len(buf)))); err != nil {
return err
}
Expand Down Expand Up @@ -434,7 +434,7 @@ func (t *trzszTransfer) recvData() ([]byte, error) {
}
return nil, err
}
buf, remaining, err := unescapeData(data, t.transferConfig.EscapeCodes, nil)
buf, remaining, err := unescapeData(data, t.transferConfig.EscapeTable, nil)
if err != nil {
return nil, err
}
Expand Down
Loading

0 comments on commit 2d4299a

Please sign in to comment.