use dbcontext factory?

This commit is contained in:
Stanley Dimant
2024-10-29 14:14:11 +01:00
parent 74633e1337
commit 19d819045c
3 changed files with 38 additions and 30 deletions

View File

@@ -21,13 +21,13 @@ public abstract class AuthControllerBase : Controller
protected readonly ILogger Logger; protected readonly ILogger Logger;
protected readonly IHttpContextAccessor HttpAccessor; protected readonly IHttpContextAccessor HttpAccessor;
protected readonly IConfigurationService<AuthServiceConfiguration> Configuration; protected readonly IConfigurationService<AuthServiceConfiguration> Configuration;
protected readonly MareDbContext MareDbContext; protected readonly IDbContextFactory<MareDbContext> MareDbContextFactory;
protected readonly SecretKeyAuthenticatorService SecretKeyAuthenticatorService; protected readonly SecretKeyAuthenticatorService SecretKeyAuthenticatorService;
private readonly IRedisDatabase _redis; private readonly IRedisDatabase _redis;
private readonly GeoIPService _geoIPProvider; private readonly GeoIPService _geoIPProvider;
protected AuthControllerBase(ILogger logger, protected AuthControllerBase(ILogger logger,
IHttpContextAccessor accessor, MareDbContext mareDbContext, IHttpContextAccessor accessor, IDbContextFactory<MareDbContext> mareDbContext,
SecretKeyAuthenticatorService secretKeyAuthenticatorService, SecretKeyAuthenticatorService secretKeyAuthenticatorService,
IConfigurationService<AuthServiceConfiguration> configuration, IConfigurationService<AuthServiceConfiguration> configuration,
IRedisDatabase redisDb, GeoIPService geoIPProvider) IRedisDatabase redisDb, GeoIPService geoIPProvider)
@@ -36,14 +36,14 @@ public abstract class AuthControllerBase : Controller
HttpAccessor = accessor; HttpAccessor = accessor;
_redis = redisDb; _redis = redisDb;
_geoIPProvider = geoIPProvider; _geoIPProvider = geoIPProvider;
MareDbContext = mareDbContext; MareDbContextFactory = mareDbContext;
SecretKeyAuthenticatorService = secretKeyAuthenticatorService; SecretKeyAuthenticatorService = secretKeyAuthenticatorService;
Configuration = configuration; Configuration = configuration;
} }
protected async Task<IActionResult> GenericAuthResponse(string charaIdent, SecretKeyAuthReply authResult) protected async Task<IActionResult> GenericAuthResponse(MareDbContext dbContext, string charaIdent, SecretKeyAuthReply authResult)
{ {
if (await IsIdentBanned(charaIdent)) if (await IsIdentBanned(dbContext, charaIdent))
{ {
Logger.LogWarning("Authenticate:IDENTBAN:{id}:{ident}", authResult.Uid, charaIdent); Logger.LogWarning("Authenticate:IDENTBAN:{id}:{ident}", authResult.Uid, charaIdent);
return Unauthorized("Your XIV service account is banned from using the service."); return Unauthorized("Your XIV service account is banned from using the service.");
@@ -114,9 +114,10 @@ public abstract class AuthControllerBase : Controller
protected async Task EnsureBan(string uid, string? primaryUid, string charaIdent) protected async Task EnsureBan(string uid, string? primaryUid, string charaIdent)
{ {
if (!MareDbContext.BannedUsers.Any(c => c.CharacterIdentification == charaIdent)) using var dbContext = await MareDbContextFactory.CreateDbContextAsync();
if (!dbContext.BannedUsers.Any(c => c.CharacterIdentification == charaIdent))
{ {
MareDbContext.BannedUsers.Add(new Banned() dbContext.BannedUsers.Add(new Banned()
{ {
CharacterIdentification = charaIdent, CharacterIdentification = charaIdent,
Reason = "Autobanned CharacterIdent (" + uid + ")", Reason = "Autobanned CharacterIdent (" + uid + ")",
@@ -125,35 +126,35 @@ public abstract class AuthControllerBase : Controller
var uidToLookFor = primaryUid ?? uid; var uidToLookFor = primaryUid ?? uid;
var primaryUserAuth = await MareDbContext.Auth.FirstAsync(f => f.UserUID == uidToLookFor); var primaryUserAuth = await dbContext.Auth.FirstAsync(f => f.UserUID == uidToLookFor);
primaryUserAuth.MarkForBan = false; primaryUserAuth.MarkForBan = false;
primaryUserAuth.IsBanned = true; primaryUserAuth.IsBanned = true;
var lodestone = await MareDbContext.LodeStoneAuth.Include(a => a.User).FirstOrDefaultAsync(c => c.User.UID == uidToLookFor); var lodestone = await dbContext.LodeStoneAuth.Include(a => a.User).FirstOrDefaultAsync(c => c.User.UID == uidToLookFor);
if (lodestone != null) if (lodestone != null)
{ {
if (!MareDbContext.BannedRegistrations.Any(c => c.DiscordIdOrLodestoneAuth == lodestone.HashedLodestoneId)) if (!dbContext.BannedRegistrations.Any(c => c.DiscordIdOrLodestoneAuth == lodestone.HashedLodestoneId))
{ {
MareDbContext.BannedRegistrations.Add(new BannedRegistrations() dbContext.BannedRegistrations.Add(new BannedRegistrations()
{ {
DiscordIdOrLodestoneAuth = lodestone.HashedLodestoneId, DiscordIdOrLodestoneAuth = lodestone.HashedLodestoneId,
}); });
} }
if (!MareDbContext.BannedRegistrations.Any(c => c.DiscordIdOrLodestoneAuth == lodestone.DiscordId.ToString())) if (!dbContext.BannedRegistrations.Any(c => c.DiscordIdOrLodestoneAuth == lodestone.DiscordId.ToString()))
{ {
MareDbContext.BannedRegistrations.Add(new BannedRegistrations() dbContext.BannedRegistrations.Add(new BannedRegistrations()
{ {
DiscordIdOrLodestoneAuth = lodestone.DiscordId.ToString(), DiscordIdOrLodestoneAuth = lodestone.DiscordId.ToString(),
}); });
} }
} }
await MareDbContext.SaveChangesAsync(); await dbContext.SaveChangesAsync();
} }
protected async Task<bool> IsIdentBanned(string charaIdent) protected async Task<bool> IsIdentBanned(MareDbContext dbContext, string charaIdent)
{ {
return await MareDbContext.BannedUsers.AsNoTracking().AnyAsync(u => u.CharacterIdentification == charaIdent).ConfigureAwait(false); return await dbContext.BannedUsers.AsNoTracking().AnyAsync(u => u.CharacterIdentification == charaIdent).ConfigureAwait(false);
} }
} }

View File

@@ -16,7 +16,7 @@ namespace MareSynchronosAuthService.Controllers;
public class JwtController : AuthControllerBase public class JwtController : AuthControllerBase
{ {
public JwtController(ILogger<JwtController> logger, public JwtController(ILogger<JwtController> logger,
IHttpContextAccessor accessor, MareDbContext mareDbContext, IHttpContextAccessor accessor, IDbContextFactory<MareDbContext> mareDbContext,
SecretKeyAuthenticatorService secretKeyAuthenticatorService, SecretKeyAuthenticatorService secretKeyAuthenticatorService,
IConfigurationService<AuthServiceConfiguration> configuration, IConfigurationService<AuthServiceConfiguration> configuration,
IRedisDatabase redisDb, GeoIPService geoIPProvider) IRedisDatabase redisDb, GeoIPService geoIPProvider)
@@ -29,28 +29,30 @@ public class JwtController : AuthControllerBase
[HttpPost(MareAuth.Auth_CreateIdent)] [HttpPost(MareAuth.Auth_CreateIdent)]
public async Task<IActionResult> CreateToken(string auth, string charaIdent) public async Task<IActionResult> CreateToken(string auth, string charaIdent)
{ {
return await AuthenticateInternal(auth, charaIdent).ConfigureAwait(false); using var dbContext = await MareDbContextFactory.CreateDbContextAsync();
return await AuthenticateInternal(dbContext, auth, charaIdent).ConfigureAwait(false);
} }
[Authorize(Policy = "Authenticated")] [Authorize(Policy = "Authenticated")]
[HttpGet(MareAuth.Auth_RenewToken)] [HttpGet(MareAuth.Auth_RenewToken)]
public async Task<IActionResult> RenewToken() public async Task<IActionResult> RenewToken()
{ {
using var dbContext = await MareDbContextFactory.CreateDbContextAsync();
try try
{ {
var uid = HttpContext.User.Claims.Single(p => string.Equals(p.Type, MareClaimTypes.Uid, StringComparison.Ordinal))!.Value; var uid = HttpContext.User.Claims.Single(p => string.Equals(p.Type, MareClaimTypes.Uid, StringComparison.Ordinal))!.Value;
var ident = HttpContext.User.Claims.Single(p => string.Equals(p.Type, MareClaimTypes.CharaIdent, StringComparison.Ordinal))!.Value; var ident = HttpContext.User.Claims.Single(p => string.Equals(p.Type, MareClaimTypes.CharaIdent, StringComparison.Ordinal))!.Value;
var alias = HttpContext.User.Claims.SingleOrDefault(p => string.Equals(p.Type, MareClaimTypes.Alias))?.Value ?? string.Empty; var alias = HttpContext.User.Claims.SingleOrDefault(p => string.Equals(p.Type, MareClaimTypes.Alias))?.Value ?? string.Empty;
if (await MareDbContext.Auth.Where(u => u.UserUID == uid || u.PrimaryUserUID == uid).AnyAsync(a => a.MarkForBan)) if (await dbContext.Auth.Where(u => u.UserUID == uid || u.PrimaryUserUID == uid).AnyAsync(a => a.MarkForBan))
{ {
var userAuth = await MareDbContext.Auth.SingleAsync(u => u.UserUID == uid); var userAuth = await dbContext.Auth.SingleAsync(u => u.UserUID == uid);
await EnsureBan(uid, userAuth.PrimaryUserUID, ident); await EnsureBan(uid, userAuth.PrimaryUserUID, ident);
return Unauthorized("Your Mare account is banned."); return Unauthorized("Your Mare account is banned.");
} }
if (await IsIdentBanned(ident)) if (await IsIdentBanned(dbContext, ident))
{ {
return Unauthorized("Your XIV service account is banned from using the service."); return Unauthorized("Your XIV service account is banned from using the service.");
} }
@@ -65,7 +67,7 @@ public class JwtController : AuthControllerBase
} }
} }
protected async Task<IActionResult> AuthenticateInternal(string auth, string charaIdent) protected async Task<IActionResult> AuthenticateInternal(MareDbContext dbContext, string auth, string charaIdent)
{ {
try try
{ {
@@ -76,7 +78,7 @@ public class JwtController : AuthControllerBase
var authResult = await SecretKeyAuthenticatorService.AuthorizeAsync(ip, auth); var authResult = await SecretKeyAuthenticatorService.AuthorizeAsync(ip, auth);
return await GenericAuthResponse(charaIdent, authResult); return await GenericAuthResponse(dbContext, charaIdent, authResult);
} }
catch (Exception ex) catch (Exception ex)
{ {

View File

@@ -26,7 +26,7 @@ public class OAuthController : AuthControllerBase
private static readonly ConcurrentDictionary<string, string> _cookieOAuthResponse = []; private static readonly ConcurrentDictionary<string, string> _cookieOAuthResponse = [];
public OAuthController(ILogger<OAuthController> logger, public OAuthController(ILogger<OAuthController> logger,
IHttpContextAccessor accessor, MareDbContext mareDbContext, IHttpContextAccessor accessor, IDbContextFactory<MareDbContext> mareDbContext,
SecretKeyAuthenticatorService secretKeyAuthenticatorService, SecretKeyAuthenticatorService secretKeyAuthenticatorService,
IConfigurationService<AuthServiceConfiguration> configuration, IConfigurationService<AuthServiceConfiguration> configuration,
IRedisDatabase redisDb, GeoIPService geoIPProvider) IRedisDatabase redisDb, GeoIPService geoIPProvider)
@@ -135,7 +135,9 @@ public class OAuthController : AuthControllerBase
if (discordUserId == 0) if (discordUserId == 0)
return BadRequest("Failed to get Discord ID from login token"); return BadRequest("Failed to get Discord ID from login token");
var mareUser = await MareDbContext.LodeStoneAuth.Include(u => u.User).SingleOrDefaultAsync(u => u.DiscordId == discordUserId); using var dbContext = await MareDbContextFactory.CreateDbContextAsync();
var mareUser = await dbContext.LodeStoneAuth.Include(u => u.User).SingleOrDefaultAsync(u => u.DiscordId == discordUserId);
if (mareUser == null) if (mareUser == null)
return BadRequest("Could not find a Mare user associated to this Discord account."); return BadRequest("Could not find a Mare user associated to this Discord account.");
@@ -213,11 +215,12 @@ public class OAuthController : AuthControllerBase
public async Task<Dictionary<string, string>> GetAvailableUIDs() public async Task<Dictionary<string, string>> GetAvailableUIDs()
{ {
string primaryUid = HttpContext.User.Claims.Single(c => string.Equals(c.Type, MareClaimTypes.Uid, StringComparison.Ordinal))!.Value; string primaryUid = HttpContext.User.Claims.Single(c => string.Equals(c.Type, MareClaimTypes.Uid, StringComparison.Ordinal))!.Value;
using var dbContext = await MareDbContextFactory.CreateDbContextAsync();
var mareUser = await MareDbContext.Auth.AsNoTracking().Include(u => u.User).FirstOrDefaultAsync(f => f.UserUID == primaryUid).ConfigureAwait(false); var mareUser = await dbContext.Auth.AsNoTracking().Include(u => u.User).FirstOrDefaultAsync(f => f.UserUID == primaryUid).ConfigureAwait(false);
if (mareUser == null || mareUser.User == null) return []; if (mareUser == null || mareUser.User == null) return [];
var uid = mareUser.User.UID; var uid = mareUser.User.UID;
var allUids = await MareDbContext.Auth.AsNoTracking().Include(u => u.User).Where(a => a.UserUID == uid || a.PrimaryUserUID == uid).ToListAsync().ConfigureAwait(false); var allUids = await dbContext.Auth.AsNoTracking().Include(u => u.User).Where(a => a.UserUID == uid || a.PrimaryUserUID == uid).ToListAsync().ConfigureAwait(false);
var result = allUids.OrderBy(u => u.UserUID == uid ? 0 : 1).ThenBy(u => u.UserUID).Select(u => (u.UserUID, u.User.Alias)).ToDictionary(); var result = allUids.OrderBy(u => u.UserUID == uid ? 0 : 1).ThenBy(u => u.UserUID).Select(u => (u.UserUID, u.User.Alias)).ToDictionary();
return result; return result;
} }
@@ -226,10 +229,12 @@ public class OAuthController : AuthControllerBase
[HttpPost(MareAuth.OAuth_CreateOAuth)] [HttpPost(MareAuth.OAuth_CreateOAuth)]
public async Task<IActionResult> CreateTokenWithOAuth(string uid, string charaIdent) public async Task<IActionResult> CreateTokenWithOAuth(string uid, string charaIdent)
{ {
return await AuthenticateOAuthInternal(uid, charaIdent); using var dbContext = await MareDbContextFactory.CreateDbContextAsync();
return await AuthenticateOAuthInternal(dbContext, uid, charaIdent);
} }
private async Task<IActionResult> AuthenticateOAuthInternal(string requestedUid, string charaIdent) private async Task<IActionResult> AuthenticateOAuthInternal(MareDbContext dbContext, string requestedUid, string charaIdent)
{ {
try try
{ {
@@ -241,7 +246,7 @@ public class OAuthController : AuthControllerBase
var authResult = await SecretKeyAuthenticatorService.AuthorizeOauthAsync(ip, primaryUid, requestedUid); var authResult = await SecretKeyAuthenticatorService.AuthorizeOauthAsync(ip, primaryUid, requestedUid);
return await GenericAuthResponse(charaIdent, authResult); return await GenericAuthResponse(dbContext, charaIdent, authResult);
} }
catch (Exception ex) catch (Exception ex)
{ {