add alias to jwt, remove caching from auth, remove db usage from files

This commit is contained in:
rootdarkarchon
2024-01-14 11:57:18 +01:00
parent 2c7ff6f73a
commit 63286127a2
7 changed files with 35 additions and 98 deletions

View File

@@ -1,3 +1,3 @@
namespace MareSynchronosServer.Authentication; namespace MareSynchronosServer.Authentication;
public record SecretKeyAuthReply(bool Success, string Uid, string PrimaryUid, bool TempBan, bool Permaban); public record SecretKeyAuthReply(bool Success, string Uid, string PrimaryUid, string Alias, bool TempBan, bool Permaban);

View File

@@ -10,30 +10,24 @@ namespace MareSynchronosServer.Authentication;
public class SecretKeyAuthenticatorService public class SecretKeyAuthenticatorService
{ {
private readonly MareMetrics _metrics; private readonly MareMetrics _metrics;
private readonly IServiceScopeFactory _serviceScopeFactory; private readonly IDbContextFactory<MareDbContext> _dbContextFactory;
private readonly IConfigurationService<MareConfigurationAuthBase> _configurationService; private readonly IConfigurationService<MareConfigurationAuthBase> _configurationService;
private readonly ILogger<SecretKeyAuthenticatorService> _logger; private readonly ILogger<SecretKeyAuthenticatorService> _logger;
private readonly ConcurrentDictionary<string, SecretKeyAuthReply> _cachedPositiveResponses = new(StringComparer.Ordinal);
private readonly ConcurrentDictionary<string, SecretKeyFailedAuthorization> _failedAuthorizations = new(StringComparer.Ordinal); private readonly ConcurrentDictionary<string, SecretKeyFailedAuthorization> _failedAuthorizations = new(StringComparer.Ordinal);
public SecretKeyAuthenticatorService(MareMetrics metrics, IServiceScopeFactory serviceScopeFactory, IConfigurationService<MareConfigurationAuthBase> configuration, ILogger<SecretKeyAuthenticatorService> logger) public SecretKeyAuthenticatorService(MareMetrics metrics, IDbContextFactory<MareDbContext> dbContextFactory,
IConfigurationService<MareConfigurationAuthBase> configuration, ILogger<SecretKeyAuthenticatorService> logger)
{ {
_logger = logger; _logger = logger;
_configurationService = configuration; _configurationService = configuration;
_metrics = metrics; _metrics = metrics;
_serviceScopeFactory = serviceScopeFactory; _dbContextFactory = dbContextFactory;
} }
public async Task<SecretKeyAuthReply> AuthorizeAsync(string ip, string hashedSecretKey) public async Task<SecretKeyAuthReply> AuthorizeAsync(string ip, string hashedSecretKey)
{ {
_metrics.IncCounter(MetricsAPI.CounterAuthenticationRequests); _metrics.IncCounter(MetricsAPI.CounterAuthenticationRequests);
if (_cachedPositiveResponses.TryGetValue(hashedSecretKey, out var cachedPositiveResponse))
{
_metrics.IncCounter(MetricsAPI.CounterAuthenticationCacheHits);
return cachedPositiveResponse;
}
if (_failedAuthorizations.TryGetValue(ip, out var existingFailedAuthorization) if (_failedAuthorizations.TryGetValue(ip, out var existingFailedAuthorization)
&& existingFailedAuthorization.FailedAttempts > _configurationService.GetValueOrDefault(nameof(MareConfigurationAuthBase.FailedAuthForTempBan), 5)) && existingFailedAuthorization.FailedAttempts > _configurationService.GetValueOrDefault(nameof(MareConfigurationAuthBase.FailedAuthForTempBan), 5))
{ {
@@ -50,12 +44,12 @@ public class SecretKeyAuthenticatorService
_failedAuthorizations.Remove(ip, out _); _failedAuthorizations.Remove(ip, out _);
}); });
} }
return new(Success: false, Uid: null, PrimaryUid: null, TempBan: true, Permaban: false); return new(Success: false, Uid: null, PrimaryUid: null, Alias: null, TempBan: true, Permaban: false);
} }
using var scope = _serviceScopeFactory.CreateScope(); using var context = await _dbContextFactory.CreateDbContextAsync().ConfigureAwait(false);
using var context = scope.ServiceProvider.GetService<MareDbContext>(); var authReply = await context.Auth.Include(a => a.User).AsNoTracking()
var authReply = await context.Auth.AsNoTracking().SingleOrDefaultAsync(u => u.HashedKey == hashedSecretKey).ConfigureAwait(false); .SingleOrDefaultAsync(u => u.HashedKey == hashedSecretKey).ConfigureAwait(false);
var isBanned = authReply?.IsBanned ?? false; var isBanned = authReply?.IsBanned ?? false;
var primaryUid = authReply?.PrimaryUserUID ?? authReply?.UserUID; var primaryUid = authReply?.PrimaryUserUID ?? authReply?.UserUID;
@@ -65,21 +59,13 @@ public class SecretKeyAuthenticatorService
isBanned = isBanned || primaryUser.IsBanned; isBanned = isBanned || primaryUser.IsBanned;
} }
SecretKeyAuthReply reply = new(authReply != null, authReply?.UserUID, authReply?.PrimaryUserUID ?? authReply?.UserUID, TempBan: false, isBanned); SecretKeyAuthReply reply = new(authReply != null, authReply?.UserUID,
authReply?.PrimaryUserUID ?? authReply?.UserUID, authReply.User.Alias ?? string.Empty, TempBan: false, isBanned);
if (reply.Success) if (reply.Success)
{ {
_metrics.IncCounter(MetricsAPI.CounterAuthenticationSuccesses); _metrics.IncCounter(MetricsAPI.CounterAuthenticationSuccesses);
_metrics.IncGauge(MetricsAPI.GaugeAuthenticationCacheEntries); _metrics.IncGauge(MetricsAPI.GaugeAuthenticationCacheEntries);
_cachedPositiveResponses[hashedSecretKey] = reply;
_ = Task.Run(async () =>
{
await Task.Delay(TimeSpan.FromMinutes(5)).ConfigureAwait(false);
_cachedPositiveResponses.TryRemove(hashedSecretKey, out _);
_metrics.DecGauge(MetricsAPI.GaugeAuthenticationCacheEntries);
});
} }
else else
{ {
@@ -107,6 +93,6 @@ public class SecretKeyAuthenticatorService
} }
} }
return new(Success: false, Uid: null, PrimaryUid: null, TempBan: false, Permaban: false); return new(Success: false, Uid: null, PrimaryUid: null, Alias: null, TempBan: false, Permaban: false);
} }
} }

View File

@@ -60,6 +60,7 @@ public class JwtController : Controller
{ {
var uid = HttpContext.User.Claims.Single(p => string.Equals(p.Type, MareClaimTypes.Uid, StringComparison.Ordinal))!.Value; var uid = HttpContext.User.Claims.Single(p => string.Equals(p.Type, MareClaimTypes.Uid, StringComparison.Ordinal))!.Value;
var ident = HttpContext.User.Claims.Single(p => string.Equals(p.Type, MareClaimTypes.CharaIdent, StringComparison.Ordinal))!.Value; var ident = HttpContext.User.Claims.Single(p => string.Equals(p.Type, MareClaimTypes.CharaIdent, StringComparison.Ordinal))!.Value;
var alias = HttpContext.User.Claims.SingleOrDefault(p => string.Equals(p.Type, MareClaimTypes.Alias))?.Value ?? string.Empty;
if (await _mareDbContext.Auth.Where(u => u.UserUID == uid || u.PrimaryUserUID == uid).AnyAsync(a => a.IsBanned)) if (await _mareDbContext.Auth.Where(u => u.UserUID == uid || u.PrimaryUserUID == uid).AnyAsync(a => a.IsBanned))
{ {
@@ -74,7 +75,7 @@ public class JwtController : Controller
} }
_logger.LogInformation("RenewToken:SUCCESS:{id}:{ident}", uid, ident); _logger.LogInformation("RenewToken:SUCCESS:{id}:{ident}", uid, ident);
return await CreateJwtFromId(uid, ident); return await CreateJwtFromId(uid, ident, alias);
} }
catch (Exception ex) catch (Exception ex)
{ {
@@ -108,7 +109,7 @@ public class JwtController : Controller
if (!authResult.Success && authResult.TempBan) if (!authResult.Success && authResult.TempBan)
{ {
_logger.LogWarning("Authenticate:TEMPBAN:{id}:{ident}", authResult.Uid ?? "NOUID", charaIdent); _logger.LogWarning("Authenticate:TEMPBAN:{id}:{ident}", authResult.Uid ?? "NOUID", charaIdent);
return Unauthorized("You are temporarily banned. Try connecting again in 5 minutes."); return Unauthorized("Due to an excessive amount of failed authentication attempts you are temporarily banned. Check your Secret Key configuration and try connecting again in 5 minutes.");
} }
if (authResult.Permaban) if (authResult.Permaban)
{ {
@@ -126,7 +127,7 @@ public class JwtController : Controller
} }
_logger.LogInformation("Authenticate:SUCCESS:{id}:{ident}", authResult.Uid, charaIdent); _logger.LogInformation("Authenticate:SUCCESS:{id}:{ident}", authResult.Uid, charaIdent);
return await CreateJwtFromId(authResult.Uid, charaIdent); return await CreateJwtFromId(authResult.Uid, charaIdent, authResult.Alias ?? string.Empty);
} }
catch (Exception ex) catch (Exception ex)
{ {
@@ -150,12 +151,13 @@ public class JwtController : Controller
return handler.CreateJwtSecurityToken(token); return handler.CreateJwtSecurityToken(token);
} }
private async Task<IActionResult> CreateJwtFromId(string uid, string charaIdent) private async Task<IActionResult> CreateJwtFromId(string uid, string charaIdent, string alias)
{ {
var token = CreateJwt(new List<Claim>() var token = CreateJwt(new List<Claim>()
{ {
new Claim(MareClaimTypes.Uid, uid), new Claim(MareClaimTypes.Uid, uid),
new Claim(MareClaimTypes.CharaIdent, charaIdent), new Claim(MareClaimTypes.CharaIdent, charaIdent),
new Claim(MareClaimTypes.Alias, alias),
new Claim(MareClaimTypes.Expires, DateTime.UtcNow.AddHours(6).Ticks.ToString(CultureInfo.InvariantCulture)), new Claim(MareClaimTypes.Expires, DateTime.UtcNow.AddHours(6).Ticks.ToString(CultureInfo.InvariantCulture)),
new Claim(MareClaimTypes.Continent, await _geoIPProvider.GetCountryFromIP(_accessor)) new Claim(MareClaimTypes.Continent, await _geoIPProvider.GetCountryFromIP(_accessor))
}); });

View File

@@ -3,6 +3,7 @@
public static class MareClaimTypes public static class MareClaimTypes
{ {
public const string Uid = "uid"; public const string Uid = "uid";
public const string Alias = "alias";
public const string CharaIdent = "character_identification"; public const string CharaIdent = "character_identification";
public const string Internal = "internal"; public const string Internal = "internal";
public const string Expires = "expiration_date"; public const string Expires = "expiration_date";

View File

@@ -14,4 +14,5 @@ public class ControllerBase : Controller
protected string MareUser => HttpContext.User.Claims.First(f => string.Equals(f.Type, MareClaimTypes.Uid, StringComparison.Ordinal)).Value; protected string MareUser => HttpContext.User.Claims.First(f => string.Equals(f.Type, MareClaimTypes.Uid, StringComparison.Ordinal)).Value;
protected string Continent => HttpContext.User.Claims.FirstOrDefault(f => string.Equals(f.Type, MareClaimTypes.Continent, StringComparison.Ordinal))?.Value ?? "*"; protected string Continent => HttpContext.User.Claims.FirstOrDefault(f => string.Equals(f.Type, MareClaimTypes.Continent, StringComparison.Ordinal))?.Value ?? "*";
protected bool IsPriority => !string.IsNullOrEmpty(HttpContext.User.Claims.FirstOrDefault(f => string.Equals(f.Type, MareClaimTypes.Alias, StringComparison.Ordinal))?.Value ?? string.Empty);
} }

View File

@@ -1,5 +1,4 @@
using MareSynchronos.API.Routes; using MareSynchronos.API.Routes;
using MareSynchronosShared.Data;
using MareSynchronosStaticFilesServer.Services; using MareSynchronosStaticFilesServer.Services;
using Microsoft.AspNetCore.Mvc; using Microsoft.AspNetCore.Mvc;
@@ -10,14 +9,12 @@ public class RequestController : ControllerBase
{ {
private readonly CachedFileProvider _cachedFileProvider; private readonly CachedFileProvider _cachedFileProvider;
private readonly RequestQueueService _requestQueue; private readonly RequestQueueService _requestQueue;
private readonly MareDbContext _mareDbContext;
private static readonly SemaphoreSlim _parallelRequestSemaphore = new(500); private static readonly SemaphoreSlim _parallelRequestSemaphore = new(500);
public RequestController(ILogger<RequestController> logger, CachedFileProvider cachedFileProvider, RequestQueueService requestQueue, MareDbContext mareDbContext) : base(logger) public RequestController(ILogger<RequestController> logger, CachedFileProvider cachedFileProvider, RequestQueueService requestQueue) : base(logger)
{ {
_cachedFileProvider = cachedFileProvider; _cachedFileProvider = cachedFileProvider;
_requestQueue = requestQueue; _requestQueue = requestQueue;
_mareDbContext = mareDbContext;
} }
[HttpGet] [HttpGet]
@@ -28,7 +25,7 @@ public class RequestController : ControllerBase
try try
{ {
_requestQueue.RemoveFromQueue(requestId, MareUser); _requestQueue.RemoveFromQueue(requestId, MareUser, IsPriority);
return Ok(); return Ok();
} }
catch (OperationCanceledException) { return BadRequest(); } catch (OperationCanceledException) { return BadRequest(); }
@@ -53,7 +50,7 @@ public class RequestController : ControllerBase
} }
Guid g = Guid.NewGuid(); Guid g = Guid.NewGuid();
await _requestQueue.EnqueueUser(new(g, MareUser, files.ToList()), _mareDbContext); _requestQueue.EnqueueUser(new(g, MareUser, files.ToList()), IsPriority);
return Ok(g); return Ok(g);
} }
@@ -72,8 +69,8 @@ public class RequestController : ControllerBase
try try
{ {
if (!await _requestQueue.StillEnqueued(requestId, MareUser, _mareDbContext)) if (!_requestQueue.StillEnqueued(requestId, MareUser, IsPriority))
await _requestQueue.EnqueueUser(new(requestId, MareUser, files.ToList()), _mareDbContext); _requestQueue.EnqueueUser(new(requestId, MareUser, files.ToList()), IsPriority);
return Ok(); return Ok();
} }
catch (OperationCanceledException) { return BadRequest(); } catch (OperationCanceledException) { return BadRequest(); }

View File

@@ -5,9 +5,6 @@ using Microsoft.AspNetCore.SignalR;
using System.Collections.Concurrent; using System.Collections.Concurrent;
using System.Timers; using System.Timers;
using MareSynchronos.API.SignalR; using MareSynchronos.API.SignalR;
using MareSynchronosShared.Data;
using Microsoft.EntityFrameworkCore;
using System.Linq;
namespace MareSynchronosStaticFilesServer.Services; namespace MareSynchronosStaticFilesServer.Services;
@@ -24,7 +21,6 @@ public class RequestQueueService : IHostedService
private readonly SemaphoreSlim _queueProcessingSemaphore = new(1); private readonly SemaphoreSlim _queueProcessingSemaphore = new(1);
private readonly SemaphoreSlim _queueSemaphore = new(1); private readonly SemaphoreSlim _queueSemaphore = new(1);
private readonly UserQueueEntry[] _userQueueRequests; private readonly UserQueueEntry[] _userQueueRequests;
private readonly ConcurrentDictionary<string, PriorityEntry> _priorityCache = new(StringComparer.Ordinal);
private int _queueLimitForReset; private int _queueLimitForReset;
private readonly int _queueReleaseSeconds; private readonly int _queueReleaseSeconds;
private System.Timers.Timer _queueTimer; private System.Timers.Timer _queueTimer;
@@ -48,49 +44,11 @@ public class RequestQueueService : IHostedService
req.MarkActive(); req.MarkActive();
} }
private async Task<bool> IsHighPriority(string uid, MareDbContext mareDbContext) public void EnqueueUser(UserRequest request, bool isPriority)
{
if (!_priorityCache.TryGetValue(uid, out PriorityEntry entry) || entry.LastChecked.Add(TimeSpan.FromHours(6)) < DateTime.UtcNow)
{
var user = await mareDbContext.Users.FirstOrDefaultAsync(u => u.UID == uid).ConfigureAwait(false);
entry = new(user != null && !string.IsNullOrEmpty(user.Alias), DateTime.UtcNow);
_priorityCache[uid] = entry;
}
return entry.IsHighPriority;
}
public async Task EnqueueUser(UserRequest request, MareDbContext mareDbContext)
{ {
_logger.LogDebug("Enqueueing req {guid} from {user} for {file}", request.RequestId, request.User, string.Join(", ", request.FileIds)); _logger.LogDebug("Enqueueing req {guid} from {user} for {file}", request.RequestId, request.User, string.Join(", ", request.FileIds));
bool isPriorityQueue = await IsHighPriority(request.User, mareDbContext).ConfigureAwait(false); GetQueue(isPriority).Enqueue(request);
if (_queueProcessingSemaphore.CurrentCount == 0)
{
if (isPriorityQueue) _priorityQueue.Enqueue(request);
else _queue.Enqueue(request);
return;
}
try
{
await _queueSemaphore.WaitAsync().ConfigureAwait(false);
if (isPriorityQueue) _priorityQueue.Enqueue(request);
else _queue.Enqueue(request);
return;
}
catch (Exception ex)
{
_logger.LogError(ex, "Error during EnqueueUser");
}
finally
{
_queueSemaphore.Release();
}
throw new Exception("Error during EnqueueUser");
} }
public void FinishRequest(Guid request) public void FinishRequest(Guid request)
@@ -115,10 +73,9 @@ public class RequestQueueService : IHostedService
return userQueueRequest != null && userRequest != null && userQueueRequest.ExpirationDate > DateTime.UtcNow; return userQueueRequest != null && userRequest != null && userQueueRequest.ExpirationDate > DateTime.UtcNow;
} }
public void RemoveFromQueue(Guid requestId, string user) public void RemoveFromQueue(Guid requestId, string user, bool isPriority)
{ {
var existingRequest = _priorityQueue.FirstOrDefault(f => f.RequestId == requestId && string.Equals(f.User, user, StringComparison.Ordinal)) var existingRequest = GetQueue(isPriority).FirstOrDefault(f => f.RequestId == requestId && string.Equals(f.User, user, StringComparison.Ordinal));
?? _queue.FirstOrDefault(f => f.RequestId == requestId && string.Equals(f.User, user, StringComparison.Ordinal));
if (existingRequest == null) if (existingRequest == null)
{ {
var activeSlot = _userQueueRequests.FirstOrDefault(r => r != null && string.Equals(r.UserRequest.User, user, StringComparison.Ordinal) && r.UserRequest.RequestId == requestId); var activeSlot = _userQueueRequests.FirstOrDefault(r => r != null && string.Equals(r.UserRequest.User, user, StringComparison.Ordinal) && r.UserRequest.RequestId == requestId);
@@ -146,14 +103,11 @@ public class RequestQueueService : IHostedService
return Task.CompletedTask; return Task.CompletedTask;
} }
public async Task<bool> StillEnqueued(Guid request, string user, MareDbContext mareDbContext) private ConcurrentQueue<UserRequest> GetQueue(bool isPriority) => isPriority ? _priorityQueue : _queue;
public bool StillEnqueued(Guid request, string user, bool isPriority)
{ {
bool isPriorityQueue = await IsHighPriority(user, mareDbContext).ConfigureAwait(false); return GetQueue(isPriority).Any(c => c.RequestId == request && string.Equals(c.User, user, StringComparison.Ordinal));
if (isPriorityQueue)
{
return _priorityQueue.Any(c => c.RequestId == request && string.Equals(c.User, user, StringComparison.Ordinal));
}
return _queue.Any(c => c.RequestId == request && string.Equals(c.User, user, StringComparison.Ordinal));
} }
public Task StopAsync(CancellationToken cancellationToken) public Task StopAsync(CancellationToken cancellationToken)
@@ -183,11 +137,7 @@ public class RequestQueueService : IHostedService
return; return;
} }
Parallel.For(0, _userQueueRequests.Length, new ParallelOptions() for (int i = 0; i < _userQueueRequests.Length; i++)
{
MaxDegreeOfParallelism = 10,
},
async (i) =>
{ {
try try
{ {
@@ -227,7 +177,7 @@ public class RequestQueueService : IHostedService
{ {
_logger.LogWarning(ex, "Error during inside queue processing"); _logger.LogWarning(ex, "Error during inside queue processing");
} }
}); }
} }
catch (Exception ex) catch (Exception ex)
{ {