diff --git a/TrueCraft/QueryProtocol.cs b/TrueCraft/QueryProtocol.cs index 47490ef..5fdf9c3 100644 --- a/TrueCraft/QueryProtocol.cs +++ b/TrueCraft/QueryProtocol.cs @@ -1,4 +1,5 @@ using System; +using System.Collections.Concurrent; using System.Collections.Generic; using System.Linq; using System.Text; @@ -19,12 +20,11 @@ namespace TrueCraft private IMultiplayerServer Server; private CancellationTokenSource CToken; - private readonly byte[] ProtocolVersion = new byte[] { 0xFE, 0xFD }; + private readonly byte[] ProtocolVersion = { 0xFE, 0xFD }; private readonly byte Type_Handshake = 0x09; private readonly byte Type_Stat = 0x00; - private Dictionary UserList; - private object UserLock = new object(); + private ConcurrentDictionary UserList; public QueryProtocol(IMultiplayerServer server) { @@ -35,6 +35,7 @@ namespace TrueCraft { Port = Program.ServerConfiguration.QueryPort; Udp = new UdpClient(Port); + UserList = new ConcurrentDictionary(); Timer = new Timer(ResetUserList, null, 0, 30000); CToken = new CancellationTokenSource(); Udp.BeginReceive(HandleReceive, null); @@ -43,7 +44,7 @@ namespace TrueCraft private void HandleReceive(IAsyncResult ar) { if (CToken.IsCancellationRequested) return; - + try { var clientEP = new IPEndPoint(IPAddress.Any, Port); @@ -72,75 +73,108 @@ namespace TrueCraft private void HandleHandshake(byte[] buffer, IPEndPoint clientEP) { - var stream = GetStream(buffer); - int sessionId = GetSessionId(stream); - - var user = new QueryUser { SessionId = sessionId, ChallengeToken = Rnd.Next() }; - lock (UserLock) + using (var ms = new MemoryStream(buffer)) { - if (UserList.ContainsKey(clientEP)) - UserList.Remove(clientEP); - UserList.Add(clientEP, user); + using (var stream = new BinaryReader(ms)) + { + int sessionId = GetSessionId(stream); + + var user = new QueryUser { SessionId = sessionId, ChallengeToken = Rnd.Next() }; + + if (UserList.ContainsKey(clientEP)) + { + QueryUser u; + while (!UserList.TryRemove(clientEP, out u)) + Thread.Sleep(1); + } + + UserList[clientEP] = user; + + using (var response = new MemoryStream()) + { + using (var writer = new BinaryWriter(response)) + { + WriteHead(Type_Handshake, user, writer); + WriteStringToStream(user.ChallengeToken.ToString(), response); + SendResponse(response.ToArray(), clientEP); + } + } + } + } - - var response = GetStream(); - WriteHead(Type_Handshake, user, response); - WriteStringToStream(user.ChallengeToken.ToString(), response.BaseStream); - - SendResponse(response, clientEP); } private void HandleBasicStat(byte[] buffer, IPEndPoint clientEP) { - var stream = GetStream(buffer); - int sessionId = GetSessionId(stream); - int token = GetToken(stream); + using (var ms = new MemoryStream(buffer)) + { + using (var stream = new BinaryReader(ms)) + { + int sessionId = GetSessionId(stream); + int token = GetToken(stream); - var user = GetUser(clientEP); - if (user.ChallengeToken != token || user.SessionId != sessionId) throw new Exception("Invalid credentials"); + var user = GetUser(clientEP); + if (user.ChallengeToken != token || user.SessionId != sessionId) throw new Exception("Invalid credentials"); - var stats = GetStats(); - var response = GetStream(); - WriteHead(Type_Stat, user, response); - WriteStringToStream(stats["hostname"], response.BaseStream); - WriteStringToStream(stats["gametype"], response.BaseStream); - WriteStringToStream(stats["numplayers"], response.BaseStream); - WriteStringToStream(stats["maxplayers"], response.BaseStream); - byte[] hostport = BitConverter.GetBytes(UInt16.Parse(stats["hostport"])); - Array.Reverse(hostport);//The specification needs little endian short - response.Write(hostport); - WriteStringToStream(stats["hostip"], response.BaseStream); + var stats = GetStats(); + using (var response = new MemoryStream()) + { + using (var writer = new BinaryWriter(response)) + { + WriteHead(Type_Stat, user, writer); + WriteStringToStream(stats["hostname"], response); + WriteStringToStream(stats["gametype"], response); + WriteStringToStream(stats["numplayers"], response); + WriteStringToStream(stats["maxplayers"], response); + byte[] hostport = BitConverter.GetBytes(ushort.Parse(stats["hostport"])); + Array.Reverse(hostport);//The specification needs little endian short + writer.Write(hostport); + WriteStringToStream(stats["hostip"], response); - SendResponse(response, clientEP); + SendResponse(response.ToArray(), clientEP); + } + } + } + } } private void HandleFullStat(byte[] buffer, IPEndPoint clientEP) { - var stream = GetStream(buffer); - int sessionId = GetSessionId(stream); - int token = GetToken(stream); - - var user = GetUser(clientEP); - if (user.ChallengeToken != token || user.SessionId != sessionId) throw new Exception("Invalid credentials"); - - var stats = GetStats(); - var response = GetStream(); - WriteHead(Type_Stat, user, response); - WriteStringToStream("SPLITNUM\0\0", response.BaseStream); - foreach (var pair in stats) + using (var stream = new MemoryStream(buffer)) { - WriteStringToStream(pair.Key, response.BaseStream); - WriteStringToStream(pair.Value, response.BaseStream); - } - response.Write((byte)0x00); - response.Write((byte)0x01); - WriteStringToStream("player_\0", response.BaseStream); - var players = GetPlayers(); - foreach (string player in players) - WriteStringToStream(player, response.BaseStream); - response.Write((byte)0x00); + using (var reader = new BinaryReader(stream)) + { + int sessionId = GetSessionId(reader); + int token = GetToken(reader); - SendResponse(response, clientEP); + var user = GetUser(clientEP); + if (user.ChallengeToken != token || user.SessionId != sessionId) throw new Exception("Invalid credentials"); + + var stats = GetStats(); + using (var response = new MemoryStream()) + { + using (var writer = new BinaryWriter(response)) + { + WriteHead(Type_Stat, user, writer); + WriteStringToStream("SPLITNUM\0\0", response); + foreach (var pair in stats) + { + WriteStringToStream(pair.Key, response); + WriteStringToStream(pair.Value, response); + } + writer.Write((byte)0x00); + writer.Write((byte)0x01); + WriteStringToStream("player_\0", response); + var players = GetPlayers(); + foreach (string player in players) + WriteStringToStream(player, response); + writer.Write((byte)0x00); + + SendResponse(response.ToArray(), clientEP); + } + } + } + } } private bool CheckVersion(byte[] ver) @@ -157,40 +191,39 @@ namespace TrueCraft stream.BaseStream.Position = 7; return stream.ReadInt32(); } - private BinaryReader GetStream(byte[] buffer) - { return new BinaryReader(new MemoryStream(buffer)); } - private BinaryWriter GetStream() - { return new BinaryWriter(new MemoryStream()); } + private void WriteHead(byte type, QueryUser user, BinaryWriter stream) { stream.Write(type); stream.Write(user.SessionId); } - private void SendResponse(BinaryWriter res, IPEndPoint destination) + + private void SendResponse(byte[] res, IPEndPoint destination) { - byte[] data = ((MemoryStream)res.BaseStream).ToArray(); - Udp.Send(data, data.Length, destination); + Udp.Send(res, res.Length, destination); } - private QueryUser GetUser(IPEndPoint ep) + private QueryUser GetUser(IPEndPoint ipe) { - QueryUser user; - lock (UserLock) - if (!UserList.TryGetValue(ep, out user)) throw new Exception("Undefined user"); - return user; + if (!UserList.ContainsKey(ipe)) + throw new Exception("Undefined user"); + + return UserList[ipe]; } private Dictionary GetStats() { - var stats = new Dictionary(); - stats.Add("hostname", Program.ServerConfiguration.MOTD); - stats.Add("gametype", "SMP"); - stats.Add("game_id", "TRUECRAFT"); - stats.Add("version", "1.0"); - stats.Add("plugins", "TrueCraft"); - stats.Add("map", Server.Worlds.First().Name); - stats.Add("numplayers", Server.Clients.Count.ToString()); - stats.Add("maxplayers", "64"); - stats.Add("hostport", Program.ServerConfiguration.ServerPort.ToString()); - stats.Add("hostip", Program.ServerConfiguration.ServerAddress); + var stats = new Dictionary + { + {"hostname", Program.ServerConfiguration.MOTD}, + {"gametype", "SMP"}, + {"game_id", "TRUECRAFT"}, + {"version", "1.0"}, + {"plugins", "TrueCraft"}, + {"map", Server.Worlds.First().Name}, + {"numplayers", Server.Clients.Count.ToString()}, + {"maxplayers", "64"}, + {"hostport", Program.ServerConfiguration.ServerPort.ToString()}, + {"hostip", Program.ServerConfiguration.ServerAddress} + }; return stats; } private List GetPlayers() @@ -223,14 +256,14 @@ namespace TrueCraft public void Stop() { + Timer.Dispose(); CToken.Cancel(); Udp.Close(); } private void ResetUserList(object state) { - lock (UserLock) - UserList = new Dictionary(); + UserList.Clear(); } struct QueryUser