diff --git a/src/component/internet.lua b/src/component/internet.lua index e4dd317..77ab19e 100644 --- a/src/component/internet.lua +++ b/src/component/internet.lua @@ -5,6 +5,7 @@ if not okay then cprint("Cannot use internet component: " .. socket) return end +require("support.http_patch") local url = require("socket.url") local okay, http = pcall(require, "ssl.https") if not okay then @@ -53,6 +54,7 @@ function obj.connect(address, port) -- Opens a new TCP connection. Returns the h -- TODO: not OC behaviour, but needed to prevent hanging client:settimeout(10) local connected = false + local closed = false local function connect() cprint("(socket) connect",host,port) local did, err = client:connect(host,port) @@ -60,39 +62,49 @@ function obj.connect(address, port) -- Opens a new TCP connection. Returns the h if did then connected = true client:settimeout(0) + else + pcall(client.close,client) + closed = true end end local fakesocket = { read = function(n) cprint("(socket) read",n) - -- TODO: Error handling + -- TODO: Better Error handling + if closed then return nil, "connection lost" end if not connected then connect() return "" end if type(n) ~= "number" then n = math.huge end local data, err, part = client:receive(n) if err == nil or err == "timeout" or part ~= "" then return data or part else + if err == "closed" then closed = true err = "connection lost" end return nil, err end end, write = function(data) cprint("(socket) write",data) + -- TODO: Better Error handling + if closed then return nil, "connection lost" end if not connected then connect() return 0 end checkArg(1,data,"string") local data, err, part = client:send(data) if err == nil or err == "timeout" or part ~= 0 then return data or part else + if err == "closed" then closed = true err = "connection lost" end return nil, err end end, close = function() cprint("(socket) close") pcall(client.close,client) + closed = true end, finishConnect = function() cprint("(socket) finishConnect") -- TODO: Does this actually error? + if closed then return nil, "connection lost" end return connected end } @@ -109,15 +121,43 @@ function obj.request(url, postData) -- Starts an HTTP request. If this returns t postData = nil end -- TODO: This works ... but is slow. - local page, _, headers, status = http.request(url, postData) - local protocol, code, message = status:match("(.-) (.-) (.*)") - code = tonumber(code) + -- TODO: Infact so slow, it can trigger the machine's sethook, so we have to work around that. + local hookf,hookm,hookc = debug.gethook() + local co = coroutine.running() + debug.sethook(co) + local page, err, headers, status = http.request(url, postData) + debug.sethook(co,hookf,hookm,hookc) + if not page then + cprint("(request) request failed",err) + end + -- Experimental fix for headers + if headers ~= nil then + local oldheaders = headers + headers = {} + for k,v in pairs(oldheaders) do + local name = k:gsub("^.",string.upper):gsub("%-.",string.upper) + if type(v) == "table" then + v.n = #v + headers[name] = v + else + headers[name] = {v,n=1} + end + end + end + local procotol, code, message + if status then + protocol, code, message = status:match("(.-) (.-) (.*)") + code = tonumber(code) + end + local closed = false local fakesocket = { read = function(n) cprint("(socket) read",n) -- OC doesn't actually return n bytes when requested. - if page == nil then + if closed then return nil, "connection lost" + elseif headers == nil then + return nil, "Connection refused" elseif page == "" then return nil else @@ -129,14 +169,23 @@ function obj.request(url, postData) -- Starts an HTTP request. If this returns t end, response = function() cprint("(socket) response") + if headers == nil then + return nil + end return code, message, headers end, close = function() cprint("(request) close") + closed = true page = nil end, finishConnect = function() cprint("(socket) finishConnect") + if closed then + return nil, "connection lost" + elseif headers == nil then + return nil, "Connection refused" + end return true end } diff --git a/src/support/http_patch.lua b/src/support/http_patch.lua new file mode 100644 index 0000000..090768b --- /dev/null +++ b/src/support/http_patch.lua @@ -0,0 +1,43 @@ +-- Welcome to hack town! +-- Patch luasocket's http library to be less stupid +local function gsub_escape(str) + return str:gsub("[%(%)%.%%%+%-%*%?%[%]%^%$]", "%%%0").."" +end +cprint("http_patch start") +-- Patch data +local patches = { + {[[if headers[name] then headers[name] = headers[name] .. ", " .. value]],[[if headers[name] then if type(headers[name]) == "string" then headers[name] = {headers[name]} end headers[name][#headers[name]+1] = value]]}, +} +package.loaded["socket.http"] = nil +local path = package.searchpath("socket.http",package.path) +if path then + local file, err = io.open(path,"rb") + if not file then + cprint("Failed to patch socket.http: " .. err) + return + end + local data = file:read("*a") + file:close() + for i = 1,#patches do + local newdata = data:gsub(gsub_escape(patches[i][1]), (patches[i][2]:gsub("%%","%%%%").."")) + if newdata == data then + cprint("Patch " .. i .. " failed") + else + data = newdata + end + end + local fn, err = load(data,"="..path) + if not fn then + cprint("Failed to compile socket.http: " .. err) + return + end + local ok, err = pcall(fn) + if not ok then + cprint("Failed to load socket.http: " .. err) + return + end + package.loaded["socket.http"] = err +else + cprint("Could not find socket.http") +end +cprint("http_patch end")