Merge pull request #856 from UnknownShadow200/SockRewrite2

Allow connecting to domain/hostnames
This commit is contained in:
UnknownShadow200 2021-06-06 23:12:08 +10:00 committed by GitHub
commit 5c9e2e96b9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 95 additions and 46 deletions

View File

@ -534,7 +534,6 @@ static void DirectConnectScreen_Load(struct DirectConnectScreen* s) {
} }
static void DirectConnectScreen_StartClient(void* w, int idx) { static void DirectConnectScreen_StartClient(void* w, int idx) {
static const cc_string loopbackIp = String_FromConst("127.0.0.1");
static const cc_string defMppass = String_FromConst("(none)"); static const cc_string defMppass = String_FromConst("(none)");
const cc_string* user = &DirectConnectScreen_Instance.iptUsername.text; const cc_string* user = &DirectConnectScreen_Instance.iptUsername.text;
const cc_string* addr = &DirectConnectScreen_Instance.iptAddress.text; const cc_string* addr = &DirectConnectScreen_Instance.iptAddress.text;
@ -550,7 +549,6 @@ static void DirectConnectScreen_StartClient(void* w, int idx) {
ip = String_UNSAFE_Substring(addr, 0, index); ip = String_UNSAFE_Substring(addr, 0, index);
port = String_UNSAFE_SubstringAt(addr, index + 1); port = String_UNSAFE_SubstringAt(addr, index + 1);
if (String_CaselessEqualsConst(&ip, "localhost")) ip = loopbackIp;
if (!user->length) { if (!user->length) {
DirectConnectScreen_SetStatus("&eUsername required"); return; DirectConnectScreen_SetStatus("&eUsername required"); return;

View File

@ -29,6 +29,7 @@
#include <utime.h> #include <utime.h>
#include <signal.h> #include <signal.h>
#include <stdio.h> #include <stdio.h>
#include <netdb.h>
#define Socket__Error() errno #define Socket__Error() errno
static char* defaultDirectory; static char* defaultDirectory;
@ -450,6 +451,13 @@ void Platform_LoadSysFonts(void) {
/*########################################################################################################################* /*########################################################################################################################*
*---------------------------------------------------------Socket----------------------------------------------------------* *---------------------------------------------------------Socket----------------------------------------------------------*
*#########################################################################################################################*/ *#########################################################################################################################*/
union SocketAddress {
struct sockaddr_storage total;
struct sockaddr raw;
struct sockaddr_in v4;
struct sockaddr_in6 v6;
};
cc_result Socket_Available(cc_socket s, int* available) { cc_result Socket_Available(cc_socket s, int* available) {
return ioctl(s, FIONREAD, available); return ioctl(s, FIONREAD, available);
} }
@ -459,47 +467,69 @@ cc_result Socket_GetError(cc_socket s, cc_result* result) {
return getsockopt(s, SOL_SOCKET, SO_ERROR, result, &resultSize); return getsockopt(s, SOL_SOCKET, SO_ERROR, result, &resultSize);
} }
static int ParseAddress(void* dst, const cc_string* address, int port, int* addrSize) { static int ParseHost(union SocketAddress* addr, const char* host) {
struct sockaddr_in* addr4 = (struct sockaddr_in* )dst; struct addrinfo hints = { 0 };
struct sockaddr_in6* addr6 = (struct sockaddr_in6*)dst; struct addrinfo* result;
struct addrinfo* cur;
int family = 0, res;
hints.ai_family = AF_INET;
hints.ai_socktype = SOCK_STREAM;
hints.ai_protocol = IPPROTO_TCP;
res = getaddrinfo(host, NULL, &hints, &result);
if (res) return 0;
for (cur = result; cur; cur = cur->ai_next) {
if (cur->ai_family != AF_INET) continue;
family = AF_INET;
Mem_Copy(addr, cur->ai_addr, cur->ai_addrlen);
break;
}
freeaddrinfo(result);
return family;
}
static int ParseAddress(union SocketAddress* addr, const cc_string* address) {
char str[NATIVE_STR_LEN]; char str[NATIVE_STR_LEN];
Platform_EncodeUtf8(str, address); Platform_EncodeUtf8(str, address);
if (inet_pton(AF_INET, str, &addr4->sin_addr) > 0) { if (inet_pton(AF_INET, str, &addr->v4.sin_addr) > 0) return AF_INET;
addr4->sin_family = AF_INET; if (inet_pton(AF_INET6, str, &addr->v6.sin6_addr) > 0) return AF_INET6;
addr4->sin_port = htons(port); return ParseHost(addr, str);
*addrSize = sizeof(struct sockaddr_in);
return AF_INET;
}
if (inet_pton(AF_INET6, str, &addr6->sin6_addr) > 0) {
addr6->sin6_family = AF_INET6;
addr6->sin6_port = htons(port);
*addrSize = sizeof(struct sockaddr_in6);
return AF_INET6;
}
return 0;
} }
int Socket_ValidAddress(const cc_string* address) { int Socket_ValidAddress(const cc_string* address) {
struct sockaddr_storage addr; union SocketAddress addr;
int addrSize; return ParseAddress(&addr, address);
return ParseAddress(&addr, address, 0, &addrSize);
} }
cc_result Socket_Connect(cc_socket* s, const cc_string* address, int port) { cc_result Socket_Connect(cc_socket* s, const cc_string* address, int port) {
int family, addrSize, blocking_raw = -1; /* non-blocking mode */ int family, addrSize, blocking_raw = -1; /* non-blocking mode */
struct sockaddr_storage addr; union SocketAddress addr;
cc_result res; cc_result res;
*s = -1; *s = -1;
if (!(family = ParseAddress(&addr, address, port, &addrSize))) if (!(family = ParseAddress(&addr, address)))
return ERR_INVALID_ARGUMENT; return ERR_INVALID_ARGUMENT;
*s = socket(family, SOCK_STREAM, IPPROTO_TCP); *s = socket(family, SOCK_STREAM, IPPROTO_TCP);
if (*s == -1) return errno; if (*s == -1) return errno;
ioctl(*s, FIONBIO, &blocking_raw); ioctl(*s, FIONBIO, &blocking_raw);
res = connect(*s, &addr, addrSize); if (family == AF_INET6) {
addr.v6.sin6_family = AF_INET6;
addr.v6.sin6_port = htons(port);
addrSize = sizeof(addr.v6);
} else if (family == AF_INET) {
addr.v4.sin_family = AF_INET;
addr.v4.sin_port = htons(port);
addrSize = sizeof(addr.v4);
}
res = connect(*s, &addr.raw, addrSize);
return res == -1 ? errno : 0; return res == -1 ? errno : 0;
} }

View File

@ -379,31 +379,39 @@ void Platform_LoadSysFonts(void) {
/*########################################################################################################################* /*########################################################################################################################*
*---------------------------------------------------------Socket----------------------------------------------------------* *---------------------------------------------------------Socket----------------------------------------------------------*
*#########################################################################################################################*/ *#########################################################################################################################*/
static INT (WSAAPI *_WSAStringToAddressW)(LPWSTR addressString, INT addressFamily, LPVOID lpProtocolInfo, LPVOID address, LPINT addressLength); static INT (WSAAPI *_WSAStringToAddressW)(LPWSTR addressString, INT addressFamily, LPVOID protocolInfo, LPVOID address, LPINT addressLength);
static int FallbackParseAddress(SOCKADDR_IN* dst, const cc_string* ip, int port) { static INT WSAAPI FallbackParseAddress(LPWSTR addressString, INT addressFamily, LPVOID protocolInfo, LPVOID address, LPINT addressLength) {
cc_uint8* addr; SOCKADDR_IN* addr4 = (SOCKADDR_IN*)address;
cc_string parts[4 + 1]; cc_uint8* addr = (cc_uint8*)&addr4->sin_addr;
/* +1 in case user tries '1.1.1.1.1' */ cc_string ip, parts[4 + 1];
if (String_UNSAFE_Split(ip, '.', parts, 4 + 1) != 4) return 0; WCHAR tmp[NATIVE_STR_LEN];
addr = (cc_uint8*)&dst->sin_addr;
Mem_Copy(tmp, addressString, sizeof(tmp));
Platform_Utf16ToAnsi(tmp);
ip = String_FromReadonly((char*)tmp);
/* 4+1 in case user tries '1.1.1.1.1' */
if (String_UNSAFE_Split(&ip, '.', parts, 4 + 1) != 4)
return ERR_INVALID_ARGUMENT;
if (!Convert_ParseUInt8(&parts[0], &addr[0]) || !Convert_ParseUInt8(&parts[1], &addr[1]) || if (!Convert_ParseUInt8(&parts[0], &addr[0]) || !Convert_ParseUInt8(&parts[1], &addr[1]) ||
!Convert_ParseUInt8(&parts[2], &addr[2]) || !Convert_ParseUInt8(&parts[3], &addr[3])) !Convert_ParseUInt8(&parts[2], &addr[2]) || !Convert_ParseUInt8(&parts[3], &addr[3]))
return 0; return ERR_INVALID_ARGUMENT;
dst->sin_family = AF_INET; addr4->sin_family = AF_INET;
dst->sin_port = htons(port); return 0;
return AF_INET;
} }
static void LoadWinsockFuncs(void) { static void LoadWinsockFuncs(void) {
static const struct DynamicLibSym funcs[1] = { static const struct DynamicLibSym funcs[] = {
DynamicLib_Sym(WSAStringToAddressW) DynamicLib_Sym(WSAStringToAddressW)
}; };
static const cc_string winsock32 = String_FromConst("WS2_32.DLL"); static const cc_string winsock32 = String_FromConst("WS2_32.DLL");
LoadDynamicFuncs(&winsock32, funcs, Array_Elems(funcs)); LoadDynamicFuncs(&winsock32, funcs, Array_Elems(funcs));
/* Fallback for older OS versions which lack WSAStringToAddressW */
if (!_WSAStringToAddressW) _WSAStringToAddressW = FallbackParseAddress;
} }
cc_result Socket_Available(cc_socket s, int* available) { cc_result Socket_Available(cc_socket s, int* available) {
@ -411,10 +419,27 @@ cc_result Socket_Available(cc_socket s, int* available) {
} }
cc_result Socket_GetError(cc_socket s, cc_result* result) { cc_result Socket_GetError(cc_socket s, cc_result* result) {
socklen_t resultSize = sizeof(cc_result); int resultSize = sizeof(cc_result);
return getsockopt(s, SOL_SOCKET, SO_ERROR, result, &resultSize); return getsockopt(s, SOL_SOCKET, SO_ERROR, result, &resultSize);
} }
static int ParseHost(void* dst, WCHAR* host, int port) {
SOCKADDR_IN* addr4 = (SOCKADDR_IN*)dst;
struct hostent* res;
Platform_Utf16ToAnsi(host);
res = gethostbyname((char*)host);
if (!res || res->h_addrtype != AF_INET) return 0;
/* Must have at least one IPv4 address */
if (!res->h_addr_list[0]) return 0;
addr4->sin_family = AF_INET;
addr4->sin_port = htons(port);
addr4->sin_addr = *(IN_ADDR*)res->h_addr_list[0];
return AF_INET;
}
static int Socket_ParseAddress(void* dst, const cc_string* address, int port) { static int Socket_ParseAddress(void* dst, const cc_string* address, int port) {
SOCKADDR_IN* addr4 = (SOCKADDR_IN*)dst; SOCKADDR_IN* addr4 = (SOCKADDR_IN*)dst;
SOCKADDR_IN6* addr6 = (SOCKADDR_IN6*)dst; SOCKADDR_IN6* addr6 = (SOCKADDR_IN6*)dst;
@ -422,12 +447,8 @@ static int Socket_ParseAddress(void* dst, const cc_string* address, int port) {
DWORD size; DWORD size;
Platform_EncodeUtf16(str, address); Platform_EncodeUtf16(str, address);
/* Fallback for older OS versions which lack WSAStringToAddressW */
if (!_WSAStringToAddressW)
return FallbackParseAddress(addr4, address, port);
size = sizeof(*addr4); size = sizeof(*addr4);
if (!_WSAStringToAddressW(str, AF_INET, NULL, (SOCKADDR*)addr4, &size)) { if (!_WSAStringToAddressW(str, AF_INET, NULL, addr4, &size)) {
addr4->sin_port = htons(port); addr4->sin_port = htons(port);
return AF_INET; return AF_INET;
} }
@ -436,7 +457,7 @@ static int Socket_ParseAddress(void* dst, const cc_string* address, int port) {
addr6->sin6_port = htons(port); addr6->sin6_port = htons(port);
return AF_INET6; return AF_INET6;
} }
return 0; return ParseHost(dst, str, port);
} }
int Socket_ValidAddress(const cc_string* address) { int Socket_ValidAddress(const cc_string* address) {
@ -720,7 +741,7 @@ static BOOL (WINAPI *_AttachConsole)(DWORD processId);
static BOOL (WINAPI *_IsDebuggerPresent)(void); static BOOL (WINAPI *_IsDebuggerPresent)(void);
static void LoadKernelFuncs(void) { static void LoadKernelFuncs(void) {
static const struct DynamicLibSym funcs[2] = { static const struct DynamicLibSym funcs[] = {
DynamicLib_Sym(AttachConsole), DynamicLib_Sym(IsDebuggerPresent) DynamicLib_Sym(AttachConsole), DynamicLib_Sym(IsDebuggerPresent)
}; };
@ -778,7 +799,7 @@ static BOOL (WINAPI *_CryptProtectData )(DATA_BLOB* dataIn, PCWSTR dataDescr, P
static BOOL (WINAPI *_CryptUnprotectData)(DATA_BLOB* dataIn, PWSTR* dataDescr, PVOID entropy, PVOID reserved, PVOID promptStruct, DWORD flags, DATA_BLOB* dataOut); static BOOL (WINAPI *_CryptUnprotectData)(DATA_BLOB* dataIn, PWSTR* dataDescr, PVOID entropy, PVOID reserved, PVOID promptStruct, DWORD flags, DATA_BLOB* dataOut);
static void LoadCryptFuncs(void) { static void LoadCryptFuncs(void) {
static const struct DynamicLibSym funcs[2] = { static const struct DynamicLibSym funcs[] = {
DynamicLib_Sym(CryptProtectData), DynamicLib_Sym(CryptUnprotectData) DynamicLib_Sym(CryptProtectData), DynamicLib_Sym(CryptUnprotectData)
}; };