From 6bc6580bb38835f7b896d11c047fecbd3db9ba0d Mon Sep 17 00:00:00 2001 From: Baptiste Wicht Date: Mon, 12 Sep 2016 22:26:47 +0200 Subject: [PATCH] Basic resolve support in tlib --- programs/nslookup/src/main.cpp | 8 +++ tlib/include/tlib/dns.hpp | 2 + tlib/src/dns.cpp | 120 +++++++++++++++++++++++++++++++++ 3 files changed, 130 insertions(+) diff --git a/programs/nslookup/src/main.cpp b/programs/nslookup/src/main.cpp index c808fa2b..652ca8f1 100644 --- a/programs/nslookup/src/main.cpp +++ b/programs/nslookup/src/main.cpp @@ -37,6 +37,14 @@ int main(int argc, char* argv[]) { std::string domain(argv[1]); + auto resolved = tlib::dns::resolve(domain); + + if(resolved){ + tlib::printf("resolve(%s): %s\n", domain.c_str(), resolved->c_str()); + } else { + tlib::printf("resolve(%s): failed: %s\n", domain.c_str(), std::error_message(resolved.error())); + } + tlib::socket sock(tlib::socket_domain::AF_INET, tlib::socket_type::DGRAM, tlib::socket_protocol::DNS); sock.client_bind(); diff --git a/tlib/include/tlib/dns.hpp b/tlib/include/tlib/dns.hpp index 33f81bfd..0b952c6f 100644 --- a/tlib/include/tlib/dns.hpp +++ b/tlib/include/tlib/dns.hpp @@ -22,6 +22,8 @@ namespace dns { std::string decode_domain(char* payload, size_t& offset); std::expected send_request(tlib::socket& sock, const std::string& domain, uint16_t rr_type = 0x1, uint16_t rr_class = 0x1); +std::expected resolve(const std::string& domain, size_t timeout = 1000, size_t retries = 1); + } // end of namespace dns } // end of namespace tlib diff --git a/tlib/src/dns.cpp b/tlib/src/dns.cpp index bb9d776e..1658b331 100644 --- a/tlib/src/dns.cpp +++ b/tlib/src/dns.cpp @@ -7,6 +7,9 @@ #include "tlib/dns.hpp" #include "tlib/malloc.hpp" +#include "tlib/system.hpp" +#include "tlib/errors.hpp" +#include "tlib/print.hpp" std::string tlib::dns::decode_domain(char* payload, size_t& offset) { std::string domain; @@ -79,3 +82,120 @@ std::expected tlib::dns::send_request(tlib::socket& sock, const std::strin return {}; } + +std::expected tlib::dns::resolve(const std::string& domain, size_t timeout_ms, size_t retries){ + tlib::socket sock(tlib::socket_domain::AF_INET, tlib::socket_type::DGRAM, tlib::socket_protocol::DNS); + + sock.client_bind(); + sock.listen(true); + + if (!sock) { + return std::make_unexpected(sock.error());; + } + + size_t tries = 0; + + auto sr = send_request(sock, domain, 0x1, 0x1); + if(!sr){ + return std::make_unexpected(sr.error());; + } + + auto before = tlib::ms_time(); + auto after = before; + + while (true) { + // Make sure we don't wait for more than the timeout + if (after > before + timeout_ms) { + break; + } + + auto remaining = timeout_ms - (after - before); + + auto p = sock.wait_for_packet(remaining); + if (!sock) { + return std::make_unexpected(sock.error());; + } else { + auto* dns_header = reinterpret_cast(p.payload + p.index); + + auto identification = tlib::switch_endian_16(dns_header->identification); + + // Only handle packet with the correct identification + if (identification == 0x666) { + auto questions = tlib::switch_endian_16(dns_header->questions); + auto answers = tlib::switch_endian_16(dns_header->answers); + + auto flags = dns_header->flags; + auto qr = flags >> 15; + + // Only handle Response + if (qr) { + auto rcode = flags & 0xF; + + if (rcode == 0x0 && answers > 0) { + auto* payload = p.payload + p.index + sizeof(tlib::dns::header); + + // Decode the questions (simply wrap around it) + + for (size_t i = 0; i < questions; ++i) { + size_t length; + tlib::dns::decode_domain(payload, length); + + payload += length; + payload += 4; + } + + for (size_t i = 0; i < answers; ++i) { + auto label = static_cast(*payload); + + if (label > 64) { + // This is a pointer + auto pointer = tlib::switch_endian_16(*reinterpret_cast(payload)); + auto offset = pointer & (0xFFFF >> 2); + + payload += 2; + + size_t ignored; + tlib::dns::decode_domain(p.payload + p.index + offset, ignored); + } else { + return std::make_unexpected(std::ERROR_UNSUPPORTED); + } + + auto rr_type = tlib::switch_endian_16(*reinterpret_cast(payload)); + payload += 2; + + auto rr_class = tlib::switch_endian_16(*reinterpret_cast(payload)); + payload += 2; + + payload += 4; // TTL + + auto rd_length = tlib::switch_endian_16(*reinterpret_cast(payload)); + payload += 2; + + if(rr_class == 1){ // IN (Internet) class + if (rr_type == 0x1) { // A record + auto ip = reinterpret_cast(payload); + + auto result = sprintf("%u.%u.%u.%u", ip[3], ip[2], ip[1], ip[0]); + return std::make_expected(result); + } + } + + payload += rd_length; + } + + break; + } else { + // There was an error, retry + if(++tries == retries || !send_request(sock, domain)){ + return std::make_unexpected(std::ERROR_SOCKET_TIMEOUT); + } + } + } + } + } + + after = tlib::ms_time(); + } + + return std::make_unexpected(std::ERROR_SOCKET_TIMEOUT); +}