@@ -9,7 +9,6 @@ package scp
9
9
import (
10
10
"bytes"
11
11
"context"
12
- "errors"
13
12
"fmt"
14
13
"io"
15
14
"io/ioutil"
@@ -85,13 +84,24 @@ func (a *Client) SSHClient() *ssh.Client {
85
84
}
86
85
87
86
// CopyFromFile copies the contents of an os.File to a remote location, it will get the length of the file by looking it up from the filesystem.
88
- func (a * Client ) CopyFromFile (ctx context.Context , file os.File , remotePath string , permissions string ) error {
87
+ func (a * Client ) CopyFromFile (
88
+ ctx context.Context ,
89
+ file os.File ,
90
+ remotePath string ,
91
+ permissions string ,
92
+ ) error {
89
93
return a .CopyFromFilePassThru (ctx , file , remotePath , permissions , nil )
90
94
}
91
95
92
96
// CopyFromFilePassThru copies the contents of an os.File to a remote location, it will get the length of the file by looking it up from the filesystem.
93
97
// Access copied bytes by providing a PassThru reader factory.
94
- func (a * Client ) CopyFromFilePassThru (ctx context.Context , file os.File , remotePath string , permissions string , passThru PassThru ) error {
98
+ func (a * Client ) CopyFromFilePassThru (
99
+ ctx context.Context ,
100
+ file os.File ,
101
+ remotePath string ,
102
+ permissions string ,
103
+ passThru PassThru ,
104
+ ) error {
95
105
stat , err := file .Stat ()
96
106
if err != nil {
97
107
return fmt .Errorf ("failed to stat file: %w" , err )
@@ -101,21 +111,39 @@ func (a *Client) CopyFromFilePassThru(ctx context.Context, file os.File, remoteP
101
111
102
112
// CopyFile copies the contents of an io.Reader to a remote location, the length is determined by reading the io.Reader until EOF
103
113
// if the file length in know in advance please use "Copy" instead.
104
- func (a * Client ) CopyFile (ctx context.Context , fileReader io.Reader , remotePath string , permissions string ) error {
114
+ func (a * Client ) CopyFile (
115
+ ctx context.Context ,
116
+ fileReader io.Reader ,
117
+ remotePath string ,
118
+ permissions string ,
119
+ ) error {
105
120
return a .CopyFilePassThru (ctx , fileReader , remotePath , permissions , nil )
106
121
}
107
122
108
123
// CopyFilePassThru copies the contents of an io.Reader to a remote location, the length is determined by reading the io.Reader until EOF
109
124
// if the file length in know in advance please use "Copy" instead.
110
125
// Access copied bytes by providing a PassThru reader factory.
111
- func (a * Client ) CopyFilePassThru (ctx context.Context , fileReader io.Reader , remotePath string , permissions string , passThru PassThru ) error {
126
+ func (a * Client ) CopyFilePassThru (
127
+ ctx context.Context ,
128
+ fileReader io.Reader ,
129
+ remotePath string ,
130
+ permissions string ,
131
+ passThru PassThru ,
132
+ ) error {
112
133
contentsBytes , err := ioutil .ReadAll (fileReader )
113
134
if err != nil {
114
135
return fmt .Errorf ("failed to read all data from reader: %w" , err )
115
136
}
116
137
bytesReader := bytes .NewReader (contentsBytes )
117
138
118
- return a .CopyPassThru (ctx , bytesReader , remotePath , permissions , int64 (len (contentsBytes )), passThru )
139
+ return a .CopyPassThru (
140
+ ctx ,
141
+ bytesReader ,
142
+ remotePath ,
143
+ permissions ,
144
+ int64 (len (contentsBytes )),
145
+ passThru ,
146
+ )
119
147
}
120
148
121
149
// wait waits for the waitgroup for the specified max timeout.
@@ -139,27 +167,36 @@ func wait(wg *sync.WaitGroup, ctx context.Context) error {
139
167
// checkResponse checks the response it reads from the remote, and will return a single error in case
140
168
// of failure.
141
169
func checkResponse (r io.Reader ) error {
142
- response , err := ParseResponse (r )
170
+ _ , err := ParseResponse (r , nil )
143
171
if err != nil {
144
172
return err
145
173
}
146
174
147
- if response .IsFailure () {
148
- return errors .New (response .GetMessage ())
149
- }
150
-
151
175
return nil
152
176
153
177
}
154
178
155
179
// Copy copies the contents of an io.Reader to a remote location.
156
- func (a * Client ) Copy (ctx context.Context , r io.Reader , remotePath string , permissions string , size int64 ) error {
180
+ func (a * Client ) Copy (
181
+ ctx context.Context ,
182
+ r io.Reader ,
183
+ remotePath string ,
184
+ permissions string ,
185
+ size int64 ,
186
+ ) error {
157
187
return a .CopyPassThru (ctx , r , remotePath , permissions , size , nil )
158
188
}
159
189
160
190
// CopyPassThru copies the contents of an io.Reader to a remote location.
161
191
// Access copied bytes by providing a PassThru reader factory
162
- func (a * Client ) CopyPassThru (ctx context.Context , r io.Reader , remotePath string , permissions string , size int64 , passThru PassThru ) error {
192
+ func (a * Client ) CopyPassThru (
193
+ ctx context.Context ,
194
+ r io.Reader ,
195
+ remotePath string ,
196
+ permissions string ,
197
+ size int64 ,
198
+ passThru PassThru ,
199
+ ) error {
163
200
session , err := a .sshClient .NewSession ()
164
201
if err != nil {
165
202
return fmt .Errorf ("Error creating ssh session in copy to remote: %v" , err )
@@ -272,7 +309,12 @@ func (a *Client) CopyFromRemote(ctx context.Context, file *os.File, remotePath s
272
309
// CopyFromRemotePassThru copies a file from the remote to the given writer. The passThru parameter can be used
273
310
// to keep track of progress and how many bytes that were download from the remote.
274
311
// `passThru` can be set to nil to disable this behaviour.
275
- func (a * Client ) CopyFromRemotePassThru (ctx context.Context , w io.Writer , remotePath string , passThru PassThru ) error {
312
+ func (a * Client ) CopyFromRemotePassThru (
313
+ ctx context.Context ,
314
+ w io.Writer ,
315
+ remotePath string ,
316
+ passThru PassThru ,
317
+ ) error {
276
318
session , err := a .sshClient .NewSession ()
277
319
if err != nil {
278
320
return fmt .Errorf ("Error creating ssh session in copy from remote: %v" , err )
@@ -319,17 +361,7 @@ func (a *Client) CopyFromRemotePassThru(ctx context.Context, w io.Writer, remote
319
361
return
320
362
}
321
363
322
- res , err := ParseResponse (r )
323
- if err != nil {
324
- errCh <- err
325
- return
326
- }
327
- if res .IsFailure () {
328
- errCh <- errors .New (res .GetMessage ())
329
- return
330
- }
331
-
332
- infos , err := res .ParseFileInfos ()
364
+ fileInfo , err := ParseResponse (r , in )
333
365
if err != nil {
334
366
errCh <- err
335
367
return
@@ -342,10 +374,10 @@ func (a *Client) CopyFromRemotePassThru(ctx context.Context, w io.Writer, remote
342
374
}
343
375
344
376
if passThru != nil {
345
- r = passThru (r , infos .Size )
377
+ r = passThru (r , fileInfo .Size )
346
378
}
347
379
348
- _ , err = CopyN (w , r , infos .Size )
380
+ _ , err = CopyN (w , r , fileInfo .Size )
349
381
if err != nil {
350
382
errCh <- err
351
383
return
0 commit comments