diff --git a/kernel/include/net/network.hpp b/kernel/include/net/network.hpp index f17d3907..527f3e6d 100644 --- a/kernel/include/net/network.hpp +++ b/kernel/include/net/network.hpp @@ -104,6 +104,14 @@ int64_t listen(size_t socket_fd, bool listen); */ int64_t wait_for_packet(char* buffer, size_t socket_fd); +/*! + * \brief Wait for a packet, for some time + * \param socket_fd The file descriptor of the packet + * \param ms The maximum time, in milliseconds, to wait for a packet + * \return the packet index + */ +int64_t wait_for_packet(char* buffer, size_t socket_fd, size_t ms); + } // end of network namespace #endif diff --git a/kernel/src/net/network.cpp b/kernel/src/net/network.cpp index c26b7d5f..30f5d842 100644 --- a/kernel/src/net/network.cpp +++ b/kernel/src/net/network.cpp @@ -264,3 +264,37 @@ int64_t network::wait_for_packet(char* buffer, size_t socket_fd){ return packet.index; } + +int64_t network::wait_for_packet(char* buffer, size_t socket_fd, size_t ms){ + if(!scheduler::has_socket(socket_fd)){ + return -std::ERROR_SOCKET_INVALID_FD; + } + + auto& socket = scheduler::get_socket(socket_fd); + + if(!socket.listen){ + return -std::ERROR_SOCKET_NOT_LISTEN; + } + + logging::logf(logging::log_level::TRACE, "network: %u wait for packet on socket %u\n", scheduler::get_pid(), socket_fd); + + if(socket.listen_packets.empty()){ + if(!ms){ + return -std::ERROR_SOCKET_TIMEOUT; + } + + if(!socket.listen_queue.sleep(ms)){ + return -std::ERROR_SOCKET_TIMEOUT; + } + } + + auto packet = socket.listen_packets.pop(); + std::copy_n(packet.payload, packet.payload_size, buffer); + + // The memory was allocated as a copy by the decoding process, it is safe to remove it here + delete[] packet.payload; + + logging::logf(logging::log_level::TRACE, "network: %u received packet on socket %u\n", scheduler::get_pid(), socket_fd); + + return packet.index; +} diff --git a/kernel/src/system_calls.cpp b/kernel/src/system_calls.cpp index f61604a8..fdec5be1 100644 --- a/kernel/src/system_calls.cpp +++ b/kernel/src/system_calls.cpp @@ -390,6 +390,17 @@ void sc_wait_for_packet(interrupt::syscall_regs* regs){ regs->rbx = reinterpret_cast(user_buffer); } +void sc_wait_for_packet_ms(interrupt::syscall_regs* regs){ + auto socket_fd = regs->rbx; + auto user_buffer = reinterpret_cast(regs->rcx); + auto ms = regs->rdx; + + auto index = network::wait_for_packet(user_buffer, socket_fd, ms); + + regs->rax = index; + regs->rbx = reinterpret_cast(user_buffer); +} + } //End of anonymous namespace void system_call_entry(interrupt::syscall_regs* regs){ @@ -629,6 +640,10 @@ void system_call_entry(interrupt::syscall_regs* regs){ sc_wait_for_packet(regs); break; + case 0x3006: + sc_wait_for_packet_ms(regs); + break; + // Special system calls case 0x6666: diff --git a/programs/ping/src/main.cpp b/programs/ping/src/main.cpp index 066d30ec..e48602bc 100644 --- a/programs/ping/src/main.cpp +++ b/programs/ping/src/main.cpp @@ -71,20 +71,24 @@ int main(int argc, char* argv[]) { return 1; } - auto p = tlib::wait_for_packet(*socket); + auto p = tlib::wait_for_packet(*socket, 2000); if (!p) { - tlib::printf("ping: wait_for_packet error: %s\n", std::error_message(p.error())); - return 1; - } - - auto* icmp_header = reinterpret_cast(p->payload + p->index); - - auto command_type = static_cast(icmp_header->type); - - if(command_type == tlib::icmp::type::ECHO_REPLY){ - tlib::printf("Reply received from %s\n", ip.c_str()); + if(p.error() == std::ERROR_SOCKET_TIMEOUT){ + tlib::printf("%s unreachable\n", ip.c_str()); + } else { + tlib::printf("ping: wait_for_packet error: %s\n", std::error_message(p.error())); + return 1; + } } else { - tlib::printf("Unhandled command type received\n"); + auto* icmp_header = reinterpret_cast(p->payload + p->index); + + auto command_type = static_cast(icmp_header->type); + + if(command_type == tlib::icmp::type::ECHO_REPLY){ + tlib::printf("Reply received from %s\n", ip.c_str()); + } else { + tlib::printf("Unhandled command type received\n"); + } } tlib::release_packet(*p); diff --git a/tlib/include/tlib/net.hpp b/tlib/include/tlib/net.hpp index 896ad630..1d7f04d9 100644 --- a/tlib/include/tlib/net.hpp +++ b/tlib/include/tlib/net.hpp @@ -30,6 +30,7 @@ std::expected prepare_packet(size_t socket_fd, void* desc); std::expected finalize_packet(size_t socket_fd, packet p); std::expected listen(size_t socket_fd, bool l); std::expected wait_for_packet(size_t socket_fd); +std::expected wait_for_packet(size_t socket_fd, size_t ms); void release_packet(packet& packet); } // end of namespace tlib diff --git a/tlib/src/net.cpp b/tlib/src/net.cpp index eec6dacf..499f56bc 100644 --- a/tlib/src/net.cpp +++ b/tlib/src/net.cpp @@ -103,6 +103,27 @@ std::expected tlib::wait_for_packet(size_t socket_fd){ } } +std::expected tlib::wait_for_packet(size_t socket_fd, size_t ms){ + auto buffer = malloc(2048); + + int64_t code; + uint64_t payload; + asm volatile("mov rax, 0x3006; mov rbx, %[socket]; mov rcx, %[buffer]; mov rdx, %[ms]; int 50; mov %[code], rax; mov %[payload], rbx;" + : [payload] "=m" (payload), [code] "=m" (code) + : [socket] "g" (socket_fd), [buffer] "g" (reinterpret_cast(buffer)), [ms] "g" (ms) + : "rax", "rbx", "rcx"); + + if(code < 0){ + free(buffer); + return std::make_expected_from_error(-code); + } else { + tlib::packet p; + p.index = code; + p.payload = reinterpret_cast(payload); + return std::make_expected(p); + } +} + void tlib::release_packet(packet& packet){ if(packet.payload){ free(packet.payload);