| ----------------------------------------------------------------------------- |
| -- TFTP support for the Lua language |
| -- LuaSocket toolkit. |
| -- Author: Diego Nehab |
| -- RCS ID: $Id: tftp.lua,v 1.16 2005/11/22 08:33:29 diego Exp $ |
| ----------------------------------------------------------------------------- |
| |
| ----------------------------------------------------------------------------- |
| -- Load required files |
| ----------------------------------------------------------------------------- |
| local base = _G |
| local table = require("table") |
| local math = require("math") |
| local string = require("string") |
| local socket = require("socket") |
| local ltn12 = require("ltn12") |
| local url = require("socket.url") |
| module("socket.tftp") |
| |
| ----------------------------------------------------------------------------- |
| -- Program constants |
| ----------------------------------------------------------------------------- |
| local char = string.char |
| local byte = string.byte |
| |
| PORT = 69 |
| local OP_RRQ = 1 |
| local OP_WRQ = 2 |
| local OP_DATA = 3 |
| local OP_ACK = 4 |
| local OP_ERROR = 5 |
| local OP_INV = {"RRQ", "WRQ", "DATA", "ACK", "ERROR"} |
| |
| ----------------------------------------------------------------------------- |
| -- Packet creation functions |
| ----------------------------------------------------------------------------- |
| local function RRQ(source, mode) |
| return char(0, OP_RRQ) .. source .. char(0) .. mode .. char(0) |
| end |
| |
| local function WRQ(source, mode) |
| return char(0, OP_RRQ) .. source .. char(0) .. mode .. char(0) |
| end |
| |
| local function ACK(block) |
| local low, high |
| low = math.mod(block, 256) |
| high = (block - low)/256 |
| return char(0, OP_ACK, high, low) |
| end |
| |
| local function get_OP(dgram) |
| local op = byte(dgram, 1)*256 + byte(dgram, 2) |
| return op |
| end |
| |
| ----------------------------------------------------------------------------- |
| -- Packet analysis functions |
| ----------------------------------------------------------------------------- |
| local function split_DATA(dgram) |
| local block = byte(dgram, 3)*256 + byte(dgram, 4) |
| local data = string.sub(dgram, 5) |
| return block, data |
| end |
| |
| local function get_ERROR(dgram) |
| local code = byte(dgram, 3)*256 + byte(dgram, 4) |
| local msg |
| _,_, msg = string.find(dgram, "(.*)\000", 5) |
| return string.format("error code %d: %s", code, msg) |
| end |
| |
| ----------------------------------------------------------------------------- |
| -- The real work |
| ----------------------------------------------------------------------------- |
| local function tget(gett) |
| local retries, dgram, sent, datahost, dataport, code |
| local last = 0 |
| socket.try(gett.host, "missing host") |
| local con = socket.try(socket.udp()) |
| local try = socket.newtry(function() con:close() end) |
| -- convert from name to ip if needed |
| gett.host = try(socket.dns.toip(gett.host)) |
| con:settimeout(1) |
| -- first packet gives data host/port to be used for data transfers |
| local path = string.gsub(gett.path or "", "^/", "") |
| path = url.unescape(path) |
| retries = 0 |
| repeat |
| sent = try(con:sendto(RRQ(path, "octet"), gett.host, gett.port)) |
| dgram, datahost, dataport = con:receivefrom() |
| retries = retries + 1 |
| until dgram or datahost ~= "timeout" or retries > 5 |
| try(dgram, datahost) |
| -- associate socket with data host/port |
| try(con:setpeername(datahost, dataport)) |
| -- default sink |
| local sink = gett.sink or ltn12.sink.null() |
| -- process all data packets |
| while 1 do |
| -- decode packet |
| code = get_OP(dgram) |
| try(code ~= OP_ERROR, get_ERROR(dgram)) |
| try(code == OP_DATA, "unhandled opcode " .. code) |
| -- get data packet parts |
| local block, data = split_DATA(dgram) |
| -- if not repeated, write |
| if block == last+1 then |
| try(sink(data)) |
| last = block |
| end |
| -- last packet brings less than 512 bytes of data |
| if string.len(data) < 512 then |
| try(con:send(ACK(block))) |
| try(con:close()) |
| try(sink(nil)) |
| return 1 |
| end |
| -- get the next packet |
| retries = 0 |
| repeat |
| sent = try(con:send(ACK(last))) |
| dgram, err = con:receive() |
| retries = retries + 1 |
| until dgram or err ~= "timeout" or retries > 5 |
| try(dgram, err) |
| end |
| end |
| |
| local default = { |
| port = PORT, |
| path ="/", |
| scheme = "tftp" |
| } |
| |
| local function parse(u) |
| local t = socket.try(url.parse(u, default)) |
| socket.try(t.scheme == "tftp", "invalid scheme '" .. t.scheme .. "'") |
| socket.try(t.host, "invalid host") |
| return t |
| end |
| |
| local function sget(u) |
| local gett = parse(u) |
| local t = {} |
| gett.sink = ltn12.sink.table(t) |
| tget(gett) |
| return table.concat(t) |
| end |
| |
| get = socket.protect(function(gett) |
| if base.type(gett) == "string" then return sget(gett) |
| else return tget(gett) end |
| end) |
| |