Skip to content

Commit

Permalink
Merge pull request #13 from rinchsan/tidy-and-add-tests
Browse files Browse the repository at this point in the history
  • Loading branch information
rinchsan authored Jan 13, 2021
2 parents 6ab2ab8 + 9fa0fa0 commit 9f0a343
Show file tree
Hide file tree
Showing 10 changed files with 157 additions and 33 deletions.
14 changes: 14 additions & 0 deletions api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,3 +198,17 @@ func TestAPI_do(t *testing.T) {
})
}
}

type roundTripFunc func(req *http.Request) *http.Response

func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return f(req), nil
}

func newMockHTTPClient(resp *http.Response) *http.Client {
return &http.Client{
Transport: roundTripFunc(func(r *http.Request) *http.Response {
return resp
}),
}
}
6 changes: 6 additions & 0 deletions codecov.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
coverage:
status:
project:
default:
target: 90%
patch: off
11 changes: 3 additions & 8 deletions error.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package devicecheck
import (
"errors"
"net/http"
"strings"
)

const (
Expand All @@ -21,14 +22,8 @@ var (
ErrBitStateNotFound = errors.New("bit state not found")
)

func newErrorForQuery(code int, body string) error {
if code != http.StatusOK {
return newError(code)
}
if body == bitStateNotFoundStr {
return ErrBitStateNotFound
}
return nil
func isErrBitStateNotFound(body string) bool {
return strings.Contains(body, bitStateNotFoundStr)
}

func newError(code int) error {
Expand Down
31 changes: 31 additions & 0 deletions error_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,37 @@ import (
"testing"
)

func Test_isErrBitStateNotFound(t *testing.T) {
t.Parallel()

cases := map[string]struct {
body string
want bool
}{
"is ErrBitStateNotFound": {
body: "Failed to find bit state",
want: true,
},
"is not ErrBitStateNotFound": {
body: "Missing or incorrectly formatted bits",
want: false,
},
}

for name, c := range cases {
c := c
t.Run(name, func(t *testing.T) {
t.Parallel()

got := isErrBitStateNotFound(c.body)

if !reflect.DeepEqual(got, c.want) {
t.Errorf("want '%+v', got '%+v'", c.want, got)
}
})
}
}

func Test_newError(t *testing.T) {
t.Parallel()

Expand Down
2 changes: 0 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
github.com/dvsekhvalnov/jose2go v0.0.0-20201001154944-b09cfaf05951 h1:U+H8oUNmugZTzB9c3EZQzdC2B9rXCqROImRSxnwh4Ck=
github.com/dvsekhvalnov/jose2go v0.0.0-20201001154944-b09cfaf05951/go.mod h1:QsHjhyTlD/lAVqn/NSbVZmSCGeDehTB/mPZadG+mhXU=
github.com/google/uuid v1.1.2 h1:EVhdT+1Kseyi1/pUmXKaFxYsDNy9RQYkMWRH68J/W7Y=
github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/uuid v1.1.4 h1:0ecGp3skIrHWPNGPJDaBIghfA6Sp7Ruo2Io8eLKzWm0=
github.com/google/uuid v1.1.4/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
Expand Down
7 changes: 2 additions & 5 deletions jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package devicecheck
import (
"crypto/ecdsa"
"encoding/json"
"fmt"
"time"

jose "github.com/dvsekhvalnov/jose2go"
Expand All @@ -27,10 +26,8 @@ func (jwt jwt) generate(key *ecdsa.PrivateKey) (string, error) {
"iat": time.Now().UTC().Unix(),
}

claimsJSON, err := json.Marshal(claims)
if err != nil {
return "", fmt.Errorf("json: %w", err)
}
// Ignoring error, because json.Marshal never fails.
claimsJSON, _ := json.Marshal(claims)

headers := map[string]interface{}{
"alg": jose.ES256,
Expand Down
27 changes: 15 additions & 12 deletions query.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"strings"
"time"

Expand Down Expand Up @@ -71,17 +72,19 @@ func (client *Client) QueryTwoBits(deviceToken string, result *QueryTwoBitsResul
}
defer resp.Body.Close()

respBody, err := ioutil.ReadAll(resp.Body)
if err != nil {
return err
}
err = newErrorForQuery(resp.StatusCode, string(respBody))
if err == ErrBitStateNotFound {
return err
}
if err != nil {
return fmt.Errorf("devicecheck: %w", err)
}
switch resp.StatusCode {
case http.StatusOK:
respBody, err := ioutil.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("devicecheck: failed to read response body: %w", err)
}

if isErrBitStateNotFound(string(respBody)) {
return ErrBitStateNotFound
}

return json.NewDecoder(bytes.NewReader(respBody)).Decode(result)
return json.NewDecoder(bytes.NewReader(respBody)).Decode(result)
default:
return fmt.Errorf("devicecheck: %w", newError(resp.StatusCode))
}
}
38 changes: 36 additions & 2 deletions query_test.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package devicecheck

import (
"io/ioutil"
"net/http"
"reflect"
"strings"
"testing"
"time"
)
Expand Down Expand Up @@ -89,13 +91,15 @@ func TestClient_QueryTwoBits(t *testing.T) {

cases := map[string]struct {
client Client
noErr bool
}{
"invalid key": {
client: Client{
api: newAPI(Development),
cred: NewCredentialFile("unknown_file.p8"),
jwt: newJWT("issuer", "keyID"),
},
noErr: false,
},
"invalid url": {
client: Client{
Expand All @@ -106,13 +110,37 @@ func TestClient_QueryTwoBits(t *testing.T) {
cred: NewCredentialFile("revoked_private_key.p8"),
jwt: newJWT("issuer", "keyID"),
},
noErr: false,
},
"invalid device token": {
client: Client{
api: newAPI(Development),
cred: NewCredentialFile("revoked_private_key.p8"),
jwt: newJWT("issuer", "keyID"),
},
noErr: false,
},
"status ok with ErrBitStateNotFound": {
client: Client{
api: newAPIWithHTTPClient(newMockHTTPClient(&http.Response{
StatusCode: http.StatusOK,
Body: ioutil.NopCloser(strings.NewReader("Failed to find bit state")),
}), Development),
cred: NewCredentialFile("revoked_private_key.p8"),
jwt: newJWT("issuer", "keyID"),
},
noErr: false,
},
"status ok with valid response": {
client: Client{
api: newAPIWithHTTPClient(newMockHTTPClient(&http.Response{
StatusCode: http.StatusOK,
Body: ioutil.NopCloser(strings.NewReader(`{"bit0":true,"bit1":false,"last_update_time":"2006-01"}`)),
}), Development),
cred: NewCredentialFile("revoked_private_key.p8"),
jwt: newJWT("issuer", "keyID"),
},
noErr: true,
},
}

Expand All @@ -124,8 +152,14 @@ func TestClient_QueryTwoBits(t *testing.T) {
var result QueryTwoBitsResult
err := c.client.QueryTwoBits("device_token", &result)

if err == nil {
t.Error("want 'not nil', got 'nil'")
if c.noErr {
if err != nil {
t.Errorf("want 'nil', got '%+v'", err)
}
} else {
if err == nil {
t.Error("want 'not nil', got 'nil'")
}
}
})
}
Expand Down
27 changes: 25 additions & 2 deletions update_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package devicecheck

import (
"io/ioutil"
"net/http"
"strings"
"testing"
)

Expand All @@ -10,13 +12,15 @@ func TestClient_UpdateTwoBits(t *testing.T) {

cases := map[string]struct {
client Client
noErr bool
}{
"invalid key": {
client: Client{
api: newAPI(Development),
cred: NewCredentialFile("unknown_file.p8"),
jwt: newJWT("issuer", "keyID"),
},
noErr: false,
},
"invalid url": {
client: Client{
Expand All @@ -27,13 +31,26 @@ func TestClient_UpdateTwoBits(t *testing.T) {
cred: NewCredentialFile("revoked_private_key.p8"),
jwt: newJWT("issuer", "keyID"),
},
noErr: false,
},
"invalid device token": {
client: Client{
api: newAPI(Development),
cred: NewCredentialFile("revoked_private_key.p8"),
jwt: newJWT("issuer", "keyID"),
},
noErr: false,
},
"status ok": {
client: Client{
api: newAPIWithHTTPClient(newMockHTTPClient(&http.Response{
StatusCode: http.StatusOK,
Body: ioutil.NopCloser(strings.NewReader("success")),
}), Development),
cred: NewCredentialFile("revoked_private_key.p8"),
jwt: newJWT("issuer", "keyID"),
},
noErr: true,
},
}

Expand All @@ -44,8 +61,14 @@ func TestClient_UpdateTwoBits(t *testing.T) {

err := c.client.UpdateTwoBits("device_token", true, true)

if err == nil {
t.Error("want 'not nil', got 'nil'")
if c.noErr {
if err != nil {
t.Errorf("want 'nil', got '%+v'", err)
}
} else {
if err == nil {
t.Error("want 'not nil', got 'nil'")
}
}
})
}
Expand Down
27 changes: 25 additions & 2 deletions validate_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package devicecheck

import (
"io/ioutil"
"net/http"
"strings"
"testing"
)

Expand All @@ -10,13 +12,15 @@ func TestClient_ValidateDeviceToken(t *testing.T) {

cases := map[string]struct {
client Client
noErr bool
}{
"invalid key": {
client: Client{
api: newAPI(Development),
cred: NewCredentialFile("unknown_file.p8"),
jwt: newJWT("issuer", "keyID"),
},
noErr: false,
},
"invalid url": {
client: Client{
Expand All @@ -27,13 +31,26 @@ func TestClient_ValidateDeviceToken(t *testing.T) {
cred: NewCredentialFile("revoked_private_key.p8"),
jwt: newJWT("issuer", "keyID"),
},
noErr: false,
},
"invalid device token": {
client: Client{
api: newAPI(Development),
cred: NewCredentialFile("revoked_private_key.p8"),
jwt: newJWT("issuer", "keyID"),
},
noErr: false,
},
"status ok": {
client: Client{
api: newAPIWithHTTPClient(newMockHTTPClient(&http.Response{
StatusCode: http.StatusOK,
Body: ioutil.NopCloser(strings.NewReader("success")),
}), Development),
cred: NewCredentialFile("revoked_private_key.p8"),
jwt: newJWT("issuer", "keyID"),
},
noErr: true,
},
}

Expand All @@ -44,8 +61,14 @@ func TestClient_ValidateDeviceToken(t *testing.T) {

err := c.client.ValidateDeviceToken("device_token")

if err == nil {
t.Error("want 'not nil', got 'nil'")
if c.noErr {
if err != nil {
t.Errorf("want 'nil', got '%+v'", err)
}
} else {
if err == nil {
t.Error("want 'not nil', got 'nil'")
}
}
})
}
Expand Down

0 comments on commit 9f0a343

Please sign in to comment.