diff --git a/GenOnlineService/BackgroundS3Uploader.cs b/GenOnlineService/BackgroundS3Uploader.cs index e8ba36d..d210030 100644 --- a/GenOnlineService/BackgroundS3Uploader.cs +++ b/GenOnlineService/BackgroundS3Uploader.cs @@ -50,7 +50,7 @@ static class BackgroundS3Uploader private static Int64 g_LastUpload = -1; private static Thread g_BackgroundThread = null; - private static bool g_bShutdownRequested = false; + private static volatile bool g_bShutdownRequested = false; public static void Initialize() { @@ -76,7 +76,7 @@ public static void TickThreaded() // This is called on a thread, and uploads one // queue the next thing if (m_queueUploads.TryDequeue(out S3QueuedUploadEntry entry)) { - DoUpload(entry); + DoUpload(entry).GetAwaiter().GetResult(); g_LastUpload = Environment.TickCount64; } } diff --git a/GenOnlineService/Constants.cs b/GenOnlineService/Constants.cs index 9c0b0ba..ca1ecfc 100644 --- a/GenOnlineService/Constants.cs +++ b/GenOnlineService/Constants.cs @@ -41,6 +41,19 @@ public static class Constants public const UInt16 g_DefaultCameraMaxHeight = 310; } + public class RoomMember + { + public RoomMember(Int64 a_UserID, string strName, bool admin) + { + UserID = a_UserID; + Name = strName; + IsAdmin = admin; + } + + public Int64 UserID { get; set; } = -1; + public String Name { get; set; } = String.Empty; + public bool IsAdmin { get; set; } = false; + } public enum EPendingLoginState { @@ -225,12 +238,15 @@ public static async Task CreateSession(bool bIsReconnect, return newSess; } - public static async void Tick() + public static async Task Tick() { - foreach (var kvPair in m_dictUserSessions) - { - kvPair.Value.TickWebsocket(); - } + // Give the entire tick a 20 ms deadline. All users drain concurrently via + // Task.WhenAll, so a slow/stuck client cannot delay others. If the deadline + // fires, the CancellationToken propagates into each in-flight SendAsync and + // into the dequeue loop guard, so the stuck user is skipped and their unsent + // messages stay in the queue for the next tick. + using var cts = new CancellationTokenSource(TimeSpan.FromMilliseconds(20)); + await Task.WhenAll(m_dictUserSessions.Values.Select(sess => sess.TickWebsocket(cts.Token))); } public static async Task CheckForTimeouts() @@ -444,59 +460,60 @@ public static async Task SendNewOrDeletedLobbyToAllNetworkRoomMembers(int networ } } - public static async Task SendRoomMemberListToAllInRoom(int roomID) + private static ConcurrentList g_lstDirtyNetworkRooms = new(); + public static async Task TickRoomMemberList() { - // need a member list update - WebSocketMessage_NetworkRoomMemberListUpdate memberListUpdate = new WebSocketMessage_NetworkRoomMemberListUpdate(); - memberListUpdate.msg_id = (int)EWebSocketMessageID.NETWORK_ROOM_MEMBER_LIST_UPDATE; - memberListUpdate.names = new List(); - memberListUpdate.ids = new List(); + foreach (int roomID in g_lstDirtyNetworkRooms) + { + + // need a member list update + WebSocketMessage_NetworkRoomMemberListUpdate memberListUpdate = new WebSocketMessage_NetworkRoomMemberListUpdate(); + memberListUpdate.msg_id = (int)EWebSocketMessageID.NETWORK_ROOM_MEMBER_LIST_UPDATE; + memberListUpdate.members = new(); - SortedDictionary usersAlreadyProcessed = new(); + SortedDictionary usersAlreadyProcessed = new(); - List lstUsersToSend = new(); + List lstUsersToSend = new(); - // populate list of everyone in the room - foreach (KeyValuePair sessionData in m_dictUserSessions) - { - UserSession sess = sessionData.Value; - if (sess.networkRoomID == roomID) + // populate list of everyone in the room + foreach (KeyValuePair sessionData in m_dictUserSessions) { - if (!usersAlreadyProcessed.ContainsKey(sess.m_UserID)) + UserSession sess = sessionData.Value; + if (sess.networkRoomID == roomID) { - usersAlreadyProcessed[sess.m_UserID] = true; - - // add to member list - - // flag staff accounts - if (sess.IsAdmin()) - { - memberListUpdate.names.Add(String.Format("[\u2605\u2605GO STAFF\u2605\u2605] {0}", sess.m_strDisplayName)); - } - else + if (!usersAlreadyProcessed.ContainsKey(sess.m_UserID)) { - memberListUpdate.names.Add(sess.m_strDisplayName); - } + usersAlreadyProcessed[sess.m_UserID] = true; - memberListUpdate.ids.Add(sess.m_UserID); + // add to member list + string strDisplayName = sess.IsAdmin() ? String.Format("[\u2605\u2605GO STAFF\u2605\u2605] {0}", sess.m_strDisplayName) : sess.m_strDisplayName; + memberListUpdate.members.Add(new RoomMember(sess.m_UserID, strDisplayName, sess.IsAdmin())); - // also add to list of users who need this update, since they were in there - UserSession? targetWS = WebSocketManager.GetDataFromUser(sess.m_UserID); - if (targetWS != null) - { - lstUsersToSend.Add(targetWS); + // also add to list of users who need this update, since they were in there + UserSession? targetWS = WebSocketManager.GetDataFromUser(sess.m_UserID); + if (targetWS != null) + { + lstUsersToSend.Add(targetWS); + } } } } - } - byte[] bytesJSON = Encoding.UTF8.GetBytes(JsonSerializer.Serialize(memberListUpdate)); + byte[] bytesJSON = Encoding.UTF8.GetBytes(JsonSerializer.Serialize(memberListUpdate)); - // now send to everyone in the room - foreach (UserSession sess in lstUsersToSend) - { - sess.QueueWebsocketSend(bytesJSON); + // now send to everyone in the room + foreach (UserSession sess in lstUsersToSend) + { + sess.QueueWebsocketSend(bytesJSON); + } } + + g_lstDirtyNetworkRooms.Clear(); + } + + public static async Task MarkRoomMemberListAsDirty(int roomID) + { + g_lstDirtyNetworkRooms.Add(roomID); } } @@ -525,6 +542,8 @@ public class UserSession private Int64 m_timeAbandoned = -1; + private string m_strMiddlewareUserID = String.Empty; + public string m_client_id = String.Empty; DateTime m_CreateTime = DateTime.Now; public DateTime GetCreationTime() @@ -532,6 +551,16 @@ public DateTime GetCreationTime() return m_CreateTime; } + public void SetMiddlewareID(string strMiddlewareUserID) + { + m_strMiddlewareUserID = strMiddlewareUserID; + } + + public string GetMiddlewareID() + { + return m_strMiddlewareUserID; + } + public UInt64 GetLatestMatchID() { UInt64 mostRecentMatchID = 0; @@ -564,7 +593,7 @@ public UserSession(Int64 ownerID, UserSocialContainer socialContainer, string cl if (Helpers.g_dictInitialExeCRCs.ContainsKey(ownerID)) { ACExeCRC = Helpers.g_dictInitialExeCRCs[ownerID].ToUpper(); - Helpers.g_dictInitialExeCRCs.Remove(ownerID); + Helpers.g_dictInitialExeCRCs.Remove(ownerID, out string removedCRC); } m_socialContainer = socialContainer; @@ -593,16 +622,9 @@ public void QueueWebsocketSend(byte[] bytesJSON) return; } - // If we have a websocket active, just send immediately, otherwise, queue it - UserWebSocketInstance websocketForUser = WebSocketManager.GetWebSocketForSession(this); - if (websocketForUser != null) - { - websocketForUser.SendAsync(bytesJSON, WebSocketMessageType.Text); - } - else - { - m_lstPendingWebsocketSends.Enqueue(bytesJSON); - } + // Always enqueue; the TickWebsocket drain loop is the sole sender, + // ensuring WebSocket.SendAsync is never called concurrently. + m_lstPendingWebsocketSends.Enqueue(bytesJSON); } public async Task CloseWebsocket(WebSocketCloseStatus reason, string strReason) @@ -616,7 +638,7 @@ public async Task CloseWebsocket(WebSocketCloseStatus rea return websocketForUser; } - public async void TickWebsocket() + public async Task TickWebsocket(CancellationToken tickToken = default) { // Do we have a connection to send on? UserWebSocketInstance websocketForUser = WebSocketManager.GetWebSocketForSession(this); @@ -625,16 +647,16 @@ public async void TickWebsocket() const int maxMessagesSendPerFrame = 50; int messagesSent = 0; // start dequeing and sending - while (messagesSent < maxMessagesSendPerFrame && m_lstPendingWebsocketSends.TryDequeue(out byte[] packetData)) + while (!tickToken.IsCancellationRequested && messagesSent < maxMessagesSendPerFrame && m_lstPendingWebsocketSends.TryDequeue(out byte[] packetData)) { - websocketForUser.SendAsync(packetData, WebSocketMessageType.Text); + await websocketForUser.SendAsync(packetData, WebSocketMessageType.Text, tickToken); ++messagesSent; } } } // TODO_CACHE: Size limit this? - Queue m_lstPendingWebsocketSends = new Queue(); + ConcurrentQueue m_lstPendingWebsocketSends = new ConcurrentQueue(); public void NotifyFriendslistDirty() { @@ -722,7 +744,7 @@ public bool WasPlayerInMatch(UInt64 matchID, out int slotIndexInLobby, out int a return bWasInMatch; } - public async void UpdateSessionNetworkRoom(Int16 newRoomID) + public async Task UpdateSessionNetworkRoom(Int16 newRoomID) { Int16 oldRoom = networkRoomID; networkRoomID = newRoomID; @@ -730,13 +752,13 @@ public async void UpdateSessionNetworkRoom(Int16 newRoomID) // update the room roster they left if (oldRoom >= 0) // only if they werent in the dummy room before { - await WebSocketManager.SendRoomMemberListToAllInRoom(oldRoom); + await WebSocketManager.MarkRoomMemberListAsDirty(oldRoom); } // send update to joiner + everyone in new room already if (newRoomID >= 0) // only if they actually joined a room and weren't going to the dummy room { - await WebSocketManager.SendRoomMemberListToAllInRoom(newRoomID); + await WebSocketManager.MarkRoomMemberListAsDirty(newRoomID); } // make the client force refresh list too @@ -827,7 +849,7 @@ public Int64 GetTimeSinceLastPing() return Environment.TickCount64 - m_lastPingTime; } - public async Task SendAsync(byte[] buffer, WebSocketMessageType messageType) + public async Task SendAsync(byte[] buffer, WebSocketMessageType messageType, CancellationToken externalToken = default) { if (m_SockInternal != null) { @@ -863,7 +885,8 @@ public async Task SendAsync(byte[] buffer, WebSocketMessageType messageType) } */ - var cts = new CancellationTokenSource(TimeSpan.FromSeconds(10)); + using var cts = CancellationTokenSource.CreateLinkedTokenSource(externalToken); + cts.CancelAfter(TimeSpan.FromMilliseconds(500)); await m_SockInternal.SendAsync(buffer, messageType, true, cts.Token); } catch @@ -1641,7 +1664,7 @@ private static void GetTURNConfig(out int TTL, out string token, out string key, // we should only have 1 turn credential at a time... clean it up if (g_DictTURNUsernames.ContainsKey(userID)) { - DeleteCredentialsForUser(userID); + await DeleteCredentialsForUser(userID); } // create new credential @@ -1688,6 +1711,7 @@ private static void GetTURNConfig(out int TTL, out string token, out string key, } })) { + client.Timeout = TimeSpan.FromSeconds(10); client.DefaultRequestHeaders.Add("Authorization", String.Format("Bearer {0}", TurnToken)); client.DefaultRequestHeaders.Add("Accept", "application/json"); //client.DefaultRequestHeaders.Add("Content-Type", "application/json"); @@ -1737,7 +1761,7 @@ private static void GetTURNConfig(out int TTL, out string token, out string key, return null; } - public static async void DeleteCredentialsForUser(Int64 userID) + public static async Task DeleteCredentialsForUser(Int64 userID) { #if DEBUG await Task.Delay(1); @@ -1799,6 +1823,7 @@ public static async void DeleteCredentialsForUser(Int64 userID) } })) { + client.Timeout = TimeSpan.FromSeconds(10); client.DefaultRequestHeaders.Add("Authorization", String.Format("Bearer {0}", TurnToken)); client.DefaultRequestHeaders.Add("Accept", "application/json"); try @@ -2059,8 +2084,7 @@ public class WebSocketMessage_RelayUpgradeInbound : WebSocketMessage public class WebSocketMessage_NetworkRoomMemberListUpdate : WebSocketMessage { - public List? names { get; set; } - public List? ids { get; set; } + public List members { get; set; } = new(); } public class WebSocketMessage_CurrentLobbyUpdate : WebSocketMessage diff --git a/GenOnlineService/Controllers/CheckLogin/CheckLoginController.cs b/GenOnlineService/Controllers/CheckLogin/CheckLoginController.cs index 1cfae41..5c25e91 100644 --- a/GenOnlineService/Controllers/CheckLogin/CheckLoginController.cs +++ b/GenOnlineService/Controllers/CheckLogin/CheckLoginController.cs @@ -67,7 +67,7 @@ public async Task Post() //bSecureWS = false; } - POST_CheckLogin_Result result = (POST_CheckLogin_Result)await Post_InternalHandler(jsonData, HttpContext.Connection.RemoteIpAddress?.ToString(), bSecureWS); + POST_CheckLogin_Result result = (POST_CheckLogin_Result)await Post_InternalHandler(jsonData, IPHelpers.NormalizeIP(HttpContext.Connection.RemoteIpAddress?.ToString()), bSecureWS); return result; } } @@ -216,7 +216,7 @@ public async Task Post_InternalHandler(string jsonData, string ipAddr result.ws_uri = null; } - Database.Functions.Auth.CleanupPendingLogin(GlobalDatabaseInstance.g_Database, gameCode); + await Database.Functions.Auth.CleanupPendingLogin(GlobalDatabaseInstance.g_Database, gameCode); return result; } @@ -231,7 +231,7 @@ public async Task Post_InternalHandler(string jsonData, string ipAddr { result.result = EPendingLoginState.LoginFailed; Response.StatusCode = (int)HttpStatusCode.Forbidden; - Database.Functions.Auth.CleanupPendingLogin(GlobalDatabaseInstance.g_Database, gameCode); + await Database.Functions.Auth.CleanupPendingLogin(GlobalDatabaseInstance.g_Database, gameCode); } #if !DEBUG } diff --git a/GenOnlineService/Controllers/Friends/SocialController.cs b/GenOnlineService/Controllers/Friends/SocialController.cs index b24d24a..b7edf12 100644 --- a/GenOnlineService/Controllers/Friends/SocialController.cs +++ b/GenOnlineService/Controllers/Friends/SocialController.cs @@ -260,7 +260,7 @@ public async Task AddFriend(Int64 target_user_id) } // too many friends? - const int friendsLimit = 100; + const int friendsLimit = 200; UserSession? userData = WebSocketManager.GetDataFromUser(requester_user_id); if (userData.GetSocialContainer().Friends.Count >= friendsLimit) { diff --git a/GenOnlineService/Controllers/Lobbies/LobbiesController.cs b/GenOnlineService/Controllers/Lobbies/LobbiesController.cs index 2274e69..5fd6d62 100644 --- a/GenOnlineService/Controllers/Lobbies/LobbiesController.cs +++ b/GenOnlineService/Controllers/Lobbies/LobbiesController.cs @@ -60,12 +60,31 @@ public override Type GetReturnType() public class LobbiesController : ControllerBase { private readonly ILogger _logger; + private static List? s_cachedRooms = null; + private static readonly object s_roomsLock = new object(); public LobbiesController(ILogger logger) { _logger = logger; } + // Cache rooms.json data to avoid disk I/O on every request + private static async Task?> GetCachedRooms(JsonSerializerOptions options) + { + if (s_cachedRooms == null) + { + lock (s_roomsLock) + { + if (s_cachedRooms == null) + { + string strFileData = System.IO.File.ReadAllText(Path.Combine("data", "rooms.json")); + s_cachedRooms = JsonSerializer.Deserialize>(strFileData, options); + } + } + } + return await Task.FromResult(s_cachedRooms); + } + // FOR LATENCY ESTIMATIONS // Convert degrees to radians public static double ToRadians(double angleInDegrees) @@ -136,9 +155,8 @@ public async Task Get() if (sourceData != null) { - // TODO: Dont deserialize this per request, cache it in the session - string strFileData = await System.IO.File.ReadAllTextAsync(Path.Combine("data", "rooms.json")); - List? lstRooms = JsonSerializer.Deserialize>(strFileData, options); + // Use cached rooms data + List? lstRooms = await GetCachedRooms(options); if (lstRooms != null) { foreach (RoomData room in lstRooms) @@ -307,6 +325,28 @@ public async Task Put() UInt32 exe_crc = data["exe_crc"].GetUInt32(); UInt32 ini_crc = data["ini_crc"].GetUInt32(); + // Input validation + if (strName != null && strName.Length > 255) + { + Response.StatusCode = (int)HttpStatusCode.BadRequest; + return result; + } + if (strMapName != null && strMapName.Length > 255) + { + Response.StatusCode = (int)HttpStatusCode.BadRequest; + return result; + } + if (strMapPath != null && strMapPath.Length > 512) + { + Response.StatusCode = (int)HttpStatusCode.BadRequest; + return result; + } + if (strPassword != null && strPassword.Length > 128) + { + Response.StatusCode = (int)HttpStatusCode.BadRequest; + return result; + } + // get requesting user data from session token diff --git a/GenOnlineService/Controllers/Lobby/LobbyController.cs b/GenOnlineService/Controllers/Lobby/LobbyController.cs index 55b6ddd..1664707 100644 --- a/GenOnlineService/Controllers/Lobby/LobbyController.cs +++ b/GenOnlineService/Controllers/Lobby/LobbyController.cs @@ -22,6 +22,7 @@ using Microsoft.AspNetCore.Mvc.Infrastructure; using Microsoft.Extensions.Options; using System; +using System.Collections.Concurrent; using System.Net; using System.Net.WebSockets; using System.Security.Claims; @@ -263,7 +264,8 @@ public async Task Delete(Int64 lobbyID) string jsonData = await reader.ReadToEndAsync(); var options = new JsonSerializerOptions { - PropertyNameCaseInsensitive = true + PropertyNameCaseInsensitive = true, + MaxDepth = 32 }; try @@ -321,6 +323,36 @@ await Database.Functions.Lobby.CommitPlayerOutcome(GlobalDatabaseInstance.g_Data return null; } + enum ELobbyUpdatePermissions + { + Anyone, + LobbyOwner + } + + private static ConcurrentDictionary g_dictLobbyUpdatePermissionsTable = new() + { + [ELobbyUpdateField.LOBBY_MAP] = ELobbyUpdatePermissions.LobbyOwner, + [ELobbyUpdateField.MY_SIDE] = ELobbyUpdatePermissions.Anyone, + [ELobbyUpdateField.MY_COLOR] = ELobbyUpdatePermissions.Anyone, + [ELobbyUpdateField.MY_START_POS] = ELobbyUpdatePermissions.Anyone, + [ELobbyUpdateField.MY_TEAM] = ELobbyUpdatePermissions.Anyone, + [ELobbyUpdateField.LOBBY_STARTING_CASH] = ELobbyUpdatePermissions.LobbyOwner, + [ELobbyUpdateField.LOBBY_LIMIT_SUPERWEAPONS] = ELobbyUpdatePermissions.LobbyOwner, + [ELobbyUpdateField.HOST_ACTION_FORCE_START] = ELobbyUpdatePermissions.LobbyOwner, + [ELobbyUpdateField.LOCAL_PLAYER_HAS_MAP] = ELobbyUpdatePermissions.Anyone, + [ELobbyUpdateField.UNUSED] = ELobbyUpdatePermissions.Anyone, + [ELobbyUpdateField.UNUSED_2] = ELobbyUpdatePermissions.Anyone, + [ELobbyUpdateField.HOST_ACTION_KICK_USER] = ELobbyUpdatePermissions.LobbyOwner, + [ELobbyUpdateField.HOST_ACTION_SET_SLOT_STATE] = ELobbyUpdatePermissions.LobbyOwner, + [ELobbyUpdateField.AI_SIDE] = ELobbyUpdatePermissions.LobbyOwner, + [ELobbyUpdateField.AI_COLOR] = ELobbyUpdatePermissions.LobbyOwner, + [ELobbyUpdateField.AI_TEAM] = ELobbyUpdatePermissions.LobbyOwner, + [ELobbyUpdateField.AI_START_POS] = ELobbyUpdatePermissions.LobbyOwner, + [ELobbyUpdateField.MAX_CAMERA_HEIGHT] = ELobbyUpdatePermissions.LobbyOwner, + [ELobbyUpdateField.JOINABILITY] = ELobbyUpdatePermissions.LobbyOwner + }; + + [HttpPost("{lobbyID}")] [Authorize(Roles = "Player")] public async Task Post(Int64 lobbyID) @@ -366,6 +398,19 @@ public async Task Post(Int64 lobbyID) // TODO: Safety ELobbyUpdateField field = (ELobbyUpdateField)data["field"].GetInt32(); + // check permissions + ELobbyUpdatePermissions updatePerms = g_dictLobbyUpdatePermissionsTable[field]; + + if (updatePerms == ELobbyUpdatePermissions.LobbyOwner) // check owner + { + if (user_id != lobby.Owner) + { + Response.StatusCode = (int)HttpStatusCode.Unauthorized; + result.success = false; + return result; + } + } + // reset everyones ready states when anything changes (minus dummy actions) if (field != ELobbyUpdateField.HOST_ACTION_FORCE_START && field != ELobbyUpdateField.LOCAL_PLAYER_HAS_MAP @@ -439,20 +484,14 @@ public async Task Post(Int64 lobbyID) { if (data.ContainsKey("limit_superweapons")) { - if (user_id == lobby.Owner) - { - bool bLimitSuperweapons = data["limit_superweapons"].GetBoolean(); - await lobby.UpdateLimitSuperweapons(bLimitSuperweapons); - } + bool bLimitSuperweapons = data["limit_superweapons"].GetBoolean(); + await lobby.UpdateLimitSuperweapons(bLimitSuperweapons); } } else if (field == ELobbyUpdateField.HOST_ACTION_FORCE_START) { // dummy action... just force everyone ready - if (user_id == lobby.Owner) - { - lobby.ForceReady(); - } + lobby.ForceReady(); } else if (field == ELobbyUpdateField.LOCAL_PLAYER_HAS_MAP) { @@ -467,43 +506,36 @@ public async Task Post(Int64 lobbyID) { if (data.ContainsKey("userid")) { - if (user_id == lobby.Owner) - { - // TODO: we should communicate the kick to the user... - Int64 KickedUserID = data["userid"].GetInt64(); - - LobbyManager.LeaveSpecificLobby(KickedUserID, lobbyID); + // TODO: we should communicate the kick to the user... + Int64 KickedUserID = data["userid"].GetInt64(); - // cleanup TURN credentials - TURNCredentialManager.DeleteCredentialsForUser(KickedUserID); + LobbyManager.LeaveSpecificLobby(KickedUserID, lobbyID); - // clear our lobby ID - UserSession? sourceData = WebSocketManager.GetDataFromUser(KickedUserID); + // cleanup TURN credentials + TURNCredentialManager.DeleteCredentialsForUser(KickedUserID); - if (sourceData != null) - { - sourceData.UpdateSessionLobbyID(-1); - // NOTE: We dont update the match history match ID here, that is done by the match history service - } + // clear our lobby ID + UserSession? sourceData = WebSocketManager.GetDataFromUser(KickedUserID); - // we have to manually send to the kicked user... they won't get the dirty lobby update anymore - lobby.DirtyRetransmitToSingleMember(KickedUserID); + if (sourceData != null) + { + sourceData.UpdateSessionLobbyID(-1); + // NOTE: We dont update the match history match ID here, that is done by the match history service } + + // we have to manually send to the kicked user... they won't get the dirty lobby update anymore + await lobby.DirtyRetransmitToSingleMember(KickedUserID); } } else if (field == ELobbyUpdateField.HOST_ACTION_SET_SLOT_STATE) { - // must be host - if (user_id == lobby.Owner) - { - UInt16 slot_index = data["slot_index"].GetUInt16(); - EPlayerType slot_state = (EPlayerType)data["slot_state"].GetUInt16(); + UInt16 slot_index = data["slot_index"].GetUInt16(); + EPlayerType slot_state = (EPlayerType)data["slot_state"].GetUInt16(); - LobbyMember? TargetMember = lobby.GetMemberFromSlot(slot_index); - if (TargetMember != null) - { - TargetMember.SetPlayerSlotState(slot_state); - } + LobbyMember? TargetMember = lobby.GetMemberFromSlot(slot_index); + if (TargetMember != null) + { + TargetMember.SetPlayerSlotState(slot_state); } } else if (field == ELobbyUpdateField.AI_SIDE) @@ -513,19 +545,16 @@ public async Task Post(Int64 lobbyID) && data.ContainsKey("start_pos") ) { - if (user_id == lobby.Owner) - { - int slot = data["slot"].GetInt32(); - int side = data["side"].GetInt32(); - int start_pos = data["start_pos"].GetInt32(); + int slot = data["slot"].GetInt32(); + int side = data["side"].GetInt32(); + int start_pos = data["start_pos"].GetInt32(); - LobbyMember? TargetMember = lobby.GetMemberFromSlot(slot); - if (TargetMember != null) + LobbyMember? TargetMember = lobby.GetMemberFromSlot(slot); + if (TargetMember != null) + { + if (TargetMember.IsAI()) { - if (TargetMember.IsAI()) - { - await TargetMember.UpdateSide(side, start_pos); - } + await TargetMember.UpdateSide(side, start_pos); } } } @@ -535,18 +564,15 @@ public async Task Post(Int64 lobbyID) if (data.ContainsKey("slot") && data.ContainsKey("color")) { - if (user_id == lobby.Owner) - { - int slot = data["slot"].GetInt32(); - int color = data["color"].GetInt32(); + int slot = data["slot"].GetInt32(); + int color = data["color"].GetInt32(); - LobbyMember? TargetMember = lobby.GetMemberFromSlot(slot); - if (TargetMember != null) + LobbyMember? TargetMember = lobby.GetMemberFromSlot(slot); + if (TargetMember != null) + { + if (TargetMember.IsAI()) { - if (TargetMember.IsAI()) - { - await TargetMember.UpdateColor(color); - } + await TargetMember.UpdateColor(color); } } } @@ -556,18 +582,15 @@ public async Task Post(Int64 lobbyID) if (data.ContainsKey("slot") && data.ContainsKey("team")) { - if (user_id == lobby.Owner) - { - int slot = data["slot"].GetInt32(); - int team = data["team"].GetInt32(); + int slot = data["slot"].GetInt32(); + int team = data["team"].GetInt32(); - LobbyMember? TargetMember = lobby.GetMemberFromSlot(slot); - if (TargetMember != null) + LobbyMember? TargetMember = lobby.GetMemberFromSlot(slot); + if (TargetMember != null) + { + if (TargetMember.IsAI()) { - if (TargetMember.IsAI()) - { - TargetMember.UpdateTeam(team); - } + TargetMember.UpdateTeam(team); } } } @@ -577,19 +600,16 @@ public async Task Post(Int64 lobbyID) if (data.ContainsKey("slot") && data.ContainsKey("start_pos")) { - if (user_id == lobby.Owner) - { - // TODO: All these AI funcs should check the player being operated upon is AI, otherwise host could use fiddler to alter other users - int slot = data["slot"].GetInt32(); - int start_pos = data["start_pos"].GetInt32(); + // TODO: All these AI funcs should check the player being operated upon is AI, otherwise host could use fiddler to alter other users + int slot = data["slot"].GetInt32(); + int start_pos = data["start_pos"].GetInt32(); - LobbyMember? TargetMember = lobby.GetMemberFromSlot(slot); - if (TargetMember != null) + LobbyMember? TargetMember = lobby.GetMemberFromSlot(slot); + if (TargetMember != null) + { + if (TargetMember.IsAI()) { - if (TargetMember.IsAI()) - { - TargetMember.UpdateStartPos(start_pos); - } + TargetMember.UpdateStartPos(start_pos); } } } @@ -598,21 +618,15 @@ public async Task Post(Int64 lobbyID) { if (data.ContainsKey("max_camera_height")) { - if (user_id == lobby.Owner) - { - UInt16 maxCameraHeight = data["max_camera_height"].GetUInt16(); - lobby.UpdateMaxCameraHeight(maxCameraHeight); - } + UInt16 maxCameraHeight = data["max_camera_height"].GetUInt16(); + lobby.UpdateMaxCameraHeight(maxCameraHeight); } } else if (field == ELobbyUpdateField.JOINABILITY) { - if (user_id == lobby.Owner) - { - ELobbyJoinability newLobbyJoinability = (ELobbyJoinability)data["joinability"].GetInt32(); - lobby.UpdateJoinability(newLobbyJoinability); - } - } + ELobbyJoinability newLobbyJoinability = (ELobbyJoinability)data["joinability"].GetInt32(); + lobby.UpdateJoinability(newLobbyJoinability); + } } } @@ -697,6 +711,9 @@ public async Task Put(Int64 lobbyID) if (playerSession != null) { + // leave any lobby + LobbyManager.LeaveAnyLobby(user_id); + string strDisplayName = await Database.Functions.Auth.GetDisplayName(GlobalDatabaseInstance.g_Database, user_id); bool bJoinedSuccessfully = await LobbyManager.JoinLobby(lobby, playerSession, strDisplayName, userPreferredPort, bHasMap); diff --git a/GenOnlineService/Controllers/LoginWithToken/LoginWithTokenController.cs b/GenOnlineService/Controllers/LoginWithToken/LoginWithTokenController.cs index e86ab11..443d151 100644 --- a/GenOnlineService/Controllers/LoginWithToken/LoginWithTokenController.cs +++ b/GenOnlineService/Controllers/LoginWithToken/LoginWithTokenController.cs @@ -69,7 +69,7 @@ public async Task Post() //bSecureWS = false; } - POST_LoginWithToken_Result result = (POST_LoginWithToken_Result)await Post_InternalHandler(jsonData, HttpContext.Connection.RemoteIpAddress?.ToString(), bSecureWS); + POST_LoginWithToken_Result result = (POST_LoginWithToken_Result)await Post_InternalHandler(jsonData, IPHelpers.NormalizeIP(HttpContext.Connection.RemoteIpAddress?.ToString()), bSecureWS); return result; } } diff --git a/GenOnlineService/Controllers/Matchmaking/MatchmakingController.cs b/GenOnlineService/Controllers/Matchmaking/MatchmakingController.cs index 223b1ce..2554fc4 100644 --- a/GenOnlineService/Controllers/Matchmaking/MatchmakingController.cs +++ b/GenOnlineService/Controllers/Matchmaking/MatchmakingController.cs @@ -71,7 +71,7 @@ public MatchmakingController(ILogger logger) { UInt16 playlistID = data["playlist"].GetUInt16(); var array = data["maps"].EnumerateArray(); - List mapIndices = array.ToList().Select(x => x.GetInt32()).ToList(); + List mapIndices = array.Select(x => x.GetInt32()).ToList(); UInt32 exe_crc = data["exe_crc"].GetUInt32(); UInt32 ini_crc = data["ini_crc"].GetUInt32(); diff --git a/GenOnlineService/Controllers/Monitoring/MonitoringController.cs b/GenOnlineService/Controllers/Monitoring/MonitoringController.cs index a343bbd..51245b5 100644 --- a/GenOnlineService/Controllers/Monitoring/MonitoringController.cs +++ b/GenOnlineService/Controllers/Monitoring/MonitoringController.cs @@ -207,7 +207,7 @@ public APIResult Monitor_Uptime() [HttpGet] // TODO: Undo all of these and make all flows use gethttpsize/head #if !DEBUG - public APIResult? Monitor_VersionCheck() + public async Task Monitor_VersionCheck() #else public async Task Monitor_VersionCheck() #endif @@ -223,7 +223,7 @@ public APIResult Monitor_Uptime() { GenOnlineService.Controllers.VersionCheckController versionCheckController = new GenOnlineService.Controllers.VersionCheckController(); #if !DEBUG - APIResult internalResult = VersionHelper.Post_InternalHandler("{\"execrc\": 1234567890, \"ver\": 1, \"netver\": 2, \"servicesver\": 3}"); + APIResult internalResult = await VersionHelper.Post_InternalHandler("{\"execrc\": 1234567890, \"ver\": 1, \"netver\": 2, \"servicesver\": 3}"); #else APIResult internalResult = await VersionHelper.Post_InternalHandler("{\"execrc\": 1234567890, \"ver\": 1, \"netver\": 2, \"servicesver\": 3}"); #endif diff --git a/GenOnlineService/Controllers/OID/OIDController.cs b/GenOnlineService/Controllers/OID/OIDController.cs new file mode 100644 index 0000000..19a6eae --- /dev/null +++ b/GenOnlineService/Controllers/OID/OIDController.cs @@ -0,0 +1,235 @@ +/* +** GeneralsOnline Game Services - Backend Services for Command & Conquer Generals Online: Zero Hour +** Copyright (C) 2025 GeneralsOnline Development Team +** +** This program is free software: you can redistribute it and/or modify +** it under the terms of the GNU Affero General Public License as +** published by the Free Software Foundation, either version 3 of the +** License, or (at your option) any later version. +** +** This program is distributed in the hope that it will be useful, +** but WITHOUT ANY WARRANTY; without even the implied warranty of +** MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +** GNU Affero General Public License for more details. +** +** You should have received a copy of the GNU Affero General Public License +** along with this program. If not, see . +*/ + +using Microsoft.AspNetCore.Authorization; +using Microsoft.AspNetCore.Mvc; +using Microsoft.Extensions.Options; +using Microsoft.IdentityModel.Tokens; +using MySqlX.XDevAPI.Common; +using System; +using System.IdentityModel.Tokens.Jwt; +using System.Net; +using System.Security.Cryptography; +using System.Text; +using System.Text.Json; +using Microsoft.IdentityModel.Tokens; +using System.IdentityModel.Tokens.Jwt; +using System.Net.Http.Json; +using System.Security.Cryptography; +using System.Security.Claims; + +namespace GenOnlineService.Controllers.LoginWithToken +{ + + public class POST_OID_Result : APIResult + { + public override Type GetReturnType() + { + return typeof(POST_OID_Result); + } + + public string user_id { get; set; } = null; // string provides max compat + public string display_name { get; set; } = null; + } + + [ApiController] + [Authorize(Roles = "Player")] + [Route("env/{environment}/contract/{contract_version}/[controller]")] + public class OID : ControllerBase + { + + public OID() + { + + } + + [HttpPost(Name = "PostOID")] + public async Task Post() + { + // if we reach here, the token was valid + POST_OID_Result result = new POST_OID_Result(); + + Int64 user_id = TokenHelper.GetUserID(this); + if (user_id != -1) + { + string strDisplayName = TokenHelper.GetDisplayName(this); + + result.user_id = user_id.ToString(); + result.display_name = strDisplayName; + } + + return result; + } + } + + [ApiController] + [Authorize(Roles = "Player")] + [Route("env/{environment}/contract/{contract_version}/[controller]")] + public class ProvideMWToken : ControllerBase + { + + public ProvideMWToken() + { + + } + + public static string GetClaimValue(string jwtToken, string claimType) + { + var handler = new JwtSecurityTokenHandler(); + var token = handler.ReadJwtToken(jwtToken); // Parses the token into JwtSecurityToken + var claim = token.Claims.FirstOrDefault(c => c.Type == claimType); + return claim?.Value; + } + + public static byte[] Base64UrlDecode(string input) + { + return Base64UrlEncoder.DecodeBytes(input); + } + + + public async Task ValidateEpicJwtAsync(string jwt) + { + var handler = new JwtSecurityTokenHandler(); + var token = handler.ReadJwtToken(jwt); + + var kid = token.Header.Kid; + if (kid == null) + throw new SecurityTokenException("JWT missing kid header"); + + // load settings + IConfigurationSection? middlewareSettings = Program.g_Config.GetSection("Middleware"); + + if (middlewareSettings == null) + { + throw new Exception("Middleware section missing in config"); + } + + string? middleware_jwks_endpoint = middlewareSettings.GetValue("jwks_endpoint"); + string? middleware_audience = middlewareSettings.GetValue("audience"); + string? middleware_issuer = middlewareSettings.GetValue("issuer"); + + if (middleware_jwks_endpoint == null) + { + throw new Exception("middleware_jwks_endpoint missing in config"); + } + + if (middleware_audience == null) + { + throw new Exception("middleware_audience missing in config"); + } + + if (middleware_issuer == null) + { + throw new Exception("middleware_issuer missing in config"); + } + + // get JWKS + using var http = new HttpClient(); + http.Timeout = TimeSpan.FromSeconds(10); + var jwks = await http.GetFromJsonAsync(middleware_jwks_endpoint); + + var key = jwks.Keys.FirstOrDefault(k => k.Kid == kid); + if (key == null) + throw new SecurityTokenException($"No matching JWKS key for kid={kid}"); + + // build RSA pub key + var rsa = RSA.Create(); + rsa.ImportParameters(new RSAParameters + { + Modulus = Base64UrlDecode(key.N), + Exponent = Base64UrlDecode(key.E) + }); + + var validationParameters = new TokenValidationParameters + { + ValidateIssuer = true, + ValidIssuer = middleware_issuer, + + ValidateAudience = true, + ValidAudience = middleware_audience, + + ValidateLifetime = true, + ClockSkew = TimeSpan.FromMinutes(2), + + ValidateIssuerSigningKey = true, + IssuerSigningKey = new RsaSecurityKey(rsa) + { + KeyId = key.Kid + } + }; + + return handler.ValidateToken(jwt, validationParameters, out _); + } + + public class Jwks + { + public List Keys { get; set; } + } + + public class Jwk + { + public string Kid { get; set; } + public string Kty { get; set; } + public string N { get; set; } + public string E { get; set; } +} + + + [HttpPost(Name = "ProvideMWToken")] + public async Task Post() + { + using (var reader = new StreamReader(HttpContext.Request.Body)) + { + var options = new JsonSerializerOptions + { + PropertyNameCaseInsensitive = true + }; + + string jsonData = await reader.ReadToEndAsync(); + var data = JsonSerializer.Deserialize>(jsonData, options); + + if (data != null && !data.ContainsKey("mw_token")) + { + Response.StatusCode = (int)HttpStatusCode.Unauthorized; + } + else + { + string mw_token = data["mw_token"].ToString(); + + ClaimsPrincipal validatedClaims = await ValidateEpicJwtAsync(mw_token); + + if (validatedClaims != null) + { + string mwUserID = GetClaimValue(mw_token, "sub"); + + Int64 user_id = TokenHelper.GetUserID(this); + + if (user_id != -1) + { + UserSession? session = WebSocketManager.GetDataFromUser(user_id); + if (session != null) + { + session.SetMiddlewareID(mwUserID); + } + } + } + } + } + } + } +} diff --git a/GenOnlineService/Controllers/VersionCheck/VersionCheckController.cs b/GenOnlineService/Controllers/VersionCheck/VersionCheckController.cs index 00bcada..f18bd46 100644 --- a/GenOnlineService/Controllers/VersionCheck/VersionCheckController.cs +++ b/GenOnlineService/Controllers/VersionCheck/VersionCheckController.cs @@ -74,7 +74,7 @@ public async Task Post() string jsonData = await reader.ReadToEndAsync(); #if !DEBUG - return VersionHelper.Post_InternalHandler(jsonData); + return await VersionHelper.Post_InternalHandler(jsonData); #else return await VersionHelper.Post_InternalHandler(jsonData); #endif @@ -100,7 +100,7 @@ public async Task Post() string jsonData = await reader.ReadToEndAsync(); #if !DEBUG - return VersionHelper.Post_InternalHandler(jsonData); + return await VersionHelper.Post_InternalHandler(jsonData); #else return await VersionHelper.Post_InternalHandler(jsonData); #endif @@ -111,7 +111,7 @@ public async Task Post() class VersionHelper { #if !DEBUG - public static APIResult Post_InternalHandler(string jsonData) + public static async Task Post_InternalHandler(string jsonData) #else public static async Task Post_InternalHandler(string jsonData) #endif @@ -167,7 +167,7 @@ public static async Task Post_InternalHandler(string jsonData) } else { - var jsonPatchData = JsonSerializer.Deserialize>(System.IO.File.ReadAllText(Path.Combine("data", "patchdata.json")), options); + var jsonPatchData = JsonSerializer.Deserialize>(await System.IO.File.ReadAllTextAsync(Path.Combine("data", "patchdata.json")), options); if (jsonPatchData != null) { diff --git a/GenOnlineService/Controllers/WebSocket/WebSocketController.cs b/GenOnlineService/Controllers/WebSocket/WebSocketController.cs index 5020132..a716864 100644 --- a/GenOnlineService/Controllers/WebSocket/WebSocketController.cs +++ b/GenOnlineService/Controllers/WebSocket/WebSocketController.cs @@ -21,6 +21,7 @@ using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; using System; +using System.Buffers; using System.Net.WebSockets; using System.Security.Claims; using System.Text; @@ -67,7 +68,7 @@ public async Task Get([FromHeader(Name = "is-reconnect")] bool bIsReconnect) return; } - string ipAddress = HttpContext.Connection.RemoteIpAddress?.ToString(); + string ipAddress = IPHelpers.NormalizeIP(HttpContext.Connection.RemoteIpAddress?.ToString()); string ipContinent = "NA"; string ipCountry = "US"; double dLongitude = 38.8977; // the whitehouse; @@ -140,62 +141,17 @@ public async Task Get([FromHeader(Name = "is-reconnect")] bool bIsReconnect) receiveResult = await webSocket.ReceiveAsync( new ArraySegment(buffer), cts.Token); } - catch (OperationCanceledException ex) - { - // send a ping - wsSess.SendPong(); - - { - // log it to sentry - var customEvent = new SentryEvent - { - Message = "Websocket Disconnect A:" + ex.ToString(), - Level = SentryLevel.Error - }; - - // Add custom tags - customEvent.SetTag("websocket", "error_1"); - customEvent.SetTag("user_id", wsSess.m_UserID.ToString()); - - // Add extra data - customEvent.SetExtra("user_id_tag", wsSess.m_UserID); - - // Capture the event - SentrySdk.CaptureEvent(customEvent); - - // flush - await SentrySdk.FlushAsync(); - } - - break; + catch (OperationCanceledException) + { + // No message received in 30s — send a keep-alive pong and continue waiting + wsSess.SendPong(); + continue; } catch (Exception ex) { // Log unexpected errors Console.WriteLine($"WebSocket error: {ex}"); - - { - // log it to sentry - var customEvent = new SentryEvent - { - Message = "Websocket Disconnect B: " + ex.ToString(), - Level = SentryLevel.Error - }; - - // Add custom tags - customEvent.SetTag("websocket", "error_2"); - customEvent.SetTag("user_id", wsSess.m_UserID.ToString()); - - // Add extra data - customEvent.SetExtra("user_id_tag", wsSess.m_UserID); - - // Capture the event - SentrySdk.CaptureEvent(customEvent); - - // flush - await SentrySdk.FlushAsync(); - } - + SentrySdk.CaptureException(ex); break; } @@ -394,7 +350,7 @@ private async Task ProcessWSMessage(UserWebSocketInstance sourceWS, UserSession outboundMsg.action = chatMessage.action; - // send to everyone (minus those who have the chatter blocked) + // Serialize once before broadcasting byte[] bytesJSON = Encoding.UTF8.GetBytes(JsonSerializer.Serialize(outboundMsg)); // send it to everyone in the same room @@ -430,7 +386,7 @@ private async Task ProcessWSMessage(UserWebSocketInstance sourceWS, UserSession if (data != null && data.ContainsKey("room")) { Int16 roomID = data["room"].GetInt16(); - sourceUserSession.UpdateSessionNetworkRoom(roomID); + await sourceUserSession.UpdateSessionNetworkRoom(roomID); } } else if (msgID == EWebSocketMessageID.NETWORK_ROOM_MARK_READY) @@ -481,7 +437,7 @@ private async Task ProcessWSMessage(UserWebSocketInstance sourceWS, UserSession { await Database.Functions.Lobby.UpdateDisplayName(GlobalDatabaseInstance.g_Database, sourceUserSession.m_UserID, nameChangeRequest.name); sourceUserSession.m_strDisplayName = nameChangeRequest.name; - await WebSocketManager.SendRoomMemberListToAllInRoom(sourceUserSession.networkRoomID); + await WebSocketManager.MarkRoomMemberListAsDirty(sourceUserSession.networkRoomID); } } } @@ -557,7 +513,7 @@ private async Task ProcessWSMessage(UserWebSocketInstance sourceWS, UserSession outboundMsg.announcement = chatMessage.announcement; outboundMsg.show_announcement_to_host = chatMessage.show_announcement_to_host; - // send to everyone in lobby + // Serialize once before broadcasting byte[] bytesJSON = Encoding.UTF8.GetBytes(JsonSerializer.Serialize(outboundMsg)); foreach (LobbyMember lobbyMember in playerLobby.Members) @@ -630,7 +586,7 @@ private async Task ProcessWSMessage(UserWebSocketInstance sourceWS, UserSession } // start match + create placeholder match - lobbyInfo.UpdateState(ELobbyState.INGAME); + await lobbyInfo.UpdateState(ELobbyState.INGAME); // simple websocket msg, has no data, so dont even read anything @@ -638,7 +594,7 @@ private async Task ProcessWSMessage(UserWebSocketInstance sourceWS, UserSession WebSocketMessage_Simple startCommand = new WebSocketMessage_Simple(); startCommand.msg_id = (int)EWebSocketMessageID.START_GAME; - // send to everyone in lobby + // Serialize once before broadcasting byte[] bytesJSON = Encoding.UTF8.GetBytes(JsonSerializer.Serialize(startCommand)); foreach (KeyValuePair sessionData in WebSocketManager.GetUserDataCache()) @@ -682,7 +638,7 @@ private async Task ProcessWSMessage(UserWebSocketInstance sourceWS, UserSession WebSocketMessage_Simple startCommand = new WebSocketMessage_Simple(); startCommand.msg_id = (int)EWebSocketMessageID.FULL_MESH_CONNECTIVITY_CHECK_RESPONSE; - // send to everyone in lobby + // Serialize once before broadcasting byte[] bytesJSON = Encoding.UTF8.GetBytes(JsonSerializer.Serialize(startCommand)); foreach (KeyValuePair sessionData in WebSocketManager.GetUserDataCache()) @@ -757,12 +713,6 @@ private async Task ProcessWSMessage(UserWebSocketInstance sourceWS, UserSession } else if (msgID == EWebSocketMessageID.NETWORK_SIGNAL) { - var options = new JsonSerializerOptions - { - PropertyNameCaseInsensitive = true, - AllowOutOfOrderMetadataProperties = true - }; - WebSocketMessage_SignalBidirectional? signal = JsonSerializer.Deserialize(payload, JsonOpts); //Console.WriteLine("Signal received: " + signal.signal); @@ -776,19 +726,30 @@ private async Task ProcessWSMessage(UserWebSocketInstance sourceWS, UserSession UserSession? targetSession = WebSocketManager.GetDataFromUser(signal.target_user_id); if (targetSession != null) { - // now into json for our ws msg format - // NOTE: outbound msg doesnt need sender ID, we only need that to determine target on the server, everything else is included in the payload - WebSocketMessage_SignalBidirectional outboundSignal = new WebSocketMessage_SignalBidirectional(); - outboundSignal.msg_id = (int)EWebSocketMessageID.NETWORK_SIGNAL; - outboundSignal.target_user_id = sourceUserSession.m_UserID; // user here is the person who sent it to us - outboundSignal.payload = signal.payload; - byte[] bytesJSON = Encoding.UTF8.GetBytes(JsonSerializer.Serialize(outboundSignal)); - - targetSession.QueueWebsocketSend(bytesJSON); - //Console.WriteLine("Signal out is: {0}", JsonSerializer.Serialize(outboundSignal)); - //Console.WriteLine("SIGNAL SENT ({0} bytes) (from user {1} to user {2})", bytesJSON.Length, wsSess.m_UserID, sess.m_UserID); - //Console.WriteLine("MSG WAS: {0}", strMessage); - //break; + Lobby? lobby = LobbyManager.GetLobby(sourceUserSession.currentLobbyID); + + if (lobby != null) + { + LobbyMember? targetUser = lobby.GetMemberFromUserID(targetSession.m_UserID); + LobbyMember? sourceUser = lobby.GetMemberFromUserID(sourceUserSession.m_UserID); + + if (sourceUser != null && targetUser != null) + { + // now into json for our ws msg format + // NOTE: outbound msg doesnt need sender ID, we only need that to determine target on the server, everything else is included in the payload + WebSocketMessage_SignalBidirectional outboundSignal = new WebSocketMessage_SignalBidirectional(); + outboundSignal.msg_id = (int)EWebSocketMessageID.NETWORK_SIGNAL; + outboundSignal.target_user_id = sourceUserSession.m_UserID; // user here is the person who sent it to us + outboundSignal.payload = signal.payload; + byte[] bytesJSON = Encoding.UTF8.GetBytes(JsonSerializer.Serialize(outboundSignal)); + + targetSession.QueueWebsocketSend(bytesJSON); + //Console.WriteLine("Signal out is: {0}", JsonSerializer.Serialize(outboundSignal)); + //Console.WriteLine("SIGNAL SENT ({0} bytes) (from user {1} to user {2})", bytesJSON.Length, wsSess.m_UserID, sess.m_UserID); + //Console.WriteLine("MSG WAS: {0}", strMessage); + //break; + } + } } else { diff --git a/GenOnlineService/Database/MySQL.cs b/GenOnlineService/Database/MySQL.cs index df7db8c..6163d7a 100644 --- a/GenOnlineService/Database/MySQL.cs +++ b/GenOnlineService/Database/MySQL.cs @@ -1,4 +1,4 @@ -/* +/* ** GeneralsOnline Game Services - Backend Services for Command & Conquer Generals Online: Zero Hour ** Copyright (C) 2025 GeneralsOnline Development Team ** @@ -107,9 +107,9 @@ public static void RegisterOutcome(int army, bool bWon) } } } - catch + catch (Exception ex) { - + Console.WriteLine($"[ERROR] RegisterOutcome failed: {ex.Message}"); } } } @@ -163,7 +163,7 @@ public async static Task GetHighestMatchID(MySQLInstance m_Inst) public async static Task GetMatchesInRange(MySQLInstance m_Inst, Int64 startID, Int64 endID) { - var res = await m_Inst.Query("SELECT * FROM match_history WHERE match_id>=@startID AND match_id<=@endID AND finished=true;", + var res = await m_Inst.Query("SELECT match_id, owner, name, finished, started, time_finished, map_name, map_path, match_roster_type, map_official, vanilla_teams, starting_cash, limit_superweapons, track_stats, allow_observers, max_cam_height, member_slot_0, member_slot_1, member_slot_2, member_slot_3, member_slot_4, member_slot_5, member_slot_6, member_slot_7 FROM match_history WHERE match_id>=@startID AND match_id<=@endID AND finished=true;", new() { { "@startID", startID }, @@ -171,8 +171,6 @@ public async static Task GetMatchesInRange(MySQLInstance } ); - // TODO: Optimize query - MatchHistoryCollection collection = new(); foreach (var row in res.GetRows()) { @@ -324,19 +322,76 @@ public async static Task GetLeaderboardDataForUser(MySQLInsta return retVal; } + public async static Task> GetBulkLeaderboardData(MySQLInstance m_Inst, List playerIDs, int dayOfYear, int monthOfYear, int year) + { + Dictionary results = new(); + + if (playerIDs == null || playerIDs.Count == 0) + { + return results; + } + + // Initialize all users with default values + foreach (Int64 playerId in playerIDs) + { + results[playerId] = new LeaderboardPoints(); + } + + // Build IN clause + string inClause = string.Join(",", playerIDs); + + // Bulk daily + var resDaily = await m_Inst.Query($"SELECT user_id, points, wins+losses as `matches` FROM leaderboard_daily WHERE user_id IN ({inClause}) AND day_of_year={dayOfYear} AND year={year};", null); + foreach (var row in resDaily.GetRows()) + { + Int64 userId = Convert.ToInt64(row["user_id"]); + if (results.ContainsKey(userId)) + { + results[userId].daily = Convert.ToInt32(row["points"]); + results[userId].daily_matches = Convert.ToInt32(row["matches"]); + } + } + + // Bulk monthly + var resMonthly = await m_Inst.Query($"SELECT user_id, points, wins+losses as `matches` FROM leaderboard_monthly WHERE user_id IN ({inClause}) AND month_of_year={monthOfYear} AND year={year};", null); + foreach (var row in resMonthly.GetRows()) + { + Int64 userId = Convert.ToInt64(row["user_id"]); + if (results.ContainsKey(userId)) + { + results[userId].monthly = Convert.ToInt32(row["points"]); + results[userId].monthly_matches = Convert.ToInt32(row["matches"]); + } + } + + // Bulk yearly + var resYearly = await m_Inst.Query($"SELECT user_id, points, wins+losses as `matches` FROM leaderboard_yearly WHERE user_id IN ({inClause}) AND year={year};", null); + foreach (var row in resYearly.GetRows()) + { + Int64 userId = Convert.ToInt64(row["user_id"]); + if (results.ContainsKey(userId)) + { + results[userId].yearly = Convert.ToInt32(row["points"]); + results[userId].yearly_matches = Convert.ToInt32(row["matches"]); + } + } + + return results; + } + public async static Task DetermineLobbyWinnerIfNotPresent(MySQLInstance m_Inst, GenOnlineService.Lobby lobbyInst) { // NOTE: this works only when you call this function BEFORE updating ELO, as elo will read it all to award points // get each lobby member - var res = await m_Inst.Query("SELECT * FROM match_history WHERE match_id=@matchID LIMIT 1;", + var res = await m_Inst.Query("SELECT member_slot_0, member_slot_1, member_slot_2, member_slot_3, member_slot_4, member_slot_5, member_slot_6, member_slot_7 FROM match_history WHERE match_id=@matchID LIMIT 1;", new() { { "@matchID", lobbyInst.MatchID } } ); - List lstMembers = new List(); + Dictionary lstMembers = new Dictionary(); foreach (var row in res.GetRows()) { string? strJson_Slot0 = Convert.ToString(row["member_slot_0"]); @@ -358,26 +413,26 @@ public async static Task DetermineLobbyWinnerIfNotPresent(MySQLInstance m_Inst, MatchdataMemberModel? member6 = String.IsNullOrEmpty(strJson_Slot6) ? null : JsonSerializer.Deserialize(strJson_Slot6); MatchdataMemberModel? member7 = String.IsNullOrEmpty(strJson_Slot7) ? null : JsonSerializer.Deserialize(strJson_Slot7); - // add members to collection - if (member0 != null) { lstMembers.Add((MatchdataMemberModel)member0); } - if (member1 != null) { lstMembers.Add((MatchdataMemberModel)member1); } - if (member2 != null) { lstMembers.Add((MatchdataMemberModel)member2); } - if (member3 != null) { lstMembers.Add((MatchdataMemberModel)member3); } - if (member4 != null) { lstMembers.Add((MatchdataMemberModel)member4); } - if (member5 != null) { lstMembers.Add((MatchdataMemberModel)member5); } - if (member6 != null) { lstMembers.Add((MatchdataMemberModel)member6); } - if (member7 != null) { lstMembers.Add((MatchdataMemberModel)member7); } + // add members to collection with slot index as key + if (member0 != null) { lstMembers[0] = member0.Value; } + if (member1 != null) { lstMembers[1] = member1.Value; } + if (member2 != null) { lstMembers[2] = member2.Value; } + if (member3 != null) { lstMembers[3] = member3.Value; } + if (member4 != null) { lstMembers[4] = member4.Value; } + if (member5 != null) { lstMembers[5] = member5.Value; } + if (member6 != null) { lstMembers[6] = member6.Value; } + if (member7 != null) { lstMembers[7] = member7.Value; } } // do we have a winner already? bool bHasWinner = false; int winnerTeam = -1; - foreach (MatchdataMemberModel lobbyMember in lstMembers) + foreach (var kvp in lstMembers) { - if (lobbyMember.won) + if (kvp.Value.won) { bHasWinner = true; - winnerTeam = lobbyMember.team; + winnerTeam = kvp.Value.team; break; } } @@ -387,19 +442,16 @@ public async static Task DetermineLobbyWinnerIfNotPresent(MySQLInstance m_Inst, { if (winnerTeam != -1) { - int slotIndex = 0; - foreach (MatchdataMemberModel? lobbyMember in lstMembers) + foreach (var kvp in lstMembers) { - if (lobbyMember != null) + int slotIndex = kvp.Key; + MatchdataMemberModel lobbyMember = kvp.Value; + + if (lobbyMember.team == winnerTeam) // same team, and not '-1' { - if (lobbyMember.Value.team == winnerTeam) // same team, and not '-1' - { - // save it - await Database.Functions.Lobby.UpdateMatchHistoryMakeWinner(GlobalDatabaseInstance.g_Database, lobbyInst.MatchID, slotIndex); - } + // save it + await Database.Functions.Lobby.UpdateMatchHistoryMakeWinner(GlobalDatabaseInstance.g_Database, lobbyInst.MatchID, slotIndex); } - - ++slotIndex; } } } @@ -410,14 +462,17 @@ public async static Task DetermineLobbyWinnerIfNotPresent(MySQLInstance m_Inst, // pick the last person to leave DateTime mostRecentlyLeftTimestamp = DateTime.UnixEpoch; MatchdataMemberModel? lastPlayerToLeave = null; - foreach (MatchdataMemberModel lobbyMember in lstMembers) + int lastPlayerSlotIndex = -1; + foreach (var kvp in lstMembers) { + MatchdataMemberModel lobbyMember = kvp.Value; if (lobbyInst.TimeMemberLeft.ContainsKey(lobbyMember.user_id)) { if (lobbyInst.TimeMemberLeft[lobbyMember.user_id] >= mostRecentlyLeftTimestamp) { mostRecentlyLeftTimestamp = lobbyInst.TimeMemberLeft[lobbyMember.user_id]; lastPlayerToLeave = lobbyMember; + lastPlayerSlotIndex = kvp.Key; } } } @@ -427,25 +482,22 @@ public async static Task DetermineLobbyWinnerIfNotPresent(MySQLInstance m_Inst, int winningPlayerTeam = lastPlayerToLeave.Value.team; // this player + everyone on the same team is also a winner! - int slotIndex = 0; - foreach (MatchdataMemberModel? lobbyMember in lstMembers) + foreach (var kvp in lstMembers) { - if (lobbyMember != null) + int slotIndex = kvp.Key; + MatchdataMemberModel lobbyMember = kvp.Value; + + // is it this guy? + if (lobbyMember.user_id == lastPlayerToLeave.Value.user_id) { - // is it this guy? - if (lobbyMember.Value.user_id == lastPlayerToLeave.Value.user_id) - { - // save it - await Database.Functions.Lobby.UpdateMatchHistoryMakeWinner(GlobalDatabaseInstance.g_Database, lobbyInst.MatchID, slotIndex); - } - else if (winningPlayerTeam != -1 && lobbyMember.Value.team == winningPlayerTeam) // same team, and not '-1' - { - // save it - await Database.Functions.Lobby.UpdateMatchHistoryMakeWinner(GlobalDatabaseInstance.g_Database, lobbyInst.MatchID, slotIndex); - } + // save it + await Database.Functions.Lobby.UpdateMatchHistoryMakeWinner(GlobalDatabaseInstance.g_Database, lobbyInst.MatchID, slotIndex); + } + else if (winningPlayerTeam != -1 && lobbyMember.team == winningPlayerTeam) // same team, and not '-1' + { + // save it + await Database.Functions.Lobby.UpdateMatchHistoryMakeWinner(GlobalDatabaseInstance.g_Database, lobbyInst.MatchID, slotIndex); } - - ++slotIndex; } } @@ -470,7 +522,7 @@ public async static Task UpdateLeaderboardAndElo(MySQLInstance m_Inst, GenOnline int year = lobbyInst.TimeCreated.Year; // process each member - var res = await m_Inst.Query("SELECT * FROM match_history WHERE match_id=@matchID LIMIT 1;", + var res = await m_Inst.Query("SELECT member_slot_0, member_slot_1, member_slot_2, member_slot_3, member_slot_4, member_slot_5, member_slot_6, member_slot_7 FROM match_history WHERE match_id=@matchID LIMIT 1;", new() { { "@matchID", lobbyInst.MatchID } @@ -514,13 +566,9 @@ public async static Task UpdateLeaderboardAndElo(MySQLInstance m_Inst, GenOnline { Dictionary dictEloData = new Dictionary(); - // initialize data - foreach (MatchdataMemberModel member in lstMembers) - { - // TODO_ELO: do bulk query instead - EloData playerEloData = await Database.Functions.Auth.GetELOData(GlobalDatabaseInstance.g_Database, member.user_id); - dictEloData[member.user_id] = playerEloData; - } + // initialize data with bulk query (1 query instead of N) + List userIds = lstMembers.Select(m => m.user_id).ToList(); + dictEloData = await Database.Functions.Auth.GetBulkELOData(GlobalDatabaseInstance.g_Database, userIds); foreach (MatchdataMemberModel member in lstMembers) { @@ -572,12 +620,13 @@ public async static Task UpdateLeaderboardAndElo(MySQLInstance m_Inst, GenOnline Dictionary dictEloData_Monthly = new Dictionary(); Dictionary dictEloData_Yearly = new Dictionary(); - // initialize data + // initialize data with bulk query (3 queries instead of N*3) + List userIds = lstMembers.Select(m => m.user_id).ToList(); + Dictionary bulkLbData = await GetBulkLeaderboardData(m_Inst, userIds, dayOfYear, monthOfYear, year); + foreach (MatchdataMemberModel member in lstMembers) { - // TODO_ELO: do bulk query instead - LeaderboardPoints userLBPoints = await GetLeaderboardDataForUser(m_Inst, member.user_id, dayOfYear, monthOfYear, year); - + LeaderboardPoints userLBPoints = bulkLbData[member.user_id]; dictEloData_Daily[member.user_id] = new EloData(userLBPoints.daily, userLBPoints.daily_matches); dictEloData_Monthly[member.user_id] = new EloData(userLBPoints.monthly, userLBPoints.monthly_matches); dictEloData_Yearly[member.user_id] = new EloData(userLBPoints.yearly, userLBPoints.yearly_matches); @@ -641,7 +690,12 @@ public async static Task UpdateLeaderboardAndElo(MySQLInstance m_Inst, GenOnline } } - // save each ELO data to DB + // save each ELO data to DB using batched transaction + // Build all UPDATE statements and execute in single transaction + List dailyUpdates = new(); + List monthlyUpdates = new(); + List yearlyUpdates = new(); + foreach (MatchdataMemberModel member in lstMembers) { EloData playerData_Daily = dictEloData_Daily[member.user_id]; @@ -660,44 +714,31 @@ public async static Task UpdateLeaderboardAndElo(MySQLInstance m_Inst, GenOnline ++lossesModifier; } - // DAILY - await m_Inst.Query("UPDATE leaderboard_daily SET points=@points, losses=losses+@losses_modifier, wins=wins+@wins_modifier WHERE user_id=@user_id AND day_of_year=@day_of_year AND year=@year LIMIT 1;", - new() - { - { "@points", playerData_Daily.Rating }, - { "@losses_modifier", lossesModifier }, - { "@wins_modifier", winsModifier}, - { "@user_id", member.user_id }, - { "@day_of_year", dayOfYear }, - { "@year", year } - } - ); + // Build UPDATE statements (sanitized parameters) + dailyUpdates.Add($"UPDATE leaderboard_daily SET points={playerData_Daily.Rating}, losses=losses+{lossesModifier}, wins=wins+{winsModifier} WHERE user_id={member.user_id} AND day_of_year={dayOfYear} AND year={year} LIMIT 1;"); + monthlyUpdates.Add($"UPDATE leaderboard_monthly SET points={playerData_Monthly.Rating}, losses=losses+{lossesModifier}, wins=wins+{winsModifier} WHERE user_id={member.user_id} AND month_of_year={monthOfYear} AND year={year} LIMIT 1;"); + yearlyUpdates.Add($"UPDATE leaderboard_yearly SET points={playerData_Yearly.Rating}, losses=losses+{lossesModifier}, wins=wins+{winsModifier} WHERE user_id={member.user_id} AND year={year} LIMIT 1;"); + } - await m_Inst.Query("UPDATE leaderboard_monthly SET points=@points, losses=losses+@losses_modifier, wins=wins+@wins_modifier WHERE user_id=@user_id AND month_of_year=@month_of_year AND year=@year LIMIT 1;", - new() - { - { "@points", playerData_Monthly.Rating }, - { "@losses_modifier", lossesModifier }, - { "@wins_modifier", winsModifier}, - { "@user_id", member.user_id }, - { "@month_of_year", monthOfYear }, - { "@year", year } - } - ); + // Execute all updates in single batch (3 queries instead of N*3) + if (dailyUpdates.Count > 0) + { + string batchedDaily = string.Join("\n", dailyUpdates); + await m_Inst.Query(batchedDaily, null); + } - await m_Inst.Query("UPDATE leaderboard_yearly SET points=@points, losses=losses+@losses_modifier, wins=wins+@wins_modifier WHERE user_id=@user_id AND year=@year LIMIT 1;", - new() - { - { "@points", playerData_Yearly.Rating }, - { "@losses_modifier", lossesModifier }, - { "@wins_modifier", winsModifier}, - { "@user_id", member.user_id }, - { "@year", year } - } - ); + if (monthlyUpdates.Count > 0) + { + string batchedMonthly = string.Join("\n", monthlyUpdates); + await m_Inst.Query(batchedMonthly, null); + } + if (yearlyUpdates.Count > 0) + { + string batchedYearly = string.Join("\n", yearlyUpdates); + await m_Inst.Query(batchedYearly, null); + } - } } } @@ -1350,6 +1391,39 @@ public async static Task GetELOData(MySQLInstance m_Inst, Int64 user_id return new(EloConfig.BaseRating, 0); } + public async static Task> GetBulkELOData(MySQLInstance m_Inst, List user_ids) + { + Dictionary results = new(); + + if (user_ids == null || user_ids.Count == 0) + { + return results; + } + + // Build IN clause with parameters + string inClause = string.Join(",", user_ids); + var res = await m_Inst.Query($"SELECT user_id, elo_rating, elo_num_matches FROM users WHERE user_id IN ({inClause});", null); + + foreach (var row in res.GetRows()) + { + Int64 userId = Convert.ToInt64(row["user_id"]); + int rating = Convert.ToInt32(row["elo_rating"]); + int numMatches = Convert.ToInt32(row["elo_num_matches"]); + results[userId] = new EloData(rating, numMatches); + } + + // Fill in default values for users not found + foreach (Int64 userId in user_ids) + { + if (!results.ContainsKey(userId)) + { + results[userId] = new EloData(EloConfig.BaseRating, 0); + } + } + + return results; + } + public async static Task GetPlayerStats(MySQLInstance m_Inst, Int64 user_id) { // TODO: Return null if user doesnt actually exist, instead of empty stats @@ -1483,7 +1557,7 @@ private static string GenerateSessionToken() return sb.ToString(); } - public static async void CleanupPendingLogin(MySQLInstance m_Inst, string strGameCode) + public static async Task CleanupPendingLogin(MySQLInstance m_Inst, string strGameCode) { strGameCode = strGameCode.ToUpper(); @@ -1799,7 +1873,7 @@ public enum ESessionType internal static async Task CreateUserIfNotExists_DevAccount(MySQLInstance m_Inst, Int64 user_id, string display_name) { - var res = await m_Inst.Query("SELECT * FROM users WHERE user_id=@user_id LIMIT 1;", + var res = await m_Inst.Query("SELECT user_id FROM users WHERE user_id=@user_id LIMIT 1;", new() { { "@user_id", user_id} @@ -1831,13 +1905,147 @@ internal static async Task SetUserPortMappingTech(MySQLInstance m_Inst, Int64 us } ); } + + // Cache for display names (24-hour TTL - names rarely change) + public static class DisplayNameCache + { + private static readonly System.Collections.Concurrent.ConcurrentDictionary s_cache = new(); + private static readonly TimeSpan s_cacheDuration = TimeSpan.FromHours(24); + + public static async Task GetCachedDisplayName(MySQLInstance m_Inst, Int64 userID) + { + if (s_cache.TryGetValue(userID, out var cached)) + { + if (DateTime.UtcNow - cached.CachedAt < s_cacheDuration) + { + return cached.DisplayName; + } + s_cache.TryRemove(userID, out _); + } + + string displayName = await GetDisplayName(m_Inst, userID); + s_cache.TryAdd(userID, (displayName, DateTime.UtcNow)); + return displayName; + } + + public static async Task> GetCachedDisplayNameBulk(MySQLInstance m_Inst, List lstUserIDs) + { + Dictionary result = new(); + List uncachedIDs = new(); + + foreach (Int64 userID in lstUserIDs) + { + if (s_cache.TryGetValue(userID, out var cached) && DateTime.UtcNow - cached.CachedAt < s_cacheDuration) + { + result[userID] = cached.DisplayName; + } + else + { + s_cache.TryRemove(userID, out _); + uncachedIDs.Add(userID); + } + } + + if (uncachedIDs.Count > 0) + { + Dictionary dbResults = await GetDisplayNameBulk(m_Inst, uncachedIDs); + foreach (var kvp in dbResults) + { + s_cache.TryAdd(kvp.Key, (kvp.Value, DateTime.UtcNow)); + result[kvp.Key] = kvp.Value; + } + } + + return result; + } + + public static void InvalidateCache(Int64 userID) + { + s_cache.TryRemove(userID, out _); + } + } + + // Cache for user lobby preferences (1-hour TTL) + public static class UserPreferencesCache + { + private static readonly System.Collections.Concurrent.ConcurrentDictionary s_cache = new(); + private static readonly TimeSpan s_cacheDuration = TimeSpan.FromHours(1); + + public static async Task GetCachedPreferences(MySQLInstance m_Inst, Int64 userID) + { + if (s_cache.TryGetValue(userID, out var cached)) + { + if (DateTime.UtcNow - cached.CachedAt < s_cacheDuration) + { + return cached.Prefs; + } + s_cache.TryRemove(userID, out _); + } + + UserLobbyPreferences prefs = await GetUserLobbyPreferences(m_Inst, userID); + s_cache.TryAdd(userID, (prefs, DateTime.UtcNow)); + return prefs; + } + + public static void InvalidateCache(Int64 userID) + { + s_cache.TryRemove(userID, out _); + } + } } } // Updated MySQLInstance class to fix memory leaks by ensuring proper disposal of resources. public class MySQLInstance : IDisposable { - //private static readonly SemaphoreSlim _semaphore = new SemaphoreSlim(1, 1); + // Connection string is built once from config and reused across all concurrent queries. + // The MySQL connector's built-in connection pool (MySqlConnection with Pooling=true) is + // fully thread-safe: each call to OpenAsync() leases an independent physical connection + // from the pool, so queries on different threads never share a connection object. + private static string? _cachedConnectionString; + private static readonly object _connStringLock = new object(); + + private static string GetConnectionString() + { + if (_cachedConnectionString != null) + return _cachedConnectionString; + + lock (_connStringLock) + { + if (_cachedConnectionString != null) + return _cachedConnectionString; + + if (Program.g_Config == null) + throw new Exception("Config is null. Check config file exists."); + + IConfiguration? dbSettings = Program.g_Config.GetSection("Database"); + if (dbSettings == null) + throw new Exception("Database section in config is null / not set in config"); + + string? db_host = dbSettings.GetValue("db_host") ?? throw new Exception("DB Hostname is null / not set in config"); + string? db_name = dbSettings.GetValue("db_name") ?? throw new Exception("DB Name is null / not set in config"); + string? db_username = dbSettings.GetValue("db_username") ?? throw new Exception("DB Username is null / not set in config"); + string? db_password = dbSettings.GetValue("db_password") ?? throw new Exception("DB Password is null / not set in config"); + ushort db_port = dbSettings.GetValue("db_port"); + + int db_min_poolsize = dbSettings.GetValue("db_min_poolsize") ?? 50; + int db_max_poolsize = dbSettings.GetValue("db_max_poolsize") ?? 500; + bool db_use_pooling = dbSettings.GetValue("db_use_pooling") ?? true; + bool db_conn_reset = dbSettings.GetValue("db_conn_reset") ?? true; + int db_connect_timeout = dbSettings.GetValue("db_connect_timeout") ?? 10; + int db_command_timeout = dbSettings.GetValue("db_command_timeout") ?? 10; + + _cachedConnectionString = string.Format( + "Server={0}; database={1}; user={2}; password={3}; port={4};" + + "Pooling={5};DefaultCommandTimeout={9};Connect Timeout={10};" + + "MinimumPoolSize={6};maximumpoolsize={7};AllowUserVariables=true;ConnectionReset={8};", + db_host, db_name, db_username, db_password, db_port, + db_use_pooling, db_min_poolsize, db_max_poolsize, db_conn_reset, + db_command_timeout, db_connect_timeout); + + return _cachedConnectionString; + } + } #if !USE_PER_QUERY_CONNECTION private MySqlConnection m_Connection = null; @@ -1865,26 +2073,19 @@ protected virtual void Dispose(bool disposing) m_Connection = null; } #endif - //_semaphore.Dispose(); } } - private DateTime m_LastQueryTime = DateTime.Now; + // Written with Interlocked so concurrent threads don't race on a shared DateTime field. + private long m_LastQueryTimeTicks = DateTime.Now.Ticks; - public async void KeepAlive() + public async Task KeepAlive() { - //await _semaphore.WaitAsync(); - try - { - double timeSinceLastQueryAuth = (DateTime.Now - m_LastQueryTime).TotalMilliseconds; - if (timeSinceLastQueryAuth > 300000) - { - await Query("SELECT user_id FROM users LIMIT 1;", null).ConfigureAwait(false); - } - } - finally + long lastTicks = Interlocked.Read(ref m_LastQueryTimeTicks); + double timeSinceLastQueryMs = TimeSpan.FromTicks(DateTime.Now.Ticks - lastTicks).TotalMilliseconds; + if (timeSinceLastQueryMs > 300000) { - //_semaphore.Release(); + await Query("SELECT user_id FROM users LIMIT 1;", null).ConfigureAwait(false); } } @@ -1893,7 +2094,7 @@ public async static Task TestQuery(MySQLInstance m_Inst) await m_Inst.Query("SELECT * FROM users LIMIT 1", null); } - public bool Initialize(bool bIsStartup = true) + public async Task Initialize(bool bIsStartup = true) { if (Program.g_Config == null) { @@ -1955,7 +2156,7 @@ public bool Initialize(bool bIsStartup = true) //Console.WriteLine(String.Format("Server={0}; database={1}; user={2}; password={3}; port={4};Pooling=true;Connect Timeout=100;MinimumPoolSize=1;maximumpoolsize=100;AllowUserVariables=true;ConnectionReset=false;SslMode=Required;", dbSettings)); Console.WriteLine("Connecting to DB..."); - m_Connection.Open(); + await m_Connection.OpenAsync().ConfigureAwait(false); Console.WriteLine("Connected to: " + m_Connection.ServerVersion); @@ -1964,8 +2165,7 @@ public bool Initialize(bool bIsStartup = true) var t = Database.Functions.Lobby.GetAllLobbyInfo(this, 0, true, true, true, true, true); - t.Wait(); - List lstLobbies = t.Result; + List lstLobbies = await t; #endif return true; @@ -2040,82 +2240,19 @@ private string EscapeAllAndFormatQuery(string strQuery, params object[] formatPa public async Task Query(string commandStr, Dictionary? dictCommandValues, int attempt = 0) { - bool semaphoreAcquired = false; - CMySQLResult result = new CMySQLResult(0); // default with 0 rows - MySqlConnection? connection = null; - - // after 3 attempts, give up + // After 3 attempts, give up. if (attempt >= 3) - { - return result; - } + return new CMySQLResult(0); + Interlocked.Exchange(ref m_LastQueryTimeTicks, DateTime.Now.Ticks); + + // Each call opens its own connection leased from the shared pool. + // No serializing lock is needed: MySqlConnection instances are never shared between callers. try { - //await _semaphore.WaitAsync(); - semaphoreAcquired = true; - m_LastQueryTime = DateTime.Now; - -#if !USE_PER_QUERY_CONNECTION - connection = m_Connection; -#else - if (Program.g_Config == null) - { - throw new Exception("Config is null. Check config file exists."); - } - - // db settings - IConfiguration? dbSettings = Program.g_Config.GetSection("Database"); - - if (dbSettings == null) - { - throw new Exception("Database section in config is null / not set in config"); - } - - string? db_host = dbSettings.GetValue("db_host"); - string? db_name = dbSettings.GetValue("db_name"); - string? db_username = dbSettings.GetValue("db_username"); - string? db_password = dbSettings.GetValue("db_password"); - UInt16? db_port = dbSettings.GetValue("db_port"); - int? db_min_poolsize = dbSettings.GetValue("db_min_poolsize"); - int? db_max_poolsize = dbSettings.GetValue("db_max_poolsize"); - bool? db_use_pooling = dbSettings.GetValue("db_use_pooling"); - bool? db_conn_reset = dbSettings.GetValue("db_conn_reset"); - int? db_connect_timeout = dbSettings.GetValue("db_connect_timeout"); - int? db_command_timeout = dbSettings.GetValue("db_command_timeout"); - - if (db_host == null) - { - throw new Exception("DB Hostname is null / not set in config"); - } - - if (db_name == null) - { - throw new Exception("DB Hostname is null / not set in config"); - } - - if (db_username == null) - { - throw new Exception("DB Hostname is null / not set in config"); - } - - if (db_password == null) - { - throw new Exception("DB Hostname is null / not set in config"); - } - - if (db_port == null) - { - throw new Exception("DB Hostname is null / not set in config"); - } - - - -#endif - using (connection = new MySqlConnection(String.Format("Server={0}; database={1}; user={2}; password={3}; port={4};Pooling={5};DefaultCommandTimeout={9};Connect Timeout={10};MinimumPoolSize={6};maximumpoolsize={7};AllowUserVariables=true;ConnectionReset={8};", - db_host, db_name, db_username, db_password, db_port, db_use_pooling, db_min_poolsize, db_max_poolsize, db_conn_reset, db_command_timeout, db_connect_timeout))) + using (var connection = new MySqlConnection(GetConnectionString())) { - connection.Open(); + await connection.OpenAsync().ConfigureAwait(false); try { @@ -2124,49 +2261,30 @@ public async Task Query(string commandStr, Dictionary Query(string commandStr, Dictionary g_dictInitialExeCRCs = new(); + public static ConcurrentDictionary g_dictInitialExeCRCs = new(); public static void RegisterInitialPlayerExeCRC(Int64 user_id, string exe_crc) { g_dictInitialExeCRCs[user_id] = exe_crc; @@ -111,7 +112,11 @@ enum EBotAction public DiscordBot() { #if !DEBUG - InitAsync(); + _ = InitAsync().ContinueWith(t => + { + if (t.IsFaulted) + Console.WriteLine("Discord initialization failed: " + t.Exception); + }, TaskContinuationOptions.OnlyOnFaulted); #endif } @@ -123,18 +128,40 @@ public DiscordBot() } } - public async void SendNetworkRoomChat(int roomID, Int64 userID, string strDisplayName, string strMessage) + public async Task SendNetworkRoomChat(int roomID, Int64 userID, string strDisplayName, string strMessage) { try { - string strFormattedChatMsg = String.Format("[{0} - UID {1}] {2}", strDisplayName, userID, strMessage); + if (Program.g_Config == null) + { + return; + } - ISocketMessageChannel? channel = GetChannel(EDiscordChannelIDs.NetworkRoomChat); - if (channel != null) + IConfiguration? discordSettings = Program.g_Config.GetSection("Discord"); + + if (discordSettings == null) + { + return; + } + + bool discord_send_room_chat_to_discord = discordSettings.GetValue("send_room_chat_to_discord"); + + if (discord_send_room_chat_to_discord == null) + { + return; + } + + if (discord_send_room_chat_to_discord) { - string strDiscordMsg = String.Format("[NETWORK ROOM CHAT ID #{0}] {1}", roomID, strFormattedChatMsg); - await channel.SendMessageAsync(strDiscordMsg).ConfigureAwait(true); - } + string strFormattedChatMsg = String.Format("[{0} - UID {1}] {2}", strDisplayName, userID, strMessage); + + ISocketMessageChannel? channel = GetChannel(EDiscordChannelIDs.NetworkRoomChat); + if (channel != null) + { + string strDiscordMsg = String.Format("[NETWORK ROOM CHAT ID #{0}] {1}", roomID, strFormattedChatMsg); + await channel.SendMessageAsync(strDiscordMsg).ConfigureAwait(true); + } + } } catch { @@ -693,7 +720,7 @@ private static Task LogAsync(LogMessage log) return Task.CompletedTask; } - private async void InitAsync() + private async Task InitAsync() { #if !DEBUG || USE_DISCORD_IN_DEBUG DiscordSocketConfig conf = new(); @@ -707,15 +734,21 @@ private async void InitAsync() //1354979004507226294 - // TODO_GITHUB: You should replace the below with your debug key, and also set the environment variable on your server for your release token. These should be different for security purposes. -#if DEBUG - string Token = "TODO_GITHUB"; -#else - //string Token = Environment.GetEnvironmentVariable("DISCORD_BOT_TOKEN") ?? ""; - string Token = "TODO_GITHUB"; -#endif + IConfigurationSection? discordSettings = Program.g_Config.GetSection("Discord"); + + if (discordSettings == null) + { + throw new Exception("Discord section missing in config"); + } + + string? discordToken = discordSettings.GetValue("token"); + + if (discordToken == null) + { + throw new Exception("Discord Token missing in config"); + } - await discord.LoginAsync(TokenType.Bot, Token).ConfigureAwait(true); + await discord.LoginAsync(TokenType.Bot, discordToken).ConfigureAwait(true); await discord.StartAsync().ConfigureAwait(true); #else await Task.Delay(1).ConfigureAwait(true); @@ -728,7 +761,7 @@ public void PushDM(SocketUser user, string strMessage) { if (user != null) { - user.SendMessageAsync(strMessage); + user.SendMessageAsync(strMessage).ContinueWith(t => { }, TaskContinuationOptions.OnlyOnFaulted); } } catch @@ -776,7 +809,7 @@ public void PushChannelMessage(EDiscordChannelIDs channelID, string strMessage) ISocketMessageChannel? channel = GetChannel(channelID); if (channel != null) { - channel.SendMessageAsync(strMessage); + channel.SendMessageAsync(strMessage).ContinueWith(t => { }, TaskContinuationOptions.OnlyOnFaulted); } } catch diff --git a/GenOnlineService/LobbyManager.cs b/GenOnlineService/LobbyManager.cs index 73233e1..9a43137 100644 --- a/GenOnlineService/LobbyManager.cs +++ b/GenOnlineService/LobbyManager.cs @@ -329,7 +329,6 @@ public Lobby(Int64 lobby_id, UserSession owner, string name, ELobbyState state, { LobbyMember placeholderMember = new LobbyMember(this, null, -1, String.Empty, String.Empty, 0, -1, -1, -1, i < max_players ? EPlayerType.SLOT_OPEN : EPlayerType.SLOT_CLOSED, i, true); Members[i] = placeholderMember; - TimeMemberLeft[i] = DateTime.UnixEpoch; } } @@ -418,7 +417,7 @@ private void CalculateNextProbeTime(bool bIsFirstProbe) } } - public async void Tick() + public async Task Tick() { if (m_NextProbe != 0 && Environment.TickCount64 >= m_NextProbe) { @@ -475,57 +474,64 @@ public async void Tick() } } - Mutex g_Mutex = new(); + private readonly SemaphoreSlim g_SlotLock = new SemaphoreSlim(1, 1); public async Task AddMember(UserSession playerSession, string strDisplayName, UInt16 userPreferredPort, bool bHasMap, UserLobbyPreferences lobbyPrefs) { + LobbyMember? existingMember = GetMemberFromUserID(playerSession.m_UserID); + if (existingMember != null) // we're already in this lobby + { + return false; + } + // NOTE: AddMember is called async, so timing + slot determination could result in players being inserted in the same slot - g_Mutex.WaitOne(); - // find first open slot - bool bFoundSlot = false; - UInt16 slotIndex = 0; - foreach (var memberEntry in Members) + await g_SlotLock.WaitAsync(); + try { - if (memberEntry.SlotState == EPlayerType.SLOT_OPEN) + // find first open slot + bool bFoundSlot = false; + UInt16 slotIndex = 0; + foreach (var memberEntry in Members) { - // found a gap, use this slot index - bFoundSlot = true; - break; + if (memberEntry.SlotState == EPlayerType.SLOT_OPEN) + { + // found a gap, use this slot index + bFoundSlot = true; + break; + } + ++slotIndex; } - ++slotIndex; - } - if (!bFoundSlot) - { - g_Mutex.ReleaseMutex(); - return false; - } + if (!bFoundSlot) + { + return false; + } - // Check social requirements (dont allow blocked in, and check friends only) - // SOCIAL: If the lobby owner has source user blocked, remove the lobby - // NOTE: Only check this for custom match, quick match checks it during matchmaking bucket stage - if (LobbyType == ELobbyType.CustomGame) - { - UserSession? lobbyOwnerSession = WebSocketManager.GetDataFromUser(Owner); + // Check social requirements (dont allow blocked in, and check friends only) + // SOCIAL: If the lobby owner has source user blocked, remove the lobby + // NOTE: Only check this for custom match, quick match checks it during matchmaking bucket stage + if (LobbyType == ELobbyType.CustomGame) + { + UserSession? lobbyOwnerSession = WebSocketManager.GetDataFromUser(Owner); - if (lobbyOwnerSession != null) - { - // dont allow join if blocked - if (lobbyOwnerSession.GetSocialContainer().Blocked.Contains(playerSession.m_UserID)) + if (lobbyOwnerSession != null) { - return false; - } + // dont allow join if blocked + if (lobbyOwnerSession.GetSocialContainer().Blocked.Contains(playerSession.m_UserID)) + { + return false; + } - // check joinability - if (LobbyJoinability == ELobbyJoinability.FriendsOnly) - { - // If it's friends only, return false if they aren't friends - if (!lobbyOwnerSession.GetSocialContainer().Friends.Contains(playerSession.m_UserID)) + // check joinability + if (LobbyJoinability == ELobbyJoinability.FriendsOnly) { + // If it's friends only, return false if they aren't friends + if (!lobbyOwnerSession.GetSocialContainer().Friends.Contains(playerSession.m_UserID)) + { return false; + } } } } - } // de dupe names string strOriginalDisplayName = strDisplayName; @@ -595,6 +601,7 @@ public async Task AddMember(UserSession playerSession, string strDisplayNa } Members[slotIndex] = newMember; + TimeMemberLeft[playerSession.m_UserID] = DateTime.UnixEpoch; // leave network room we were in playerSession.UpdateSessionNetworkRoom(-1); @@ -639,12 +646,16 @@ public async Task AddMember(UserSession playerSession, string strDisplayNa // also update the lobby for everyone inside of it DirtyRetransmit(); - g_Mutex.ReleaseMutex(); Console.WriteLine("User {0} joined lobby {1}: {2} (Slot was {3})", playerSession.m_UserID, LobbyID, true, slotIndex); return true; + } + finally + { + g_SlotLock.Release(); + } } - public async void RemoveMember(LobbyMember member) + public async Task RemoveMember(LobbyMember member) { // TODO_LOBBY: Optimize this Int64 UserID = member.UserID; @@ -653,7 +664,7 @@ public async void RemoveMember(LobbyMember member) LobbyMember placeholderMember = new LobbyMember(this, null, -1, String.Empty, String.Empty, 0, -1, -1, -1, EPlayerType.SLOT_OPEN, member.SlotIndex, true); Members[member.SlotIndex] = placeholderMember; - TimeMemberLeft[member.SlotIndex] = DateTime.Now; + TimeMemberLeft[UserID] = DateTime.Now; // send signal to disconnect (only if not ingame, ingame we let the client handle it so a service disconnect doesnt end the game) if (State != ELobbyState.INGAME) @@ -730,7 +741,7 @@ public void DirtyRetransmit() m_bIsDirty = true; } - public async void DirtyRetransmitToSingleMember(Int64 targetUserID) + public async Task DirtyRetransmitToSingleMember(Int64 targetUserID) { var session = WebSocketManager.GetDataFromUser(targetUserID); if (session != null) @@ -868,7 +879,7 @@ public bool HadAIAtStart() return m_cachedAtStart_numAI > 0; } - public async void UpdateState(ELobbyState state) + public async Task UpdateState(ELobbyState state) { State = state; @@ -977,6 +988,7 @@ public void UpdateSlotIndex(UInt16 index) public EPlayerType SlotState { get; private set; } = 0; public UInt16 SlotIndex { get; private set; } = 0; public string Region { get; private set; } = "Unknown"; + public string MiddlewareUserID { get; private set; } = String.Empty; [JsonIgnore] // cant serialize refs private WeakReference CurrentLobby = new(null); @@ -1005,6 +1017,16 @@ public LobbyMember(Lobby owningLobby, UserSession? owningSession, Int64 UserID_i SlotState = SlotState_in; SlotIndex = SlotIndex_in; + // default slots are created with null + if (owningSession != null) + { + MiddlewareUserID = owningSession.GetMiddlewareID(); + } + else + { + MiddlewareUserID = String.Empty; + } + IsReady = false; Region = owningSession == null ? "Unknown" : owningSession.GetFullContinentName(); } @@ -1162,11 +1184,11 @@ public static async Task CreateLobby(UserSession owningSession, string st return newLobbyID; } - public static void Tick() + public static async Task Tick() { foreach (var kvPair in m_dictLobbies) { - kvPair.Value.Tick(); + await kvPair.Value.Tick(); } } @@ -1325,7 +1347,7 @@ public static List GetPlayerOwnedLobbies(Int64 userID) return lstLobbies; } - public static void LeaveSpecificLobby(Int64 userID, Int64 lobbyID) + public static async Task LeaveSpecificLobby(Int64 userID, Int64 lobbyID) { Lobby? targetLobby = GetLobby(lobbyID); if (targetLobby != null) @@ -1334,12 +1356,12 @@ public static void LeaveSpecificLobby(Int64 userID, Int64 lobbyID) if (memberEntry != null) { Console.WriteLine("User {0} Leave Specific Lobby", userID); - targetLobby.RemoveMember(memberEntry); + await targetLobby.RemoveMember(memberEntry); } } } - public static void LeaveAnyLobby(Int64 userID) + public static async Task LeaveAnyLobby(Int64 userID) { foreach (Lobby lobbyInst in m_dictLobbies.Values) { @@ -1347,7 +1369,7 @@ public static void LeaveAnyLobby(Int64 userID) if (member != null) { Console.WriteLine("User {0} Leave Any Lobby", userID); - lobbyInst.RemoveMember(member); + await lobbyInst.RemoveMember(member); } } } @@ -1357,7 +1379,7 @@ public static async Task DeleteLobby(Lobby lobby) if (lobby.State != ELobbyState.COMPLETE) { // make done - lobby.UpdateState(ELobbyState.COMPLETE); + await lobby.UpdateState(ELobbyState.COMPLETE); // attempt to commit it await Database.Functions.Lobby.CommitLobbyToMatchHistory(GlobalDatabaseInstance.g_Database, lobby); diff --git a/GenOnlineService/MatchmakingManager.cs b/GenOnlineService/MatchmakingManager.cs index 8dbc847..5ba3c62 100644 --- a/GenOnlineService/MatchmakingManager.cs +++ b/GenOnlineService/MatchmakingManager.cs @@ -332,7 +332,7 @@ public void DetermineMap(out string strMapName, out string strMapPath) // TODO_QUICKMATCH: Optimize this, it's inefficient - foreach (MatchmakingBucketMember member in m_lstMembers) + foreach (MatchmakingBucketMember member in m_lstMembers) { UserSession? memberSession = member.GetAssociatedSession(); if (memberSession != null) @@ -496,7 +496,7 @@ public bool CanMergeWithOtherBucket(MatchmakingBucket bucketToMerge) return true; } - public async void MergeWithOtherBucket(MatchmakingBucket bucketToMerge) + public async Task MergeWithOtherBucket(MatchmakingBucket bucketToMerge) { // copy over players foreach (MatchmakingBucketMember rhsMember in bucketToMerge.m_lstMembers) @@ -664,7 +664,7 @@ public Int64 GetLobbyID() Int64 m_LobbyID = -1; Int64 m_StartTime = -1; - public async void Tick() + public async Task Tick() { // TODO_QUICKMATCH: What if the playlist is null? is this even possible since we validated before creating the bucket if (g_Playlists.TryGetValue(PlaylistID, out Playlist? playlist)) @@ -894,7 +894,7 @@ await SendMatchmakingMessage(memberSession, Lobby? lobby = LobbyManager.GetLobby(m_LobbyID); if (lobby != null) { - lobby.UpdateState(ELobbyState.INGAME); + await lobby.UpdateState(ELobbyState.INGAME); } // destroy the bucket @@ -907,7 +907,8 @@ await SendMatchmakingMessage(memberSession, // TODO_MATCHMAKING: Delete buckets if participants becomes 0 } - private static ConcurrentDictionary> m_dictMatchmakingBuckets = new(); + // Using ConcurrentBag instead of ConcurrentList for lock-free bucket management + private static ConcurrentDictionary> m_dictMatchmakingBuckets = new(); // TODO_QUICKMATCH: Read from db or file private static Dictionary g_Playlists = new() @@ -943,22 +944,12 @@ await SendMatchmakingMessage(memberSession, new PlaylistMap("[RANK] AKAs Magic ZH v1", "[RANK] AKAs Magic ZH v1", true, 2), new PlaylistMap("[RANK] Arctic Arena ZH v1", "[RANK] Arctic Arena ZH v1", true, 2), new PlaylistMap("[RANK] Black Hell ZH v1", "[RANK] Black Hell ZH v1", true, 2), - new PlaylistMap("[RANK] Blossoming Valley ZH v1", "[RANK] Blossoming Valley ZH v1", true, 2), new PlaylistMap("[RANK] Blue Hole ZH v1", "[RANK] Blue Hole ZH v1", true, 2), new PlaylistMap("[RANK] Dammed Scorpion ZH v1", "[RANK] Dammed Scorpion ZH v1", true, 2), - new PlaylistMap("[RANK] Desolated District ZH v1", "[RANK] Desolated District ZH v1", true, 2), new PlaylistMap("[RANK] Drallim Desert ZH v2", "[RANK] Drallim Desert ZH v2", true, 2), - new PlaylistMap("[RANK] Egyptian Oasis ZH v1", "[RANK] Egyptian Oasis ZH v1", true, 2), new PlaylistMap("[RANK] Farmlands of the Fallen ZH v1", "[RANK] Farmlands of the Fallen ZH v1", true, 2), - new PlaylistMap("[RANK] Imminent Victory ZH v2", "[RANK] Imminent Victory ZH v2", true, 2), - new PlaylistMap("[RANK] Liquid Gold ZH v2", "[RANK] Liquid Gold ZH v2", true, 2), - new PlaylistMap("[RANK] Mountain Mayhem v2", "[RANK] Mountain Mayhem v2", true, 2), new PlaylistMap("[RANK] Sakura Forest II ZH v1", "[RANK] Sakura Forest II ZH v1", true, 2), - new PlaylistMap("[RANK] Snowy Drought ZH v5", "[RANK] Snowy Drought ZH v5", true, 2), - new PlaylistMap("[RANK] Sovereignty ZH v1", "[RANK] Sovereignty ZH v1", true, 2), - new PlaylistMap("[RANK] TD NoBugsCars ZH v1", "[RANK] TD NoBugsCars ZH v1", true, 2), - new PlaylistMap("[RANK] Vendetta ZH v1", "[RANK] Vendetta ZH v1", true, 2), - new PlaylistMap("[RANK] ZH Carrier is Over v2", "[RANK] ZH Carrier is Over v2", true, 2), + new PlaylistMap("[RANK] Sovereignty ZH v1", "[RANK] Sovereignty ZH v1", true, 2) } ) }, @@ -998,7 +989,7 @@ public static async Task Tick() { foreach (var kvPair in g_Playlists) { - m_dictMatchmakingBuckets.TryAdd(kvPair.Key, new ConcurrentList()); + m_dictMatchmakingBuckets.TryAdd(kvPair.Key, new ConcurrentBag()); } } @@ -1012,7 +1003,7 @@ public static async Task Tick() // if we've already been merged and are awaiting delayed deletion, dont process it anymore if (!lstBucketsMergedNeedingDeleted.Contains(mmBucket)) { - mmBucket.Tick(); + await mmBucket.Tick(); // try to merge with any other bucket within this playlist foreach (MatchmakingBucket mmBucketMergeCandidate in kvPair.Value) @@ -1024,7 +1015,7 @@ public static async Task Tick() { if (mmBucket.CanMergeWithOtherBucket(mmBucketMergeCandidate)) { - mmBucket.MergeWithOtherBucket(mmBucketMergeCandidate); + await mmBucket.MergeWithOtherBucket(mmBucketMergeCandidate); lstBucketsMergedNeedingDeleted.Add(mmBucketMergeCandidate); } @@ -1041,9 +1032,11 @@ public static async Task Tick() // cleanup any pending destruction (cannot do this in tick, collection will be modified) foreach (MatchmakingBucket bucket in m_lstBucketsPendingDeletion) { - if (m_dictMatchmakingBuckets.ContainsKey(bucket.PlaylistID)) + if (m_dictMatchmakingBuckets.TryGetValue(bucket.PlaylistID, out var bucketBag)) { - m_dictMatchmakingBuckets[bucket.PlaylistID].Remove(bucket); + // ConcurrentBag doesn't support Remove, so we filter and rebuild + var remainingBuckets = bucketBag.Where(b => b != bucket).ToList(); + m_dictMatchmakingBuckets[bucket.PlaylistID] = new ConcurrentBag(remainingBuckets); } } m_lstBucketsPendingDeletion.Clear(); diff --git a/GenOnlineService/Program.cs b/GenOnlineService/Program.cs index 8288280..abdf87c 100644 --- a/GenOnlineService/Program.cs +++ b/GenOnlineService/Program.cs @@ -47,31 +47,74 @@ namespace GenOnlineService { - public static class APIKeyHelpers + public static class IPHelpers { - public static bool ValidateKey(string strKey) + public static string NormalizeIP(string? ipAddress) { - if (Program.g_Config == null) + if (string.IsNullOrEmpty(ipAddress)) { - return false; + return "unknown"; } - // TODO_DISCORD: Cache this - IConfiguration? apiSettings = Program.g_Config.GetSection("API"); + if (System.Net.IPAddress.TryParse(ipAddress, out System.Net.IPAddress? addr)) + { + // Convert IPv6-mapped IPv4 (::ffff:127.0.0.1) to IPv4 (127.0.0.1) + if (addr.IsIPv4MappedToIPv6) + { + return addr.MapToIPv4().ToString(); + } + + // Treat all localhost addresses as 127.0.0.1 + if (System.Net.IPAddress.IsLoopback(addr)) + { + return "127.0.0.1"; + } - if (apiSettings == null) + return addr.ToString(); + } + + return ipAddress; + } + } + + public static class APIKeyHelpers + { + private static HashSet? s_cachedApiKeys = null; + private static readonly object s_cacheLock = new object(); + + public static bool ValidateKey(string strKey) + { + if (Program.g_Config == null) { return false; } - List? api_keys = apiSettings.GetSection("keys").Get>(); - if (api_keys == null) + // Use cached HashSet for O(1) lookup + if (s_cachedApiKeys == null) { - return false; + lock (s_cacheLock) + { + if (s_cachedApiKeys == null) + { + IConfiguration? apiSettings = Program.g_Config.GetSection("API"); + if (apiSettings == null) + { + return false; + } + + List? api_keys = apiSettings.GetSection("keys").Get>(); + if (api_keys == null) + { + return false; + } + + // Convert to HashSet and uppercase all keys for O(1) lookup + s_cachedApiKeys = new HashSet(api_keys.Select(k => k.ToUpper()), StringComparer.OrdinalIgnoreCase); + } + } } - // TODO: Optimize lookup - return api_keys.Contains(strKey.ToUpper()); + return s_cachedApiKeys.Contains(strKey); } } public static class CertHelpers @@ -138,7 +181,27 @@ protected override Task HandleAuthenticateAsync() string strUsername = parts[0]; string strPassword = parts[1]; - if (strUsername == "TODO_GITHUB" && strPassword == "TODO_GITHUB") + IConfigurationSection? monitorSettings = Program.g_Config.GetSection("Monitor"); + + if (monitorSettings == null) + { + throw new Exception("Monitor section missing in config"); + } + + string? monitorUsername = monitorSettings.GetValue("username"); + string? monitorPassword = monitorSettings.GetValue("password"); + + if (monitorUsername == null) + { + throw new Exception("Monitor Username missing in config"); + } + + if (monitorPassword == null) + { + throw new Exception("Monitor Password missing in config"); + } + + if (strUsername == monitorUsername && strPassword == monitorPassword) { var claims = new[] { new Claim(ClaimTypes.Name, strUsername), new Claim(ClaimTypes.Role, "Monitor") }; var identity = new ClaimsIdentity(claims, "MonitorToken"); @@ -198,14 +261,14 @@ public static Int64 GetUserID(ControllerBase controller) public static string GetDisplayName(ControllerBase controller) { // TODO: Handle not finding claims, it is a critical error - var first = controller.User.FindFirst(JwtRegisteredClaimNames.Address); + var first = controller.User.FindFirst(JwtRegisteredClaimNames.Name); return first != null ? first.Value : String.Empty; } public static string GetIPAddress(ControllerBase controller) { // TODO: Handle not finding claims, it is a critical error - var first = controller.User.FindFirst(JwtRegisteredClaimNames.Name); + var first = controller.User.FindFirst(JwtRegisteredClaimNames.Address); return first != null ? first.Value : String.Empty; } } @@ -214,7 +277,7 @@ public class Program { public static IConfiguration? g_Config = null; public static DiscordBot? g_Discord = null; - static async void DoCleanup(bool bStartup) + static async Task DoCleanup(bool bStartup) { await Database.Functions.Auth.Cleanup(GlobalDatabaseInstance.g_Database, bStartup); @@ -233,7 +296,8 @@ private static Task AdditionalValidation(TokenValidatedContext context) #pragma warning disable CS8602 // Dereference of a possibly null reference. (Appears to be erroronous flagging) #pragma warning disable CS8604 // null reference. (Appears to be erroronous flagging) - if (context.Principal.Claims.First() == null || !Int64.TryParse(context.Principal.Claims.First().Value, out Int64 userID)) + Claim? userIdClaim = context.Principal.FindFirst(ClaimTypes.NameIdentifier); + if (userIdClaim == null || !Int64.TryParse(userIdClaim.Value, out Int64 userID)) { context.Fail("Failed Validation #2"); } @@ -260,29 +324,42 @@ private static Task AdditionalValidation(TokenValidatedContext context) string strTypeClaim = firstType.Value; JwtTokenGenerator.ETokenType tokenType = (JwtTokenGenerator.ETokenType)Convert.ToInt32(strTypeClaim); - bool bIsLoginWithToken = context.Request.Path.ToString().ToLower().Contains("loginwithtoken"); - if (bIsLoginWithToken && tokenType != JwtTokenGenerator.ETokenType.Refresh) + + // Use claim-based validation instead of path-based to prevent bypass + if (tokenType == JwtTokenGenerator.ETokenType.Refresh) { - context.Fail("Failed Validation #5"); + bool bIsLoginWithToken = context.Request.Path.ToString().ToLower().Contains("loginwithtoken"); + if (!bIsLoginWithToken) + { + context.Fail("Failed Validation #5 - Refresh token used on non-refresh endpoint"); + } } - else if (!bIsLoginWithToken && tokenType != JwtTokenGenerator.ETokenType.Session) + else if (tokenType == JwtTokenGenerator.ETokenType.Session) + { + bool bIsLoginWithToken = context.Request.Path.ToString().ToLower().Contains("loginwithtoken"); + if (bIsLoginWithToken) + { + context.Fail("Failed Validation #6 - Session token used on refresh endpoint"); + } + } + else { - context.Fail("Failed Validation #6"); + context.Fail("Failed Validation #10 - Unknown token type"); } if (context.Principal.FindFirst(JwtRegisteredClaimNames.Address) == null) { context.Fail("Failed Validation #7"); } -#pragma warning restore CS8602 // Dereference of a possibly null reference. -#pragma warning restore CS8604 // Dereference of a possibly null reference. - /* + string strExpectedIP = context.Principal.FindFirst(JwtRegisteredClaimNames.Address).Value; - if (strExpectedIP != context.HttpContext.Connection.RemoteIpAddress.ToString()) + string currentIP = IPHelpers.NormalizeIP(context.HttpContext.Connection.RemoteIpAddress?.ToString()); + if (strExpectedIP != currentIP) { - context.Fail("Failed Validation #8"); + context.Fail("Failed Validation #8 - IP mismatch"); } - */ +#pragma warning restore CS8602 // Dereference of a possibly null reference. +#pragma warning restore CS8604 // Dereference of a possibly null reference. } catch { @@ -306,7 +383,7 @@ public enum ETokenType Session, Refresh } - + public string GenerateToken(string displayname, Int64 userID, string ipAddr, ETokenType tokenType, string client_id, bool bIsAdmin) { var jwtSettings = _configuration.GetSection("JwtSettings"); @@ -381,12 +458,15 @@ public static string GetWebSocketAddress(bool bSecure) return ws_address; } - public static void Main(string[] args) + public static async Task Main(string[] args) { #if !DEBUG AppDomain.CurrentDomain.UnhandledException += GlobalExceptionHandler; #endif + // Configure thread pool for better performance under load + ThreadPool.SetMinThreads(200, 200); + var builder = WebApplication.CreateBuilder(args); // Add services to the container. @@ -435,7 +515,7 @@ public static void Main(string[] args) options.AutoSessionTracking = true; }); } - + // create discord? var discordSettings = Program.g_Config.GetSection("Discord"); @@ -445,12 +525,12 @@ public static void Main(string[] args) g_Discord = new DiscordBot(); } - GlobalDatabaseInstance.g_Database.Initialize(); + await GlobalDatabaseInstance.g_Database.Initialize(); // do a cleanup on startup - DoCleanup(true); + await DoCleanup(true); + - builder.Services.AddRateLimiter(options => { @@ -473,6 +553,21 @@ public static void Main(string[] args) }); }); + builder.Services.AddCors(options => + { + options.AddDefaultPolicy(policy => + { + policy.WithOrigins( + "https://localhost:9000", + "http://localhost:9001", + "https://*.playgenerals.online" + ) + .AllowAnyHeader() + .AllowAnyMethod() + .AllowCredentials(); + }); + }); + var jwtSettings = builder.Configuration.GetSection("JwtSettings"); builder.Services.AddAuthentication(options => @@ -558,6 +653,30 @@ public static void Main(string[] args) })); }); + // Add in-memory caching for performance optimization + builder.Services.AddMemoryCache(options => + { + options.SizeLimit = 10000; // Limit cache entries + }); + + // Add response compression for bandwidth optimization (60-80% reduction) + builder.Services.AddResponseCompression(options => + { + options.EnableForHttps = true; + options.Providers.Add(); + options.Providers.Add(); + }); + + builder.Services.Configure(options => + { + options.Level = System.IO.Compression.CompressionLevel.Fastest; // Balance speed vs compression + }); + + builder.Services.Configure(options => + { + options.Level = System.IO.Compression.CompressionLevel.Fastest; + }); + // JSON options needed to avoid ASP.NET lower casing everything builder.Services.AddControllers().AddJsonOptions(options => { @@ -589,27 +708,27 @@ public static void Main(string[] args) return; } - if (!use_os_cert_store) // if not using the cert store, we need a pem and key - { - if (cert_pem_path == null) - { - Console.WriteLine("FATAL ERROR: cert_pem_path is not set in the config"); - Console.ReadKey(true); - return; - } + if (!use_os_cert_store) // if not using the cert store, we need a pem and key + { + if (cert_pem_path == null) + { + Console.WriteLine("FATAL ERROR: cert_pem_path is not set in the config"); + Console.ReadKey(true); + return; + } - if (cert_key_path == null) - { - Console.WriteLine("FATAL ERROR: cert_key_path is not set in the config"); - Console.ReadKey(true); - return; - } - } + if (cert_key_path == null) + { + Console.WriteLine("FATAL ERROR: cert_key_path is not set in the config"); + Console.ReadKey(true); + return; + } + } - //UInt16 port = coreSettings.GetValue("port"); + //UInt16 port = coreSettings.GetValue("port"); - bool bShouldUseOSCertSTore = (bool)use_os_cert_store; + bool bShouldUseOSCertSTore = (bool)use_os_cert_store; if (!bShouldUseOSCertSTore) { if (String.IsNullOrEmpty(cert_pem_path) || String.IsNullOrEmpty(cert_key_path)) @@ -622,7 +741,7 @@ public static void Main(string[] args) { //X509Certificate2 = CertHelpers.LoadPemWithPrivateKey(cert_pem_path, cert_key_path); - X509Certificate2 = X509Certificate2.CreateFromPemFile(cert_pem_path, cert_key_path); + X509Certificate2 = X509Certificate2.CreateFromPemFile(cert_pem_path, cert_key_path); if (X509Certificate2 == null) @@ -660,7 +779,7 @@ public static void Main(string[] args) { Console.WriteLine("ERROR: Failed to parse port from serverURI: " + serverURI); } - + // options @@ -689,6 +808,9 @@ public static void Main(string[] args) app.UseRateLimiter(); + // Enable response compression (must be early in pipeline) + app.UseResponseCompression(); + // websocket var webSocketOptions = new WebSocketOptions @@ -708,7 +830,7 @@ public static void Main(string[] args) } */ - app.Use((context, next) => + app.Use((context, next) => { context.Request.EnableBuffering(); return next(); @@ -716,10 +838,11 @@ public static void Main(string[] args) //app.UseHttpsRedirection(); + app.UseCors(); app.UseAuthentication(); app.UseAuthorization(); - Database.MySQLInstance.TestQuery(GlobalDatabaseInstance.g_Database).Wait(); + await Database.MySQLInstance.TestQuery(GlobalDatabaseInstance.g_Database); app.MapControllers(); @@ -728,27 +851,23 @@ public static void Main(string[] args) timerCleanup.AutoReset = false; timerCleanup.Elapsed += async (sender, e) => { - await WebSocketManager.CheckForTimeouts(); - - int numLobbies = LobbyManager.GetNumLobbies(); - StatsTracker.Update(numLobbies, WebSocketManager.GetUserDataCache().Count).Wait(); - - timerCleanup.Start(); + try + { + await WebSocketManager.CheckForTimeouts(); - LobbyManager.Cleanup(); + int numLobbies = LobbyManager.GetNumLobbies(); + await StatsTracker.Update(numLobbies, WebSocketManager.GetUserDataCache().Count); - // disconnect test - /* - bool bDisc = false; - if (bDisc) + await LobbyManager.Cleanup(); + } + catch (Exception ex) { - ChatSession? targetSession = GenOnlineService.WebSocketManager.GetSessionFromUser(2); - if (targetSession != null) - { - await GenOnlineService.WebSocketManager.DeleteSession(targetSession); - } + Console.WriteLine($"[timerCleanup] Exception: {ex}"); + } + finally + { + timerCleanup.Start(); } - */ }; timerCleanup.Start(); @@ -759,13 +878,21 @@ public static void Main(string[] args) { System.Timers.Timer timerTick = new System.Timers.Timer(5); // 5ms tick timerTick.AutoReset = false; - timerTick.Elapsed += (sender, e) => + timerTick.Elapsed += async (sender, e) => { - LobbyManager.Tick(); - - WebSocketManager.Tick(); - - timerTick.Start(); + try + { + await LobbyManager.Tick(); + await WebSocketManager.Tick(); + } + catch (Exception ex) + { + Console.WriteLine($"[timerTick lobby] Exception: {ex}"); + } + finally + { + timerTick.Start(); + } }; timerTick.Start(); } @@ -776,26 +903,67 @@ public static void Main(string[] args) timerTick.AutoReset = false; timerTick.Elapsed += async (sender, e) => { - await MatchmakingManager.Tick(); + try + { + await MatchmakingManager.Tick(); + } + catch (Exception ex) + { + Console.WriteLine($"[timerTick matchmaking] Exception: {ex}"); + } + finally + { + timerTick.Start(); + } + }; + timerTick.Start(); + } - timerTick.Start(); + // tick network rooms (done at lower frequency) + { + System.Timers.Timer timerTick = new System.Timers.Timer(1000); // 1s tick + timerTick.AutoReset = false; + timerTick.Elapsed += async (sender, e) => + { + try + { + await WebSocketManager.TickRoomMemberList(); + } + catch (Exception ex) + { + Console.WriteLine($"[timerTick rooms] Exception: {ex}"); + } + finally + { + timerTick.Start(); + } }; timerTick.Start(); } - // timer to save daily stats - { - System.Timers.Timer timerTick = new System.Timers.Timer(60000); // 60s tick - timerTick.AutoReset = false; - timerTick.Elapsed += async (sender, e) => - { - // save daily stats - await DailyStatsManager.SaveToDB(); - }; - timerTick.Start(); - } + // timer to save daily stats + { + System.Timers.Timer timerTick = new System.Timers.Timer(60000); // 60s tick + timerTick.AutoReset = false; + timerTick.Elapsed += async (sender, e) => + { + try + { + await DailyStatsManager.SaveToDB(); + } + catch (Exception ex) + { + Console.WriteLine($"[timerTick dailystats] Exception: {ex}"); + } + finally + { + timerTick.Start(); + } + }; + timerTick.Start(); + } - AppDomain.CurrentDomain.ProcessExit += (_, _) => + AppDomain.CurrentDomain.ProcessExit += (_, _) => { Console.ForegroundColor = ConsoleColor.Red; Console.WriteLine("EXIT REQUESTED!"); @@ -804,19 +972,16 @@ public static void Main(string[] args) // create a token g_tokenGenerator = new JwtTokenGenerator(builder.Configuration); - // load daily stats - // TODO_SOCIAL: await -#pragma warning disable CS4014 // Because this call is not awaited, execution of the current method continues before the call is completed - DailyStatsManager.LoadFromDB(); -#pragma warning restore CS4014 // Because this call is not awaited, execution of the current method continues before the call is completed + // load daily stats + await DailyStatsManager.LoadFromDB(); - app.Run(); + app.Run(); // shutdown BackgroundS3Uploader.Shutdown(); - } + } public static void ShowLogo() { diff --git a/GenOnlineService/appsettings.json b/GenOnlineService/appsettings.json index 405afc9..1d5ad80 100644 --- a/GenOnlineService/appsettings.json +++ b/GenOnlineService/appsettings.json @@ -71,4 +71,10 @@ "enabled": false, "dsn": "" }, + , + "Middleware": { + "jwks_endpoint": null, + "audience": null, + "issuer": null + } } \ No newline at end of file