Skip to content

Commit

Permalink
aghio: add validation to constructor
Browse files Browse the repository at this point in the history
  • Loading branch information
EugeneOne1 committed Nov 23, 2020
1 parent f57a2f5 commit 060f923
Show file tree
Hide file tree
Showing 8 changed files with 64 additions and 13 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ and this project adheres to

### Added

- HTTP API request body limit [#2305].
- HTTP API request body size limit [#2305].

[#2305]: https://github.com/AdguardTeam/AdGuardHome/issues/2305

Expand Down
8 changes: 4 additions & 4 deletions internal/aghio/limitedreadcloser.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ type limitedReadCloser struct {

// Read implements Reader interface.
func (lrc *limitedReadCloser) Read(p []byte) (n int, err error) {
if lrc.n <= 0 {
if lrc.n == 0 {
return 0, &LimitReachedError{
Limit: lrc.limit,
}
Expand All @@ -47,13 +47,13 @@ func (lrc *limitedReadCloser) Close() error {

// LimitReadCloser wraps ReadCloser to make it's Reader stop with
// ErrLimitReached after n bytes read.
func LimitReadCloser(rc io.ReadCloser, n int64) io.ReadCloser {
func LimitReadCloser(rc io.ReadCloser, n int64) (new io.ReadCloser, err error) {
if n < 0 {
n = 0
return nil, fmt.Errorf("invalid n: %d", n)
}
return &limitedReadCloser{
limit: n,
n: n,
rc: rc,
}
}, nil
}
33 changes: 31 additions & 2 deletions internal/aghio/limitedreadcloser_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package aghio

import (
"fmt"
"io"
"io/ioutil"
"strings"
Expand All @@ -9,6 +10,33 @@ import (
"github.com/stretchr/testify/assert"
)

func TestLimitReadCloser(t *testing.T) {
testCases := []struct {
name string
n int64
want error
}{{
name: "positive",
n: 1,
want: nil,
}, {
name: "zero",
n: 0,
want: nil,
}, {
name: "negative",
n: -1,
want: fmt.Errorf("invalid n: -1"),
}}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
_, err := LimitReadCloser(nil, tc.n)
assert.Equal(t, tc.want, err)
})
}
}

func TestLimitedReadCloser_Read(t *testing.T) {
testCases := []struct {
name string
Expand Down Expand Up @@ -49,9 +77,10 @@ func TestLimitedReadCloser_Read(t *testing.T) {
readCloser := ioutil.NopCloser(strings.NewReader(tc.rStr))
buf := make([]byte, tc.limit+1)

lreader := LimitReadCloser(readCloser, tc.limit)
n, err := lreader.Read(buf)
lreader, err := LimitReadCloser(readCloser, tc.limit)
assert.Nil(t, err)

n, err := lreader.Read(buf)
assert.Equal(t, n, tc.want)
assert.Equal(t, tc.err, err)
})
Expand Down
7 changes: 6 additions & 1 deletion internal/home/auth_glinet.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,12 @@ func glGetTokenDate(file string) uint32 {
log.Error("os.Open: %s", err)
return 0
}
fileReadCloser := aghio.LimitReadCloser(f, MaxFileSize)
fileReadCloser, err := aghio.LimitReadCloser(f, MaxFileSize)
if err != nil {
log.Error("LimitReadCloser: %s", err)
f.Close()
return 0
}
defer fileReadCloser.Close()

var dateToken uint32
Expand Down
6 changes: 5 additions & 1 deletion internal/home/middlewares.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,11 @@ const RequestBodySizeLimit = 64 * 1024
// method limited.
func limitRequestBody(h http.Handler) (limited http.Handler) {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
r.Body = aghio.LimitReadCloser(r.Body, RequestBodySizeLimit)
var err error
r.Body, err = aghio.LimitReadCloser(r.Body, RequestBodySizeLimit)
if err != nil {
return
}

h.ServeHTTP(w, r)
})
Expand Down
6 changes: 5 additions & 1 deletion internal/home/whois.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,11 @@ func (w *Whois) query(target, serverAddr string) (string, error) {
if err != nil {
return "", err
}
connReadCloser := aghio.LimitReadCloser(conn, MaxConnReadSize)
connReadCloser, err := aghio.LimitReadCloser(conn, MaxConnReadSize)
if err != nil {
conn.Close()
return "", err
}
defer connReadCloser.Close()

_ = conn.SetReadDeadline(time.Now().Add(time.Duration(w.timeoutMsec) * time.Millisecond))
Expand Down
6 changes: 5 additions & 1 deletion internal/update/check.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,11 @@ func (u *Updater) GetVersionResponse(forceRecheck bool) (VersionInfo, error) {
if err != nil {
return VersionInfo{}, fmt.Errorf("updater: HTTP GET %s: %w", u.VersionURL, err)
}
resp.Body = aghio.LimitReadCloser(resp.Body, MaxResponseSize)
resp.Body, err = aghio.LimitReadCloser(resp.Body, MaxResponseSize)
if err != nil {
resp.Body.Close()
return VersionInfo{}, fmt.Errorf("updater: LimitReadCloser: %w", err)
}
defer resp.Body.Close()

// This use of ReadAll is safe, because we just limited the appropriate
Expand Down
9 changes: 7 additions & 2 deletions internal/update/updater.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,8 @@ func (u *Updater) clean() {
_ = os.RemoveAll(u.updateDir)
}

// MaxPackageFileSize is a maximum package file length in bytes.
// MaxPackageFileSize is a maximum package file length in bytes. The largest
// package whose size is limited by this constant has size of 9MiB for now.
const MaxPackageFileSize = 32 * 1024 * 1024

// Download package file and save it to disk
Expand All @@ -228,7 +229,11 @@ func (u *Updater) downloadPackageFile(url string, filename string) error {
return fmt.Errorf("http request failed: %w", err)
}

resp.Body = aghio.LimitReadCloser(resp.Body, MaxPackageFileSize)
resp.Body, err = aghio.LimitReadCloser(resp.Body, MaxPackageFileSize)
if err != nil {
resp.Body.Close()
return fmt.Errorf("http request failed: %w", err)
}
defer resp.Body.Close()

log.Debug("updater: reading HTTP body")
Expand Down

0 comments on commit 060f923

Please sign in to comment.