diff --git a/MareSynchronosServer/MareSynchronosShared/RequirementHandlers/ExistingUserRequirementHandler.cs b/MareSynchronosServer/MareSynchronosShared/RequirementHandlers/ExistingUserRequirementHandler.cs index 5eaba34..e5f7c26 100644 --- a/MareSynchronosServer/MareSynchronosShared/RequirementHandlers/ExistingUserRequirementHandler.cs +++ b/MareSynchronosServer/MareSynchronosShared/RequirementHandlers/ExistingUserRequirementHandler.cs @@ -3,12 +3,15 @@ using MareSynchronosShared.Utils; using Microsoft.AspNetCore.Authorization; using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.Logging; +using System.Collections.Concurrent; namespace MareSynchronosShared.RequirementHandlers; public class ExistingUserRequirementHandler : AuthorizationHandler { private readonly IDbContextFactory _dbContextFactory; private readonly ILogger _logger; + private readonly static ConcurrentDictionary _existingUserDict = []; + private readonly static ConcurrentDictionary _existingDiscordDict = []; public ExistingUserRequirementHandler(IDbContextFactory dbContext, ILogger logger) { @@ -18,21 +21,59 @@ public class ExistingUserRequirementHandler : AuthorizationHandler string.Equals(g.Type, MareClaimTypes.Uid, StringComparison.Ordinal))?.Value; - if (uid == null) context.Fail(); + try + { + var uid = context.User.Claims.SingleOrDefault(g => string.Equals(g.Type, MareClaimTypes.Uid, StringComparison.Ordinal))?.Value; + if (uid == null) + { + context.Fail(); + return; + } - var discordIdString = context.User.Claims.SingleOrDefault(g => string.Equals(g.Type, MareClaimTypes.DiscordId, StringComparison.Ordinal))?.Value; - if (discordIdString == null) context.Fail(); + var discordIdString = context.User.Claims.SingleOrDefault(g => string.Equals(g.Type, MareClaimTypes.DiscordId, StringComparison.Ordinal))?.Value; + if (discordIdString == null) + { + context.Fail(); + return; + } + if (!ulong.TryParse(discordIdString, out ulong discordId)) + { + context.Fail(); + return; + } - using var dbContext = await _dbContextFactory.CreateDbContextAsync().ConfigureAwait(false); - var user = await dbContext.Users.AsNoTracking().SingleOrDefaultAsync(b => b.UID == uid).ConfigureAwait(false); - if (user == null) context.Fail(); + using var dbContext = await _dbContextFactory.CreateDbContextAsync().ConfigureAwait(false); - if (!ulong.TryParse(discordIdString, out ulong discordId)) context.Fail(); + if (!_existingUserDict.TryGetValue(uid, out (bool Exists, DateTime LastCheck) existingUser) + || DateTime.UtcNow.Subtract(existingUser.LastCheck).TotalHours > 1) + { + var userExists = await dbContext.Users.SingleOrDefaultAsync(context => context.UID == uid).ConfigureAwait(false) != null; + _existingUserDict[uid] = existingUser = (userExists, DateTime.UtcNow); + } + if (!existingUser.Exists) + { + context.Fail(); + return; + } - var discordUser = await dbContext.LodeStoneAuth.AsNoTracking().SingleOrDefaultAsync(b => b.DiscordId == discordId).ConfigureAwait(false); - if (discordUser == null) context.Fail(); + if (!_existingDiscordDict.TryGetValue(discordId, out (bool Exists, DateTime LastCheck) existingDiscordUser) + || DateTime.UtcNow.Subtract(existingDiscordUser.LastCheck).TotalHours > 1) + { + var discordUserExists = await dbContext.LodeStoneAuth.AsNoTracking().SingleOrDefaultAsync(b => b.DiscordId == discordId).ConfigureAwait(false) != null; + _existingDiscordDict[discordId] = existingDiscordUser = (discordUserExists, DateTime.UtcNow); + } - context.Succeed(requirement); + if (!existingDiscordUser.Exists) + { + context.Fail(); + return; + } + + context.Succeed(requirement); + } + catch (Exception e) + { + _logger.LogWarning(e, "ExistingUserRequirementHandler failed"); + } } } \ No newline at end of file