diff --git a/lualib/skynet.lua b/lualib/skynet.lua index 9b18467af..df3bfa2ee 100644 --- a/lualib/skynet.lua +++ b/lualib/skynet.lua @@ -69,6 +69,38 @@ local watching_session = {} local error_queue = {} local fork_queue = { h = 1, t = 0 } +local sessionstep = 1 +local maxint32 = ((1<<31) - 1) + +local function check_session_wraparound(next_session) + sessionstep = maxint32 + 1 + while session_id_coroutine[next_session - maxint32] do + next_session = c.genid() + sessionstep + end +end + +local function auxsend(addr, id, session, msg, sz) + session = c.send(addr, id, session, msg, sz) + if not session then + return nil + end + local next_session = session + sessionstep + if next_session > maxint32 then -- wrap around + check_session_wraparound(next_session) + end + return session +end + +local function ticommand(timeout) + local session = c.intcommand("TIMEOUT", timeout) + assert(session) + local next_session = session + sessionstep + if next_session > maxint32 then -- wrap around + check_session_wraparound(next_session) + end + return session +end + do ---- request/select local function send_requests(self) local sessions = {} @@ -85,7 +117,7 @@ do ---- request/select c.trace(tag, "call", 4) c.send(addr, skynet.PTYPE_TRACE, 0, tag) end - local session = c.send(addr, p.id , nil , p.pack(tunpack(req, 3, req.n))) + local session = auxsend(addr, p.id , nil , p.pack(tunpack(req, 3, req.n))) if session == nil then err = err or {} err[#err+1] = req @@ -188,7 +220,7 @@ do ---- request/select self._error = send_requests(self) self._resp = {} if timeout then - self._timeout = c.intcommand("TIMEOUT",timeout) + self._timeout = ticommand(timeout) session_id_coroutine[self._timeout] = self._thread end @@ -373,8 +405,7 @@ end skynet.trace_timeout(false) -- turn off by default function skynet.timeout(ti, func) - local session = c.intcommand("TIMEOUT",ti) - assert(session) + local session = ticommand(ti) local co = co_create_for_timeout(func, ti) assert(session_id_coroutine[session] == nil) session_id_coroutine[session] = co @@ -392,8 +423,7 @@ local function suspend_sleep(session, token) end function skynet.sleep(ti, token) - local session = c.intcommand("TIMEOUT",ti) - assert(session) + local session = ticommand(ti) token = token or coroutine.running() local succ, ret = suspend_sleep(session, token) sleep_session[token] = nil @@ -605,7 +635,7 @@ function skynet.call(addr, typename, ...) end local p = proto[typename] - local session = c.send(addr, p.id , nil , p.pack(...)) + local session = auxsend(addr, p.id , nil , p.pack(...)) if session == nil then error("call to invalid address " .. skynet.address(addr)) end @@ -619,7 +649,7 @@ function skynet.rawcall(addr, typename, msg, sz) c.send(addr, skynet.PTYPE_TRACE, 0, tag) end local p = proto[typename] - local session = assert(c.send(addr, p.id , nil , msg, sz), "call to invalid address") + local session = assert(auxsend(addr, p.id , nil , msg, sz), "call to invalid address") return yield_call(addr, session) end @@ -627,7 +657,7 @@ function skynet.tracecall(tag, addr, typename, msg, sz) c.trace(tag, "tracecall begin") c.send(addr, skynet.PTYPE_TRACE, 0, tag) local p = proto[typename] - local session = assert(c.send(addr, p.id , nil , msg, sz), "call to invalid address") + local session = assert(auxsend(addr, p.id , nil , msg, sz), "call to invalid address") local msg, sz = yield_call(addr, session) c.trace(tag, "tracecall end") return msg, sz