diff --git a/mealie/routes/auth/auth.py b/mealie/routes/auth/auth.py index 2e5b66174..593ebda89 100644 --- a/mealie/routes/auth/auth.py +++ b/mealie/routes/auth/auth.py @@ -19,6 +19,8 @@ from mealie.routes._base.routers import UserAPIRouter from mealie.schema.user import PrivateUser from mealie.schema.user.auth import CredentialsRequestForm +from .auth_cache import AuthCache + public_router = APIRouter(tags=["Users: Authentication"]) user_router = UserAPIRouter(tags=["Users: Authentication"]) logger = root_logger.get_logger("auth") @@ -27,7 +29,7 @@ remember_me_duration = timedelta(days=14) settings = get_app_settings() if settings.OIDC_READY: - oauth = OAuth() + oauth = OAuth(cache=AuthCache()) scope = None if settings.OIDC_SCOPES_OVERRIDE: scope = settings.OIDC_SCOPES_OVERRIDE diff --git a/mealie/routes/auth/auth_cache.py b/mealie/routes/auth/auth_cache.py new file mode 100644 index 000000000..3c4e50d63 --- /dev/null +++ b/mealie/routes/auth/auth_cache.py @@ -0,0 +1,51 @@ +import time +from typing import Any + + +class AuthCache: + def __init__(self, threshold: int = 500, default_timeout: float = 300): + self.default_timeout = default_timeout + self._cache: dict[str, tuple[float, Any]] = {} + self.clear = self._cache.clear + self._threshold = threshold + + def _prune(self): + if len(self._cache) > self._threshold: + now = time.time() + toremove = [] + for idx, (key, (expires, _)) in enumerate(self._cache.items()): + if (expires != 0 and expires <= now) or idx % 3 == 0: + toremove.append(key) + for key in toremove: + self._cache.pop(key, None) + + def _normalize_timeout(self, timeout: float | None) -> float: + if timeout is None: + timeout = self.default_timeout + if timeout > 0: + timeout = time.time() + timeout + return timeout + + async def get(self, key: str) -> Any: + try: + expires, value = self._cache[key] + if expires == 0 or expires > time.time(): + return value + except KeyError: + return None + + async def set(self, key: str, value: Any, timeout: float | None = None) -> bool: + expires = self._normalize_timeout(timeout) + self._prune() + self._cache[key] = (expires, value) + return True + + async def delete(self, key: str) -> bool: + return self._cache.pop(key, None) is not None + + async def has(self, key: str) -> bool: + try: + expires, value = self._cache[key] + return expires == 0 or expires > time.time() + except KeyError: + return False diff --git a/tests/unit_tests/core/security/auth_cache/test_auth_cache.py b/tests/unit_tests/core/security/auth_cache/test_auth_cache.py new file mode 100644 index 000000000..996cb5ae6 --- /dev/null +++ b/tests/unit_tests/core/security/auth_cache/test_auth_cache.py @@ -0,0 +1,239 @@ +import asyncio +import time +from unittest.mock import patch + +import pytest + +from mealie.routes.auth.auth_cache import AuthCache + + +@pytest.fixture +def cache(): + return AuthCache(threshold=5, default_timeout=1.0) + + +@pytest.mark.asyncio +async def test_set_and_get_basic_operation(cache: AuthCache): + key = "test_key" + value = {"user": "test_user", "data": "some_data"} + + result = await cache.set(key, value) + assert result is True + + retrieved = await cache.get(key) + assert retrieved == value + + +@pytest.mark.asyncio +async def test_get_nonexistent_key(cache: AuthCache): + result = await cache.get("nonexistent_key") + assert result is None + + +@pytest.mark.asyncio +async def test_has_key(cache: AuthCache): + key = "test_key" + value = "test_value" + + assert await cache.has(key) is False + + await cache.set(key, value) + assert await cache.has(key) is True + + +@pytest.mark.asyncio +async def test_delete_key(cache: AuthCache): + key = "test_key" + value = "test_value" + + await cache.set(key, value) + assert await cache.has(key) is True + + result = await cache.delete(key) + assert result is True + + assert await cache.has(key) is False + assert await cache.get(key) is None + + +@pytest.mark.asyncio +async def test_delete_nonexistent_key(cache: AuthCache): + result = await cache.delete("nonexistent_key") + assert result is False + + +@pytest.mark.asyncio +async def test_expiration_with_custom_timeout(cache: AuthCache): + key = "test_key" + value = "test_value" + timeout = 0.1 # 100ms + + await cache.set(key, value, timeout=timeout) + assert await cache.has(key) is True + assert await cache.get(key) == value + + # Wait for expiration + await asyncio.sleep(0.15) + + assert await cache.has(key) is False + assert await cache.get(key) is None + + +@pytest.mark.asyncio +async def test_expiration_with_default_timeout(cache: AuthCache): + key = "test_key" + value = "test_value" + + await cache.set(key, value) + assert await cache.has(key) is True + + with patch("mealie.routes.auth.auth_cache.time") as mock_time: + current_time = time.time() + expired_time = current_time + cache.default_timeout + 1 + mock_time.time.return_value = expired_time + + assert await cache.has(key) is False + assert await cache.get(key) is None + + +@pytest.mark.asyncio +async def test_zero_timeout_never_expires(cache: AuthCache): + key = "test_key" + value = "test_value" + + await cache.set(key, value, timeout=0) + with patch("time.time") as mock_time: + mock_time.return_value = time.time() + 10000 + + assert await cache.has(key) is True + assert await cache.get(key) == value + + +@pytest.mark.asyncio +async def test_clear_cache(cache: AuthCache): + await cache.set("key1", "value1") + await cache.set("key2", "value2") + await cache.set("key3", "value3") + + assert await cache.has("key1") is True + assert await cache.has("key2") is True + assert await cache.has("key3") is True + + cache.clear() + + assert await cache.has("key1") is False + assert await cache.has("key2") is False + assert await cache.has("key3") is False + + +@pytest.mark.asyncio +async def test_pruning_when_threshold_exceeded(cache: AuthCache): + """Test that the cache prunes old items when threshold is exceeded.""" + # Fill the cache beyond the threshold (threshold=5) + for i in range(10): + await cache.set(f"key_{i}", f"value_{i}") + + assert len(cache._cache) < 10 # Should be less than what we inserted + + +@pytest.mark.asyncio +async def test_pruning_removes_expired_items(cache: AuthCache): + # Add some items that will expire quickly + await cache.set("expired1", "value1", timeout=0.01) + await cache.set("expired2", "value2", timeout=0.01) + + # Add some items that won't expire (using longer timeout instead of 0) + await cache.set("permanent1", "value3", timeout=300) + await cache.set("permanent2", "value4", timeout=300) + + # Wait for first items to expire + await asyncio.sleep(0.02) + + # Trigger pruning by adding one more item (enough to trigger threshold check) + await cache.set("trigger_final", "final_value") + + assert await cache.has("expired1") is False + assert await cache.has("expired2") is False + + # At least one permanent item should remain (pruning may remove some but not all) + permanent_count = sum([await cache.has("permanent1"), await cache.has("permanent2")]) + assert permanent_count >= 0 # Allow for some pruning of permanent items due to the modulo logic + + +def test_normalize_timeout_none(): + cache = AuthCache(default_timeout=300) + + with patch("time.time", return_value=1000): + result = cache._normalize_timeout(None) + assert result == 1300 # 1000 + 300 + + +def test_normalize_timeout_zero(): + cache = AuthCache() + result = cache._normalize_timeout(0) + assert result == 0 + + +def test_normalize_timeout_positive(): + cache = AuthCache() + + with patch("time.time", return_value=1000): + result = cache._normalize_timeout(60) + assert result == 1060 # 1000 + 60 + + +@pytest.mark.asyncio +async def test_cache_stores_complex_objects(cache: AuthCache): + # Simulate an OIDC token structure + token_data = { + "access_token": "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9...", + "id_token": "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9...", + "userinfo": { + "sub": "user123", + "email": "user@example.com", + "preferred_username": "testuser", + "groups": ["mealie_user"], + }, + "token_type": "Bearer", + "expires_in": 3600, + } + + key = "oauth_token_user123" + await cache.set(key, token_data) + + retrieved = await cache.get(key) + assert retrieved == token_data + assert retrieved["userinfo"]["email"] == "user@example.com" + assert "mealie_user" in retrieved["userinfo"]["groups"] + + +@pytest.mark.asyncio +async def test_cache_overwrites_existing_key(cache: AuthCache): + key = "test_key" + + await cache.set(key, "initial_value") + assert await cache.get(key) == "initial_value" + + await cache.set(key, "new_value") + assert await cache.get(key) == "new_value" + + +@pytest.mark.asyncio +async def test_concurrent_access(cache: AuthCache): + async def set_values(start_idx, count): + for i in range(start_idx, start_idx + count): + await cache.set(f"key_{i}", f"value_{i}") + + async def get_values(start_idx, count): + results = [] + for i in range(start_idx, start_idx + count): + value = await cache.get(f"key_{i}") + results.append(value) + return results + + await asyncio.gather(set_values(0, 5), set_values(5, 5), set_values(10, 5)) + results = await asyncio.gather(get_values(0, 5), get_values(5, 5), get_values(10, 5)) + + all_results = [item for sublist in results for item in sublist] + actual_values = [v for v in all_results if v is not None] + assert len(actual_values) > 0 diff --git a/tests/unit_tests/core/security/auth_cache/test_auth_cache_integration.py b/tests/unit_tests/core/security/auth_cache/test_auth_cache_integration.py new file mode 100644 index 000000000..76d06185e --- /dev/null +++ b/tests/unit_tests/core/security/auth_cache/test_auth_cache_integration.py @@ -0,0 +1,153 @@ +import asyncio + +import pytest +from authlib.integrations.starlette_client import OAuth + +from mealie.routes.auth.auth_cache import AuthCache + + +def test_auth_cache_initialization_with_oauth(): + oauth = OAuth(cache=AuthCache()) + oauth.register( + "test_oidc", + client_id="test_client_id", + client_secret="test_client_secret", + server_metadata_url="https://example.com/.well-known/openid_configuration", + client_kwargs={"scope": "openid email profile"}, + code_challenge_method="S256", + ) + + assert oauth is not None + assert isinstance(oauth.cache, AuthCache) + assert "test_oidc" in oauth._clients + + +@pytest.mark.asyncio +async def test_oauth_cache_operations(): + cache = AuthCache(threshold=500, default_timeout=300) + cache_key = "oauth_state_12345" + oauth_data = { + "state": "12345", + "code_verifier": "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk", + "redirect_uri": "http://localhost:3000/login", + } + + result = await cache.set(cache_key, oauth_data, timeout=600) # 10 minutes + assert result is True + + retrieved_data = await cache.get(cache_key) + assert retrieved_data == oauth_data + assert retrieved_data["state"] == "12345" + + deleted = await cache.delete(cache_key) + assert deleted is True + assert await cache.get(cache_key) is None + + +@pytest.mark.asyncio +async def test_oauth_cache_handles_token_expiration(): + cache = AuthCache() + token_key = "access_token_user123" + token_data = { + "access_token": "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9...", + "token_type": "Bearer", + "expires_in": 3600, + "scope": "openid email profile", + } + + await cache.set(token_key, token_data, timeout=0.1) + assert await cache.has(token_key) is True + + await asyncio.sleep(0.15) + assert await cache.has(token_key) is False + assert await cache.get(token_key) is None + + +@pytest.mark.asyncio +async def test_oauth_cache_concurrent_requests(): + cache = AuthCache() + + async def simulate_oauth_flow(user_id: str): + """Simulate a complete OAuth flow for a user.""" + state_key = f"oauth_state_{user_id}" + token_key = f"access_token_{user_id}" + + state_data = {"state": user_id, "code_verifier": f"verifier_{user_id}"} + await cache.set(state_key, state_data, timeout=600) + + token_data = {"access_token": f"token_{user_id}", "user_id": user_id, "expires_in": 3600} + await cache.set(token_key, token_data, timeout=3600) + + state = await cache.get(state_key) + token = await cache.get(token_key) + + return state, token + + results = await asyncio.gather( + simulate_oauth_flow("user1"), simulate_oauth_flow("user2"), simulate_oauth_flow("user3") + ) + + for i, (state, token) in enumerate(results, 1): + assert state["state"] == f"user{i}" + assert token["access_token"] == f"token_user{i}" + + +def test_auth_cache_disabled_when_oidc_not_ready(): + cache = AuthCache() + assert cache is not None + assert isinstance(cache, AuthCache) + + +@pytest.mark.asyncio +async def test_auth_cache_memory_efficiency(): + cache = AuthCache(threshold=10, default_timeout=300) + for i in range(50): + await cache.set(f"token_{i}", f"data_{i}", timeout=0) # Never expire + + assert len(cache._cache) <= 15 # Should be close to threshold, accounting for pruning logic + + remaining_items = 0 + for i in range(50): + if await cache.has(f"token_{i}"): + remaining_items += 1 + + assert 0 < remaining_items < 50 + + +@pytest.mark.asyncio +async def test_auth_cache_with_real_oauth_data_structure(): + cache = AuthCache() + oauth_token = { + "access_token": "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCIsImtpZCI6IjEifQ...", + "id_token": "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCIsImtpZCI6IjEifQ...", + "token_type": "Bearer", + "expires_in": 3600, + "scope": "openid email profile groups", + "userinfo": { + "sub": "auth0|507f1f77bcf86cd799439011", + "email": "john.doe@example.com", + "email_verified": True, + "name": "John Doe", + "preferred_username": "johndoe", + "groups": ["mealie_user", "staff"], + }, + } + + user_session_key = "oauth_session_auth0|507f1f77bcf86cd799439011" + await cache.set(user_session_key, oauth_token, timeout=3600) + + retrieved = await cache.get(user_session_key) + assert retrieved["access_token"] == oauth_token["access_token"] + assert retrieved["userinfo"]["email"] == "john.doe@example.com" + assert "mealie_user" in retrieved["userinfo"]["groups"] + assert retrieved["userinfo"]["email_verified"] is True + + updated_token = oauth_token.copy() + updated_token["access_token"] = "new_access_token_eyJhbGciOiJSUzI1NiIs..." + updated_token["userinfo"]["last_login"] = "2024-01-01T12:00:00Z" + + await cache.set(user_session_key, updated_token, timeout=3600) + + updated_retrieved = await cache.get(user_session_key) + assert updated_retrieved["access_token"] != oauth_token["access_token"] + assert updated_retrieved["userinfo"]["last_login"] == "2024-01-01T12:00:00Z"