diff --git a/trzsz/comm.go b/trzsz/comm.go index 2a673d9..3901c54 100644 --- a/trzsz/comm.go +++ b/trzsz/comm.go @@ -31,6 +31,7 @@ import ( "encoding/json" "fmt" "io" + "net" "os" "os/exec" "os/signal" @@ -530,20 +531,22 @@ func getTerminalColumns() int { return cols } -func wrapStdinInput(transfer *trzszTransfer) { - const bufSize = 32 * 1024 - buffer := make([]byte, bufSize) - for { - n, err := os.Stdin.Read(buffer) - if n > 0 { - buf := buffer[0:n] - transfer.addReceivedData(buf) - buffer = make([]byte, bufSize) - } - if err == io.EOF { - transfer.stopTransferringFiles(false) +func wrapTransferInput(transfer *trzszTransfer, reader io.Reader, tunnel bool) { + go func() { + const bufSize = 32 * 1024 + buffer := make([]byte, bufSize) + for { + n, err := reader.Read(buffer) + if n > 0 { + buf := buffer[0:n] + transfer.addReceivedData(buf, tunnel) + buffer = make([]byte, bufSize) + } + if err != nil { + break + } } - } + }() } func handleServerSignal(transfer *trzszTransfer) { @@ -634,6 +637,15 @@ func (v *trzszVersion) compare(ver *trzszVersion) int { return 0 } +type trzszTrigger struct { + mode byte + version *trzszVersion + uniqueID string + winServer bool + tunnelPort int + tmuxPrefix string +} + type trzszDetector struct { relay bool tmux bool @@ -644,9 +656,9 @@ func newTrzszDetector(relay, tmux bool) *trzszDetector { return &trzszDetector{relay, tmux, make(map[string]int)} } -var trzszRegexp = regexp.MustCompile(`::TRZSZ:TRANSFER:([SRD]):(\d+\.\d+\.\d+)(:\d+)?`) +var trzszRegexp = regexp.MustCompile(`::TRZSZ:TRANSFER:([SRD]):(\d+\.\d+\.\d+)(:\d+)?(:\d+)?`) var uniqueIDRegexp = regexp.MustCompile(`::TRZSZ:TRANSFER:[SRD]:\d+\.\d+\.\d+:(\d{13}\d*)`) -var tmuxControlModeRegexp = regexp.MustCompile(`((%output %\d+)|(%extended-output %\d+ \d+ :)) .*::TRZSZ:TRANSFER:`) +var tmuxControlModeRegexp = regexp.MustCompile(`((%output %\d+ )|(%extended-output %\d+ \d+ : )).*::TRZSZ:TRANSFER:`) func (detector *trzszDetector) rewriteTrzszTrigger(buf []byte) []byte { for _, match := range uniqueIDRegexp.FindAllSubmatch(buf, -1) { @@ -681,13 +693,32 @@ func (detector *trzszDetector) addRelaySuffix(output []byte, idx int) []byte { return buf.Bytes() } -func (detector *trzszDetector) detectTrzsz(output []byte) ([]byte, *byte, *trzszVersion, bool) { +func (detector *trzszDetector) isRepeatedID(uniqueID string) bool { + if len(uniqueID) > 6 && (isWindowsEnvironment() || !(len(uniqueID) == 13 && strings.HasSuffix(uniqueID, "00"))) { + if _, ok := detector.uniqueIDMap[uniqueID]; ok { + return true + } + if len(detector.uniqueIDMap) > 100 { + m := make(map[string]int) + for k, v := range detector.uniqueIDMap { + if v >= 50 { + m[k] = v - 50 + } + } + detector.uniqueIDMap = m + } + detector.uniqueIDMap[uniqueID] = len(detector.uniqueIDMap) + } + return false +} + +func (detector *trzszDetector) detectTrzsz(output []byte, tunnel bool) ([]byte, *trzszTrigger) { if len(output) < 24 { - return output, nil, nil, false + return output, nil } idx := bytes.LastIndex(output, []byte("::TRZSZ:TRANSFER:")) if idx < 0 { - return output, nil, nil, false + return output, nil } if detector.relay && detector.tmux { @@ -698,50 +729,51 @@ func (detector *trzszDetector) detectTrzsz(output []byte) ([]byte, *byte, *trzsz subOutput := output[idx:] match := trzszRegexp.FindSubmatch(subOutput) if len(match) < 3 { - return output, nil, nil, false + return output, nil } - if tmuxControlModeRegexp.Match(output) { - return output, nil, nil, false + + tmuxPrefix := "" + tmuxMatch := tmuxControlModeRegexp.FindSubmatch(output) + if len(tmuxMatch) > 1 { + if !tunnel || len(match) < 5 || match[4] == nil { + return output, nil + } + tmuxPrefix = string(tmuxMatch[1]) } if len(subOutput) > 40 { for _, s := range []string{"#CFG:", "Saved", "Cancelled", "Stopped", "Interrupted"} { if bytes.Contains(subOutput[40:], []byte(s)) { - return output, nil, nil, false + return output, nil } } } mode := match[1][0] - serverVersion, err := parseTrzszVersion(string(match[2])) + version, err := parseTrzszVersion(string(match[2])) if err != nil { - return output, nil, nil, false + return output, nil } uniqueID := "" - if len(match) > 3 { - uniqueID = string(match[3]) + if len(match) > 3 && match[3] != nil { + uniqueID = string(match[3][1:]) } - if len(uniqueID) >= 8 && (isWindowsEnvironment() || !(len(uniqueID) == 14 && strings.HasSuffix(uniqueID, "00"))) { - if _, ok := detector.uniqueIDMap[uniqueID]; ok { - return output, nil, nil, false - } - if len(detector.uniqueIDMap) > 100 { - m := make(map[string]int) - for k, v := range detector.uniqueIDMap { - if v >= 50 { - m[k] = v - 50 - } - } - detector.uniqueIDMap = m - } - detector.uniqueIDMap[uniqueID] = len(detector.uniqueIDMap) + if detector.isRepeatedID(uniqueID) { + return output, nil + } + + winServer := false + if uniqueID == "1" || (len(uniqueID) == 13 && strings.HasSuffix(uniqueID, "10")) { + winServer = true } - remoteIsWindows := false - if uniqueID == ":1" || (len(uniqueID) == 14 && strings.HasSuffix(uniqueID, "10")) { - remoteIsWindows = true + port := 0 + if len(match) > 4 && match[4] != nil { + if v, err := strconv.Atoi(string(match[4][1:])); err == nil { + port = v + } } if detector.relay { @@ -750,7 +782,14 @@ func (detector *trzszDetector) detectTrzsz(output []byte) ([]byte, *byte, *trzsz output = bytes.ReplaceAll(output, []byte("TRZSZ"), []byte("TRZSZGO")) } - return output, &mode, serverVersion, remoteIsWindows + return output, &trzszTrigger{ + mode: mode, + version: version, + uniqueID: uniqueID, + winServer: winServer, + tunnelPort: port, + tmuxPrefix: tmuxPrefix, + } } type traceLogger struct { @@ -926,3 +965,44 @@ func resolveHomeDir(path string) string { } return path } + +func listenForTunnel() (net.Listener, int) { + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + return nil, 0 + } + return listener, listener.Addr().(*net.TCPAddr).Port +} + +func encodeTmuxOutput(prefix string, output []byte) []byte { + buffer := bytes.NewBuffer(make([]byte, 0, len(prefix)+len(output)<<2+2)) + buffer.Write([]byte(prefix)) + for _, b := range output { + if b >= '0' && b <= '9' || b >= 'A' && b <= 'Z' || b >= 'a' && b <= 'z' { + buffer.WriteByte(b) + continue + } + buffer.Write([]byte(fmt.Sprintf("\\%03o", b))) + } + buffer.Write([]byte("\r\n")) + return buffer.Bytes() +} + +type promptWriter struct { + prefix string + writer io.Writer +} + +func (w *promptWriter) Write(p []byte) (int, error) { + if w.prefix == "" { + return w.writer.Write(p) + } + if err := writeAll(w.writer, encodeTmuxOutput(w.prefix, p)); err != nil { + return 0, err + } + return len(p), nil +} + +func (w *promptWriter) Close() error { + return nil +} diff --git a/trzsz/comm_test.go b/trzsz/comm_test.go index 6dbb2ac..cddf1ed 100644 --- a/trzsz/comm_test.go +++ b/trzsz/comm_test.go @@ -153,119 +153,201 @@ func TestTrzszVersion(t *testing.T) { func TestTrzszDetector(t *testing.T) { assert := assert.New(t) detector := newTrzszDetector(false, false) - assertDetectTrzsz := func(output string, mode *byte, ver *trzszVersion, win bool) { + assertDetectTrzsz := func(output string, tunnel bool, trigger *trzszTrigger) { t.Helper() - buf, m, v, w := detector.detectTrzsz([]byte(output)) - if mode == nil { + buf, t := detector.detectTrzsz([]byte(output), tunnel) + assert.Equal(trigger, t) + if trigger == nil { assert.Equal([]byte(output), buf) } else { assert.Equal(bytes.ReplaceAll([]byte(output), []byte("TRZSZ"), []byte("TRZSZGO")), buf) } - assert.Equal(mode, m) - assert.Equal(ver, v) - assert.Equal(win, w) } - assertDetectTrzsz("", nil, nil, false) - assertDetectTrzsz("ABC", nil, nil, false) - assertDetectTrzsz(strings.Repeat("A::", 10), nil, nil, false) - assertDetectTrzsz("::TRZSZ:TRANSFER:R:", nil, nil, false) + assertDetectTrzsz("", false, nil) + assertDetectTrzsz("ABC", false, nil) + assertDetectTrzsz(strings.Repeat("A::", 10), false, nil) + assertDetectTrzsz("::TRZSZ:TRANSFER:R:", false, nil) // normal trzsz trigger - R := byte('R') - D := byte('D') - S := byte('S') - assertDetectTrzsz("::TRZSZ:TRANSFER:"+"R:1.0.0:0", &R, &trzszVersion{1, 0, 0}, false) - assertDetectTrzsz("ABC::TRZSZ:TRANSFER:"+"D:1.0.0:123", &D, &trzszVersion{1, 0, 0}, false) - assertDetectTrzsz("\x1b7\x07::TRZSZ:TRANSFER:"+"S:1.0.0:1", &S, &trzszVersion{1, 0, 0}, true) - assertDetectTrzsz("\x1b7\x07::TRZSZ:TRANSFER:"+"S:1.0.0:1:1234", &S, &trzszVersion{1, 0, 0}, true) - assertDetectTrzsz("XYX\x1b7\x07::TRZSZ:TRANSFER:"+"S:1.0.0:1:7890", &S, &trzszVersion{1, 0, 0}, true) - assertDetectTrzsz("\x1b7\x07::TRZSZ:TRANSFER:"+"S:1.0.0:1:1234ABC\n", &S, &trzszVersion{1, 0, 0}, true) - assertDetectTrzsz("XYX\x1b7\x07::TRZSZ:TRANSFER:"+"S:1.0.0:1:7890EFG\r\n", &S, &trzszVersion{1, 0, 0}, true) + newTrigger100 := func(mode byte, uid string, win bool, port int) *trzszTrigger { + return &trzszTrigger{mode, &trzszVersion{1, 0, 0}, uid, win, port, ""} + } + assertDetectTrzsz("::TRZSZ:TRANSFER:"+"R:1.0.0:0", false, + newTrigger100('R', "0", false, 0)) + assertDetectTrzsz("ABC::TRZSZ:TRANSFER:"+"D:1.0.0:123", false, + newTrigger100('D', "123", false, 0)) + assertDetectTrzsz("\x1b7\x07::TRZSZ:TRANSFER:"+"S:1.0.0:1", false, + newTrigger100('S', "1", true, 0)) + assertDetectTrzsz("\x1b7\x07::TRZSZ:TRANSFER:"+"S:1.0.0:1:1234", false, + newTrigger100('S', "1", true, 1234)) + assertDetectTrzsz("XYX\x1b7\x07::TRZSZ:TRANSFER:"+"S:1.0.0:1:7890", false, + newTrigger100('S', "1", true, 7890)) + assertDetectTrzsz("\x1b7\x07::TRZSZ:TRANSFER:"+"S:1.0.0:1:1337:1234", false, + newTrigger100('S', "1", true, 1337)) + assertDetectTrzsz("XYX\x1b7\x07::TRZSZ:TRANSFER:"+"S:1.0.0:1:1337:7890", false, + newTrigger100('S', "1", true, 1337)) + assertDetectTrzsz("\x1b7\x07::TRZSZ:TRANSFER:"+"S:1.0.0:1:1337:1234ABC\n", false, + newTrigger100('S', "1", true, 1337)) + assertDetectTrzsz("XYX\x1b7\x07::TRZSZ:TRANSFER:"+"S:1.0.0:1:1337:7890EFG\r\n", false, + newTrigger100('S', "1", true, 1337)) // repeated trigger uniqueID := time.Now().UnixMilli() % 10e10 - assertDetectTrzsz(fmt.Sprintf("::TRZSZ:TRANSFER:R:1.1.0:%013d", uniqueID*100+10), &R, &trzszVersion{1, 1, 0}, true) + newTrigger110 := func(mode byte, uid int64, win bool, port int) *trzszTrigger { + return &trzszTrigger{mode, &trzszVersion{1, 1, 0}, fmt.Sprintf("%013d", uid), win, port, ""} + } + assertDetectTrzsz(fmt.Sprintf("::TRZSZ:TRANSFER:R:1.1.0:%013d", uniqueID*100+10), false, + newTrigger110('R', uniqueID*100+10, true, 0)) for i := 0; i <= 100; i++ { - assertDetectTrzsz(fmt.Sprintf("::TRZSZ:TRANSFER:R:1.1.0:%013d", i*100+10), &R, &trzszVersion{1, 1, 0}, true) - assertDetectTrzsz(fmt.Sprintf("::TRZSZ:TRANSFER:R:1.1.0:%013d", i*100+10), nil, nil, false) + assertDetectTrzsz(fmt.Sprintf("::TRZSZ:TRANSFER:R:1.1.0:%013d", i*100+10), false, + newTrigger110('R', int64(i*100+10), true, 0)) + assertDetectTrzsz(fmt.Sprintf("::TRZSZ:TRANSFER:R:1.1.0:%013d", i*100+10), false, nil) if i > 0 { - assertDetectTrzsz(fmt.Sprintf("::TRZSZ:TRANSFER:R:1.1.0:%013d", (i-1)*100+10), nil, nil, false) + assertDetectTrzsz(fmt.Sprintf("::TRZSZ:TRANSFER:R:1.1.0:%013d", (i-1)*100+10), false, nil) } } for i := 0; i < 49; i++ { - assertDetectTrzsz(fmt.Sprintf("::TRZSZ:TRANSFER:R:1.1.0:%013d", i*100+10), &R, &trzszVersion{1, 1, 0}, true) - assertDetectTrzsz(fmt.Sprintf("::TRZSZ:TRANSFER:R:1.1.0:%013d", i*100+10), nil, nil, false) + assertDetectTrzsz(fmt.Sprintf("::TRZSZ:TRANSFER:R:1.1.0:%013d", i*100+10), false, + newTrigger110('R', int64(i*100+10), true, 0)) + assertDetectTrzsz(fmt.Sprintf("::TRZSZ:TRANSFER:R:1.1.0:%013d", i*100+10), false, nil) if i > 0 { - assertDetectTrzsz(fmt.Sprintf("::TRZSZ:TRANSFER:R:1.1.0:%013d", (i-1)*100+10), nil, nil, false) + assertDetectTrzsz(fmt.Sprintf("::TRZSZ:TRANSFER:R:1.1.0:%013d", (i-1)*100+10), false, nil) } } - // ignore tmux control mode - assertDetectTrzsz("%output %1 \x1b7\x07::TRZSZ:TRANSFER:"+"R:1.0.0:0ABC", nil, nil, false) - assertDetectTrzsz("%output %23 \x1b7\x07::TRZSZ:TRANSFER:"+"R:1.0.0:0ABC", nil, nil, false) - assertDetectTrzsz("%extended-output %0 0 : \x1b7\x07::TRZSZ:TRANSFER:"+"R:1.0.0:0ABC", nil, nil, false) - assertDetectTrzsz("%extended-output %10 0 : \x1b7\x07::TRZSZ:TRANSFER:"+"R:1.0.0:0ABC", nil, nil, false) - - assertDetectTrzsz("%output %x \x1b7\x07::TRZSZ:TRANSFER:"+"R:1.0.0:0ABC", &R, &trzszVersion{1, 0, 0}, false) - assertDetectTrzsz("%output 1 \x1b7\x07::TRZSZ:TRANSFER:"+"R:1.0.0:0ABC", &R, &trzszVersion{1, 0, 0}, false) - assertDetectTrzsz("%output % \x1b7\x07::TRZSZ:TRANSFER:"+"R:1.0.0:0ABC", &R, &trzszVersion{1, 0, 0}, false) - assertDetectTrzsz("output %1 \x1b7\x07::TRZSZ:TRANSFER:"+"R:1.0.0:0ABC", &R, &trzszVersion{1, 0, 0}, false) - - assertDetectTrzsz("%extended-output %a 0 : \x1b7\x07::TRZSZ:TRANSFER:"+"R:1.0.0:0ABC", &R, &trzszVersion{1, 0, 0}, false) - assertDetectTrzsz("%extended-output %0 b : \x1b7\x07::TRZSZ:TRANSFER:"+"R:1.0.0:0ABC", &R, &trzszVersion{1, 0, 0}, false) - assertDetectTrzsz("extended-output %0 0 : \x1b7\x07::TRZSZ:TRANSFER:"+"R:1.0.0:0ABC", &R, &trzszVersion{1, 0, 0}, false) - assertDetectTrzsz("%extended-output 0 0 : \x1b7\x07::TRZSZ:TRANSFER:"+"R:1.0.0:0ABC", &R, &trzszVersion{1, 0, 0}, false) - assertDetectTrzsz("%extended-output % 0 : \x1b7\x07::TRZSZ:TRANSFER:"+"R:1.0.0:0ABC", &R, &trzszVersion{1, 0, 0}, false) - assertDetectTrzsz("%extended-output %0 0 \x1b7\x07::TRZSZ:TRANSFER:"+"R:1.0.0:0ABC", &R, &trzszVersion{1, 0, 0}, false) + // tmux control mode + newTriggerTMUX := func(prefix string, port int) *trzszTrigger { + return &trzszTrigger{'R', &trzszVersion{1, 0, 0}, "0", false, port, prefix} + } + assertDetectTrzsz("%output %1 \x1b7\x07::TRZSZ:TRANSFER:"+"R:1.0.0:0ABC", false, nil) + assertDetectTrzsz("%output %23 \x1b7\x07::TRZSZ:TRANSFER:"+"R:1.0.0:0ABC", false, nil) + assertDetectTrzsz("%extended-output %0 0 : \x1b7\x07::TRZSZ:TRANSFER:"+"R:1.0.0:0ABC", false, nil) + assertDetectTrzsz("%extended-output %10 0 : \x1b7\x07::TRZSZ:TRANSFER:"+"R:1.0.0:0ABC", false, nil) + + assertDetectTrzsz("%output %1 \x1b7\x07::TRZSZ:TRANSFER:"+"R:1.0.0:0:1337ABC", false, nil) + assertDetectTrzsz("%extended-output %0 0 : \x1b7\x07::TRZSZ:TRANSFER:"+"R:1.0.0:0:1337ABC", false, nil) + assertDetectTrzsz("%output %1 \x1b7\x07::TRZSZ:TRANSFER:"+"R:1.0.0:0:1337ABC", true, + newTriggerTMUX("%output %1 ", 1337)) + assertDetectTrzsz("%extended-output %0 0 : \x1b7\x07::TRZSZ:TRANSFER:"+"R:1.0.0:0:1337ABC", true, + newTriggerTMUX("%extended-output %0 0 : ", 1337)) + + tmuxTrigger := newTriggerTMUX("", 0) + assertDetectTrzsz("%output %x \x1b7\x07::TRZSZ:TRANSFER:"+"R:1.0.0:0ABC", false, tmuxTrigger) + assertDetectTrzsz("%output 1 \x1b7\x07::TRZSZ:TRANSFER:"+"R:1.0.0:0ABC", false, tmuxTrigger) + assertDetectTrzsz("%output % \x1b7\x07::TRZSZ:TRANSFER:"+"R:1.0.0:0ABC", false, tmuxTrigger) + assertDetectTrzsz("output %1 \x1b7\x07::TRZSZ:TRANSFER:"+"R:1.0.0:0ABC", false, tmuxTrigger) + + assertDetectTrzsz("%extended-output %a 0 : \x1b7\x07::TRZSZ:TRANSFER:"+"R:1.0.0:0ABC", false, tmuxTrigger) + assertDetectTrzsz("%extended-output %0 b : \x1b7\x07::TRZSZ:TRANSFER:"+"R:1.0.0:0ABC", false, tmuxTrigger) + assertDetectTrzsz("extended-output %0 0 : \x1b7\x07::TRZSZ:TRANSFER:"+"R:1.0.0:0ABC", false, tmuxTrigger) + assertDetectTrzsz("%extended-output 0 0 : \x1b7\x07::TRZSZ:TRANSFER:"+"R:1.0.0:0ABC", false, tmuxTrigger) + assertDetectTrzsz("%extended-output % 0 : \x1b7\x07::TRZSZ:TRANSFER:"+"R:1.0.0:0ABC", false, tmuxTrigger) + assertDetectTrzsz("%extended-output %0 0 \x1b7\x07::TRZSZ:TRANSFER:"+"R:1.0.0:0ABC", false, tmuxTrigger) } func TestRelayDetector(t *testing.T) { assert := assert.New(t) detector := newTrzszDetector(true, true) - R := byte('R') prefix := "\x1b7\x07::TRZSZ:TRANSFER:R:1.0.0" - assertRewriteEqual := func(output, expected string, mode *byte, win bool) { + assertRewriteEqual := func(output, expected string, trigger *trzszTrigger) { t.Helper() detector.uniqueIDMap = make(map[string]int) // ignore unique check - buf, m, v, w := detector.detectTrzsz([]byte(prefix + output)) + buf, t := detector.detectTrzsz([]byte(prefix+output), false) assert.Equal([]byte(prefix+expected), buf) - assert.Equal(mode, m) - assert.Equal(win, w) - if mode != nil { - assert.Equal(&trzszVersion{1, 0, 0}, v) - } + assert.Equal(t, trigger) + } + newTrigger := func(uid string, win bool, port int) *trzszTrigger { + return &trzszTrigger{'R', &trzszVersion{1, 0, 0}, uid, win, port, ""} } - assertRewriteEqual(":0", ":0#R", &R, false) - assertRewriteEqual(":1", ":1#R", &R, true) - assertRewriteEqual(":0\n", ":0#R\n", &R, false) - assertRewriteEqual(":1\r\n", ":1#R\r\n", &R, true) - - assertRewriteEqual(":1234567890110", ":1234567890110#R", &R, true) - assertRewriteEqual(":9876543210210", ":9876543210210#R", &R, true) - assertRewriteEqual(":1234567890110\n", ":1234567890110#R\n", &R, true) - assertRewriteEqual(":9876543210210\r\n", ":9876543210210#R\r\n", &R, true) - assertRewriteEqual(":1234567890110#R\n", ":1234567890110#R#R\n", &R, true) - assertRewriteEqual(":9876543210210#R\r\n", ":9876543210210#R#R\r\n", &R, true) - - assertRewriteEqual(":123456789\n0100", ":123456789#R\n0100", &R, false) - assertRewriteEqual(":123456789\r\n0200", ":123456789#R\r\n0200", &R, false) - assertRewriteEqual(":123456789\n0100\n", ":123456789#R\n0100\n", &R, false) - assertRewriteEqual(":123456789\r\n0200\r\n", ":123456789#R\r\n0200\r\n", &R, false) - - assertRewriteEqual(":1234567890100", ":1234567890120#R", &R, false) - assertRewriteEqual(":9876543210200", ":9876543210220#R", &R, false) - assertRewriteEqual(":1234567890100\n", ":1234567890120#R\n", &R, false) - assertRewriteEqual(":9876543210200\r\n", ":9876543210220#R\r\n", &R, false) - assertRewriteEqual(":1234567890100#R\n", ":1234567890120#R#R\n", &R, false) - assertRewriteEqual(":9876543210200#R\r\n", ":9876543210220#R#R\r\n", &R, false) + assertRewriteEqual(":0", ":0#R", newTrigger("0", false, 0)) + assertRewriteEqual(":1", ":1#R", newTrigger("1", true, 0)) + assertRewriteEqual(":0\n", ":0#R\n", newTrigger("0", false, 0)) + assertRewriteEqual(":1\r\n", ":1#R\r\n", newTrigger("1", true, 0)) + + assertRewriteEqual(":0:1234", ":0:1234#R", newTrigger("0", false, 1234)) + assertRewriteEqual(":1:1234", ":1:1234#R", newTrigger("1", true, 1234)) + assertRewriteEqual(":0:1337\n", ":0:1337#R\n", newTrigger("0", false, 1337)) + assertRewriteEqual(":1:1337\r\n", ":1:1337#R\r\n", newTrigger("1", true, 1337)) + + assertRewriteEqual(":1234567890110", ":1234567890110#R", + newTrigger("1234567890110", true, 0)) + assertRewriteEqual(":9876543210210", ":9876543210210#R", + newTrigger("9876543210210", true, 0)) + assertRewriteEqual(":1234567890110\n", ":1234567890110#R\n", + newTrigger("1234567890110", true, 0)) + assertRewriteEqual(":9876543210210\r\n", ":9876543210210#R\r\n", + newTrigger("9876543210210", true, 0)) + assertRewriteEqual(":1234567890110#R\n", ":1234567890110#R#R\n", + newTrigger("1234567890110", true, 0)) + assertRewriteEqual(":9876543210210#R\r\n", ":9876543210210#R#R\r\n", + newTrigger("9876543210210", true, 0)) + + assertRewriteEqual(":1234567890110:12345", ":1234567890110:12345#R", + newTrigger("1234567890110", true, 12345)) + assertRewriteEqual(":9876543210210:12345", ":9876543210210:12345#R", + newTrigger("9876543210210", true, 12345)) + assertRewriteEqual(":1234567890110:12345\n", ":1234567890110:12345#R\n", + newTrigger("1234567890110", true, 12345)) + assertRewriteEqual(":9876543210210:12345\r\n", ":9876543210210:12345#R\r\n", + newTrigger("9876543210210", true, 12345)) + assertRewriteEqual(":1234567890110:12345#R\n", ":1234567890110:12345#R#R\n", + newTrigger("1234567890110", true, 12345)) + assertRewriteEqual(":9876543210210:12345#R\r\n", ":9876543210210:12345#R#R\r\n", + newTrigger("9876543210210", true, 12345)) + + assertRewriteEqual(":123456789\n0100", ":123456789#R\n0100", + newTrigger("123456789", false, 0)) + assertRewriteEqual(":123456789\r\n0200", ":123456789#R\r\n0200", + newTrigger("123456789", false, 0)) + assertRewriteEqual(":123456789\n0100\n", ":123456789#R\n0100\n", + newTrigger("123456789", false, 0)) + assertRewriteEqual(":123456789\r\n0200\r\n", ":123456789#R\r\n0200\r\n", + newTrigger("123456789", false, 0)) + + assertRewriteEqual(":123456789:1223\n0100", ":123456789:1223#R\n0100", + newTrigger("123456789", false, 1223)) + assertRewriteEqual(":123456789:1223\r\n0200", ":123456789:1223#R\r\n0200", + newTrigger("123456789", false, 1223)) + assertRewriteEqual(":123456789:1223\n0100\n", ":123456789:1223#R\n0100\n", + newTrigger("123456789", false, 1223)) + assertRewriteEqual(":123456789:1223\r\n0200\r\n", ":123456789:1223#R\r\n0200\r\n", + newTrigger("123456789", false, 1223)) + + assertRewriteEqual(":1234567890100", ":1234567890120#R", + newTrigger("1234567890120", false, 0)) + assertRewriteEqual(":9876543210200", ":9876543210220#R", + newTrigger("9876543210220", false, 0)) + assertRewriteEqual(":1234567890100\n", ":1234567890120#R\n", + newTrigger("1234567890120", false, 0)) + assertRewriteEqual(":9876543210200\r\n", ":9876543210220#R\r\n", + newTrigger("9876543210220", false, 0)) + assertRewriteEqual(":1234567890100#R\n", ":1234567890120#R#R\n", + newTrigger("1234567890120", false, 0)) + assertRewriteEqual(":9876543210200#R\r\n", ":9876543210220#R#R\r\n", + newTrigger("9876543210220", false, 0)) + + assertRewriteEqual(":1234567890100:333", ":1234567890120:333#R", + newTrigger("1234567890120", false, 333)) + assertRewriteEqual(":9876543210200:333", ":9876543210220:333#R", + newTrigger("9876543210220", false, 333)) + assertRewriteEqual(":1234567890100:333\n", ":1234567890120:333#R\n", + newTrigger("1234567890120", false, 333)) + assertRewriteEqual(":9876543210200:333\r\n", ":9876543210220:333#R\r\n", + newTrigger("9876543210220", false, 333)) + assertRewriteEqual(":1234567890100:333#R\n", ":1234567890120:333#R#R\n", + newTrigger("1234567890120", false, 333)) + assertRewriteEqual(":9876543210200:333#R\r\n", ":9876543210220:333#R#R\r\n", + newTrigger("9876543210220", false, 333)) assertRewriteEqual(":1234567890100\n"+prefix+":9876543210200\r\n", - ":1234567890120\n"+prefix+":9876543210220#R\r\n", &R, false) + ":1234567890120\n"+prefix+":9876543210220#R\r\n", newTrigger("9876543210220", false, 0)) + assertRewriteEqual(":1234567890100\n"+prefix+":9876543210200:8\r\n", + ":1234567890120\n"+prefix+":9876543210220:8#R\r\n", newTrigger("9876543210220", false, 8)) assertRewriteEqual(":1234567890100\n"+prefix+":9876543210200\r\n::TRZSZ:TRANSFER:R:", - ":1234567890120\n"+prefix+":9876543210220\r\n::TRZSZ:TRANSFER:R:", nil, false) + ":1234567890120\n"+prefix+":9876543210220\r\n::TRZSZ:TRANSFER:R:", nil) } func TestFormatSavedFileNames(t *testing.T) { diff --git a/trzsz/filter.go b/trzsz/filter.go index 80c232e..99b7bc5 100644 --- a/trzsz/filter.go +++ b/trzsz/filter.go @@ -28,9 +28,12 @@ import ( "bufio" "fmt" "io" + "net" "os" "os/exec" "path/filepath" + "regexp" + "strconv" "strings" "sync" "sync/atomic" @@ -62,8 +65,7 @@ type TrzszFilter struct { transfer atomic.Pointer[trzszTransfer] progress atomic.Pointer[textProgressBar] promptPipe atomic.Pointer[io.PipeWriter] - serverVersion *trzszVersion - remoteIsWindows bool + trigger *trzszTrigger dragging atomic.Bool dragHasDir atomic.Bool dragMutex sync.Mutex @@ -73,6 +75,7 @@ type TrzszFilter struct { logger *traceLogger defaultUploadPath string defaultDownloadPath string + tunnelConnector func(int) net.Conn } // NewTrzszFilter create a TrzszFilter to support trzsz ( trz / tsz ). @@ -159,6 +162,11 @@ func (filter *TrzszFilter) SetDefaultDownloadPath(downloadPath string) { filter.defaultDownloadPath = downloadPath } +// SetTunnelConnector set the connector for tunnel transferring. +func (filter *TrzszFilter) SetTunnelConnector(connector func(int) net.Conn) { + filter.tunnelConnector = connector +} + func (filter *TrzszFilter) getTrzszConfig(name string) string { home, err := os.UserHomeDir() if err != nil { @@ -288,7 +296,7 @@ func (filter *TrzszFilter) chooseUploadPaths(directory bool) ([]string, error) { func (filter *TrzszFilter) downloadFiles(transfer *trzszTransfer) error { path, err := filter.chooseDownloadPath() if err == zenity.ErrCanceled { - return transfer.sendAction(false, filter.serverVersion, filter.remoteIsWindows) + return transfer.sendAction(false, filter.trigger.version, filter.trigger.winServer) } if err != nil { return err @@ -301,7 +309,7 @@ func (filter *TrzszFilter) downloadFiles(transfer *trzszTransfer) error { return simpleTrzszError("Swap transfer failed") } - if err := transfer.sendAction(true, filter.serverVersion, filter.remoteIsWindows); err != nil { + if err := transfer.sendAction(true, filter.trigger.version, filter.trigger.winServer); err != nil { return err } config, err := transfer.recvConfig() @@ -311,7 +319,8 @@ func (filter *TrzszFilter) downloadFiles(transfer *trzszTransfer) error { filter.progress.Store(nil) if !config.Quiet { - filter.progress.Store(newTextProgressBar(filter.clientOut, filter.options.TerminalColumns, config.TmuxPaneColumns)) + filter.progress.Store(newTextProgressBar(filter.clientOut, filter.options.TerminalColumns, + config.TmuxPaneColumns, filter.trigger.tmuxPrefix)) defer filter.progress.Store(nil) } @@ -326,7 +335,7 @@ func (filter *TrzszFilter) downloadFiles(transfer *trzszTransfer) error { func (filter *TrzszFilter) uploadFiles(transfer *trzszTransfer, directory bool) error { paths, err := filter.chooseUploadPaths(directory) if err == zenity.ErrCanceled { - return transfer.sendAction(false, filter.serverVersion, filter.remoteIsWindows) + return transfer.sendAction(false, filter.trigger.version, filter.trigger.winServer) } if err != nil { return err @@ -340,7 +349,7 @@ func (filter *TrzszFilter) uploadFiles(transfer *trzszTransfer, directory bool) return simpleTrzszError("Swap transfer failed") } - if err := transfer.sendAction(true, filter.serverVersion, filter.remoteIsWindows); err != nil { + if err := transfer.sendAction(true, filter.trigger.version, filter.trigger.winServer); err != nil { return err } config, err := transfer.recvConfig() @@ -356,7 +365,8 @@ func (filter *TrzszFilter) uploadFiles(transfer *trzszTransfer, directory bool) filter.progress.Store(nil) if !config.Quiet { - filter.progress.Store(newTextProgressBar(filter.clientOut, filter.options.TerminalColumns, config.TmuxPaneColumns)) + filter.progress.Store(newTextProgressBar(filter.clientOut, filter.options.TerminalColumns, + config.TmuxPaneColumns, filter.trigger.tmuxPrefix)) defer filter.progress.Store(nil) } @@ -367,10 +377,15 @@ func (filter *TrzszFilter) uploadFiles(transfer *trzszTransfer, directory bool) return transfer.clientExit(formatSavedFiles(remoteNames, "")) } -func (filter *TrzszFilter) handleTrzsz(mode byte) { - transfer := newTransfer(filter.serverIn, nil, isWindowsEnvironment() || filter.remoteIsWindows, filter.logger) +func (filter *TrzszFilter) handleTrzsz() { + transfer := newTransfer(filter.serverIn, nil, isWindowsEnvironment() || filter.trigger.winServer, filter.logger) + + if filter.tunnelConnector != nil { + transfer.connectToTunnel(filter.tunnelConnector, filter.trigger.uniqueID, filter.trigger.tunnelPort) + } defer func() { + transfer.cleanup() filter.transfer.CompareAndSwap(transfer, nil) }() @@ -381,7 +396,7 @@ func (filter *TrzszFilter) handleTrzsz(mode byte) { }() var err error - switch mode { + switch filter.trigger.mode { case 'S': err = filter.downloadFiles(transfer) case 'R': @@ -445,38 +460,60 @@ func (filter *TrzszFilter) uploadDragFiles() { filter.resetDragFiles() } +var tmuxInputRegexp = regexp.MustCompile(`send -(l?)t %\d+ (.*?)[;\r]`) + func (filter *TrzszFilter) transformPromptInput(promptPipe *io.PipeWriter, buf []byte) { + if len(buf) > 6 { + var input []byte + for _, match := range tmuxInputRegexp.FindAllSubmatch(buf, -1) { + if len(match) == 3 { + if len(match[1]) == 1 { + input = append(input, match[2]...) + continue + } + for _, hex := range strings.Fields(string(match[2])) { + if strings.HasPrefix(hex, "0x") { + if char, err := strconv.ParseInt(hex[2:], 16, 32); err == nil { + input = append(input, byte(char)) + } + } + } + } + } + buf = input + } + const keyPrev = '\x10' const keyNext = '\x0E' - n := len(buf) - for i := 0; i < n; i++ { - c := buf[i] - if c == '\x1b' && n-i > 2 && buf[i+1] == '[' { - switch buf[i+2] { - case '\x42': // ↓ to Next - c = keyNext - case '\x41', '\x5A': // ↑ Shift-TAB to Prev - c = keyPrev - } - i += 2 - } else { - switch c { - case '\x03': // Ctrl-C to Stop - _, _ = promptPipe.Write([]byte{keyPrev, keyPrev, '\r'}) - return - case 'q', 'Q', '\x11': // q Ctrl-C Ctrl-Q to Quit - _, _ = promptPipe.Write([]byte{keyNext, keyNext, '\r'}) - return - case '\t', '\x0E', 'j', 'J', '\x0A': // Tab ↓ j Ctrl-J to Next - c = keyNext - case '\x10', 'k', 'K', '\x0B': // ↑ k Ctrl-K to Prev - c = keyPrev - case '\r': // Enter - default: - continue - } + const keyEnter = '\r' + moveNext := func() { _, _ = promptPipe.Write([]byte{keyNext}) } + movePrev := func() { _, _ = promptPipe.Write([]byte{keyPrev}) } + stop := func() { _, _ = promptPipe.Write([]byte{keyPrev, keyPrev, keyEnter}) } + quit := func() { _, _ = promptPipe.Write([]byte{keyNext, keyNext, keyEnter}) } + confirm := func() { _, _ = promptPipe.Write([]byte{keyEnter}) } + + if len(buf) == 3 && buf[0] == '\x1b' && buf[1] == '[' { + switch buf[2] { + case '\x42': // ↓ to Next + moveNext() + case '\x41', '\x5A': // ↑ Shift-TAB to Prev + movePrev() + } + } + + if len(buf) == 1 { + switch buf[0] { + case '\x03': // Ctrl-C to Stop + stop() + case 'q', 'Q', '\x11': // q Ctrl-C Ctrl-Q to Quit + quit() + case '\t', '\x0E', 'j', 'J', '\x0A': // Tab ↓ j Ctrl-J to Next + moveNext() + case '\x10', 'k', 'K', '\x0B': // ↑ k Ctrl-K to Prev + movePrev() + case '\r': // Enter + confirm() } - _, _ = promptPipe.Write([]byte{c}) } } @@ -495,14 +532,15 @@ func (filter *TrzszFilter) confirmStopTransfer(transfer *trzszTransfer) { defer pipeOut.Close() defer filter.promptPipe.Store(nil) + writer := &promptWriter{filter.trigger.tmuxPrefix, filter.clientOut} if progress := filter.progress.Load(); progress != nil { progress.setPause(true) defer func() { progress.setTerminalColumns(filter.options.TerminalColumns) progress.setPause(false) }() - time.Sleep(50 * time.Millisecond) // wait for the progress bar output - _, _ = filter.clientOut.Write([]byte("\r\n")) // keep the progress bar displayed + time.Sleep(50 * time.Millisecond) // wait for the progress bar output + _, _ = writer.Write([]byte("\r\n")) // keep the progress bar displayed } prompt := promptui.Select{ @@ -513,7 +551,7 @@ func (filter *TrzszFilter) confirmStopTransfer(transfer *trzszTransfer) { "Continue to transfer remaining files", }, Stdin: pipeIn, - Stdout: filter.clientOut, + Stdout: writer, Templates: &promptui.SelectTemplates{ Help: `{{ "Use ↓ ↑ j k to navigate" | faint }}`, }, @@ -533,6 +571,8 @@ func (filter *TrzszFilter) confirmStopTransfer(transfer *trzszTransfer) { }() } +var ctrlCRegexp = regexp.MustCompile(`^send -t %\d+ 0x3\r$`) + func (filter *TrzszFilter) sendInput(buf []byte) { if filter.logger != nil { filter.logger.writeTraceLog(buf, "stdin") @@ -542,8 +582,9 @@ func (filter *TrzszFilter) sendInput(buf []byte) { return } if transfer := filter.transfer.Load(); transfer != nil { - if buf[0] == '\x03' { // `ctrl + c` to stop transferring files - if filter.serverVersion.compare(&trzszVersion{1, 1, 3}) > 0 { + if len(buf) == 1 && buf[0] == '\x03' || len(buf) > 14 && ctrlCRegexp.Match(buf) { + // `ctrl + c` to stop transferring files + if filter.trigger.version.compare(&trzszVersion{1, 1, 3}) > 0 { filter.confirmStopTransfer(transfer) } else { transfer.stopTransferringFiles(false) @@ -589,23 +630,20 @@ func (filter *TrzszFilter) wrapOutput() { n, err := filter.serverOut.Read(buffer) if n > 0 { buf := buffer[0:n] - if filter.logger != nil { - buf = filter.logger.writeTraceLog(buf, "svrout") - } if transfer := filter.transfer.Load(); transfer != nil { - transfer.addReceivedData(buf) + transfer.addReceivedData(buf, false) buffer = make([]byte, bufSize) continue } - var win bool - var mode *byte - var ver *trzszVersion - buf, mode, ver, win = detector.detectTrzsz(buf) - if mode != nil { + if filter.logger != nil { + buf = filter.logger.writeTraceLog(buf, "svrout") + } + var trigger *trzszTrigger + buf, trigger = detector.detectTrzsz(buf, filter.tunnelConnector != nil) + if trigger != nil { _ = writeAll(filter.clientOut, buf) - filter.serverVersion = ver - filter.remoteIsWindows = win - go filter.handleTrzsz(*mode) + filter.trigger = trigger + go filter.handleTrzsz() continue } if filter.interrupting.Load() { diff --git a/trzsz/progress.go b/trzsz/progress.go index dabc4ec..02438ef 100644 --- a/trzsz/progress.go +++ b/trzsz/progress.go @@ -154,13 +154,14 @@ type textProgressBar struct { timeArray [kSpeedArraySize]*time.Time stepArray [kSpeedArraySize]int64 pausing atomic.Bool + tmuxPrefix string } -func newTextProgressBar(writer io.Writer, columns int32, tmuxPaneColumns int32) *textProgressBar { +func newTextProgressBar(writer io.Writer, columns int32, tmuxPaneColumns int32, tmuxPrefix string) *textProgressBar { if tmuxPaneColumns > 1 { columns = tmuxPaneColumns - 1 // -1 to avoid messing up the tmux pane } - progress := &textProgressBar{writer: writer, firstWrite: true} + progress := &textProgressBar{writer: writer, firstWrite: true, tmuxPrefix: tmuxPrefix} progress.columns.Store(columns) progress.tmuxPaneColumns.Store(tmuxPaneColumns) return progress @@ -247,6 +248,14 @@ func (p *textProgressBar) setPause(pausing bool) { p.pausing.Store(pausing) } +func (p *textProgressBar) writeProgress(progress string) { + data := []byte(progress) + if p.tmuxPrefix != "" { + data = encodeTmuxOutput(p.tmuxPrefix, data) + } + _ = writeAll(p.writer, data) +} + func (p *textProgressBar) showProgress() { now := timeNowFunc() if p.lastUpdateTime != nil && now.Sub(*p.lastUpdateTime) < 200*time.Millisecond { @@ -270,14 +279,14 @@ func (p *textProgressBar) showProgress() { if p.firstWrite { p.firstWrite = false - _ = writeAll(p.writer, []byte(progressText)) + p.writeProgress(progressText) return } if p.tmuxPaneColumns.Load() > 0 { - _ = writeAll(p.writer, []byte(fmt.Sprintf("\x1b[%dD%s", p.columns.Load(), progressText))) + p.writeProgress(fmt.Sprintf("\x1b[%dD%s", p.columns.Load(), progressText)) } else { - _ = writeAll(p.writer, []byte(fmt.Sprintf("\r%s", progressText))) + p.writeProgress(fmt.Sprintf("\r%s", progressText)) } } diff --git a/trzsz/progress_test.go b/trzsz/progress_test.go index f3ad6c8..01ccc62 100644 --- a/trzsz/progress_test.go +++ b/trzsz/progress_test.go @@ -59,7 +59,7 @@ func TestProgressWithEmptyFile(t *testing.T) { writer := newTestWriter(t) callTimeNowCount := mockTimeNow([]int64{1646564135000, 1646564135000}, 0) - progress := newTextProgressBar(writer, 100, 0) + progress := newTextProgressBar(writer, 100, 0, "") progress.onNum(1) progress.onName("中文😀test.txt") progress.onSize(0) @@ -75,7 +75,7 @@ func TestProgressZeroStep(t *testing.T) { writer := newTestWriter(t) callTimeNowCount := mockTimeNow([]int64{1646564135000, 1646564135100}, 0) - progress := newTextProgressBar(writer, 100, 0) + progress := newTextProgressBar(writer, 100, 0, "") progress.onNum(1) progress.onName("中文😀test.txt") progress.onSize(100) @@ -91,7 +91,7 @@ func TestProgressLastStep(t *testing.T) { writer := newTestWriter(t) callTimeNowCount := mockTimeNow([]int64{1646564135000, 1646564135200}, 0) - progress := newTextProgressBar(writer, 100, 0) + progress := newTextProgressBar(writer, 100, 0, "") progress.onNum(1) progress.onName("中文😀test.txt") progress.onSize(100) @@ -107,7 +107,7 @@ func TestProgressWithSpeedAndEta(t *testing.T) { writer := newTestWriter(t) callTimeNowCount := mockTimeNow([]int64{1646564135000, 1646564135100}, 0) - progress := newTextProgressBar(writer, 100, 0) + progress := newTextProgressBar(writer, 100, 0, "") progress.onNum(1) progress.onName("中文😀test.txt") progress.onSize(100) @@ -128,7 +128,7 @@ func TestProgressNewestSpeed(t *testing.T) { } callTimeNowCount := mockTimeNow(mockTimes, 0) - progress := newTextProgressBar(writer, 100, 0) + progress := newTextProgressBar(writer, 100, 0, "") progress.onNum(1) progress.onName("中文😀test.txt") progress.onSize(100000) @@ -187,7 +187,7 @@ func TestProgressReduceOutput(t *testing.T) { writer := newTestWriter(t) callTimeNowCount := mockTimeNow([]int64{1646564135000, 1646564135001, 1646564135099}, 0) - progress := newTextProgressBar(writer, 100, 0) + progress := newTextProgressBar(writer, 100, 0, "") progress.onNum(1) progress.onName("中文😀test.txt") progress.onSize(100) @@ -204,7 +204,7 @@ func TestProgressFastSpeed(t *testing.T) { writer := newTestWriter(t) callTimeNowCount := mockTimeNow([]int64{1646564135000, 1646564136000}, 0) - progress := newTextProgressBar(writer, 100, 0) + progress := newTextProgressBar(writer, 100, 0, "") progress.onNum(1) progress.onName("中文😀test.txt") progress.onSize(1125899906842624) @@ -220,7 +220,7 @@ func TestProgressSlowSpeed(t *testing.T) { writer := newTestWriter(t) callTimeNowCount := mockTimeNow([]int64{1646564135000, 1646564136000}, 0) - progress := newTextProgressBar(writer, 100, 0) + progress := newTextProgressBar(writer, 100, 0, "") progress.onNum(1) progress.onName("中文😀test.txt") progress.onSize(1024 * 1024) @@ -236,7 +236,7 @@ func TestProgressLongFileName(t *testing.T) { writer := newTestWriter(t) callTimeNowCount := mockTimeNow([]int64{1646564135000, 1646564136000, 1646564138000}, 0) - progress := newTextProgressBar(writer, 110, 0) + progress := newTextProgressBar(writer, 110, 0, "") progress.onNum(1) progress.onName("中文😀非常长非常长非常长非常长非常长非常长非常长非常长.txt") progress.onSize(1024 * 1024) @@ -257,7 +257,7 @@ func TestProgressWithoutTotalSize(t *testing.T) { writer := newTestWriter(t) callTimeNowCount := mockTimeNow([]int64{1646564135000, 1646564136000, 1646564138000}, 0) - progress := newTextProgressBar(writer, 95, 0) + progress := newTextProgressBar(writer, 95, 0, "") progress.onNum(1) progress.onName("中文😀非常长非常长非常长非常长非常长非常长非常长非常长.txt") progress.onSize(1000 * 1024 * 1024 * 1024) @@ -276,7 +276,7 @@ func TestProgressWithoutSpeedOrEta(t *testing.T) { writer := newTestWriter(t) callTimeNowCount := mockTimeNow([]int64{1646564135000, 1646564136000, 1646564138000}, 0) - progress := newTextProgressBar(writer, 70, 0) + progress := newTextProgressBar(writer, 70, 0, "") progress.onNum(1) progress.onName("中文😀longlonglonglonglonglongname.txt") progress.onSize(1000) @@ -295,7 +295,7 @@ func TestProgressWithoutFileName(t *testing.T) { writer := newTestWriter(t) callTimeNowCount := mockTimeNow([]int64{1646564135000, 1646564136000, 1646564138000}, 0) - progress := newTextProgressBar(writer, 48, 0) + progress := newTextProgressBar(writer, 48, 0, "") progress.onNum(1) progress.onName("中文😀llong文件名.txt") progress.onSize(1000) @@ -315,7 +315,7 @@ func TestProgressWithoutBar(t *testing.T) { writer := newTestWriter(t) callTimeNowCount := mockTimeNow([]int64{1646564135000, 1646564136000}, 0) - progress := newTextProgressBar(writer, 10, 0) + progress := newTextProgressBar(writer, 10, 0, "") progress.onNum(1) progress.onName("中文😀test.txt") progress.onSize(1000) @@ -332,7 +332,7 @@ func TestProgressWithMultiFiles(t *testing.T) { callTimeNowCount := mockTimeNow([]int64{1646564135000, 1646564136000, 1646564136000, 1646564137000, 1646564139000, 1646564139000}, 0) - progress := newTextProgressBar(writer, 100, 0) + progress := newTextProgressBar(writer, 100, 0, "") progress.onNum(2) progress.onName("中文😀test.txt") progress.onSize(1000) @@ -358,7 +358,7 @@ func TestProgressInTmuxPane(t *testing.T) { callTimeNowCount := mockTimeNow([]int64{1646564135000, 1646564136000, 1646564137000, 1646564137000, 1646564138000, 1646564139000, 1646564139000}, 0) - progress := newTextProgressBar(writer, 100, 80) + progress := newTextProgressBar(writer, 100, 80, "") progress.onNum(2) progress.onName("中文😀test.txt") progress.onSize(1000) diff --git a/trzsz/relay.go b/trzsz/relay.go index 79924d2..f06d21c 100644 --- a/trzsz/relay.go +++ b/trzsz/relay.go @@ -356,12 +356,11 @@ func (r *trzszRelay) wrapOutput() { continue } - var mode *byte - var serverIsWindows bool - buf, mode, _, serverIsWindows = detector.detectTrzsz(buf) - if mode != nil { + var trigger *trzszTrigger + buf, trigger = detector.detectTrzsz(buf, false) // TODO r.tunnelConnector != nil + if trigger != nil { r.relayStatus.Store(kRelayHandshaking) // store status before send to client - r.serverIsWindows = serverIsWindows + r.serverIsWindows = trigger.winServer go r.handshake() } diff --git a/trzsz/transfer.go b/trzsz/transfer.go index 147896f..3510d43 100644 --- a/trzsz/transfer.go +++ b/trzsz/transfer.go @@ -31,6 +31,7 @@ import ( "fmt" "io" "io/fs" + "net" "os" "path/filepath" "strconv" @@ -60,6 +61,7 @@ type transferAction struct { Protocol int `json:"protocol"` SupportBinary bool `json:"binary"` SupportDirectory bool `json:"support_dir"` + TunnelConnected bool `json:"tunnel"` } type transferConfig struct { @@ -101,6 +103,9 @@ type trzszTransfer struct { transferConfig transferConfig logger *traceLogger createdFiles []string + tunnelConnected bool + tunnelConn atomic.Pointer[net.Conn] + tunnelInitWG sync.WaitGroup } func maxDuration(a, b time.Duration) time.Duration { @@ -144,7 +149,112 @@ func newTransfer(writer io.Writer, stdinState *term.State, flushInTime bool, log return t } -func (t *trzszTransfer) addReceivedData(buf []byte) { +func getHelloConstant(uniqueID string, port int) (string, string) { + uid := uniqueID + if len(uid) > 2 { + uid = uid[:len(uid)-2] + } + clientHello := fmt.Sprintf("::TRZSZ::CLIENT::HELLO::%s:%d", uid, port) + serverHello := fmt.Sprintf("::TRZSZ::SERVER::HELLO::%s:%d", uid, port) + return clientHello, serverHello +} + +func (t *trzszTransfer) acceptOnTunnel(listener net.Listener, uniqueID string, port int) { + go func() { + defer listener.Close() + clientHello, serverHello := getHelloConstant(uniqueID, port) + for { + conn, err := listener.Accept() + if err != nil { + return + } + if t.tunnelConn.Load() != nil { + conn.Close() + return + } + go func(conn net.Conn) { + buf := make([]byte, 100) + n, err := conn.Read(buf) + if err != nil || string(buf[:n]) != clientHello { + conn.Close() + return + } + if _, err := conn.Write([]byte(serverHello)); err != nil { + conn.Close() + return + } + if t.tunnelConn.CompareAndSwap(nil, &conn) { + wrapTransferInput(t, conn, true) + listener.Close() + } + }(conn) + } + }() +} + +func (t *trzszTransfer) connectToTunnel(connector func(int) net.Conn, uniqueID string, port int) { + t.tunnelInitWG.Add(1) + go func() { + defer t.tunnelInitWG.Done() + + timeout := false + connChan := make(chan net.Conn, 1) + go func() { + defer close(connChan) + conn := connector(port) + if conn == nil { + connChan <- nil + return + } + if timeout { + conn.Close() + connChan <- nil + return + } + clientHello, serverHello := getHelloConstant(uniqueID, port) + if _, err := conn.Write([]byte(clientHello)); err != nil || timeout { + conn.Close() + connChan <- nil + return + } + buf := make([]byte, 100) + n, err := conn.Read(buf) + if err != nil || string(buf[:n]) != serverHello || timeout { + conn.Close() + connChan <- nil + return + } + connChan <- conn + }() + + select { + case conn := <-connChan: + if conn != nil { + t.tunnelConn.Store(&conn) + wrapTransferInput(t, conn, true) + } + case <-time.After(time.Second): + timeout = true + } + }() +} + +func (t *trzszTransfer) cleanup() { + if conn := t.tunnelConn.Load(); conn != nil { + (*conn).Close() + } +} + +func (t *trzszTransfer) addReceivedData(buf []byte, tunnel bool) { + if t.tunnelConnected && !tunnel { + if t.logger != nil { + t.logger.writeTraceLog(buf, "ignout") + } + return + } + if t.logger != nil { + t.logger.writeTraceLog(buf, "svrout") + } if !t.stopped.Load() { t.buffer.addBuffer(buf) } @@ -159,18 +269,20 @@ func (t *trzszTransfer) stopTransferringFiles(stopAndDelete bool) { t.stopped.Store(true) t.buffer.stopBuffer() - maxChunkTime := time.Duration(0) - for _, chunkTime := range t.lastChunkTimeArr { - if chunkTime > maxChunkTime { - maxChunkTime = chunkTime + if !t.tunnelConnected { + maxChunkTime := time.Duration(0) + for _, chunkTime := range t.lastChunkTimeArr { + if chunkTime > maxChunkTime { + maxChunkTime = chunkTime + } } + waitTime := maxChunkTime * 2 + beginTime := t.pauseBeginTime.Load() + if beginTime > 0 { + waitTime -= time.Since(time.UnixMilli(beginTime)) + } + t.cleanTimeout = maxDuration(waitTime, 500*time.Millisecond) } - waitTime := maxChunkTime * 2 - beginTime := t.pauseBeginTime.Load() - if beginTime > 0 { - waitTime -= time.Since(time.UnixMilli(beginTime)) - } - t.cleanTimeout = maxDuration(waitTime, 500*time.Millisecond) } func (t *trzszTransfer) pauseTransferringFiles() { @@ -260,7 +372,7 @@ func (t *trzszTransfer) recvLine(expectType string, mayHasJunk bool, timeout <-c return nil, err } - if isWindowsEnvironment() || t.windowsProtocol { + if !t.tunnelConnected && (isWindowsEnvironment() || t.windowsProtocol) { line, err := t.buffer.readLineOnWindows(timeout) if err != nil { if e := t.checkStop(); e != nil { @@ -280,7 +392,13 @@ func (t *trzszTransfer) recvLine(expectType string, mayHasJunk bool, timeout <-c return line, nil } - line, err := t.buffer.readLine(t.transferConfig.TmuxOutputJunk || mayHasJunk, timeout) + if t.tunnelConnected { + mayHasJunk = false + } else if t.transferConfig.TmuxOutputJunk { + mayHasJunk = true + } + + line, err := t.buffer.readLine(mayHasJunk, timeout) if err != nil { if e := t.checkStop(); e != nil { return nil, e @@ -288,7 +406,7 @@ func (t *trzszTransfer) recvLine(expectType string, mayHasJunk bool, timeout <-c return nil, err } - if t.transferConfig.TmuxOutputJunk || mayHasJunk { + if mayHasJunk { idx := bytes.LastIndex(line, []byte("#"+expectType+":")) if idx >= 0 { line = line[idx:] @@ -459,7 +577,15 @@ func (t *trzszTransfer) sendAction(confirm bool, serverVersion *trzszVersion, re SupportBinary: true, SupportDirectory: true, } - if isWindowsEnvironment() || remoteIsWindows { + + t.tunnelInitWG.Wait() + if conn := t.tunnelConn.Load(); conn != nil { + t.writer = *conn + t.tunnelConnected = true + action.TunnelConnected = true + } + + if !t.tunnelConnected && (isWindowsEnvironment() || remoteIsWindows) { action.Newline = "!\n" action.SupportBinary = false } @@ -471,7 +597,11 @@ func (t *trzszTransfer) sendAction(confirm bool, serverVersion *trzszVersion, re t.windowsProtocol = true t.transferConfig.Newline = "!\n" } - return t.sendString("ACT", string(actStr)) + if err := t.sendString("ACT", string(actStr)); err != nil { + return err + } + t.transferConfig.Newline = action.Newline + return nil } func (t *trzszTransfer) recvAction() (*transferAction, error) { @@ -486,6 +616,14 @@ func (t *trzszTransfer) recvAction() (*transferAction, error) { if err := json.Unmarshal([]byte(actStr), action); err != nil { return nil, err } + if action.TunnelConnected { + t.tunnelConnected = true + if conn := t.tunnelConn.Load(); conn != nil { + t.writer = *conn + } else { + return nil, simpleTrzszError("The tunnel connection is nil") + } + } t.transferConfig.Newline = action.Newline return action, nil } @@ -497,7 +635,9 @@ func (t *trzszTransfer) sendConfig(args *baseArgs, action *transferAction, escap if args.Quiet { cfgMap["quiet"] = true } - if args.Binary { + if action.TunnelConnected { + cfgMap["binary"] = true + } else if args.Binary { cfgMap["binary"] = true cfgMap["escape_chars"] = escapeChars } diff --git a/trzsz/transfer_test.go b/trzsz/transfer_test.go index 1781898..b26919d 100644 --- a/trzsz/transfer_test.go +++ b/trzsz/transfer_test.go @@ -43,8 +43,9 @@ func TestTransferAction(t *testing.T) { serverTransfer := newTransfer(writer, nil, false, nil) // compatible with older versions - serverTransfer.addReceivedData([]byte( - "#ACT:eJyrVspJzEtXslJQKqhU0lFQSs7PS8ssygUKlBSVpgIFylKLijPz80AqDPUM9AxAiopLCwryi0riUzKLEAoLivJL8pPzc4AiBrUAlAQbEA==\n")) + serverTransfer.addReceivedData( + []byte("#ACT:eJyrVspJzEtXslJQKqhU0lFQSs7PS8ssygUKlBSVpgIFylKLijPz80AqDPUM9AxAiopLCwryi0riUzKLEAoLivJL8pPzc4AiBrUAlAQbEA==\n"), + false) action, err := serverTransfer.recvAction() assert.Nil(err) assert.Equal(&transferAction{ @@ -68,7 +69,7 @@ func TestTransferAction(t *testing.T) { assert.Equal("\n", clientTransfer.transferConfig.Newline) SetAffectedByWindows(false) - serverTransfer.addReceivedData([]byte(writer.buffer[0])) + serverTransfer.addReceivedData([]byte(writer.buffer[0]), false) action, err = serverTransfer.recvAction() assert.Nil(err) assert.Equal("\n", action.Newline) @@ -86,7 +87,7 @@ func TestTransferAction(t *testing.T) { assert.Equal("\n", clientTransfer.transferConfig.Newline) SetAffectedByWindows(false) - serverTransfer.addReceivedData([]byte(writer.buffer[1])) + serverTransfer.addReceivedData([]byte(writer.buffer[1]), false) action, err = serverTransfer.recvAction() assert.Nil(err) assert.Equal("!\n", action.Newline) @@ -104,7 +105,7 @@ func TestTransferAction(t *testing.T) { assert.Equal("!\n", clientTransfer.transferConfig.Newline) SetAffectedByWindows(true) - serverTransfer.addReceivedData([]byte(writer.buffer[2])) + serverTransfer.addReceivedData([]byte(writer.buffer[2]), false) action, err = serverTransfer.recvAction() assert.Nil(err) assert.Equal("!\n", action.Newline) @@ -122,7 +123,7 @@ func TestTransferAction(t *testing.T) { assert.Equal("!\n", clientTransfer.transferConfig.Newline) SetAffectedByWindows(true) - serverTransfer.addReceivedData([]byte(writer.buffer[3])) + serverTransfer.addReceivedData([]byte(writer.buffer[3]), false) action, err = serverTransfer.recvAction() assert.Nil(err) assert.Equal("!\n", action.Newline) @@ -180,7 +181,7 @@ func TestTransferConfig(t *testing.T) { assertConfigEqual := func(cfgStr string) { t.Helper() - transfer.addReceivedData([]byte(cfgStr)) + transfer.addReceivedData([]byte(cfgStr), false) transferConfig, err := transfer.recvConfig() assert.Nil(err) assert.Equal(config, *transferConfig) diff --git a/trzsz/trz.go b/trzsz/trz.go index 80a87ea..c9af21b 100644 --- a/trzsz/trz.go +++ b/trzsz/trz.go @@ -158,11 +158,13 @@ func TrzMain() int { uniqueID += 20 } + listener, port := listenForTunnel() + mode := "R" if args.Directory { mode = "D" } - os.Stdout.WriteString(fmt.Sprintf("\x1b7\x07::TRZSZ:TRANSFER:%s:%s:%013d\r\n", mode, kTrzszVersion, uniqueID)) + os.Stdout.WriteString(fmt.Sprintf("\x1b7\x07::TRZSZ:TRANSFER:%s:%s:%013d:%d\r\n", mode, kTrzszVersion, uniqueID, port)) os.Stdout.Sync() var state *term.State @@ -183,12 +185,18 @@ func TrzMain() int { } }() - go wrapStdinInput(transfer) + if listener != nil { + defer listener.Close() + transfer.acceptOnTunnel(listener, fmt.Sprintf("%013d", uniqueID), port) + } + wrapTransferInput(transfer, os.Stdin, false) handleServerSignal(transfer) if err := recvFiles(transfer, args, tmuxMode, tmuxPaneWidth); err != nil { transfer.serverError(err) } + transfer.cleanup() + return 0 } diff --git a/trzsz/tsz.go b/trzsz/tsz.go index 62ea375..3308229 100644 --- a/trzsz/tsz.go +++ b/trzsz/tsz.go @@ -158,7 +158,9 @@ func TszMain() int { uniqueID += 20 } - os.Stdout.WriteString(fmt.Sprintf("\x1b7\x07::TRZSZ:TRANSFER:S:%s:%013d\r\n", kTrzszVersion, uniqueID)) + listener, port := listenForTunnel() + + os.Stdout.WriteString(fmt.Sprintf("\x1b7\x07::TRZSZ:TRANSFER:S:%s:%013d:%d\r\n", kTrzszVersion, uniqueID, port)) os.Stdout.Sync() var state *term.State @@ -179,12 +181,18 @@ func TszMain() int { } }() - go wrapStdinInput(transfer) + if listener != nil { + defer listener.Close() + transfer.acceptOnTunnel(listener, fmt.Sprintf("%013d", uniqueID), port) + } + wrapTransferInput(transfer, os.Stdin, false) handleServerSignal(transfer) if err := sendFiles(transfer, files, args, tmuxMode, tmuxPaneWidth); err != nil { transfer.serverError(err) } + transfer.cleanup() + return 0 }