Skip to content

Commit db7cf4f

Browse files
Merge pull request #80 from datadius/preserve-protocol
Handle T message and prepare for adding -p option
2 parents bd16750 + b4cd115 commit db7cf4f

File tree

2 files changed

+179
-77
lines changed

2 files changed

+179
-77
lines changed

client.go

+59-27
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ package scp
99
import (
1010
"bytes"
1111
"context"
12-
"errors"
1312
"fmt"
1413
"io"
1514
"io/ioutil"
@@ -85,13 +84,24 @@ func (a *Client) SSHClient() *ssh.Client {
8584
}
8685

8786
// CopyFromFile copies the contents of an os.File to a remote location, it will get the length of the file by looking it up from the filesystem.
88-
func (a *Client) CopyFromFile(ctx context.Context, file os.File, remotePath string, permissions string) error {
87+
func (a *Client) CopyFromFile(
88+
ctx context.Context,
89+
file os.File,
90+
remotePath string,
91+
permissions string,
92+
) error {
8993
return a.CopyFromFilePassThru(ctx, file, remotePath, permissions, nil)
9094
}
9195

9296
// CopyFromFilePassThru copies the contents of an os.File to a remote location, it will get the length of the file by looking it up from the filesystem.
9397
// Access copied bytes by providing a PassThru reader factory.
94-
func (a *Client) CopyFromFilePassThru(ctx context.Context, file os.File, remotePath string, permissions string, passThru PassThru) error {
98+
func (a *Client) CopyFromFilePassThru(
99+
ctx context.Context,
100+
file os.File,
101+
remotePath string,
102+
permissions string,
103+
passThru PassThru,
104+
) error {
95105
stat, err := file.Stat()
96106
if err != nil {
97107
return fmt.Errorf("failed to stat file: %w", err)
@@ -101,21 +111,39 @@ func (a *Client) CopyFromFilePassThru(ctx context.Context, file os.File, remoteP
101111

102112
// CopyFile copies the contents of an io.Reader to a remote location, the length is determined by reading the io.Reader until EOF
103113
// if the file length in know in advance please use "Copy" instead.
104-
func (a *Client) CopyFile(ctx context.Context, fileReader io.Reader, remotePath string, permissions string) error {
114+
func (a *Client) CopyFile(
115+
ctx context.Context,
116+
fileReader io.Reader,
117+
remotePath string,
118+
permissions string,
119+
) error {
105120
return a.CopyFilePassThru(ctx, fileReader, remotePath, permissions, nil)
106121
}
107122

108123
// CopyFilePassThru copies the contents of an io.Reader to a remote location, the length is determined by reading the io.Reader until EOF
109124
// if the file length in know in advance please use "Copy" instead.
110125
// Access copied bytes by providing a PassThru reader factory.
111-
func (a *Client) CopyFilePassThru(ctx context.Context, fileReader io.Reader, remotePath string, permissions string, passThru PassThru) error {
126+
func (a *Client) CopyFilePassThru(
127+
ctx context.Context,
128+
fileReader io.Reader,
129+
remotePath string,
130+
permissions string,
131+
passThru PassThru,
132+
) error {
112133
contentsBytes, err := ioutil.ReadAll(fileReader)
113134
if err != nil {
114135
return fmt.Errorf("failed to read all data from reader: %w", err)
115136
}
116137
bytesReader := bytes.NewReader(contentsBytes)
117138

118-
return a.CopyPassThru(ctx, bytesReader, remotePath, permissions, int64(len(contentsBytes)), passThru)
139+
return a.CopyPassThru(
140+
ctx,
141+
bytesReader,
142+
remotePath,
143+
permissions,
144+
int64(len(contentsBytes)),
145+
passThru,
146+
)
119147
}
120148

121149
// wait waits for the waitgroup for the specified max timeout.
@@ -139,27 +167,36 @@ func wait(wg *sync.WaitGroup, ctx context.Context) error {
139167
// checkResponse checks the response it reads from the remote, and will return a single error in case
140168
// of failure.
141169
func checkResponse(r io.Reader) error {
142-
response, err := ParseResponse(r)
170+
_, err := ParseResponse(r, nil)
143171
if err != nil {
144172
return err
145173
}
146174

147-
if response.IsFailure() {
148-
return errors.New(response.GetMessage())
149-
}
150-
151175
return nil
152176

153177
}
154178

155179
// Copy copies the contents of an io.Reader to a remote location.
156-
func (a *Client) Copy(ctx context.Context, r io.Reader, remotePath string, permissions string, size int64) error {
180+
func (a *Client) Copy(
181+
ctx context.Context,
182+
r io.Reader,
183+
remotePath string,
184+
permissions string,
185+
size int64,
186+
) error {
157187
return a.CopyPassThru(ctx, r, remotePath, permissions, size, nil)
158188
}
159189

160190
// CopyPassThru copies the contents of an io.Reader to a remote location.
161191
// Access copied bytes by providing a PassThru reader factory
162-
func (a *Client) CopyPassThru(ctx context.Context, r io.Reader, remotePath string, permissions string, size int64, passThru PassThru) error {
192+
func (a *Client) CopyPassThru(
193+
ctx context.Context,
194+
r io.Reader,
195+
remotePath string,
196+
permissions string,
197+
size int64,
198+
passThru PassThru,
199+
) error {
163200
session, err := a.sshClient.NewSession()
164201
if err != nil {
165202
return fmt.Errorf("Error creating ssh session in copy to remote: %v", err)
@@ -272,7 +309,12 @@ func (a *Client) CopyFromRemote(ctx context.Context, file *os.File, remotePath s
272309
// CopyFromRemotePassThru copies a file from the remote to the given writer. The passThru parameter can be used
273310
// to keep track of progress and how many bytes that were download from the remote.
274311
// `passThru` can be set to nil to disable this behaviour.
275-
func (a *Client) CopyFromRemotePassThru(ctx context.Context, w io.Writer, remotePath string, passThru PassThru) error {
312+
func (a *Client) CopyFromRemotePassThru(
313+
ctx context.Context,
314+
w io.Writer,
315+
remotePath string,
316+
passThru PassThru,
317+
) error {
276318
session, err := a.sshClient.NewSession()
277319
if err != nil {
278320
return fmt.Errorf("Error creating ssh session in copy from remote: %v", err)
@@ -319,17 +361,7 @@ func (a *Client) CopyFromRemotePassThru(ctx context.Context, w io.Writer, remote
319361
return
320362
}
321363

322-
res, err := ParseResponse(r)
323-
if err != nil {
324-
errCh <- err
325-
return
326-
}
327-
if res.IsFailure() {
328-
errCh <- errors.New(res.GetMessage())
329-
return
330-
}
331-
332-
infos, err := res.ParseFileInfos()
364+
fileInfo, err := ParseResponse(r, in)
333365
if err != nil {
334366
errCh <- err
335367
return
@@ -342,10 +374,10 @@ func (a *Client) CopyFromRemotePassThru(ctx context.Context, w io.Writer, remote
342374
}
343375

344376
if passThru != nil {
345-
r = passThru(r, infos.Size)
377+
r = passThru(r, fileInfo.Size)
346378
}
347379

348-
_, err = CopyN(w, r, infos.Size)
380+
_, err = CopyN(w, r, fileInfo.Size)
349381
if err != nil {
350382
errCh <- err
351383
return

protocol.go

+120-50
Original file line numberDiff line numberDiff line change
@@ -9,42 +9,30 @@ package scp
99
import (
1010
"bufio"
1111
"errors"
12+
"fmt"
1213
"io"
1314
"strconv"
1415
"strings"
1516
)
1617

17-
type ResponseType = uint8
18+
type ResponseType = byte
1819

1920
const (
2021
Ok ResponseType = 0
2122
Warning ResponseType = 1
2223
Error ResponseType = 2
24+
Create ResponseType = 'C'
25+
Time ResponseType = 'T'
2326
)
2427

25-
// Response represent a response from the SCP command.
26-
// There are tree types of responses that the remote can send back:
27-
// ok, warning and error
28-
//
29-
// The difference between warning and error is that the connection is not closed by the remote,
30-
// however, a warning can indicate a file transfer failure (such as invalid destination directory)
31-
// and such be handled as such.
32-
//
33-
// All responses except for the `Ok` type always have a message (although these can be empty)
34-
//
35-
// The remote sends a confirmation after every SCP command, because a failure can occur after every
36-
// command, the response should be read and checked after sending them.
37-
type Response struct {
38-
Type ResponseType
39-
Message string
40-
}
41-
4228
// ParseResponse reads from the given reader (assuming it is the output of the remote) and parses it into a Response structure.
43-
func ParseResponse(reader io.Reader) (Response, error) {
29+
func ParseResponse(reader io.Reader, writer io.Writer) (*FileInfos, error) {
30+
fileInfos := NewFileInfos()
31+
4432
buffer := make([]uint8, 1)
4533
_, err := reader.Read(buffer)
4634
if err != nil {
47-
return Response{}, err
35+
return fileInfos, err
4836
}
4937

5038
responseType := buffer[0]
@@ -53,61 +41,143 @@ func ParseResponse(reader io.Reader) (Response, error) {
5341
bufferedReader := bufio.NewReader(reader)
5442
message, err = bufferedReader.ReadString('\n')
5543
if err != nil {
56-
return Response{}, err
44+
return fileInfos, err
5745
}
58-
}
5946

60-
return Response{responseType, message}, nil
61-
}
47+
if responseType == Warning || responseType == Error {
48+
return fileInfos, errors.New(message)
49+
}
6250

63-
func (r *Response) IsOk() bool {
64-
return r.Type == Ok
65-
}
51+
// Exit early because we're only interested in the ok response
52+
if responseType == Ok {
53+
return fileInfos, nil
54+
}
6655

67-
func (r *Response) IsWarning() bool {
68-
return r.Type == Warning
69-
}
56+
if !(responseType == Create || responseType == Time) {
57+
return fileInfos, errors.New(
58+
fmt.Sprintf(
59+
"Message does not follow scp protocol: %s\n Cmmmm <length> <filename> or T<mtime> 0 <atime> 0",
60+
message,
61+
),
62+
)
63+
}
7064

71-
// IsError returns true when the remote responded with an error.
72-
func (r *Response) IsError() bool {
73-
return r.Type == Error
74-
}
65+
if responseType == Time {
66+
err = ParseFileTime(message, fileInfos)
67+
if err != nil {
68+
return nil, err
69+
}
70+
71+
message, err = bufferedReader.ReadString('\n')
72+
if err == io.EOF {
73+
err = Ack(writer)
74+
if err != nil {
75+
return fileInfos, err
76+
}
77+
message, err = bufferedReader.ReadString('\n')
78+
79+
if err != nil {
80+
return fileInfos, err
81+
}
82+
}
83+
84+
if err != nil && err != io.EOF {
85+
return fileInfos, err
86+
}
87+
88+
responseType = message[0]
89+
}
7590

76-
// IsFailure returns true when the remote answered with a warning or an error.
77-
func (r *Response) IsFailure() bool {
78-
return r.IsWarning() || r.IsError()
79-
}
91+
if responseType == Create {
92+
err = ParseFileInfos(message, fileInfos)
93+
if err != nil {
94+
return nil, err
95+
}
96+
}
97+
}
8098

81-
// GetMessage returns the message the remote sent back.
82-
func (r *Response) GetMessage() string {
83-
return r.Message
99+
return fileInfos, nil
84100
}
85101

86102
type FileInfos struct {
87103
Message string
88104
Filename string
89105
Permissions string
90106
Size int64
107+
Atime int64
108+
Mtime int64
109+
}
110+
111+
func NewFileInfos() *FileInfos {
112+
return &FileInfos{}
91113
}
92114

93-
func (r *Response) ParseFileInfos() (*FileInfos, error) {
94-
message := strings.ReplaceAll(r.Message, "\n", "")
95-
parts := strings.Split(message, " ")
115+
func (fileInfos *FileInfos) Update(new *FileInfos) {
116+
if new == nil {
117+
return
118+
}
119+
if new.Filename != "" {
120+
fileInfos.Filename = new.Filename
121+
}
122+
if new.Permissions != "" {
123+
fileInfos.Permissions = new.Permissions
124+
}
125+
if new.Size != 0 {
126+
fileInfos.Size = new.Size
127+
}
128+
if new.Atime != 0 {
129+
fileInfos.Atime = new.Atime
130+
}
131+
if new.Mtime != 0 {
132+
fileInfos.Mtime = new.Mtime
133+
}
134+
}
135+
136+
func ParseFileInfos(message string, fileInfos *FileInfos) error {
137+
processMessage := strings.ReplaceAll(message, "\n", "")
138+
parts := strings.Split(processMessage, " ")
96139
if len(parts) < 3 {
97-
return nil, errors.New("unable to parse message as file infos")
140+
return errors.New("unable to parse Chmod protocol")
98141
}
99142

100143
size, err := strconv.Atoi(parts[1])
101144
if err != nil {
102-
return nil, err
145+
return err
103146
}
104147

105-
return &FileInfos{
106-
Message: r.Message,
148+
fileInfos.Update(&FileInfos{
149+
Filename: parts[2],
107150
Permissions: parts[0],
108151
Size: int64(size),
109-
Filename: parts[2],
110-
}, nil
152+
})
153+
154+
return nil
155+
}
156+
157+
func ParseFileTime(
158+
message string,
159+
fileInfos *FileInfos,
160+
) error {
161+
processMessage := strings.ReplaceAll(message, "\n", "")
162+
parts := strings.Split(processMessage, " ")
163+
if len(parts) < 3 {
164+
return errors.New("unable to parse Time protocol")
165+
}
166+
167+
aTime, err := strconv.Atoi(string(parts[0][0:10]))
168+
if err != nil {
169+
return errors.New("unable to parse ATime component of message")
170+
}
171+
mTime, err := strconv.Atoi(string(parts[2][0:10]))
172+
if err != nil {
173+
return errors.New("unable to parse MTime component of message")
174+
}
175+
176+
fileInfos.Update(&FileInfos{
177+
Atime: int64(aTime),
178+
Mtime: int64(mTime),
179+
})
180+
return nil
111181
}
112182

113183
// Ack writes an `Ack` message to the remote, does not await its response, a seperate call to ParseResponse is

0 commit comments

Comments
 (0)