diff --git a/webpush.go b/webpush.go index a31a94d..decb919 100644 --- a/webpush.go +++ b/webpush.go @@ -20,6 +20,8 @@ import ( const MaxRecordSize uint32 = 4096 +var ErrMaxPadExceeded = errors.New("payload has exceeded the maximum length") + // saltFunc generates a salt of 16 bytes var saltFunc = func() ([]byte, error) { salt := make([]byte, 16) @@ -166,7 +168,9 @@ func SendNotification(message []byte, s *Subscription, options *Options) (*http. // Pad content to max record size - 16 - header // Padding ending delimeter dataBuf.Write([]byte("\x02")) - pad(dataBuf, recordLength-recordBuf.Len()) + if err := pad(dataBuf, recordLength-recordBuf.Len()); err != nil { + return nil, err + } // Compose the ciphertext ciphertext := gcm.Seal([]byte{}, nonce, dataBuf.Bytes(), nil) @@ -244,10 +248,16 @@ func getHKDFKey(hkdf io.Reader, length int) ([]byte, error) { return key, nil } -func pad(payload *bytes.Buffer, maxPadLen int) { +func pad(payload *bytes.Buffer, maxPadLen int) error { payloadLen := payload.Len() + if payloadLen > maxPadLen { + return ErrMaxPadExceeded + } + padLen := maxPadLen - payloadLen padding := make([]byte, padLen) payload.Write(padding) + + return nil } diff --git a/webpush_test.go b/webpush_test.go index 93ec994..807a1f7 100644 --- a/webpush_test.go +++ b/webpush_test.go @@ -2,6 +2,7 @@ package webpush import ( "net/http" + "strings" "testing" ) @@ -76,3 +77,17 @@ func TestSendNotificationToStandardEncodedSubscription(t *testing.T) { ) } } + +func TestSendTooLargeNotification(t *testing.T) { + _, err := SendNotification([]byte(strings.Repeat("Test", int(MaxRecordSize))), getStandardEncodedTestSubscription(), &Options{ + HTTPClient: &testHTTPClient{}, + Subscriber: "", + Topic: "test_topic", + TTL: 0, + Urgency: "low", + VAPIDPrivateKey: "testKey", + }) + if err == nil { + t.Fatalf("Error is nil, expected=%s", ErrMaxPadExceeded) + } +}