Skip to content

Commit

Permalink
Merge pull request #8 from sony/feature/pure_websocket
Browse files Browse the repository at this point in the history
pure websocket support
  • Loading branch information
KazuhitoNakamura authored Mar 22, 2021
2 parents 75608c9 + ce1b3f9 commit 9b2af41
Show file tree
Hide file tree
Showing 11 changed files with 963 additions and 65 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
appsync-client-go
=================

[![Go Reference](https://pkg.go.dev/badge/github.com/sony/appsync-client-go.svg)](https://pkg.go.dev/github.com/sony/appsync-client-go)
[![Job Status](https://inspecode.rocro.com/badges/github.com/sony/appsync-client-go/status?token=VN4s0UD-m44_nY-nP0kSWRE3aVQiTg4UY2oTpm8r_Zc&branch=master)](https://inspecode.rocro.com/jobs/github.com/sony/appsync-client-go/latest?completed=true&branch=master)
[![Report](https://inspecode.rocro.com/badges/github.com/sony/appsync-client-go/report?token=VN4s0UD-m44_nY-nP0kSWRE3aVQiTg4UY2oTpm8r_Zc&branch=master)](https://inspecode.rocro.com/reports/github.com/sony/appsync-client-go/branch/master/summary)
[![Job Status](https://docstand.rocro.com/badges/github.com/sony/appsync-client-go/status?token=VN4s0UD-m44_nY-nP0kSWRE3aVQiTg4UY2oTpm8r_Zc&branch=master)](https://docstand.rocro.com/jobs/github.com/sony/appsync-client-go/latest?completed=true&branch=master)
[![godoc](https://docstand.rocro.com/badges/github.com/sony/appsync-client-go/documentation/godoc?token=VN4s0UD-m44_nY-nP0kSWRE3aVQiTg4UY2oTpm8r_Zc&branch=master)](https://docstand.rocro.com/docs/github.com/sony/appsync-client-go/branch/master/godoc/github.com/sony/appsync-client-go/)


[AWS AppSync](https://aws.amazon.com/jp/appsync/) Go client library

Expand All @@ -13,6 +13,7 @@ Features

* GraphQL Query(Queries, Mutations and Subscriptions).
* MQTT over Websocket for subscriptions.
* Pure Websockets subscriptions.

Getting Started
---------------
Expand Down
46 changes: 45 additions & 1 deletion appsync_example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"encoding/json"
"fmt"
"log"
"strings"

"github.com/sony/appsync-client-go/internal/appsynctest"

Expand Down Expand Up @@ -59,7 +60,7 @@ func ExampleClient_Post_mutation() {
// Hi, AppSync!
}

func ExampleClient_Post_subscription() {
func ExampleClient_mqtt_subscription() {
server := appsynctest.NewAppSyncEchoServer()
defer server.Close()

Expand Down Expand Up @@ -108,3 +109,46 @@ func ExampleClient_Post_subscription() {
// Output:
// Hi, AppSync!
}

func ExampleClient_graphqlws_subscription() {
server := appsynctest.NewAppSyncEchoServer()
defer server.Close()

client := appsync.NewClient(appsync.NewGraphQLClient(graphql.NewClient(server.URL)))
subscription := `subscription SubscribeToEcho() { subscribeToEcho }`

ch := make(chan *graphql.Response)
subscriber := appsync.NewPureWebSocketSubscriber(
strings.Replace(server.URL, "http", "ws", 1),
graphql.PostRequest{
Query: subscription,
},
func(r *graphql.Response) { ch <- r },
func(err error) { log.Println(err) },
)

if err := subscriber.Start(); err != nil {
log.Fatalln(err)
}
defer subscriber.Stop()

mutation := `mutation Echo($message: String!) { echo(message: $message) }`
variables := json.RawMessage(fmt.Sprintf(`{ "message": "%s" }`, "Hi, AppSync!"))
_, err := client.Post(graphql.PostRequest{
Query: mutation,
Variables: &variables,
})
if err != nil {
log.Fatal(err)
}

response := <-ch
data := new(string)
if err := response.DataAs(data); err != nil {
log.Fatalln(err, response)
}
fmt.Println(*data)

// Output:
// Hi, AppSync!
}
File renamed without changes.
File renamed without changes.
2 changes: 0 additions & 2 deletions graphql/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"net/url"
"time"

v4 "github.com/aws/aws-sdk-go/aws/signer/v4"
"github.com/cenkalti/backoff"
)

Expand All @@ -32,7 +31,6 @@ type Client struct {
timeout time.Duration
maxElapsedTime time.Duration
header http.Header
signer *v4.Signer
}

// NewClient returns a Client instance.
Expand Down
179 changes: 140 additions & 39 deletions internal/appsynctest/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,56 +21,98 @@ var (
mqttEchoTopic = "echo"
schema = `
schema {
query: Query
mutation: Mutation
subscription: Subscription
query: Query
mutation: Mutation
subscription: Subscription
}
type Query {
message: String!
message: String!
}
type Mutation {
echo(message: String!): String!
echo(message: String!): String!
}
type Subscription {
subscribeToEcho: String!
subscribeToEcho: String!
}
`
initialMessage = "Hello, AppSync!"

gqlwsconnack = `
{
"type": "connection_ack",
"payload" : {
"connectionTimeoutMs": 300000
}
}
`
gqlwsstartackfmt = `
{
"type": "start_ack",
"id" : "%s"
}
`
gqlwsdatafmt = `
{
"type": "data",
"id" : "%s",
"payload": %s
}
`
gqlwscompletefmt = `
{
"type": "complete",
"id" : "%s"
}
`
)

type mqttPublisher struct {
w http.ResponseWriter
sessions mqttSessions
w http.ResponseWriter
mqttSessions mqttSessions
grapqhWsSessions grapqhWsSessions
}

func (m *mqttPublisher) Header() http.Header {
return m.w.Header()
}

func (m *mqttPublisher) Write(payload []byte) (int, error) {
go func() {
pub := packets.NewControlPacket(packets.Publish).(*packets.PublishPacket)
pub.TopicName = mqttEchoTopic
pub.Payload = payload
for sub := range m.sessions {
writer, err := sub.NextWriter(websocket.BinaryMessage)
if err != nil {
log.Println(err)
continue
}
if err := pub.Write(writer); err != nil {
log.Println(err)
continue
if len(m.mqttSessions) != 0 {
go func() {
pub := packets.NewControlPacket(packets.Publish).(*packets.PublishPacket)
pub.TopicName = mqttEchoTopic
pub.Payload = payload
for s := range m.mqttSessions {
writer, err := s.NextWriter(websocket.BinaryMessage)
if err != nil {
log.Println(err)
continue
}
if err := pub.Write(writer); err != nil {
log.Println(err)
continue
}
if err := writer.Close(); err != nil {
log.Println(err)
continue
}
}
if err := writer.Close(); err != nil {
log.Println(err)
continue
}()
}
if len(m.grapqhWsSessions) != 0 {
go func() {
for s := range m.grapqhWsSessions {
data := json.RawMessage(fmt.Sprintf(gqlwsdatafmt, "id", string(payload)))
if err := s.WriteJSON(data); err != nil {
log.Println(err)
continue
}
}
}
}()
}()
}
return m.w.Write(payload)
}

Expand All @@ -95,21 +137,21 @@ func (e *echoResolver) SubscribeToEcho() string {
return e.message
}

func mqttWsSession(ws *websocket.Conn, onConnected func(ws *websocket.Conn, clientId string), onDisconnected func(ws *websocket.Conn)) {
func mqttWsSession(ws *websocket.Conn, onConnected func(ws *websocket.Conn), onDisconnected func(ws *websocket.Conn)) {
defer func() {
if err := ws.Close(); err != nil {
log.Println(err)
}
}()

for {
mt, reader, err := ws.NextReader()
mt, r, err := ws.NextReader()
if err != nil {
log.Println(err)
return
}

cp, err := packets.ReadPacket(reader)
cp, err := packets.ReadPacket(r)
if err != nil {
log.Println(err)
return
Expand All @@ -119,7 +161,7 @@ func mqttWsSession(ws *websocket.Conn, onConnected func(ws *websocket.Conn, clie
switch cp.(type) {
case *packets.ConnectPacket:
ack = packets.NewControlPacket(packets.Connack)
onConnected(ws, cp.(*packets.ConnectPacket).ClientIdentifier)
onConnected(ws)
case *packets.SubscribePacket:
ack = packets.NewControlPacket(packets.Suback)
ack.(*packets.SubackPacket).MessageID = cp.(*packets.SubscribePacket).MessageID
Expand Down Expand Up @@ -150,15 +192,46 @@ func mqttWsSession(ws *websocket.Conn, onConnected func(ws *websocket.Conn, clie
}
}

func graphQLWsSession(ws *websocket.Conn, onConnected func(ws *websocket.Conn), onDisconnected func(ws *websocket.Conn)) {
defer func() {
if err := ws.Close(); err != nil {
log.Println(err)
}
}()

for {
msg := map[string]interface{}{}
if err := ws.ReadJSON(&msg); err != nil {
return
}

var ack json.RawMessage
switch msg["type"].(string) {
case "connection_init":
ack = json.RawMessage(gqlwsconnack)
onConnected(ws)
case "start":
ack = json.RawMessage(fmt.Sprintf(gqlwsstartackfmt, msg["id"].(string)))
case "stop":
ack = json.RawMessage(fmt.Sprintf(gqlwscompletefmt, msg["id"].(string)))
onDisconnected(ws)
}
if err := ws.WriteJSON(ack); err != nil {
log.Println(err)
return
}
}
}

func newQueryHandlerFunc(h relay.Handler) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
h.ServeHTTP(w, r)
}
}

func newMutationHandlerFunc(h relay.Handler, sessions mqttSessions) http.HandlerFunc {
func newMutationHandlerFunc(h relay.Handler, mqtt mqttSessions, graphqlws grapqhWsSessions) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
h.ServeHTTP(&mqttPublisher{w, sessions}, r)
h.ServeHTTP(&mqttPublisher{w, mqtt, graphqlws}, r)
}
}

Expand Down Expand Up @@ -200,10 +273,16 @@ func newSubscriptionHandlerFunc() http.HandlerFunc {
}

func isMqttWs(r *http.Request) bool {
return r.Method == http.MethodGet && r.Header.Get("Upgrade") == "websocket"
return r.Method == http.MethodGet && r.Header.Get("Upgrade") == "websocket" &&
r.Header.Get("Sec-Websocket-Protocol") == "mqtt"
}

func newMqttWsHandlerFunc(onConnected func(ws *websocket.Conn, clientId string), onDisconnected func(ws *websocket.Conn)) http.HandlerFunc {
func isGraphQLWs(r *http.Request) bool {
return r.Method == http.MethodGet && r.Header.Get("Upgrade") == "websocket" &&
r.Header.Get("Sec-Websocket-Protocol") == "graphql-ws"
}

func newMqttWsHandlerFunc(onConnected func(ws *websocket.Conn), onDisconnected func(ws *websocket.Conn)) http.HandlerFunc {
upgrader := websocket.Upgrader{}
return func(w http.ResponseWriter, r *http.Request) {
ws, err := upgrader.Upgrade(w, r, nil)
Expand All @@ -214,18 +293,35 @@ func newMqttWsHandlerFunc(onConnected func(ws *websocket.Conn, clientId string),
}
}

type mqttSessions map[*websocket.Conn]string
type mqttSessions map[*websocket.Conn]bool
type grapqhWsSessions map[*websocket.Conn]bool

func newGraphQLWsHandlerFunc(onConnected func(ws *websocket.Conn), onDisconnected func(ws *websocket.Conn)) http.HandlerFunc {
upgrader := websocket.Upgrader{}
return func(w http.ResponseWriter, r *http.Request) {
ws, err := upgrader.Upgrade(w, r, nil)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
}
go graphQLWsSession(ws, onConnected, onDisconnected)
}
}

func newAppSyncEchoHandlerFunc(initialMessage string) http.HandlerFunc {
s := graphqlgo.MustParseSchema(schema, &echoResolver{initialMessage})
handler := relay.Handler{Schema: s}
sessions := make(mqttSessions)
mqttSessions := make(mqttSessions)
grapqhWsSessions := make(grapqhWsSessions)
query := newQueryHandlerFunc(handler)
mutation := newMutationHandlerFunc(handler, sessions)
mutation := newMutationHandlerFunc(handler, mqttSessions, grapqhWsSessions)
subscription := newSubscriptionHandlerFunc()
mqttws := newMqttWsHandlerFunc(
func(ws *websocket.Conn, clientId string) { sessions[ws] = clientId },
func(ws *websocket.Conn) { delete(sessions, ws) },
func(ws *websocket.Conn) { mqttSessions[ws] = true },
func(ws *websocket.Conn) { delete(mqttSessions, ws) },
)
graphqlws := newGraphQLWsHandlerFunc(
func(ws *websocket.Conn) { grapqhWsSessions[ws] = true },
func(ws *websocket.Conn) { delete(grapqhWsSessions, ws) },
)
return func(w http.ResponseWriter, r *http.Request) {
body, err := ioutil.ReadAll(r.Body)
Expand All @@ -242,6 +338,11 @@ func newAppSyncEchoHandlerFunc(initialMessage string) http.HandlerFunc {
return
}

if isGraphQLWs(r) {
graphqlws.ServeHTTP(w, r)
return
}

req := new(graphql.PostRequest)
if err := json.Unmarshal(body, req); err != nil {
log.Println(err)
Expand Down
Loading

0 comments on commit 9b2af41

Please sign in to comment.