diff --git a/MareSynchronosServer/MareSynchronosAuthService/Controllers/AuthControllerBase.cs b/MareSynchronosServer/MareSynchronosAuthService/Controllers/AuthControllerBase.cs index b7e59fa..30ccf7b 100644 --- a/MareSynchronosServer/MareSynchronosAuthService/Controllers/AuthControllerBase.cs +++ b/MareSynchronosServer/MareSynchronosAuthService/Controllers/AuthControllerBase.cs @@ -21,13 +21,13 @@ public abstract class AuthControllerBase : Controller protected readonly ILogger Logger; protected readonly IHttpContextAccessor HttpAccessor; protected readonly IConfigurationService Configuration; - protected readonly MareDbContext MareDbContext; + protected readonly IDbContextFactory MareDbContextFactory; protected readonly SecretKeyAuthenticatorService SecretKeyAuthenticatorService; private readonly IRedisDatabase _redis; private readonly GeoIPService _geoIPProvider; protected AuthControllerBase(ILogger logger, - IHttpContextAccessor accessor, MareDbContext mareDbContext, + IHttpContextAccessor accessor, IDbContextFactory mareDbContext, SecretKeyAuthenticatorService secretKeyAuthenticatorService, IConfigurationService configuration, IRedisDatabase redisDb, GeoIPService geoIPProvider) @@ -36,14 +36,14 @@ public abstract class AuthControllerBase : Controller HttpAccessor = accessor; _redis = redisDb; _geoIPProvider = geoIPProvider; - MareDbContext = mareDbContext; + MareDbContextFactory = mareDbContext; SecretKeyAuthenticatorService = secretKeyAuthenticatorService; Configuration = configuration; } - protected async Task GenericAuthResponse(string charaIdent, SecretKeyAuthReply authResult) + protected async Task GenericAuthResponse(MareDbContext dbContext, string charaIdent, SecretKeyAuthReply authResult) { - if (await IsIdentBanned(charaIdent)) + if (await IsIdentBanned(dbContext, charaIdent)) { Logger.LogWarning("Authenticate:IDENTBAN:{id}:{ident}", authResult.Uid, charaIdent); return Unauthorized("Your XIV service account is banned from using the service."); @@ -114,9 +114,10 @@ public abstract class AuthControllerBase : Controller protected async Task EnsureBan(string uid, string? primaryUid, string charaIdent) { - if (!MareDbContext.BannedUsers.Any(c => c.CharacterIdentification == charaIdent)) + using var dbContext = await MareDbContextFactory.CreateDbContextAsync(); + if (!dbContext.BannedUsers.Any(c => c.CharacterIdentification == charaIdent)) { - MareDbContext.BannedUsers.Add(new Banned() + dbContext.BannedUsers.Add(new Banned() { CharacterIdentification = charaIdent, Reason = "Autobanned CharacterIdent (" + uid + ")", @@ -125,35 +126,35 @@ public abstract class AuthControllerBase : Controller var uidToLookFor = primaryUid ?? uid; - var primaryUserAuth = await MareDbContext.Auth.FirstAsync(f => f.UserUID == uidToLookFor); + var primaryUserAuth = await dbContext.Auth.FirstAsync(f => f.UserUID == uidToLookFor); primaryUserAuth.MarkForBan = false; primaryUserAuth.IsBanned = true; - var lodestone = await MareDbContext.LodeStoneAuth.Include(a => a.User).FirstOrDefaultAsync(c => c.User.UID == uidToLookFor); + var lodestone = await dbContext.LodeStoneAuth.Include(a => a.User).FirstOrDefaultAsync(c => c.User.UID == uidToLookFor); if (lodestone != null) { - if (!MareDbContext.BannedRegistrations.Any(c => c.DiscordIdOrLodestoneAuth == lodestone.HashedLodestoneId)) + if (!dbContext.BannedRegistrations.Any(c => c.DiscordIdOrLodestoneAuth == lodestone.HashedLodestoneId)) { - MareDbContext.BannedRegistrations.Add(new BannedRegistrations() + dbContext.BannedRegistrations.Add(new BannedRegistrations() { DiscordIdOrLodestoneAuth = lodestone.HashedLodestoneId, }); } - if (!MareDbContext.BannedRegistrations.Any(c => c.DiscordIdOrLodestoneAuth == lodestone.DiscordId.ToString())) + if (!dbContext.BannedRegistrations.Any(c => c.DiscordIdOrLodestoneAuth == lodestone.DiscordId.ToString())) { - MareDbContext.BannedRegistrations.Add(new BannedRegistrations() + dbContext.BannedRegistrations.Add(new BannedRegistrations() { DiscordIdOrLodestoneAuth = lodestone.DiscordId.ToString(), }); } } - await MareDbContext.SaveChangesAsync(); + await dbContext.SaveChangesAsync(); } - protected async Task IsIdentBanned(string charaIdent) + protected async Task IsIdentBanned(MareDbContext dbContext, string charaIdent) { - return await MareDbContext.BannedUsers.AsNoTracking().AnyAsync(u => u.CharacterIdentification == charaIdent).ConfigureAwait(false); + return await dbContext.BannedUsers.AsNoTracking().AnyAsync(u => u.CharacterIdentification == charaIdent).ConfigureAwait(false); } } diff --git a/MareSynchronosServer/MareSynchronosAuthService/Controllers/JwtController.cs b/MareSynchronosServer/MareSynchronosAuthService/Controllers/JwtController.cs index 2424cfd..9aa5980 100644 --- a/MareSynchronosServer/MareSynchronosAuthService/Controllers/JwtController.cs +++ b/MareSynchronosServer/MareSynchronosAuthService/Controllers/JwtController.cs @@ -16,7 +16,7 @@ namespace MareSynchronosAuthService.Controllers; public class JwtController : AuthControllerBase { public JwtController(ILogger logger, - IHttpContextAccessor accessor, MareDbContext mareDbContext, + IHttpContextAccessor accessor, IDbContextFactory mareDbContext, SecretKeyAuthenticatorService secretKeyAuthenticatorService, IConfigurationService configuration, IRedisDatabase redisDb, GeoIPService geoIPProvider) @@ -29,28 +29,30 @@ public class JwtController : AuthControllerBase [HttpPost(MareAuth.Auth_CreateIdent)] public async Task CreateToken(string auth, string charaIdent) { - return await AuthenticateInternal(auth, charaIdent).ConfigureAwait(false); + using var dbContext = await MareDbContextFactory.CreateDbContextAsync(); + return await AuthenticateInternal(dbContext, auth, charaIdent).ConfigureAwait(false); } [Authorize(Policy = "Authenticated")] [HttpGet(MareAuth.Auth_RenewToken)] public async Task RenewToken() { + using var dbContext = await MareDbContextFactory.CreateDbContextAsync(); try { var uid = HttpContext.User.Claims.Single(p => string.Equals(p.Type, MareClaimTypes.Uid, StringComparison.Ordinal))!.Value; var ident = HttpContext.User.Claims.Single(p => string.Equals(p.Type, MareClaimTypes.CharaIdent, StringComparison.Ordinal))!.Value; var alias = HttpContext.User.Claims.SingleOrDefault(p => string.Equals(p.Type, MareClaimTypes.Alias))?.Value ?? string.Empty; - if (await MareDbContext.Auth.Where(u => u.UserUID == uid || u.PrimaryUserUID == uid).AnyAsync(a => a.MarkForBan)) + if (await dbContext.Auth.Where(u => u.UserUID == uid || u.PrimaryUserUID == uid).AnyAsync(a => a.MarkForBan)) { - var userAuth = await MareDbContext.Auth.SingleAsync(u => u.UserUID == uid); + var userAuth = await dbContext.Auth.SingleAsync(u => u.UserUID == uid); await EnsureBan(uid, userAuth.PrimaryUserUID, ident); return Unauthorized("Your Mare account is banned."); } - if (await IsIdentBanned(ident)) + if (await IsIdentBanned(dbContext, ident)) { return Unauthorized("Your XIV service account is banned from using the service."); } @@ -65,7 +67,7 @@ public class JwtController : AuthControllerBase } } - protected async Task AuthenticateInternal(string auth, string charaIdent) + protected async Task AuthenticateInternal(MareDbContext dbContext, string auth, string charaIdent) { try { @@ -76,7 +78,7 @@ public class JwtController : AuthControllerBase var authResult = await SecretKeyAuthenticatorService.AuthorizeAsync(ip, auth); - return await GenericAuthResponse(charaIdent, authResult); + return await GenericAuthResponse(dbContext, charaIdent, authResult); } catch (Exception ex) { diff --git a/MareSynchronosServer/MareSynchronosAuthService/Controllers/OAuthController.cs b/MareSynchronosServer/MareSynchronosAuthService/Controllers/OAuthController.cs index 8deef0d..a0e3cf9 100644 --- a/MareSynchronosServer/MareSynchronosAuthService/Controllers/OAuthController.cs +++ b/MareSynchronosServer/MareSynchronosAuthService/Controllers/OAuthController.cs @@ -26,7 +26,7 @@ public class OAuthController : AuthControllerBase private static readonly ConcurrentDictionary _cookieOAuthResponse = []; public OAuthController(ILogger logger, - IHttpContextAccessor accessor, MareDbContext mareDbContext, + IHttpContextAccessor accessor, IDbContextFactory mareDbContext, SecretKeyAuthenticatorService secretKeyAuthenticatorService, IConfigurationService configuration, IRedisDatabase redisDb, GeoIPService geoIPProvider) @@ -135,7 +135,9 @@ public class OAuthController : AuthControllerBase if (discordUserId == 0) return BadRequest("Failed to get Discord ID from login token"); - var mareUser = await MareDbContext.LodeStoneAuth.Include(u => u.User).SingleOrDefaultAsync(u => u.DiscordId == discordUserId); + using var dbContext = await MareDbContextFactory.CreateDbContextAsync(); + + var mareUser = await dbContext.LodeStoneAuth.Include(u => u.User).SingleOrDefaultAsync(u => u.DiscordId == discordUserId); if (mareUser == null) return BadRequest("Could not find a Mare user associated to this Discord account."); @@ -213,11 +215,12 @@ public class OAuthController : AuthControllerBase public async Task> GetAvailableUIDs() { string primaryUid = HttpContext.User.Claims.Single(c => string.Equals(c.Type, MareClaimTypes.Uid, StringComparison.Ordinal))!.Value; + using var dbContext = await MareDbContextFactory.CreateDbContextAsync(); - var mareUser = await MareDbContext.Auth.AsNoTracking().Include(u => u.User).FirstOrDefaultAsync(f => f.UserUID == primaryUid).ConfigureAwait(false); + var mareUser = await dbContext.Auth.AsNoTracking().Include(u => u.User).FirstOrDefaultAsync(f => f.UserUID == primaryUid).ConfigureAwait(false); if (mareUser == null || mareUser.User == null) return []; var uid = mareUser.User.UID; - var allUids = await MareDbContext.Auth.AsNoTracking().Include(u => u.User).Where(a => a.UserUID == uid || a.PrimaryUserUID == uid).ToListAsync().ConfigureAwait(false); + var allUids = await dbContext.Auth.AsNoTracking().Include(u => u.User).Where(a => a.UserUID == uid || a.PrimaryUserUID == uid).ToListAsync().ConfigureAwait(false); var result = allUids.OrderBy(u => u.UserUID == uid ? 0 : 1).ThenBy(u => u.UserUID).Select(u => (u.UserUID, u.User.Alias)).ToDictionary(); return result; } @@ -226,10 +229,12 @@ public class OAuthController : AuthControllerBase [HttpPost(MareAuth.OAuth_CreateOAuth)] public async Task CreateTokenWithOAuth(string uid, string charaIdent) { - return await AuthenticateOAuthInternal(uid, charaIdent); + using var dbContext = await MareDbContextFactory.CreateDbContextAsync(); + + return await AuthenticateOAuthInternal(dbContext, uid, charaIdent); } - private async Task AuthenticateOAuthInternal(string requestedUid, string charaIdent) + private async Task AuthenticateOAuthInternal(MareDbContext dbContext, string requestedUid, string charaIdent) { try { @@ -241,7 +246,7 @@ public class OAuthController : AuthControllerBase var authResult = await SecretKeyAuthenticatorService.AuthorizeOauthAsync(ip, primaryUid, requestedUid); - return await GenericAuthResponse(charaIdent, authResult); + return await GenericAuthResponse(dbContext, charaIdent, authResult); } catch (Exception ex) {