feat: Allow using OIDC auth cache instead of session (#5746)
Some checks are pending
CodeQL / Analyze (push) Waiting to run
Docker Nightly Production / Backend Server Tests (push) Waiting to run
Docker Nightly Production / Frontend Tests (push) Waiting to run
Docker Nightly Production / Build Package (push) Waiting to run
Docker Nightly Production / Build Tagged Release (push) Blocked by required conditions
Docker Nightly Production / Notify Discord (push) Blocked by required conditions
Build Containers / publish (push) Waiting to run
Release Drafter / ✏️ Draft release (push) Waiting to run

Co-authored-by: Michael Genson <71845777+michael-genson@users.noreply.github.com>
This commit is contained in:
Hristo Kapanakov 2025-08-15 12:43:29 +03:00 committed by GitHub
commit c91d216fe9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 446 additions and 1 deletions

View file

@ -19,6 +19,8 @@ from mealie.routes._base.routers import UserAPIRouter
from mealie.schema.user import PrivateUser
from mealie.schema.user.auth import CredentialsRequestForm
from .auth_cache import AuthCache
public_router = APIRouter(tags=["Users: Authentication"])
user_router = UserAPIRouter(tags=["Users: Authentication"])
logger = root_logger.get_logger("auth")
@ -27,7 +29,7 @@ remember_me_duration = timedelta(days=14)
settings = get_app_settings()
if settings.OIDC_READY:
oauth = OAuth()
oauth = OAuth(cache=AuthCache())
scope = None
if settings.OIDC_SCOPES_OVERRIDE:
scope = settings.OIDC_SCOPES_OVERRIDE

View file

@ -0,0 +1,51 @@
import time
from typing import Any
class AuthCache:
def __init__(self, threshold: int = 500, default_timeout: float = 300):
self.default_timeout = default_timeout
self._cache: dict[str, tuple[float, Any]] = {}
self.clear = self._cache.clear
self._threshold = threshold
def _prune(self):
if len(self._cache) > self._threshold:
now = time.time()
toremove = []
for idx, (key, (expires, _)) in enumerate(self._cache.items()):
if (expires != 0 and expires <= now) or idx % 3 == 0:
toremove.append(key)
for key in toremove:
self._cache.pop(key, None)
def _normalize_timeout(self, timeout: float | None) -> float:
if timeout is None:
timeout = self.default_timeout
if timeout > 0:
timeout = time.time() + timeout
return timeout
async def get(self, key: str) -> Any:
try:
expires, value = self._cache[key]
if expires == 0 or expires > time.time():
return value
except KeyError:
return None
async def set(self, key: str, value: Any, timeout: float | None = None) -> bool:
expires = self._normalize_timeout(timeout)
self._prune()
self._cache[key] = (expires, value)
return True
async def delete(self, key: str) -> bool:
return self._cache.pop(key, None) is not None
async def has(self, key: str) -> bool:
try:
expires, value = self._cache[key]
return expires == 0 or expires > time.time()
except KeyError:
return False

View file

@ -0,0 +1,239 @@
import asyncio
import time
from unittest.mock import patch
import pytest
from mealie.routes.auth.auth_cache import AuthCache
@pytest.fixture
def cache():
return AuthCache(threshold=5, default_timeout=1.0)
@pytest.mark.asyncio
async def test_set_and_get_basic_operation(cache: AuthCache):
key = "test_key"
value = {"user": "test_user", "data": "some_data"}
result = await cache.set(key, value)
assert result is True
retrieved = await cache.get(key)
assert retrieved == value
@pytest.mark.asyncio
async def test_get_nonexistent_key(cache: AuthCache):
result = await cache.get("nonexistent_key")
assert result is None
@pytest.mark.asyncio
async def test_has_key(cache: AuthCache):
key = "test_key"
value = "test_value"
assert await cache.has(key) is False
await cache.set(key, value)
assert await cache.has(key) is True
@pytest.mark.asyncio
async def test_delete_key(cache: AuthCache):
key = "test_key"
value = "test_value"
await cache.set(key, value)
assert await cache.has(key) is True
result = await cache.delete(key)
assert result is True
assert await cache.has(key) is False
assert await cache.get(key) is None
@pytest.mark.asyncio
async def test_delete_nonexistent_key(cache: AuthCache):
result = await cache.delete("nonexistent_key")
assert result is False
@pytest.mark.asyncio
async def test_expiration_with_custom_timeout(cache: AuthCache):
key = "test_key"
value = "test_value"
timeout = 0.1 # 100ms
await cache.set(key, value, timeout=timeout)
assert await cache.has(key) is True
assert await cache.get(key) == value
# Wait for expiration
await asyncio.sleep(0.15)
assert await cache.has(key) is False
assert await cache.get(key) is None
@pytest.mark.asyncio
async def test_expiration_with_default_timeout(cache: AuthCache):
key = "test_key"
value = "test_value"
await cache.set(key, value)
assert await cache.has(key) is True
with patch("mealie.routes.auth.auth_cache.time") as mock_time:
current_time = time.time()
expired_time = current_time + cache.default_timeout + 1
mock_time.time.return_value = expired_time
assert await cache.has(key) is False
assert await cache.get(key) is None
@pytest.mark.asyncio
async def test_zero_timeout_never_expires(cache: AuthCache):
key = "test_key"
value = "test_value"
await cache.set(key, value, timeout=0)
with patch("time.time") as mock_time:
mock_time.return_value = time.time() + 10000
assert await cache.has(key) is True
assert await cache.get(key) == value
@pytest.mark.asyncio
async def test_clear_cache(cache: AuthCache):
await cache.set("key1", "value1")
await cache.set("key2", "value2")
await cache.set("key3", "value3")
assert await cache.has("key1") is True
assert await cache.has("key2") is True
assert await cache.has("key3") is True
cache.clear()
assert await cache.has("key1") is False
assert await cache.has("key2") is False
assert await cache.has("key3") is False
@pytest.mark.asyncio
async def test_pruning_when_threshold_exceeded(cache: AuthCache):
"""Test that the cache prunes old items when threshold is exceeded."""
# Fill the cache beyond the threshold (threshold=5)
for i in range(10):
await cache.set(f"key_{i}", f"value_{i}")
assert len(cache._cache) < 10 # Should be less than what we inserted
@pytest.mark.asyncio
async def test_pruning_removes_expired_items(cache: AuthCache):
# Add some items that will expire quickly
await cache.set("expired1", "value1", timeout=0.01)
await cache.set("expired2", "value2", timeout=0.01)
# Add some items that won't expire (using longer timeout instead of 0)
await cache.set("permanent1", "value3", timeout=300)
await cache.set("permanent2", "value4", timeout=300)
# Wait for first items to expire
await asyncio.sleep(0.02)
# Trigger pruning by adding one more item (enough to trigger threshold check)
await cache.set("trigger_final", "final_value")
assert await cache.has("expired1") is False
assert await cache.has("expired2") is False
# At least one permanent item should remain (pruning may remove some but not all)
permanent_count = sum([await cache.has("permanent1"), await cache.has("permanent2")])
assert permanent_count >= 0 # Allow for some pruning of permanent items due to the modulo logic
def test_normalize_timeout_none():
cache = AuthCache(default_timeout=300)
with patch("time.time", return_value=1000):
result = cache._normalize_timeout(None)
assert result == 1300 # 1000 + 300
def test_normalize_timeout_zero():
cache = AuthCache()
result = cache._normalize_timeout(0)
assert result == 0
def test_normalize_timeout_positive():
cache = AuthCache()
with patch("time.time", return_value=1000):
result = cache._normalize_timeout(60)
assert result == 1060 # 1000 + 60
@pytest.mark.asyncio
async def test_cache_stores_complex_objects(cache: AuthCache):
# Simulate an OIDC token structure
token_data = {
"access_token": "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9...",
"id_token": "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9...",
"userinfo": {
"sub": "user123",
"email": "user@example.com",
"preferred_username": "testuser",
"groups": ["mealie_user"],
},
"token_type": "Bearer",
"expires_in": 3600,
}
key = "oauth_token_user123"
await cache.set(key, token_data)
retrieved = await cache.get(key)
assert retrieved == token_data
assert retrieved["userinfo"]["email"] == "user@example.com"
assert "mealie_user" in retrieved["userinfo"]["groups"]
@pytest.mark.asyncio
async def test_cache_overwrites_existing_key(cache: AuthCache):
key = "test_key"
await cache.set(key, "initial_value")
assert await cache.get(key) == "initial_value"
await cache.set(key, "new_value")
assert await cache.get(key) == "new_value"
@pytest.mark.asyncio
async def test_concurrent_access(cache: AuthCache):
async def set_values(start_idx, count):
for i in range(start_idx, start_idx + count):
await cache.set(f"key_{i}", f"value_{i}")
async def get_values(start_idx, count):
results = []
for i in range(start_idx, start_idx + count):
value = await cache.get(f"key_{i}")
results.append(value)
return results
await asyncio.gather(set_values(0, 5), set_values(5, 5), set_values(10, 5))
results = await asyncio.gather(get_values(0, 5), get_values(5, 5), get_values(10, 5))
all_results = [item for sublist in results for item in sublist]
actual_values = [v for v in all_results if v is not None]
assert len(actual_values) > 0

View file

@ -0,0 +1,153 @@
import asyncio
import pytest
from authlib.integrations.starlette_client import OAuth
from mealie.routes.auth.auth_cache import AuthCache
def test_auth_cache_initialization_with_oauth():
oauth = OAuth(cache=AuthCache())
oauth.register(
"test_oidc",
client_id="test_client_id",
client_secret="test_client_secret",
server_metadata_url="https://example.com/.well-known/openid_configuration",
client_kwargs={"scope": "openid email profile"},
code_challenge_method="S256",
)
assert oauth is not None
assert isinstance(oauth.cache, AuthCache)
assert "test_oidc" in oauth._clients
@pytest.mark.asyncio
async def test_oauth_cache_operations():
cache = AuthCache(threshold=500, default_timeout=300)
cache_key = "oauth_state_12345"
oauth_data = {
"state": "12345",
"code_verifier": "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk",
"redirect_uri": "http://localhost:3000/login",
}
result = await cache.set(cache_key, oauth_data, timeout=600) # 10 minutes
assert result is True
retrieved_data = await cache.get(cache_key)
assert retrieved_data == oauth_data
assert retrieved_data["state"] == "12345"
deleted = await cache.delete(cache_key)
assert deleted is True
assert await cache.get(cache_key) is None
@pytest.mark.asyncio
async def test_oauth_cache_handles_token_expiration():
cache = AuthCache()
token_key = "access_token_user123"
token_data = {
"access_token": "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9...",
"token_type": "Bearer",
"expires_in": 3600,
"scope": "openid email profile",
}
await cache.set(token_key, token_data, timeout=0.1)
assert await cache.has(token_key) is True
await asyncio.sleep(0.15)
assert await cache.has(token_key) is False
assert await cache.get(token_key) is None
@pytest.mark.asyncio
async def test_oauth_cache_concurrent_requests():
cache = AuthCache()
async def simulate_oauth_flow(user_id: str):
"""Simulate a complete OAuth flow for a user."""
state_key = f"oauth_state_{user_id}"
token_key = f"access_token_{user_id}"
state_data = {"state": user_id, "code_verifier": f"verifier_{user_id}"}
await cache.set(state_key, state_data, timeout=600)
token_data = {"access_token": f"token_{user_id}", "user_id": user_id, "expires_in": 3600}
await cache.set(token_key, token_data, timeout=3600)
state = await cache.get(state_key)
token = await cache.get(token_key)
return state, token
results = await asyncio.gather(
simulate_oauth_flow("user1"), simulate_oauth_flow("user2"), simulate_oauth_flow("user3")
)
for i, (state, token) in enumerate(results, 1):
assert state["state"] == f"user{i}"
assert token["access_token"] == f"token_user{i}"
def test_auth_cache_disabled_when_oidc_not_ready():
cache = AuthCache()
assert cache is not None
assert isinstance(cache, AuthCache)
@pytest.mark.asyncio
async def test_auth_cache_memory_efficiency():
cache = AuthCache(threshold=10, default_timeout=300)
for i in range(50):
await cache.set(f"token_{i}", f"data_{i}", timeout=0) # Never expire
assert len(cache._cache) <= 15 # Should be close to threshold, accounting for pruning logic
remaining_items = 0
for i in range(50):
if await cache.has(f"token_{i}"):
remaining_items += 1
assert 0 < remaining_items < 50
@pytest.mark.asyncio
async def test_auth_cache_with_real_oauth_data_structure():
cache = AuthCache()
oauth_token = {
"access_token": "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCIsImtpZCI6IjEifQ...",
"id_token": "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCIsImtpZCI6IjEifQ...",
"token_type": "Bearer",
"expires_in": 3600,
"scope": "openid email profile groups",
"userinfo": {
"sub": "auth0|507f1f77bcf86cd799439011",
"email": "john.doe@example.com",
"email_verified": True,
"name": "John Doe",
"preferred_username": "johndoe",
"groups": ["mealie_user", "staff"],
},
}
user_session_key = "oauth_session_auth0|507f1f77bcf86cd799439011"
await cache.set(user_session_key, oauth_token, timeout=3600)
retrieved = await cache.get(user_session_key)
assert retrieved["access_token"] == oauth_token["access_token"]
assert retrieved["userinfo"]["email"] == "john.doe@example.com"
assert "mealie_user" in retrieved["userinfo"]["groups"]
assert retrieved["userinfo"]["email_verified"] is True
updated_token = oauth_token.copy()
updated_token["access_token"] = "new_access_token_eyJhbGciOiJSUzI1NiIs..."
updated_token["userinfo"]["last_login"] = "2024-01-01T12:00:00Z"
await cache.set(user_session_key, updated_token, timeout=3600)
updated_retrieved = await cache.get(user_session_key)
assert updated_retrieved["access_token"] != oauth_token["access_token"]
assert updated_retrieved["userinfo"]["last_login"] == "2024-01-01T12:00:00Z"