| ----------------------------------------------------------------------------- |
| -- A hacked dispatcher module |
| -- LuaSocket sample files |
| -- Author: Diego Nehab |
| -- RCS ID: $$ |
| ----------------------------------------------------------------------------- |
| local base = _G |
| local table = require("table") |
| local socket = require("socket") |
| local coroutine = require("coroutine") |
| module("dispatch") |
| |
| -- if too much time goes by without any activity in one of our sockets, we |
| -- just kill it |
| TIMEOUT = 60 |
| |
| ----------------------------------------------------------------------------- |
| -- We implement 3 types of dispatchers: |
| -- sequential |
| -- coroutine |
| -- threaded |
| -- The user can choose whatever one is needed |
| ----------------------------------------------------------------------------- |
| local handlert = {} |
| |
| -- default handler is coroutine |
| function newhandler(mode) |
| mode = mode or "coroutine" |
| return handlert[mode]() |
| end |
| |
| local function seqstart(self, func) |
| return func() |
| end |
| |
| -- sequential handler simply calls the functions and doesn't wrap I/O |
| function handlert.sequential() |
| return { |
| tcp = socket.tcp, |
| start = seqstart |
| } |
| end |
| |
| ----------------------------------------------------------------------------- |
| -- Mega hack. Don't try to do this at home. |
| ----------------------------------------------------------------------------- |
| -- we can't yield across calls to protect, so we rewrite it with coxpcall |
| -- make sure you don't require any module that uses socket.protect before |
| -- loading our hack |
| function socket.protect(f) |
| return function(...) |
| local co = coroutine.create(f) |
| while true do |
| local results = {coroutine.resume(co, base.unpack(arg))} |
| local status = table.remove(results, 1) |
| if not status then |
| if type(results[1]) == 'table' then |
| return nil, results[1][1] |
| else base.error(results[1]) end |
| end |
| if coroutine.status(co) == "suspended" then |
| arg = {coroutine.yield(base.unpack(results))} |
| else |
| return base.unpack(results) |
| end |
| end |
| end |
| end |
| |
| ----------------------------------------------------------------------------- |
| -- Simple set data structure. O(1) everything. |
| ----------------------------------------------------------------------------- |
| local function newset() |
| local reverse = {} |
| local set = {} |
| return base.setmetatable(set, {__index = { |
| insert = function(set, value) |
| if not reverse[value] then |
| table.insert(set, value) |
| reverse[value] = table.getn(set) |
| end |
| end, |
| remove = function(set, value) |
| local index = reverse[value] |
| if index then |
| reverse[value] = nil |
| local top = table.remove(set) |
| if top ~= value then |
| reverse[top] = index |
| set[index] = top |
| end |
| end |
| end |
| }}) |
| end |
| |
| ----------------------------------------------------------------------------- |
| -- socket.tcp() wrapper for the coroutine dispatcher |
| ----------------------------------------------------------------------------- |
| local function cowrap(dispatcher, tcp, error) |
| if not tcp then return nil, error end |
| -- put it in non-blocking mode right away |
| tcp:settimeout(0) |
| -- metatable for wrap produces new methods on demand for those that we |
| -- don't override explicitly. |
| local metat = { __index = function(table, key) |
| table[key] = function(...) |
| arg[1] = tcp |
| return tcp[key](base.unpack(arg)) |
| end |
| return table[key] |
| end} |
| -- does our user want to do his own non-blocking I/O? |
| local zero = false |
| -- create a wrap object that will behave just like a real socket object |
| local wrap = { } |
| -- we ignore settimeout to preserve our 0 timeout, but record whether |
| -- the user wants to do his own non-blocking I/O |
| function wrap:settimeout(value, mode) |
| if value == 0 then zero = true |
| else zero = false end |
| return 1 |
| end |
| -- send in non-blocking mode and yield on timeout |
| function wrap:send(data, first, last) |
| first = (first or 1) - 1 |
| local result, error |
| while true do |
| -- return control to dispatcher and tell it we want to send |
| -- if upon return the dispatcher tells us we timed out, |
| -- return an error to whoever called us |
| if coroutine.yield(dispatcher.sending, tcp) == "timeout" then |
| return nil, "timeout" |
| end |
| -- try sending |
| result, error, first = tcp:send(data, first+1, last) |
| -- if we are done, or there was an unexpected error, |
| -- break away from loop |
| if error ~= "timeout" then return result, error, first end |
| end |
| end |
| -- receive in non-blocking mode and yield on timeout |
| -- or simply return partial read, if user requested timeout = 0 |
| function wrap:receive(pattern, partial) |
| local error = "timeout" |
| local value |
| while true do |
| -- return control to dispatcher and tell it we want to receive |
| -- if upon return the dispatcher tells us we timed out, |
| -- return an error to whoever called us |
| if coroutine.yield(dispatcher.receiving, tcp) == "timeout" then |
| return nil, "timeout" |
| end |
| -- try receiving |
| value, error, partial = tcp:receive(pattern, partial) |
| -- if we are done, or there was an unexpected error, |
| -- break away from loop. also, if the user requested |
| -- zero timeout, return all we got |
| if (error ~= "timeout") or zero then |
| return value, error, partial |
| end |
| end |
| end |
| -- connect in non-blocking mode and yield on timeout |
| function wrap:connect(host, port) |
| local result, error = tcp:connect(host, port) |
| if error == "timeout" then |
| -- return control to dispatcher. we will be writable when |
| -- connection succeeds. |
| -- if upon return the dispatcher tells us we have a |
| -- timeout, just abort |
| if coroutine.yield(dispatcher.sending, tcp) == "timeout" then |
| return nil, "timeout" |
| end |
| -- when we come back, check if connection was successful |
| result, error = tcp:connect(host, port) |
| if result or error == "already connected" then return 1 |
| else return nil, "non-blocking connect failed" end |
| else return result, error end |
| end |
| -- accept in non-blocking mode and yield on timeout |
| function wrap:accept() |
| while 1 do |
| -- return control to dispatcher. we will be readable when a |
| -- connection arrives. |
| -- if upon return the dispatcher tells us we have a |
| -- timeout, just abort |
| if coroutine.yield(dispatcher.receiving, tcp) == "timeout" then |
| return nil, "timeout" |
| end |
| local client, error = tcp:accept() |
| if error ~= "timeout" then |
| return cowrap(dispatcher, client, error) |
| end |
| end |
| end |
| -- remove cortn from context |
| function wrap:close() |
| dispatcher.stamp[tcp] = nil |
| dispatcher.sending.set:remove(tcp) |
| dispatcher.sending.cortn[tcp] = nil |
| dispatcher.receiving.set:remove(tcp) |
| dispatcher.receiving.cortn[tcp] = nil |
| return tcp:close() |
| end |
| return base.setmetatable(wrap, metat) |
| end |
| |
| |
| ----------------------------------------------------------------------------- |
| -- Our coroutine dispatcher |
| ----------------------------------------------------------------------------- |
| local cometat = { __index = {} } |
| |
| function schedule(cortn, status, operation, tcp) |
| if status then |
| if cortn and operation then |
| operation.set:insert(tcp) |
| operation.cortn[tcp] = cortn |
| operation.stamp[tcp] = socket.gettime() |
| end |
| else base.error(operation) end |
| end |
| |
| function kick(operation, tcp) |
| operation.cortn[tcp] = nil |
| operation.set:remove(tcp) |
| end |
| |
| function wakeup(operation, tcp) |
| local cortn = operation.cortn[tcp] |
| -- if cortn is still valid, wake it up |
| if cortn then |
| kick(operation, tcp) |
| return cortn, coroutine.resume(cortn) |
| -- othrewise, just get scheduler not to do anything |
| else |
| return nil, true |
| end |
| end |
| |
| function abort(operation, tcp) |
| local cortn = operation.cortn[tcp] |
| if cortn then |
| kick(operation, tcp) |
| coroutine.resume(cortn, "timeout") |
| end |
| end |
| |
| -- step through all active cortns |
| function cometat.__index:step() |
| -- check which sockets are interesting and act on them |
| local readable, writable = socket.select(self.receiving.set, |
| self.sending.set, 1) |
| -- for all readable connections, resume their cortns and reschedule |
| -- when they yield back to us |
| for _, tcp in base.ipairs(readable) do |
| schedule(wakeup(self.receiving, tcp)) |
| end |
| -- for all writable connections, do the same |
| for _, tcp in base.ipairs(writable) do |
| schedule(wakeup(self.sending, tcp)) |
| end |
| -- politely ask replacement I/O functions in idle cortns to |
| -- return reporting a timeout |
| local now = socket.gettime() |
| for tcp, stamp in base.pairs(self.stamp) do |
| if tcp.class == "tcp{client}" and now - stamp > TIMEOUT then |
| abort(self.sending, tcp) |
| abort(self.receiving, tcp) |
| end |
| end |
| end |
| |
| function cometat.__index:start(func) |
| local cortn = coroutine.create(func) |
| schedule(cortn, coroutine.resume(cortn)) |
| end |
| |
| function handlert.coroutine() |
| local stamp = {} |
| local dispatcher = { |
| stamp = stamp, |
| sending = { |
| name = "sending", |
| set = newset(), |
| cortn = {}, |
| stamp = stamp |
| }, |
| receiving = { |
| name = "receiving", |
| set = newset(), |
| cortn = {}, |
| stamp = stamp |
| }, |
| } |
| function dispatcher.tcp() |
| return cowrap(dispatcher, socket.tcp()) |
| end |
| return base.setmetatable(dispatcher, cometat) |
| end |
| |