Skip to content

Commit

Permalink
handle session wrap around
Browse files Browse the repository at this point in the history
  • Loading branch information
t0350 committed Sep 27, 2023
1 parent c178c39 commit 6b23784
Showing 1 changed file with 39 additions and 9 deletions.
48 changes: 39 additions & 9 deletions lualib/skynet.lua
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,38 @@ local watching_session = {}
local error_queue = {}
local fork_queue = { h = 1, t = 0 }

local sessionstep = 1
local maxint32 = ((1<<31) - 1)

local function check_session_wraparound(next_session)
sessionstep = maxint32 + 1
while session_id_coroutine[next_session - maxint32] do
next_session = c.genid() + sessionstep
end
end

local function auxsend(addr, id, session, msg, sz)
session = c.send(addr, id, session, msg, sz)
if not session then
return nil
end
local next_session = session + sessionstep
if next_session > maxint32 then -- wrap around
check_session_wraparound(next_session)
end
return session
end

local function ticommand(timeout)
local session = c.intcommand("TIMEOUT", timeout)
assert(session)
local next_session = session + sessionstep
if next_session > maxint32 then -- wrap around
check_session_wraparound(next_session)
end
return session
end

do ---- request/select
local function send_requests(self)
local sessions = {}
Expand All @@ -85,7 +117,7 @@ do ---- request/select
c.trace(tag, "call", 4)
c.send(addr, skynet.PTYPE_TRACE, 0, tag)
end
local session = c.send(addr, p.id , nil , p.pack(tunpack(req, 3, req.n)))
local session = auxsend(addr, p.id , nil , p.pack(tunpack(req, 3, req.n)))
if session == nil then
err = err or {}
err[#err+1] = req
Expand Down Expand Up @@ -188,7 +220,7 @@ do ---- request/select
self._error = send_requests(self)
self._resp = {}
if timeout then
self._timeout = c.intcommand("TIMEOUT",timeout)
self._timeout = ticommand(timeout)
session_id_coroutine[self._timeout] = self._thread
end

Expand Down Expand Up @@ -373,8 +405,7 @@ end
skynet.trace_timeout(false) -- turn off by default

function skynet.timeout(ti, func)
local session = c.intcommand("TIMEOUT",ti)
assert(session)
local session = ticommand(ti)
local co = co_create_for_timeout(func, ti)
assert(session_id_coroutine[session] == nil)
session_id_coroutine[session] = co
Expand All @@ -392,8 +423,7 @@ local function suspend_sleep(session, token)
end

function skynet.sleep(ti, token)
local session = c.intcommand("TIMEOUT",ti)
assert(session)
local session = ticommand(ti)
token = token or coroutine.running()
local succ, ret = suspend_sleep(session, token)
sleep_session[token] = nil
Expand Down Expand Up @@ -605,7 +635,7 @@ function skynet.call(addr, typename, ...)
end

local p = proto[typename]
local session = c.send(addr, p.id , nil , p.pack(...))
local session = auxsend(addr, p.id , nil , p.pack(...))
if session == nil then
error("call to invalid address " .. skynet.address(addr))
end
Expand All @@ -619,15 +649,15 @@ function skynet.rawcall(addr, typename, msg, sz)
c.send(addr, skynet.PTYPE_TRACE, 0, tag)
end
local p = proto[typename]
local session = assert(c.send(addr, p.id , nil , msg, sz), "call to invalid address")
local session = assert(auxsend(addr, p.id , nil , msg, sz), "call to invalid address")
return yield_call(addr, session)
end

function skynet.tracecall(tag, addr, typename, msg, sz)
c.trace(tag, "tracecall begin")
c.send(addr, skynet.PTYPE_TRACE, 0, tag)
local p = proto[typename]
local session = assert(c.send(addr, p.id , nil , msg, sz), "call to invalid address")
local session = assert(auxsend(addr, p.id , nil , msg, sz), "call to invalid address")
local msg, sz = yield_call(addr, session)
c.trace(tag, "tracecall end")
return msg, sz
Expand Down

0 comments on commit 6b23784

Please sign in to comment.