diff --git a/kernel/include/net/udp_layer.hpp b/kernel/include/net/udp_layer.hpp index 2aba48b2..f8db3e83 100644 --- a/kernel/include/net/udp_layer.hpp +++ b/kernel/include/net/udp_layer.hpp @@ -61,6 +61,11 @@ std::expected finalize_packet(network::interface_descriptor& interface, ne std::expected client_bind(network::socket& socket, size_t server_port, network::ip::address server); std::expected client_unbind(network::socket& socket); +std::expected receive(char* buffer, network::socket& socket, size_t n); +std::expected receive(char* buffer, network::socket& socket, size_t n, size_t ms); + +std::expected send(char* target_buffer, network::socket& socket, const char* buffer, size_t n); + } // end of upd namespace } // end of network namespace diff --git a/kernel/src/net/network.cpp b/kernel/src/net/network.cpp index 71e951b5..c7d66fde 100644 --- a/kernel/src/net/network.cpp +++ b/kernel/src/net/network.cpp @@ -274,12 +274,12 @@ std::expected network::open(network::socket_domain domain, } // Make sure the socket protocol is valid - if(protocol != socket_protocol::ICMP && protocol != socket_protocol::DNS && protocol != socket_protocol::TCP){ + if(protocol != socket_protocol::ICMP && protocol != socket_protocol::DNS && protocol != socket_protocol::TCP && protocol != socket_protocol::UDP){ return std::make_expected_from_error(std::ERROR_SOCKET_INVALID_PROTOCOL); } // Make sure the socket protocol is valid for the given socket type - if(type == socket_type::DGRAM && !(protocol == socket_protocol::DNS)){ + if(type == socket_type::DGRAM && !(protocol == socket_protocol::DNS || protocol == socket_protocol::UDP)){ return std::make_expected_from_error(std::ERROR_SOCKET_INVALID_TYPE_PROTOCOL); } @@ -367,6 +367,9 @@ std::expected network::send(socket_fd_t socket_fd, const char* buffer, siz case network::socket_protocol::TCP: return network::tcp::send(target_buffer, socket, buffer, n); + case network::socket_protocol::UDP: + return network::udp::send(target_buffer, socket, buffer, n); + default: return std::make_unexpected(std::ERROR_SOCKET_UNIMPLEMENTED); } @@ -388,6 +391,9 @@ std::expected network::receive(socket_fd_t socket_fd, char* buffer, size } switch (socket.protocol) { + case network::socket_protocol::UDP: + return network::udp::receive(buffer, socket, n); + case network::socket_protocol::TCP: return network::tcp::receive(buffer, socket, n); @@ -412,6 +418,9 @@ std::expected network::receive(socket_fd_t socket_fd, char* buffer, size } switch (socket.protocol) { + case network::socket_protocol::UDP: + return network::udp::receive(buffer, socket, n, ms); + case network::socket_protocol::TCP: return network::tcp::receive(buffer, socket, n, ms); diff --git a/kernel/src/net/udp_layer.cpp b/kernel/src/net/udp_layer.cpp index e6883b8f..11fe4ae6 100644 --- a/kernel/src/net/udp_layer.cpp +++ b/kernel/src/net/udp_layer.cpp @@ -184,3 +184,93 @@ std::expected network::udp::client_unbind(network::socket& sock){ return {}; } + +std::expected network::udp::send(char* target_buffer, network::socket& socket, const char* buffer, size_t n){ + auto& connection = socket.get_connection_data(); + + // Make sure stream sockets are connected + if(!connection.connected){ + return std::make_unexpected(std::ERROR_SOCKET_NOT_CONNECTED); + } + + network::udp::packet_descriptor desc{n}; + auto packet = user_prepare_packet(target_buffer, socket, &desc); + + if (packet) { + for(size_t i = 0; i < n; ++i){ + packet->payload[packet->index + i] = buffer[i]; + } + + auto target_ip = connection.server_address; + auto& interface = network::select_interface(target_ip); + return network::udp::finalize_packet(interface, *packet); + } + + return std::make_unexpected(packet.error()); +} + +std::expected network::udp::receive(char* buffer, network::socket& socket, size_t n){ + auto& connection = socket.get_connection_data(); + + // Make sure stream sockets are connected + if(!connection.connected){ + return std::make_unexpected(std::ERROR_SOCKET_NOT_CONNECTED); + } + + if(socket.listen_packets.empty()){ + socket.listen_queue.wait(); + } + + auto packet = socket.listen_packets.pop(); + + auto* udp_header = reinterpret_cast(packet.payload + packet.tag(2)); + auto payload_len = switch_endian_16(udp_header->length); + + if(payload_len > n){ + delete[] packet.payload; + + return std::make_unexpected(std::ERROR_BUFFER_SMALL); + } + + std::copy_n(packet.payload + packet.index, payload_len, buffer); + + delete[] packet.payload; + + return payload_len; +} + +std::expected network::udp::receive(char* buffer, network::socket& socket, size_t n, size_t ms){ + auto& connection = socket.get_connection_data(); + + // Make sure stream sockets are connected + if(!connection.connected){ + return std::make_unexpected(std::ERROR_SOCKET_NOT_CONNECTED); + } + + if(socket.listen_packets.empty()){ + if(!ms){ + return std::make_unexpected(std::ERROR_SOCKET_TIMEOUT); + } + + if(!socket.listen_queue.wait_for(ms)){ + return std::make_unexpected(std::ERROR_SOCKET_TIMEOUT); + } + } + + auto packet = socket.listen_packets.pop(); + + auto* udp_header = reinterpret_cast(packet.payload + packet.tag(2)); + auto payload_len = switch_endian_16(udp_header->length); + + if(payload_len > n){ + delete[] packet.payload; + + return std::make_unexpected(std::ERROR_BUFFER_SMALL); + } + + std::copy_n(packet.payload + packet.index, payload_len, buffer); + + delete[] packet.payload; + + return payload_len; +} diff --git a/tlib/src/net.cpp b/tlib/src/net.cpp index 182ba36e..235d487d 100644 --- a/tlib/src/net.cpp +++ b/tlib/src/net.cpp @@ -322,9 +322,10 @@ void tlib::socket::client_bind(tlib::ip::address server) { } auto status = tlib::client_bind(fd, server); - if (!status) { - _bound = false; + if (status) { + _bound = true; } else { + _bound = false; error_code = status.error(); } } @@ -335,10 +336,11 @@ void tlib::socket::client_bind(tlib::ip::address server, size_t port) { } auto status = tlib::client_bind(fd, server, port); - if (!status) { - _bound = false; + if (status) { + _bound = true; } else { error_code = status.error(); + _bound = false; } } @@ -370,6 +372,7 @@ void tlib::socket::connect(tlib::ip::address server, size_t port) { _connected = true; } else { error_code = status.error(); + _connected = false; } }