From 771ae94ad5789c7dd5c3b443e22440a4d810964d Mon Sep 17 00:00:00 2001 From: payonel Date: Wed, 2 Sep 2015 05:38:50 -0700 Subject: [PATCH] refactor modem code for robustness --- src/component/modem.lua | 250 +++++++++++++++++++++++++++------------- src/main.lua | 2 +- 2 files changed, 174 insertions(+), 78 deletions(-) diff --git a/src/component/modem.lua b/src/component/modem.lua index 9e1e9c3..7a879ee 100644 --- a/src/component/modem.lua +++ b/src/component/modem.lua @@ -24,27 +24,26 @@ modem_host.clients = {} -- [port_number] = true when open modem_host.open_ports = {} -local function createPacket(type, address, port, ...) - -- args are not checked here (unlike the modem api methods) - -- address can be nil, which means broadcast +function modem_host.createPacketArray(t, address, port, ...) + compCheckArg(1,t,type("")) + compCheckArg(2,address,type(""),type(0)) + compCheckArg(3,port,type(0)) + local packed = { - type or "unknown_type", - address or "no address", - modem_host.id or "no sender", - port or "no port", + t, + address, + modem_host.id, + port, 0, -- distance - table.unpack(table.pack(...)) + ... } - local datagram = ser.serialize(packed) - return datagram .. '\n' + return packed end -local function parsePacket(raw) - assert(raw) - local packed = ser.unserialize(raw) - assert(packed ~= nil) +function modem_host.packetArrayToPacket(packed) + compCheckArg(1,packed,type({})) assert(#packed >= 5) local packet = {} @@ -62,7 +61,14 @@ local function parsePacket(raw) return packet end -local function packetToArray(packet) +function modem_host.packetArrayToDatagram(packed) + compCheckArg(1,packed,type({})) + + local datagram = ser.serialize(packed) + return datagram .. '\n' +end + +function modem_host.packetToPacketArray(packet) return { packet.type, @@ -70,24 +76,78 @@ local function packetToArray(packet) packet.source, packet.port, packet.distance, - table.unpack(packet.payload) + table.unpack(packet.payload or {}) } end +function modem_host.datagramToPacketArray(datagram) + compCheckArg(1,datagram,type("")) + local packed = ser.unserialize(datagram) + return packed +end + +function modem_host.datagramToPacket(datagram) + local packed = modem_host.datagramToPacketArray(datagram) + local packet = modem_host.packedToPacket(packed) + return packet +end + +function modem_host.packetToDatagram(packet) + local packed = modem_host.packetToPacketArray(packet) + local datagram = modem_host.packetArrayToDatagram(packed) + return datagram +end + +function modem_host.readDatagram(client) -- client:receive() + local raw, err = client:receive() + if raw then cprint("received: " .. raw) end + return raw, err +end + +function modem_host.readPacketArray(client) -- client:receive() + local datagram, err = modem_host.readDatagram(client) + if not datagram then return nil, err end + return modem_host.datagramToPacketArray(datagram) +end + +function modem_host.readPacket(client) -- client:receive() + local packed, err = modem_host.readPacketArray(client) + if not packed then return nil, err end + return modem_host.packetArrayToPacket(packed) +end + +function modem_host.sendDatagram(client, datagram) + cprint("sending: " .. datagram) + return client:send(datagram) +end + +function modem_host.sendPacketArray(client, packed) + local datagram = modem_host.packetArrayToDatagram(packed) + return modem_host.sendDatagram(client, datagram) +end + +function modem_host.sendPacket(client, packet) + local datagram = modem_host.packetToDatagram(packet) + return modem_host.sendDatagram(client, datagram) +end + function modem_host.broadcast(packet) -- only host broadcasts -- this method will be hit for all broadcasted messages -- but nonhosting clients will simply not repeat the broadcast if modem_host.hosting then - local plainArray = packetToArray(packet) - local datagram = ser.serialize(plainArray) + local datagram = modem_host.packetToDatagram(packet) for addr,client in pairs(modem_host.clients) do - client:send(datagram) + modem_host.sendDatagram(client, datagram) end end end function modem_host.validTarget(target) + if target == 0 then + return true + end + if target == modem_host.id then return true end @@ -106,87 +166,112 @@ function modem_host.validTarget(target) end -- backend private methods, these are not pushed to user machine environments -function modem_host.pushMessage(target, datagram) - if not modem_host.validTarget(target) then +function modem_host.pushMessage(packet) + if not modem_host.validTarget(packet.target) then return false, "invalid target, no such client listening" --ignored end - local packet = parsePacket(datagram) table.insert(modem_host.messages, packet) return true end -function modem_host.processPendingMessages() - modem_host.recvPendingMessages() +function modem_host.dispatchPacket(packet) + if packet.target == modem_host.id then + if obj.isOpen(packet.port) then + table.insert(machine.signals, modem_host.packetToPacketArray(packet)) + end + elseif modem_host.hosting then -- if hosting we will route + for source,client in pairs(modem_host.clients) do + if source == packet.target then + modem_host.sendPacket(client, packet) + break + end + end + else -- not hosting, send to host + modem_host.sendPacket(modem_host.socket, packet) + end +end +function modem_host.processPendingMessages() -- computer address seems to be applied late if not modem_host.id then modem_host.id = component.list("computer",true)() assert(modem_host.id) end - local i = 1; - while i <= #modem_host.messages do - local packet = modem_host.messages[i] - local move = true + modem_host.recvPendingMessages() + + for _,packet in pairs(modem_host.messages) do if packet.type == 'modem_message' then - -- broadcast if no target if packet.target == 0 then - modem_host.broadcast(packet) + modem_host.broadcast(packet) -- ignored by clients + -- clean up for broadcasting to self packet.target = modem_host.id end - if packet.target == modem_host.id then - if obj.isOpen(packet.port) then - table.insert(machine.signals, packetToArray(packet)) - end - move = false - end - end - - if move then - i = i + 1 - else - table.remove(modem_host.messages, i) + modem_host.dispatchPacket(packet) end end + + modem_host.messages = {} end function modem_host.recvPendingMessages() if modem_host.hosting then - while 1 do + while true do local client = modem_host.socket:accept() if not client then break; end - local handshakeDatagram, err = client:receive() - if err then + local handshake, err = modem_host.readPacket(client) -- client:receive() + if not handshake then client:close() else - client:settimeout(0, 't') - local handshake = parsePacket(handshakeDatagram) - modem_host.clients[handshake.source] = client + local connectionResponse + local accepted = false + if handshake.type ~= "handshake" then + connectionResponse = modem_host.createPacketArray("handshake", modem_host.id, -1, + false, "unsupported message type"); + elseif modem_host.validTarget(handshake.source) then -- repeated client + connectionResponse = modem_host.createPacketArray("handshake", modem_host.id, -1, + false, "computer address conflict detected, ignoring connection"); + else + client:settimeout(0, 't') + modem_host.clients[handshake.source] = client + accepted = true + + connectionResponse = modem_host.createPacketArray("handshake", modem_host.id, -1, true); + end + + modem_host.sendPacketArray(client, connectionResponse) + + if not accepted then + client:close() + end end end -- recv all pending packets for source, client in pairs(modem_host.clients) do - local line, err = client:receive() - if not err then - modem_host.pushMessage(source, line) + local packet, err = modem_host.readPacket(client) + if packet then + modem_host.pushMessage(packet) + elseif err ~= "timeout" then + client:close() + modem_host.clients[source] = nil end end elseif modem_host.socket then - while 1 do - local line, err = modem_host.socket:receive() - if not err then - modem_host.pushMessage(modem_host.id, line) + while true do + local packet, err = modem_host.readPacket(modem_host.socket) + if packet then + modem_host.pushMessage(packet) else break end @@ -210,8 +295,18 @@ function modem_host.joinExistingMessageBoard() modem_host.hosting = nil -- send handshake data - local datagram = createPacket("client_handshake") - modem_host.send(datagram) + local packed = modem_host.createPacketArray("handshake", 0, -1) + local sendResult = modem_host.sendPacketArray(modem_host.socket, packed) + + local response, why = modem_host.readPacket(modem_host.socket) + assert(response) + assert(response.payload) + + if not response.payload[1] then + modem_host.socket:close() + modem_host.socket = nil + return false, response.payload[2], true + end end return modem_host.socket, why end @@ -221,12 +316,18 @@ function modem_host.connectMessageBoard() return true end - local ok, reason = - modem_host.joinExistingMessageBoard() or - modem_host.createNewMessageBoard() + local ok, info, critical = modem_host.joinExistingMessageBoard() + + if not ok and critical then + return nil, info + end if not ok then - return nil, reason + ok, info = modem_host.createNewMessageBoard() + end + + if not ok then + return nil, info end modem_host.socket:settimeout(0, 't') -- accept calls must be already pending @@ -237,15 +338,6 @@ function modem_host.connectMessageBoard() return true end -function modem_host.send(datagram) - -- if we are the host, we simply call pushMessage directly - if modem_host.hosting then - return modem_host.pushMessage(modem_host.id, datagram) - else - return not not modem_host.socket:send(datagram) - end -end - local wakeMessage local strength if wireless then @@ -265,8 +357,9 @@ function obj.send(address, port, ...) -- Sends the specified data to the specifi compCheckArg(2,port,"number") port=checkPort(port) - local datagram = createPacket("modem_message", address, port, ...) - return modem_host.send(datagram) + local packed = modem_host.createPacketArray("modem_message", address, port, ...) + local packet = modem_host.packetArrayToPacket(packed) + return modem_host.dispatchPacket(packet) end function obj.getWakeMessage() -- Get the current wake-up message. @@ -324,12 +417,14 @@ function obj.open(port) -- Opens the specified port. Returns true if the port wa port=checkPort(port) if obj.isOpen(port) then - return false + return false, "port already open" end -- make sure we are connected to the message board - if not modem_host.connectMessageBoard() then - return false + local ok, why = modem_host.connectMessageBoard() + + if not ok then + return false, why end modem_host.open_ports[port] = true @@ -351,8 +446,9 @@ function obj.broadcast(port, ...) -- Broadcasts the specified data on the specif return false end - local datagram = createPacket("modem_message", 0, port, ...) - return modem_host.send(datagram) + local packed = modem_host.createPacketArray("modem_message", 0, port, ...) + local packet = modem_host.packetArrayToPacket(packed) + return modem_host.dispatchPacket(packet) end local cec = {} diff --git a/src/main.lua b/src/main.lua index 7a164be..ca8beb6 100644 --- a/src/main.lua +++ b/src/main.lua @@ -49,11 +49,11 @@ if settings.components == nil then -- Read component files for parameter documentation settings.components = { {"gpu",nil,0,160,50,3}, + {"modem",nil,1,false}, {"eeprom",nil,9,"lua/bios.lua"}, {"filesystem",nil,7,"loot/OpenOS",true}, {"filesystem",nil,nil,"tmpfs",false}, {"filesystem",nil,5,nil,false}, - {"modem",nil,nil,false}, {"internet"}, {"computer"}, {"ocemu"},