Skip to content

Commit

Permalink
Fix match and party data message callbacks (#79)
Browse files Browse the repository at this point in the history
* Fix match and party data message callbacks

* Update CHANGELOG.md

* Fixed tests

* Update test.lua

* Excluded more socket messages
  • Loading branch information
britzl authored Jun 14, 2024
1 parent 905c9dd commit 80d46d4
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 32 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.

## [Unreleased]

## [3.3.0] - 2024-06-14
### Fixed
- Fixed issue with wrong argument name for `nakama.rpc_func`
- Updated several of the socket messages so that they no longer incorrectly wait for a response from the server.

## [3.2.0] - 2023-12-11
### Changed
Expand Down
48 changes: 42 additions & 6 deletions codegen/realtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,13 @@
if message.match_data then
message.match_data.data = b64.decode(message.match_data.data)
end
if message.cid then
local callback = socket.requests[message.cid]
if callback then
callback(message)
end
socket.requests[message.cid] = nil
end
for event_id,_ in pairs(message) do
if socket.events[event_id] then
socket.events[event_id](message)
Expand All @@ -28,12 +35,23 @@
if message.match_data_send and message.match_data_send.data then
message.match_data_send.data = b64.encode(message.match_data_send.data)
end
if callback then
socket.engine.socket_send(socket, message, callback)
if message.cid then
socket.requests[message.cid] = callback
socket.engine.socket_send(socket, message)
else
socket.engine.socket_send(socket, message)
callback({})
end
else
return async(function(done)
socket.engine.socket_send(socket, message, done)
if message.cid then
socket.requests[message.cid] = done
socket.engine.socket_send(socket, message)
else
socket.engine.socket_send(socket, message)
done({})
end
end)
end
end
Expand All @@ -45,6 +63,10 @@
assert(type(socket) == "table", "The created instance must be a table")
socket.client = client
socket.engine = client.engine
-- callbacks
socket.cid = 0
socket.requests = {}
-- event handlers are registered here
socket.events = {}
Expand Down Expand Up @@ -168,7 +190,6 @@ def type_to_lua(t):
elif t == "map":
return "table"
else:
print("WARNING: Unknown type '%s' - Will use type 'table'" % t)
return "table"


Expand All @@ -194,7 +215,7 @@ def parse_proto_message(message):
return properties


def message_to_lua(message_id, api):
def message_to_lua(message_id, api, wait_for_callback):
message = get_proto_message(message_id, api)
if not message:
print("Unable to find message %s" % message_id)
Expand All @@ -219,7 +240,11 @@ def message_to_lua(message_id, api):
lua = lua + " assert(socket)\n"
for prop in props:
lua = lua + " assert(%s == nil or _G.type(%s) == '%s')\n" % (prop["name"], prop["name"], prop["type"])
if wait_for_callback:
lua = lua + " socket.cid = socket.cid + 1\n"
lua = lua + " local message = {\n"
if wait_for_callback:
lua = lua + " cid = tostring(socket.cid),\n"
lua = lua + " %s = {\n" % message_id
for prop in props:
lua = lua + " %s = %s,\n" % (prop["name"], prop["name"])
Expand Down Expand Up @@ -253,17 +278,28 @@ def event_to_lua(event_id, api):


def messages_to_lua(rtapi):
# list of message names that should generate Lua code
CHANNEL_MESSAGES = [ "ChannelJoin", "ChannelLeave", "ChannelMessageSend", "ChannelMessageRemove", "ChannelMessageUpdate" ]
MATCH_MESSAGES = [ "MatchDataSend", "MatchCreate", "MatchJoin", "MatchLeave" ]
MATCHMAKER_MESSAGES = [ "MatchmakerAdd", "MatchmakerRemove" ]
PARTY_MESSAGES = [ "PartyCreate", "PartyJoin", "PartyLeave", "PartyPromote", "PartyAccept", "PartyRemove", "PartyClose", "PartyJoinRequestList", "PartyMatchmakerAdd", "PartyMatchmakerRemove", "PartyDataSend" ]
STATUS_MESSAGES = [ "StatusFollow", "StatusUnfollow", "StatusUpdate" ]
ALL_MESSAGES = CHANNEL_MESSAGES + MATCH_MESSAGES + MATCHMAKER_MESSAGES + PARTY_MESSAGES + STATUS_MESSAGES

# list of messages that do not expect a server response
CHANNEL_MESSAGES_NOCB = [ "ChannelLeave" ]
MATCH_MESSAGES_NOCB = [ "MatchLeave", "MatchDataSend"]
MATCHMAKER_MESSAGES_NOCB = [ "MatchmakerRemove" ]
PARTY_MESSAGES_NOCB = [ "PartyDataSend", "PartyAccept", "PartyClose", "PartyJoin", "PartyLeave", "PartyPromote", "PartyRemove", "PartyMatchmakerRemove" ]
STATUS_MESSAGES_NOCB = [ "StatusUnfollow", "StatusUpdate" ]

NO_CALLBACK_MESSAGES = CHANNEL_MESSAGES_NOCB + MATCH_MESSAGES_NOCB + MATCHMAKER_MESSAGES_NOCB + PARTY_MESSAGES_NOCB + STATUS_MESSAGES_NOCB

ids = []
lua = ""
for message_id in ALL_MESSAGES:
lua = lua + message_to_lua(message_id, rtapi)
wait_for_callback = (message_id not in NO_CALLBACK_MESSAGES)
lua = lua + message_to_lua(message_id, rtapi, wait_for_callback)
ids.append(message_id)

return { "ids": ids, "lua": lua }
Expand Down
22 changes: 2 additions & 20 deletions nakama/engine/defold.lua
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,6 @@ function M.socket_create(config, on_message)
local socket = {}
socket.config = config
socket.scheme = config.use_ssl and "wss" or "ws"

socket.cid = 0
socket.requests = {}
socket.on_message = on_message

return socket
Expand All @@ -159,18 +156,7 @@ end
-- internal on_message, calls user defined socket.on_message function
local function on_message(socket, message)
message = json.decode(message)
if not message.cid then
socket.on_message(socket, message)
return
end

local callback = socket.requests[message.cid]
if not callback then
log("Unable to find callback for cid", message.cid)
return
end
socket.requests[message.cid] = nil
callback(message)
socket.on_message(socket, message)
end

--- Connect a created socket using web sockets.
Expand Down Expand Up @@ -210,13 +196,9 @@ end
--- Send a socket message.
-- @param socket The socket table, see socket_create.
-- @param message The message string to send.
-- @param callback The callback function.
function M.socket_send(socket, message, callback)
function M.socket_send(socket, message)
assert(socket and socket.connection, "You must provide a socket")
assert(message, "You must provide a message to send")
socket.cid = socket.cid + 1
message.cid = tostring(socket.cid)
socket.requests[message.cid] = callback

local data = json.encode(message)
-- Fix encoding of match_create and status_update messages to send {} instead of []
Expand Down
4 changes: 1 addition & 3 deletions nakama/engine/test.lua
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,8 @@ function M.socket_connect(socket, callback)
callback(result)
end

function M.socket_send(socket, message, callback)
function M.socket_send(socket, message)
table.insert(socket_send_queue, message)
local result = {}
callback(result)
end


Expand Down
51 changes: 48 additions & 3 deletions nakama/socket.lua
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,13 @@ local function on_socket_message(socket, message)
if message.match_data then
message.match_data.data = b64.decode(message.match_data.data)
end
if message.cid then
local callback = socket.requests[message.cid]
if callback then
callback(message)
end
socket.requests[message.cid] = nil
end
for event_id,_ in pairs(message) do
if socket.events[event_id] then
socket.events[event_id](message)
Expand All @@ -22,12 +29,23 @@ local function socket_send(socket, message, callback)
if message.match_data_send and message.match_data_send.data then
message.match_data_send.data = b64.encode(message.match_data_send.data)
end

if callback then
socket.engine.socket_send(socket, message, callback)
if message.cid then
socket.requests[message.cid] = callback
socket.engine.socket_send(socket, message)
else
socket.engine.socket_send(socket, message)
callback({})
end
else
return async(function(done)
socket.engine.socket_send(socket, message, done)
if message.cid then
socket.requests[message.cid] = done
socket.engine.socket_send(socket, message)
else
socket.engine.socket_send(socket, message)
done({})
end
end)
end
end
Expand All @@ -39,6 +57,10 @@ function M.create(client)
assert(type(socket) == "table", "The created instance must be a table")
socket.client = client
socket.engine = client.engine

-- callbacks
socket.cid = 0
socket.requests = {}

-- event handlers are registered here
socket.events = {}
Expand Down Expand Up @@ -107,7 +129,9 @@ function M.channel_join(socket, target, type, persistence, hidden, callback)
assert(type == nil or _G.type(type) == 'number')
assert(persistence == nil or _G.type(persistence) == 'boolean')
assert(hidden == nil or _G.type(hidden) == 'boolean')
socket.cid = socket.cid + 1
local message = {
cid = tostring(socket.cid),
channel_join = {
target = target,
type = type,
Expand Down Expand Up @@ -142,7 +166,9 @@ function M.channel_message_send(socket, channel_id, content, callback)
assert(socket)
assert(channel_id == nil or _G.type(channel_id) == 'string')
assert(content == nil or _G.type(content) == 'string')
socket.cid = socket.cid + 1
local message = {
cid = tostring(socket.cid),
channel_message_send = {
channel_id = channel_id,
content = content,
Expand All @@ -160,7 +186,9 @@ function M.channel_message_remove(socket, channel_id, message_id, callback)
assert(socket)
assert(channel_id == nil or _G.type(channel_id) == 'string')
assert(message_id == nil or _G.type(message_id) == 'string')
socket.cid = socket.cid + 1
local message = {
cid = tostring(socket.cid),
channel_message_remove = {
channel_id = channel_id,
message_id = message_id,
Expand All @@ -180,7 +208,9 @@ function M.channel_message_update(socket, channel_id, message_id, content, callb
assert(channel_id == nil or _G.type(channel_id) == 'string')
assert(message_id == nil or _G.type(message_id) == 'string')
assert(content == nil or _G.type(content) == 'string')
socket.cid = socket.cid + 1
local message = {
cid = tostring(socket.cid),
channel_message_update = {
channel_id = channel_id,
message_id = message_id,
Expand Down Expand Up @@ -224,7 +254,9 @@ end
function M.match_create(socket, name, callback)
assert(socket)
assert(name == nil or _G.type(name) == 'string')
socket.cid = socket.cid + 1
local message = {
cid = tostring(socket.cid),
match_create = {
name = name,
}
Expand All @@ -243,7 +275,9 @@ function M.match_join(socket, match_id, token, metadata, callback)
assert(match_id == nil or _G.type(match_id) == 'string')
assert(token == nil or _G.type(token) == 'string')
assert(metadata == nil or _G.type(metadata) == 'table')
socket.cid = socket.cid + 1
local message = {
cid = tostring(socket.cid),
match_join = {
match_id = match_id,
token = token,
Expand Down Expand Up @@ -285,7 +319,9 @@ function M.matchmaker_add(socket, min_count, max_count, query, string_properties
assert(string_properties == nil or _G.type(string_properties) == 'table')
assert(numeric_properties == nil or _G.type(numeric_properties) == 'table')
assert(count_multiple == nil or _G.type(count_multiple) == 'number')
socket.cid = socket.cid + 1
local message = {
cid = tostring(socket.cid),
matchmaker_add = {
min_count = min_count,
max_count = max_count,
Expand Down Expand Up @@ -322,7 +358,9 @@ function M.party_create(socket, open, max_size, callback)
assert(socket)
assert(open == nil or _G.type(open) == 'boolean')
assert(max_size == nil or _G.type(max_size) == 'number')
socket.cid = socket.cid + 1
local message = {
cid = tostring(socket.cid),
party_create = {
open = open,
max_size = max_size,
Expand Down Expand Up @@ -437,7 +475,9 @@ end
function M.party_join_request_list(socket, party_id, callback)
assert(socket)
assert(party_id == nil or _G.type(party_id) == 'string')
socket.cid = socket.cid + 1
local message = {
cid = tostring(socket.cid),
party_join_request_list = {
party_id = party_id,
}
Expand All @@ -464,7 +504,9 @@ function M.party_matchmaker_add(socket, party_id, min_count, max_count, query, s
assert(string_properties == nil or _G.type(string_properties) == 'table')
assert(numeric_properties == nil or _G.type(numeric_properties) == 'table')
assert(count_multiple == nil or _G.type(count_multiple) == 'number')
socket.cid = socket.cid + 1
local message = {
cid = tostring(socket.cid),
party_matchmaker_add = {
party_id = party_id,
min_count = min_count,
Expand Down Expand Up @@ -526,7 +568,9 @@ function M.status_follow(socket, user_ids, usernames, callback)
assert(socket)
assert(user_ids == nil or _G.type(user_ids) == 'string')
assert(usernames == nil or _G.type(usernames) == 'string')
socket.cid = socket.cid + 1
local message = {
cid = tostring(socket.cid),
status_follow = {
user_ids = user_ids,
usernames = usernames,
Expand Down Expand Up @@ -760,3 +804,4 @@ M.ERROR_RUNTIME_FUNCTION_NOT_FOUND = 6
M.ERROR_RUNTIME_FUNCTION_EXCEPTION = 7

return M

0 comments on commit 80d46d4

Please sign in to comment.