Merge pull request #855 from UnknownShadow200/SockRewrite

Rewrite address parsing for sockets
This commit is contained in:
UnknownShadow200 2021-06-04 21:03:34 +10:00 committed by GitHub
commit c0dff3c980
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 142 additions and 82 deletions

View File

@ -541,7 +541,6 @@ static void DirectConnectScreen_StartClient(void* w, int idx) {
const cc_string* mppass = &DirectConnectScreen_Instance.iptMppass.text;
cc_string ip, port;
cc_uint8 raw_ip[4];
cc_uint16 raw_port;
int index = String_LastIndexOf(addr, ':');
@ -556,7 +555,7 @@ static void DirectConnectScreen_StartClient(void* w, int idx) {
if (!user->length) {
DirectConnectScreen_SetStatus("&eUsername required"); return;
}
if (!Utils_ParseIP(&ip, raw_ip)) {
if (!Socket_ValidAddress(&ip)) {
DirectConnectScreen_SetStatus("&eInvalid ip"); return;
}
if (!Convert_ParseUInt16(&port, &raw_port)) {

View File

@ -41,7 +41,7 @@ endif
ifeq ($(PLAT),darwin)
LIBS=
LDFLAGS=-rdynamic -framework Carbon -framework AGL -framework OpenGL
LDFLAGS=-rdynamic -framework Carbon -framework AGL -framework OpenGL -framework IOKit
endif
ifeq ($(PLAT),freebsd)

View File

@ -218,15 +218,15 @@ CC_API void Waitable_WaitFor(void* handle, cc_uint32 milliseconds);
/* Calls SysFonts_Register on each font that is available on this platform. */
void Platform_LoadSysFonts(void);
/* Allocates a new non-blocking socket. */
CC_API cc_result Socket_Create(cc_socket* s);
/* Returns how much data is available to be read from the given socket. */
CC_API cc_result Socket_Available(cc_socket s, int* available);
/* Returns (and resets) the last error generated by the given socket. */
CC_API cc_result Socket_GetError(cc_socket s, cc_result* result);
/* Returns non-zero if the given address is valid for a socket to connect to */
CC_API int Socket_ValidAddress(const cc_string* address);
/* Attempts to open a connection to the given IP address:port. */
CC_API cc_result Socket_Connect(cc_socket s, const cc_string* ip, int port);
/* Allocates a new non-blocking socket and then begins connecting to the given address:port. */
CC_API cc_result Socket_Connect(cc_socket* s, const cc_string* address, int port);
/* Attempts to read data from the given socket. */
CC_API cc_result Socket_Read(cc_socket s, cc_uint8* data, cc_uint32 count, cc_uint32* modified);
/* Attempts to write data to the given socket. */

View File

@ -450,16 +450,6 @@ void Platform_LoadSysFonts(void) {
/*########################################################################################################################*
*---------------------------------------------------------Socket----------------------------------------------------------*
*#########################################################################################################################*/
cc_result Socket_Create(cc_socket* s) {
int blockingMode = -1; /* non-blocking mode */
*s = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
if (*s == -1) return errno;
ioctl(*s, FIONBIO, &blockingMode);
return 0;
}
cc_result Socket_Available(cc_socket s, int* available) {
return ioctl(s, FIONREAD, available);
}
@ -469,16 +459,47 @@ cc_result Socket_GetError(cc_socket s, cc_result* result) {
return getsockopt(s, SOL_SOCKET, SO_ERROR, result, &resultSize);
}
cc_result Socket_Connect(cc_socket s, const cc_string* ip, int port) {
struct sockaddr addr;
cc_result res;
addr.sa_family = AF_INET;
static int ParseAddress(void* dst, const cc_string* address, int port, int* addrSize) {
struct sockaddr_in* addr4 = (struct sockaddr_in* )dst;
struct sockaddr_in6* addr6 = (struct sockaddr_in6*)dst;
char str[NATIVE_STR_LEN];
Platform_EncodeUtf8(str, address);
Stream_SetU16_BE( (cc_uint8*)&addr.sa_data[0], port);
if (!Utils_ParseIP(ip, (cc_uint8*)&addr.sa_data[2]))
if (inet_pton(AF_INET, str, &addr4->sin_addr) > 0) {
addr4->sin_family = AF_INET;
addr4->sin_port = htons(port);
*addrSize = sizeof(struct sockaddr_in);
return AF_INET;
}
if (inet_pton(AF_INET6, str, &addr6->sin6_addr) > 0) {
addr6->sin6_family = AF_INET6;
addr6->sin6_port = htons(port);
*addrSize = sizeof(struct sockaddr_in6);
return AF_INET6;
}
return 0;
}
int Socket_ValidAddress(const cc_string* address) {
struct sockaddr_storage addr;
int addrSize;
return ParseAddress(&addr, address, 0, &addrSize);
}
cc_result Socket_Connect(cc_socket* s, const cc_string* address, int port) {
int family, addrSize, blocking_raw = -1; /* non-blocking mode */
struct sockaddr_storage addr;
cc_result res;
*s = -1;
if (!(family = ParseAddress(&addr, address, port, &addrSize)))
return ERR_INVALID_ARGUMENT;
res = connect(s, &addr, sizeof(addr));
*s = socket(family, SOCK_STREAM, IPPROTO_TCP);
if (*s == -1) return errno;
ioctl(*s, FIONBIO, &blocking_raw);
res = connect(*s, &addr, addrSize);
return res == -1 ? errno : 0;
}
@ -1177,6 +1198,7 @@ int Platform_GetCommandLineArgs(int argc, STRING_REF char** argv, cc_string* arg
argc--; argv++; /* skip executable path argument */
#ifdef CC_BUILD_DARWIN
/* Sometimes a "-psn_0_[number]" argument is added before actual args */
if (argc) {
static const cc_string psn = String_FromConst("-psn_0_");
cc_string arg0 = String_FromReadonly(argv[0]);
@ -1186,6 +1208,7 @@ int Platform_GetCommandLineArgs(int argc, STRING_REF char** argv, cc_string* arg
count = min(argc, GAME_MAX_CMDARGS);
for (i = 0; i < count; i++) {
/* -d[directory] argument to change directory data is stored in */
if (argv[i][0] == '-' && argv[i][1] == 'd' && argv[i][2]) {
--count;
continue;
@ -1224,6 +1247,7 @@ cc_result Platform_SetDefaultCurrentDirectory(int argc, char **argv) {
static const cc_string bundle = String_FromConst(".app/Contents/MacOS/");
cc_string raw = String_Init(path, len, 0);
/* If running from within a bundle, set data folder to folder containing bundle */
if (String_CaselessEnds(&raw, &bundle)) {
len -= bundle.length;

View File

@ -257,11 +257,6 @@ extern int interop_SocketGetPending(int sock);
extern int interop_SocketGetError(int sock);
extern int interop_SocketPoll(int sock);
cc_result Socket_Create(cc_socket* s) {
*s = interop_SocketCreate();
return 0;
}
cc_result Socket_Available(cc_socket s, int* available) {
int res = interop_SocketGetPending(s);
/* returned result is negative for error */
@ -283,13 +278,16 @@ cc_result Socket_GetError(cc_socket s, cc_result* result) {
*result = 0; return -res;
}
}
int Socket_ValidAddress(const cc_string* address) { return true; }
cc_result Socket_Connect(cc_socket s, const cc_string* ip, int port) {
cc_result Socket_Connect(cc_socket* s, const cc_string* address, int port) {
char addr[NATIVE_STR_LEN];
int res;
Platform_EncodeUtf8(addr, ip);
Platform_EncodeUtf8(addr, address);
*s = interop_SocketCreate();
/* returned result is negative for error */
res = -interop_SocketConnect(s, addr, port);
res = -interop_SocketConnect(*s, addr, port);
/* error returned when invalid address provided */
if (res == _EHOSTUNREACH) return ERR_INVALID_ARGUMENT;

View File

@ -29,6 +29,12 @@ const cc_result ReturnCode_SocketInProgess = WSAEINPROGRESS;
const cc_result ReturnCode_SocketWouldBlock = WSAEWOULDBLOCK;
const cc_result ReturnCode_DirectoryExists = ERROR_ALREADY_EXISTS;
static void LoadDynamicFuncs(const cc_string* path, const struct DynamicLibSym* syms, int count) {
void* lib = DynamicLib_Load2(path);
if (!lib) { Logger_DynamicLibWarn("loading", path); return; }
DynamicLib_GetAll(lib, syms, count);
}
/*########################################################################################################################*
*---------------------------------------------------------Memory----------------------------------------------------------*
*#########################################################################################################################*/
@ -373,14 +379,31 @@ void Platform_LoadSysFonts(void) {
/*########################################################################################################################*
*---------------------------------------------------------Socket----------------------------------------------------------*
*#########################################################################################################################*/
cc_result Socket_Create(cc_socket* s) {
int blockingMode = -1; /* non-blocking mode */
static INT (WSAAPI *_WSAStringToAddressW)(LPWSTR addressString, INT addressFamily, LPVOID lpProtocolInfo, LPVOID address, LPINT addressLength);
*s = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
if (*s == -1) return WSAGetLastError();
static int FallbackParseAddress(SOCKADDR_IN* dst, const cc_string* ip, int port) {
cc_uint8* addr;
cc_string parts[4 + 1];
/* +1 in case user tries '1.1.1.1.1' */
if (String_UNSAFE_Split(ip, '.', parts, 4 + 1) != 4) return 0;
addr = (cc_uint8*)&dst->sin_addr;
ioctlsocket(*s, FIONBIO, &blockingMode);
return 0;
if (!Convert_ParseUInt8(&parts[0], &addr[0]) || !Convert_ParseUInt8(&parts[1], &addr[1]) ||
!Convert_ParseUInt8(&parts[2], &addr[2]) || !Convert_ParseUInt8(&parts[3], &addr[3]))
return 0;
dst->sin_family = AF_INET;
dst->sin_port = htons(port);
return AF_INET;
}
static void LoadWinsockFuncs(void) {
static const struct DynamicLibSym funcs[1] = {
DynamicLib_Sym(WSAStringToAddressW)
};
static const cc_string winsock32 = String_FromConst("WS2_32.DLL");
LoadDynamicFuncs(&winsock32, funcs, Array_Elems(funcs));
}
cc_result Socket_Available(cc_socket s, int* available) {
@ -392,16 +415,49 @@ cc_result Socket_GetError(cc_socket s, cc_result* result) {
return getsockopt(s, SOL_SOCKET, SO_ERROR, result, &resultSize);
}
cc_result Socket_Connect(cc_socket s, const cc_string* ip, int port) {
struct sockaddr addr;
cc_result res;
addr.sa_family = AF_INET;
static int Socket_ParseAddress(void* dst, const cc_string* address, int port) {
SOCKADDR_IN* addr4 = (SOCKADDR_IN*)dst;
SOCKADDR_IN6* addr6 = (SOCKADDR_IN6*)dst;
WCHAR str[NATIVE_STR_LEN];
DWORD size;
Platform_EncodeUtf16(str, address);
Stream_SetU16_BE( (cc_uint8*)&addr.sa_data[0], port);
if (!Utils_ParseIP(ip, (cc_uint8*)&addr.sa_data[2]))
/* Fallback for older OS versions which lack WSAStringToAddressW */
if (!_WSAStringToAddressW)
return FallbackParseAddress(addr4, address, port);
size = sizeof(*addr4);
if (!_WSAStringToAddressW(str, AF_INET, NULL, (SOCKADDR*)addr4, &size)) {
addr4->sin_port = htons(port);
return AF_INET;
}
size = sizeof(*addr6);
if (!_WSAStringToAddressW(str, AF_INET6, NULL, (SOCKADDR*)addr6, &size)) {
addr6->sin6_port = htons(port);
return AF_INET6;
}
return 0;
}
int Socket_ValidAddress(const cc_string* address) {
SOCKADDR_STORAGE addr;
return Socket_ParseAddress(&addr, address, 0);
}
cc_result Socket_Connect(cc_socket* s, const cc_string* address, int port) {
int family, blockingMode = -1; /* non-blocking mode */
SOCKADDR_STORAGE addr;
cc_result res;
*s = -1;
if (!(family = Socket_ParseAddress(&addr, address, port)))
return ERR_INVALID_ARGUMENT;
res = connect(s, &addr, sizeof(addr));
*s = socket(family, SOCK_STREAM, IPPROTO_TCP);
if (*s == -1) return WSAGetLastError();
ioctlsocket(*s, FIONBIO, &blockingMode);
res = connect(*s, (SOCKADDR*)&addr, sizeof(addr));
return res == -1 ? WSAGetLastError() : 0;
}
@ -667,11 +723,9 @@ static void LoadKernelFuncs(void) {
static const struct DynamicLibSym funcs[2] = {
DynamicLib_Sym(AttachConsole), DynamicLib_Sym(IsDebuggerPresent)
};
static const cc_string kernel32 = String_FromConst("KERNEL32.DLL");
void* lib = DynamicLib_Load2(&kernel32);
if (!lib) { Logger_DynamicLibWarn("loading", &kernel32); return; }
DynamicLib_GetAll(lib, funcs, Array_Elems(funcs));
static const cc_string kernel32 = String_FromConst("KERNEL32.DLL");
LoadDynamicFuncs(&kernel32, funcs, Array_Elems(funcs));
}
void Platform_Init(void) {
@ -685,6 +739,7 @@ void Platform_Init(void) {
if (res) Logger_SysWarn(res, "starting WSA");
LoadKernelFuncs();
LoadWinsockFuncs();
if (_IsDebuggerPresent) hasDebugger = _IsDebuggerPresent();
/* For when user runs from command prompt */
if (_AttachConsole) _AttachConsole(-1); /* ATTACH_PARENT_PROCESS */
@ -726,11 +781,9 @@ static void LoadCryptFuncs(void) {
static const struct DynamicLibSym funcs[2] = {
DynamicLib_Sym(CryptProtectData), DynamicLib_Sym(CryptUnprotectData)
};
static const cc_string crypt32 = String_FromConst("CRYPT32.DLL");
void* lib = DynamicLib_Load2(&crypt32);
if (!lib) { Logger_DynamicLibWarn("loading", &crypt32); return; }
DynamicLib_GetAll(lib, funcs, Array_Elems(funcs));
static const cc_string crypt32 = String_FromConst("CRYPT32.DLL");
LoadDynamicFuncs(&crypt32, funcs, Array_Elems(funcs));
}
cc_result Platform_Encrypt(const void* data, int len, cc_string* dst) {

View File

@ -121,7 +121,7 @@ static int Program_Run(int argc, char** argv) {
} else {
String_Copy(&Game_Username, &args[0]);
String_Copy(&Game_Mppass, &args[1]);
String_Copy(&Server.IP, &args[2]);
String_Copy(&Server.Address,&args[2]);
if (!Convert_ParseUInt16(&args[3], &port)) {
WarnInvalidArg("Invalid port", &args[3]);
@ -154,7 +154,7 @@ int main(int argc, char** argv) {
main_imdct();
#endif
Platform_LogConst("Starting " GAME_APP_NAME " ..");
String_InitArray(Server.IP, ipBuffer);
String_InitArray(Server.Address, ipBuffer);
Options_Load();
res = Program_Run(argc, argv);

View File

@ -241,7 +241,7 @@ static void MPConnection_Fail(const cc_string* reason) {
String_InitArray(msg, msgBuffer);
net_connecting = false;
String_Format2(&msg, "Failed to connect to %s:%i", &Server.IP, &Server.Port);
String_Format2(&msg, "Failed to connect to %s:%i", &Server.Address, &Server.Port);
Game_Disconnect(&msg, reason);
OnClose();
}
@ -252,7 +252,7 @@ static void MPConnection_FailConnect(cc_result result) {
String_InitArray(msg, msgBuffer);
if (result) {
String_Format3(&msg, "Error connecting to %s:%i: %i" _NL, &Server.IP, &Server.Port, &result);
String_Format3(&msg, "Error connecting to %s:%i: %i" _NL, &Server.Address, &Server.Port, &result);
Logger_Log(&msg);
}
MPConnection_Fail(&reason);
@ -292,21 +292,18 @@ static void MPConnection_BeginConnect(void) {
Blocks.CanPlace[BLOCK_STILL_WATER] = false; Blocks.CanDelete[BLOCK_STILL_WATER] = false;
Blocks.CanPlace[BLOCK_BEDROCK] = false; Blocks.CanDelete[BLOCK_BEDROCK] = false;
res = Socket_Create(&net_socket);
if (res) { MPConnection_FailConnect(res); return; }
Server.Disconnected = false;
net_connecting = true;
net_connectTimeout = Game.Time + NET_TIMEOUT_SECS;
res = Socket_Connect(net_socket, &Server.IP, Server.Port);
res = Socket_Connect(&net_socket, &Server.Address, Server.Port);
if (res == ERR_INVALID_ARGUMENT) {
static const cc_string reason = String_FromConst("Invalid IP address");
MPConnection_Fail(&reason);
} else if (res && res != ReturnCode_SocketInProgess && res != ReturnCode_SocketWouldBlock) {
MPConnection_FailConnect(res);
} else {
String_Format2(&title, "Connecting to %s:%i..", &Server.IP, &Server.Port);
Server.Disconnected = false;
net_connecting = true;
net_connectTimeout = Game.Time + NET_TIMEOUT_SECS;
String_Format2(&title, "Connecting to %s:%i..", &Server.Address, &Server.Port);
LoadingScreen_Show(&title, &String_Empty);
}
}
@ -397,7 +394,7 @@ static void MPConnection_Tick(struct ScheduledTask* task) {
if (res) {
String_InitArray(msg, msgBuffer);
String_Format3(&msg, "Error reading from %s:%i: %i" _NL, &Server.IP, &Server.Port, &res);
String_Format3(&msg, "Error reading from %s:%i: %i" _NL, &Server.Address, &Server.Port, &res);
Logger_Log(&msg);
Game_Disconnect(&title_lost, &reason_err);
@ -504,7 +501,7 @@ static void OnInit(void) {
String_InitArray(Server.MOTD, motdBuffer);
String_InitArray(Server.AppName, appBuffer);
if (!Server.IP.length) {
if (!Server.Address.length) {
SPConnection_Init();
} else {
MPConnection_Init();
@ -527,7 +524,7 @@ static void OnReset(void) {
}
static void OnFree(void) {
Server.IP.length = 0;
Server.Address.length = 0;
OnClose();
}

View File

@ -57,8 +57,8 @@ CC_VAR extern struct _ServerConnectionData {
/* Whether the server supports all of code page 437, not just ASCII. */
cc_bool SupportsFullCP437;
/* IP address of the server if multiplayer, empty string if singleplayer. */
cc_string IP;
/* Address of the server if multiplayer, empty string if singleplayer. */
cc_string Address;
/* Port of the server if multiplayer, 0 if singleplayer. */
int Port;
} Server;

View File

@ -138,16 +138,6 @@ void Utils_Resize(void** buffer, int* capacity, cc_uint32 elemSize, int defCapac
}
}
cc_bool Utils_ParseIP(const cc_string* ip, cc_uint8* data) {
cc_string parts[4 + 1];
int count = String_UNSAFE_Split(ip, '.', parts, 4 + 1);
if (count != 4) return false;
return
Convert_ParseUInt8(&parts[0], &data[0]) && Convert_ParseUInt8(&parts[1], &data[1]) &&
Convert_ParseUInt8(&parts[2], &data[2]) && Convert_ParseUInt8(&parts[3], &data[3]);
}
static const char base64_table[64] = {
'A','B','C','D','E','F','G','H','I','J','K','L','M','N','O','P',
'Q','R','S','T','U','V','W','X','Y','Z','a','b','c','d','e','f',

View File

@ -44,7 +44,6 @@ cc_uint32 Utils_CRC32(const cc_uint8* data, cc_uint32 length);
/* NOTE: This cannot be just indexed by byte value - see Utils_CRC32 implementation. */
extern const cc_uint32 Utils_Crc32Table[256];
CC_NOINLINE void Utils_Resize(void** buffer, int* capacity, cc_uint32 elemSize, int defCapacity, int expandElems);
CC_NOINLINE cc_bool Utils_ParseIP(const cc_string* ip, cc_uint8* data);
/* Converts blocks of 3 bytes into 4 ASCII characters. (pads if needed) */
/* Returns the number of ASCII characters written. */