diff --git a/kernel/include/net/network.hpp b/kernel/include/net/network.hpp index 2bd6bb44..ccf78b54 100644 --- a/kernel/include/net/network.hpp +++ b/kernel/include/net/network.hpp @@ -115,6 +115,8 @@ std::expected wait_for_packet(char* buffer, socket_fd_t socket_fd); */ std::expected wait_for_packet(char* buffer, socket_fd_t socket_fd, size_t ms); +void propagate_packet(const ethernet::packet& packet, socket_protocol protocol); + } // end of network namespace #endif diff --git a/kernel/src/net/dns_layer.cpp b/kernel/src/net/dns_layer.cpp index 858b4b4b..4742c6f0 100644 --- a/kernel/src/net/dns_layer.cpp +++ b/kernel/src/net/dns_layer.cpp @@ -180,6 +180,8 @@ void network::dns::decode(network::interface_descriptor& /*interface*/, network: logging::logf(logging::log_level::TRACE, "dns: Refused\n"); } } + + network::propagate_packet(packet, network::socket_protocol::DNS); } std::expected network::dns::prepare_packet_query(network::interface_descriptor& interface, network::ip::address target_ip, uint16_t source_port, uint16_t identification, size_t payload_size) { diff --git a/kernel/src/net/icmp_layer.cpp b/kernel/src/net/icmp_layer.cpp index 06fb6fa2..b6dfc4c5 100644 --- a/kernel/src/net/icmp_layer.cpp +++ b/kernel/src/net/icmp_layer.cpp @@ -111,23 +111,7 @@ void network::icmp::decode(network::interface_descriptor& interface, network::et break; } - // TODO Need something better for this - - for(size_t pid = 0; pid < scheduler::MAX_PROCESS; ++pid){ - auto state = scheduler::get_process_state(pid); - if(state != scheduler::process_state::EMPTY && state != scheduler::process_state::NEW && state != scheduler::process_state::KILLED){ - for(auto& socket : scheduler::get_sockets(pid)){ - if(socket.listen && socket.protocol == network::socket_protocol::ICMP){ - auto copy = packet; - copy.payload = new char[copy.payload_size]; - std::copy_n(packet.payload, packet.payload_size, copy.payload); - - socket.listen_packets.push(copy); - socket.listen_queue.wake_up(); - } - } - } - } + network::propagate_packet(packet, network::socket_protocol::ICMP); } std::expected network::icmp::prepare_packet(network::interface_descriptor& interface, network::ip::address target_ip, size_t payload_size, type t, size_t code){ diff --git a/kernel/src/net/network.cpp b/kernel/src/net/network.cpp index 0188dbb9..f1d66c31 100644 --- a/kernel/src/net/network.cpp +++ b/kernel/src/net/network.cpp @@ -24,6 +24,7 @@ #include "tlib/errors.hpp" #include "net/icmp_layer.hpp" +#include "net/dns_layer.hpp" namespace { @@ -218,7 +219,7 @@ std::expected network::open(network::socket_domain domain, return std::make_expected_from_error(std::ERROR_SOCKET_INVALID_TYPE); } - if(protocol != socket_protocol::ICMP){ + if(protocol != socket_protocol::ICMP && protocol != socket_protocol::DNS){ return std::make_expected_from_error(std::ERROR_SOCKET_INVALID_PROTOCOL); } @@ -242,19 +243,37 @@ std::tuple network::prepare_packet(socket_fd_t socket_fd, void* auto& socket = scheduler::get_socket(socket_fd); - switch(socket.protocol){ - case network::socket_protocol::ICMP: + auto return_from_packet = [&socket](std::expected& packet) -> std::tuple { + if (packet) { + auto fd = socket.register_packet(*packet); + + return {fd, packet->index}; + } else { + return {-packet.error(), 0}; + } + }; + + switch (socket.protocol) { + case network::socket_protocol::ICMP: { auto descriptor = static_cast(desc); auto& interface = select_interface(descriptor->target_ip); - auto packet = network::icmp::prepare_packet(buffer, interface, descriptor->target_ip, descriptor->payload_size, descriptor->type, descriptor->code); + auto packet = network::icmp::prepare_packet(buffer, interface, descriptor->target_ip, descriptor->payload_size, descriptor->type, descriptor->code); - if(packet){ - auto fd = socket.register_packet(*packet); + return return_from_packet(packet); + } - return {fd, packet->index}; + case network::socket_protocol::DNS: { + auto descriptor = static_cast(desc); + auto& interface = select_interface(descriptor->target_ip); + + if(descriptor->query){ + auto packet = network::dns::prepare_packet_query(buffer, interface, descriptor->target_ip, descriptor->source_port, descriptor->identification, descriptor->payload_size); + + return return_from_packet(packet); } else { - return {-packet.error(), 0}; + return {-std::ERROR_SOCKET_INVALID_PACKET_DESCRIPTOR, 0}; } + } } return {-std::ERROR_SOCKET_UNIMPLEMENTED, 0}; @@ -279,6 +298,12 @@ std::expected network::finalize_packet(socket_fd_t socket_fd, size_t packe network::icmp::finalize_packet(interface, packet); socket.erase_packet(packet_fd); + return std::make_expected(); + + case network::socket_protocol::DNS: + network::dns::finalize_packet(interface, packet); + socket.erase_packet(packet_fd); + return std::make_expected(); } @@ -358,3 +383,23 @@ std::expected network::wait_for_packet(char* buffer, socket_fd_t socket_ return {packet.index}; } + +void network::propagate_packet(const ethernet::packet& packet, socket_protocol protocol){ + // TODO Need something better for this + + for(size_t pid = 0; pid < scheduler::MAX_PROCESS; ++pid){ + auto state = scheduler::get_process_state(pid); + if(state != scheduler::process_state::EMPTY && state != scheduler::process_state::NEW && state != scheduler::process_state::KILLED){ + for(auto& socket : scheduler::get_sockets(pid)){ + if(socket.listen && socket.protocol == protocol){ + auto copy = packet; + copy.payload = new char[copy.payload_size]; + std::copy_n(packet.payload, packet.payload_size, copy.payload); + + socket.listen_packets.push(copy); + socket.listen_queue.wake_up(); + } + } + } + } +} diff --git a/programs/nslookup/Makefile b/programs/nslookup/Makefile new file mode 100644 index 00000000..ed48b0e4 --- /dev/null +++ b/programs/nslookup/Makefile @@ -0,0 +1,14 @@ +.PHONY: default clean + +EXEC_NAME=nslookup + +default: link + +include ../../cpp.mk + +$(eval $(call program_compile_cpp_folder,src)) +$(eval $(call program_link_executable,$(EXEC_NAME))) + +clean: + @ echo -e "Remove compiled files" + @ rm -rf debug diff --git a/programs/nslookup/src/main.cpp b/programs/nslookup/src/main.cpp new file mode 100644 index 00000000..636000ba --- /dev/null +++ b/programs/nslookup/src/main.cpp @@ -0,0 +1,207 @@ +//======================================================================= +// Copyright Baptiste Wicht 2013-2016. +// Distributed under the terms of the MIT License. +// (See accompanying file LICENSE or copy at +// http://www.opensource.org/licenses/MIT) +//======================================================================= + +#include +#include +#include +#include +#include + +namespace { + +static constexpr const size_t retries = 10; +static constexpr const size_t timeout_ms = 2000; + +bool send_request(tlib::socket& sock, const std::string& domain){ + auto parts = std::split(domain, '.'); + + size_t characters = domain.size() - (parts.size() - 1); // The dots are not included + size_t labels = parts.size(); + + tlib::dns::packet_descriptor desc; + desc.payload_size = labels + characters + 1 + 2 * 2; + desc.target_ip = tlib::ip::make_address(10, 0, 2, 3); + desc.source_port = 3456; + desc.identification = 0x666; + desc.query = true; + + auto packet = sock.prepare_packet(&desc); + + if (!sock) { + tlib::printf("nslookup: prepare_packet error: %s\n", std::error_message(sock.error())); + return false; + } + + auto* payload = reinterpret_cast(packet.payload + packet.index); + + size_t i = 0; + for (auto& part : parts) { + payload[i++] = part.size(); + + for (size_t j = 0; j < part.size(); ++j) { + payload[i++] = part[j]; + } + } + + payload[i++] = 0; + + auto* q_type = reinterpret_cast(packet.payload + packet.index + i); + *q_type = 0x0100; // A Record + + auto* q_class = reinterpret_cast(packet.payload + packet.index + i + 2); + *q_class = 0x0100; // IN (internet) + + sock.finalize_packet(packet); + + if (!sock) { + tlib::printf("nslookup: finalize_packet error: %s\n", std::error_message(sock.error())); + return false; + } + + return true; +} + +} // end of anonymous namespace + +int main(int argc, char* argv[]) { + if (argc != 2) { + tlib::print_line("usage: nslookup domain"); + return 1; + } + + std::string domain(argv[1]); + + tlib::socket sock(tlib::socket_domain::AF_INET, tlib::socket_type::RAW, tlib::socket_protocol::DNS); + + if (!sock) { + tlib::printf("nslookup: socket error: %s\n", std::error_message(sock.error())); + return 1; + } + + sock.listen(true); + + if (!sock) { + tlib::printf("nslookup: socket error: %s\n", std::error_message(sock.error())); + return 1; + } + + size_t tries = 0; + + if(!send_request(sock, domain)){ + return 1; + } + + auto before = tlib::ms_time(); + auto after = before; + + while (true) { + // Make sure we don't wait for more than the timeout + if (after > before + timeout_ms) { + break; + } + + auto remaining = timeout_ms - (after - before); + + auto p = sock.wait_for_packet(remaining); + if (!sock) { + tlib::printf("nslookup: wait_for_packet error: %s\n", std::error_message(sock.error())); + return 1; + } else { + auto* dns_header = reinterpret_cast(p.payload + p.index); + + auto identification = tlib::switch_endian_16(dns_header->identification); + + // Only handle packet with the correct identification + if (identification == 0x666) { + auto questions = tlib::switch_endian_16(dns_header->questions); + auto answers = tlib::switch_endian_16(dns_header->answers); + + auto flags = dns_header->flags; + auto qr = flags >> 15; + + // Only handle Response + if (qr) { + auto rcode = flags & 0xF; + + if (rcode == 0x0 && answers > 0) { + auto* payload = p.payload + p.index + sizeof(tlib::dns::header); + + // Decode the questions (simply wrap around it) + + for (size_t i = 0; i < questions; ++i) { + size_t length; + auto domain = tlib::dns::decode_domain(payload, length); + + payload += length; + payload += 4; + + tlib::printf("DNS Question %s\n", domain.c_str()); + } + + tlib::printf("DNS Answers:\n"); + + for (size_t i = 0; i < answers; ++i) { + auto label = static_cast(*payload); + + std::string domain; + if (label > 64) { + // This is a pointer + auto pointer = tlib::switch_endian_16(*reinterpret_cast(payload)); + auto offset = pointer & (0xFFFF >> 2); + + payload += 2; + + size_t ignored; + domain = tlib::dns::decode_domain(p.payload + p.index + offset, ignored); + } else { + tlib::printf("nslookup: Cannot read DNS response\n"); + return 1; + } + + auto rr_type = tlib::switch_endian_16(*reinterpret_cast(payload)); + payload += 2; + + auto rr_class = tlib::switch_endian_16(*reinterpret_cast(payload)); + payload += 2; + + auto ttl = tlib::switch_endian_32(*reinterpret_cast(payload)); + payload += 4; + + auto rd_length = tlib::switch_endian_16(*reinterpret_cast(payload)); + payload += 2; + + if (rr_type == 0x1 && rr_class == 0x1) { + auto ip = reinterpret_cast(payload); + + tlib::printf(" Answer %u Domain %s Type %u Class %u TTL %u IP: %u.%u.%u.%u\n", + i, domain.c_str(), rr_type, rr_class, ttl, ip[3], ip[2], ip[1], ip[0]); + } else { + tlib::printf("nslookup: Cannot read DNS response\n"); + return 1; + } + + payload += rd_length; + } + + break; + } else { + // There was an error, retry + if(++tries == retries || !send_request(sock, domain)){ + return 1; + } + } + } + } + } + + after = tlib::ms_time(); + } + + sock.listen(false); + + return 0; +} diff --git a/tlib/include/tlib/dns.hpp b/tlib/include/tlib/dns.hpp new file mode 100644 index 00000000..3eb3e105 --- /dev/null +++ b/tlib/include/tlib/dns.hpp @@ -0,0 +1,27 @@ +//======================================================================= +// Copyright Baptiste Wicht 2013-2016. +// Distributed under the terms of the MIT License. +// (See accompanying file LICENSE or copy at +// http://www.opensource.org/licenses/MIT) +//======================================================================= + +#ifndef TLIB_NET_DNS_H +#define TLIB_NET_DNS_H + +#include + +#include "tlib/net_constants.hpp" + +ASSERT_ONLY_THOR_PROGRAM + +namespace tlib { + +namespace dns { + +std::string decode_domain(char* payload, size_t& offset); + +} // end of namespace dns + +} // end of namespace tlib + +#endif diff --git a/tlib/include/tlib/errors.hpp b/tlib/include/tlib/errors.hpp index 128b78ec..7c559d20 100644 --- a/tlib/include/tlib/errors.hpp +++ b/tlib/include/tlib/errors.hpp @@ -12,34 +12,35 @@ namespace std { -constexpr const size_t ERROR_NOT_EXISTS = 1; -constexpr const size_t ERROR_NOT_EXECUTABLE = 2; -constexpr const size_t ERROR_FAILED_EXECUTION = 3; -constexpr const size_t ERROR_NOTHING_MOUNTED = 4; -constexpr const size_t ERROR_INVALID_FILE_PATH = 5; -constexpr const size_t ERROR_DIRECTORY = 6; -constexpr const size_t ERROR_INVALID_FILE_DESCRIPTOR = 7; -constexpr const size_t ERROR_FAILED = 8; -constexpr const size_t ERROR_EXISTS = 9; -constexpr const size_t ERROR_BUFFER_SMALL = 10; -constexpr const size_t ERROR_INVALID_FILE_SYSTEM = 11; -constexpr const size_t ERROR_DISK_FULL = 12; -constexpr const size_t ERROR_PERMISSION_DENIED = 13; -constexpr const size_t ERROR_INVALID_OFFSET = 14; -constexpr const size_t ERROR_UNSUPPORTED = 15; -constexpr const size_t ERROR_INVALID_COUNT = 16; -constexpr const size_t ERROR_INVALID_REQUEST = 17; -constexpr const size_t ERROR_INVALID_DEVICE = 18; -constexpr const size_t ERROR_ALREADY_MOUNTED = 19; -constexpr const size_t ERROR_SOCKET_INVALID_DOMAIN = 20; -constexpr const size_t ERROR_SOCKET_INVALID_TYPE = 21; -constexpr const size_t ERROR_SOCKET_INVALID_PROTOCOL = 22; -constexpr const size_t ERROR_SOCKET_INVALID_FD = 23; -constexpr const size_t ERROR_SOCKET_UNIMPLEMENTED = 24; -constexpr const size_t ERROR_SOCKET_NO_INTERFACE = 25; -constexpr const size_t ERROR_SOCKET_INVALID_PACKET_FD = 26; -constexpr const size_t ERROR_SOCKET_NOT_LISTEN = 27; -constexpr const size_t ERROR_SOCKET_TIMEOUT = 28; +constexpr const size_t ERROR_NOT_EXISTS = 1; +constexpr const size_t ERROR_NOT_EXECUTABLE = 2; +constexpr const size_t ERROR_FAILED_EXECUTION = 3; +constexpr const size_t ERROR_NOTHING_MOUNTED = 4; +constexpr const size_t ERROR_INVALID_FILE_PATH = 5; +constexpr const size_t ERROR_DIRECTORY = 6; +constexpr const size_t ERROR_INVALID_FILE_DESCRIPTOR = 7; +constexpr const size_t ERROR_FAILED = 8; +constexpr const size_t ERROR_EXISTS = 9; +constexpr const size_t ERROR_BUFFER_SMALL = 10; +constexpr const size_t ERROR_INVALID_FILE_SYSTEM = 11; +constexpr const size_t ERROR_DISK_FULL = 12; +constexpr const size_t ERROR_PERMISSION_DENIED = 13; +constexpr const size_t ERROR_INVALID_OFFSET = 14; +constexpr const size_t ERROR_UNSUPPORTED = 15; +constexpr const size_t ERROR_INVALID_COUNT = 16; +constexpr const size_t ERROR_INVALID_REQUEST = 17; +constexpr const size_t ERROR_INVALID_DEVICE = 18; +constexpr const size_t ERROR_ALREADY_MOUNTED = 19; +constexpr const size_t ERROR_SOCKET_INVALID_DOMAIN = 20; +constexpr const size_t ERROR_SOCKET_INVALID_TYPE = 21; +constexpr const size_t ERROR_SOCKET_INVALID_PROTOCOL = 22; +constexpr const size_t ERROR_SOCKET_INVALID_FD = 23; +constexpr const size_t ERROR_SOCKET_UNIMPLEMENTED = 24; +constexpr const size_t ERROR_SOCKET_NO_INTERFACE = 25; +constexpr const size_t ERROR_SOCKET_INVALID_PACKET_FD = 26; +constexpr const size_t ERROR_SOCKET_NOT_LISTEN = 27; +constexpr const size_t ERROR_SOCKET_TIMEOUT = 28; +constexpr const size_t ERROR_SOCKET_INVALID_PACKET_DESCRIPTOR = 29; inline const char* error_message(size_t error){ switch(error){ @@ -99,6 +100,8 @@ inline const char* error_message(size_t error){ return "The socket is not configured to listen"; case ERROR_SOCKET_TIMEOUT: return "Network timeout"; + case ERROR_SOCKET_INVALID_PACKET_DESCRIPTOR: + return "The packet descriptor for the packet to send is invalid"; default: return "Unknonwn error"; } diff --git a/tlib/include/tlib/net.hpp b/tlib/include/tlib/net.hpp index c3a830f9..5aa30a2f 100644 --- a/tlib/include/tlib/net.hpp +++ b/tlib/include/tlib/net.hpp @@ -92,6 +92,17 @@ private: size_t error_code; ///< The error code }; +inline uint16_t switch_endian_16(uint16_t nb) { + return (nb >> 8) | (nb << 8); +} + +inline uint32_t switch_endian_32(uint32_t nb) { + return ((nb >> 24) & 0xff) | + ((nb << 8) & 0xff0000) | + ((nb >> 8) & 0xff00) | + ((nb << 24) & 0xff000000); +} + } // end of namespace tlib #endif diff --git a/tlib/include/tlib/net_constants.hpp b/tlib/include/tlib/net_constants.hpp index cdc32387..5057d1b0 100644 --- a/tlib/include/tlib/net_constants.hpp +++ b/tlib/include/tlib/net_constants.hpp @@ -103,6 +103,7 @@ struct packet_descriptor { ip::address target_ip; uint16_t source_port; uint16_t identification; + bool query; }; } // end of dns namespace diff --git a/tlib/src/dns.cpp b/tlib/src/dns.cpp new file mode 100644 index 00000000..2747b184 --- /dev/null +++ b/tlib/src/dns.cpp @@ -0,0 +1,35 @@ +//======================================================================= +// Copyright Baptiste Wicht 2013-2016. +// Distributed under the terms of the MIT License. +// (See accompanying file LICENSE or copy at +// http://www.opensource.org/licenses/MIT) +//======================================================================= + +#include "tlib/dns.hpp" +#include "tlib/malloc.hpp" + +std::string tlib::dns::decode_domain(char* payload, size_t& offset) { + std::string domain; + + offset = 0; + + while (true) { + auto label_size = static_cast(*(payload + offset)); + ++offset; + + if (!label_size) { + break; + } + + if (!domain.empty()) { + domain += '.'; + } + + for (size_t i = 0; i < label_size; ++i) { + domain += *(payload + offset); + ++offset; + } + } + + return domain; +}