From 548e652a82ffe2a9f9ca216f3d92c79723c0d850 Mon Sep 17 00:00:00 2001 From: Hayden Date: Mon, 18 Jan 2021 15:11:01 -0900 Subject: [PATCH] test: refactor database init to support tests --- mealie/app.py | 4 +- mealie/db/database.py | 17 ++---- mealie/db/db_base.py | 32 +++------- mealie/db/db_setup.py | 22 +++++-- mealie/db/sql/db_session.py | 26 ++++---- mealie/db/sql/recipe_models.py | 6 +- mealie/routes/backup_routes.py | 12 +++- mealie/routes/meal_routes.py | 30 ++++++---- mealie/routes/migration_routes.py | 14 +++-- mealie/routes/recipe_routes.py | 49 +++++++++------ mealie/routes/setting_routes.py | 30 ++++++---- mealie/services/backups/exports.py | 20 ++++--- mealie/services/backups/imports.py | 10 ++-- mealie/services/meal_services.py | 27 +++++---- mealie/services/migrations/chowdown.py | 7 ++- mealie/services/migrations/nextcloud.py | 6 +- mealie/services/recipe_services.py | 29 +++++---- mealie/services/scheduler_services.py | 3 +- mealie/services/scrape_services.py | 7 ++- mealie/services/settings_services.py | 45 +++++++------- mealie/tests/conftest.py | 25 +++++++- .../tests/test_routes/test_recipe_routes.py | 59 +++++++++++++++++++ mealie/utils/global_scheduler.py | 2 +- 23 files changed, 304 insertions(+), 178 deletions(-) create mode 100644 mealie/tests/test_routes/test_recipe_routes.py diff --git a/mealie/app.py b/mealie/app.py index 049764bba..5444c15cd 100644 --- a/mealie/app.py +++ b/mealie/app.py @@ -3,7 +3,7 @@ from fastapi import FastAPI from fastapi.staticfiles import StaticFiles # import utils.startup as startup -from app_config import PORT, PRODUCTION, WEB_PATH, docs_url, redoc_url +from app_config import PORT, PRODUCTION, SQLITE_FILE, WEB_PATH, docs_url, redoc_url from routes import ( backup_routes, meal_routes, @@ -13,6 +13,7 @@ from routes import ( static_routes, user_routes, ) + # from utils.api_docs import generate_api_docs from utils.logger import logger @@ -25,6 +26,7 @@ app = FastAPI( ) + def mount_static_files(): app.mount("/static", StaticFiles(directory=WEB_PATH, html=True)) diff --git a/mealie/db/database.py b/mealie/db/database.py index c44bf5d3c..c8ecef7c5 100644 --- a/mealie/db/database.py +++ b/mealie/db/database.py @@ -1,5 +1,4 @@ from db.db_base import BaseDocument -from db.sql.db_session import create_session from db.sql.meal_models import MealPlanModel from db.sql.recipe_models import RecipeModel from db.sql.settings_models import SiteSettingsModel @@ -10,11 +9,12 @@ from db.sql.theme_models import SiteThemeModel - [ ] Abstract Classes to use save_new, and update from base models - [ ] Create Category and Tags Table with Many to Many relationship """ + + class _Recipes(BaseDocument): def __init__(self) -> None: self.primary_key = "slug" self.sql_model = RecipeModel - self.create_session = create_session def update_image(self, slug: str, extension: str) -> None: pass @@ -24,17 +24,14 @@ class _Meals(BaseDocument): def __init__(self) -> None: self.primary_key = "uid" self.sql_model = MealPlanModel - self.create_session = create_session class _Settings(BaseDocument): def __init__(self) -> None: self.primary_key = "name" self.sql_model = SiteSettingsModel - self.create_session = create_session - def save_new(self, main: dict, webhooks: dict) -> str: - session = create_session() + def save_new(self, session, main: dict, webhooks: dict) -> str: new_settings = self.sql_model(main.get("name"), webhooks) session.add(new_settings) @@ -47,16 +44,14 @@ class _Themes(BaseDocument): def __init__(self) -> None: self.primary_key = "name" self.sql_model = SiteThemeModel - self.create_session = create_session - def update(self, data: dict) -> dict: - session, theme_model = self._query_one( - match_value=data["name"], match_key="name" + def update(self, session, data: dict) -> dict: + theme_model = self._query_one( + session=session, match_value=data["name"], match_key="name" ) theme_model.update(**data) session.commit() - session.close() class Database: diff --git a/mealie/db/db_base.py b/mealie/db/db_base.py index 730c262aa..7318de7ad 100644 --- a/mealie/db/db_base.py +++ b/mealie/db/db_base.py @@ -2,7 +2,6 @@ from typing import Union from sqlalchemy.orm.session import Session -from db.sql.db_session import create_session from db.sql.model_base import SqlAlchemyBase @@ -11,12 +10,9 @@ class BaseDocument: self.primary_key: str self.store: str self.sql_model: SqlAlchemyBase - self.create_session = create_session - def get_all(self, limit: int = None, order_by: str = None): - session = create_session() + def get_all(self, session: Session, limit: int = None, order_by: str = None): list = [x.dict() for x in session.query(self.sql_model).all()] - session.close() if limit == 1: return list[0] @@ -24,7 +20,7 @@ class BaseDocument: return list def _query_one( - self, match_value: str, match_key: str = None + self, session: Session, match_value: str, match_key: str = None ) -> Union[Session, SqlAlchemyBase]: """Query the sql database for one item an return the sql alchemy model object. If no match key is provided the primary_key attribute will be used. @@ -36,8 +32,6 @@ class BaseDocument: Returns: Union[Session, SqlAlchemyBase]: Will return both the session and found model """ - session = self.create_session() - if match_key == None: match_key = self.primary_key @@ -45,10 +39,10 @@ class BaseDocument: session.query(self.sql_model).filter_by(**{match_key: match_value}).one() ) - return session, result + return result def get( - self, match_value: str, match_key: str = None, limit=1 + self, session: Session, match_value: str, match_key: str = None, limit=1 ) -> dict or list[dict]: """Retrieves an entry from the database by matching a key/value pair. If no key is provided the class objects primary key will be used to match against. @@ -65,39 +59,30 @@ class BaseDocument: if match_key == None: match_key = self.primary_key - session = self.create_session() result = ( session.query(self.sql_model).filter_by(**{match_key: match_value}).one() ) db_entry = result.dict() - session.close() return db_entry - def save_new(self, document: dict) -> dict: - session = self.create_session() + def save_new(self, session: Session, document: dict) -> dict: new_document = self.sql_model(**document) session.add(new_document) return_data = new_document.dict() session.commit() - - session.close() return return_data - def update(self, match_value, new_data) -> dict: - session, entry = self._query_one(match_value=match_value) + def update(self, session: Session, match_value, new_data) -> dict: + entry = self._query_one(session=session, match_value=match_value) entry.update(session=session, **new_data) return_data = entry.dict() session.commit() - session.close() - return return_data - def delete(self, primary_key_value) -> dict: - session = create_session() - + def delete(self, session: Session, primary_key_value) -> dict: result = ( session.query(self.sql_model) .filter_by(**{self.primary_key: primary_key_value}) @@ -107,4 +92,3 @@ class BaseDocument: session.delete(result) session.commit() - session.close() diff --git a/mealie/db/db_setup.py b/mealie/db/db_setup.py index f05ab8940..aa315a012 100644 --- a/mealie/db/db_setup.py +++ b/mealie/db/db_setup.py @@ -1,14 +1,26 @@ from app_config import SQLITE_FILE, USE_SQL +from sqlalchemy.orm.session import Session -from db.sql.db_session import globa_init as sql_global_init +from db.sql.db_session import sql_global_init sql_exists = True if USE_SQL: sql_exists = SQLITE_FILE.is_file() - sql_global_init(SQLITE_FILE) - - pass - + SessionLocal = sql_global_init(SQLITE_FILE) else: raise Exception("Cannot identify database type") + + +def create_session() -> Session: + global SessionLocal + return SessionLocal() + + +def generate_session() -> Session: + global SessionLocal + db = SessionLocal() + try: + yield db + finally: + db.close() diff --git a/mealie/db/sql/db_session.py b/mealie/db/sql/db_session.py index 442d401cb..8376cc593 100644 --- a/mealie/db/sql/db_session.py +++ b/mealie/db/sql/db_session.py @@ -1,29 +1,25 @@ from pathlib import Path import sqlalchemy as sa -import sqlalchemy.orm as orm from db.sql.model_base import SqlAlchemyBase -from sqlalchemy.orm.session import Session - -__factory = None +from sqlalchemy.orm import sessionmaker -def globa_init(db_file: Path): - global __factory +def sql_global_init(db_file: Path, check_thread=False): - if __factory: - return - conn_str = "sqlite:///" + str(db_file.absolute()) + SQLALCHEMY_DATABASE_URL = "sqlite:///" + str(db_file.absolute()) + # SQLALCHEMY_DATABASE_URL = "postgresql://user:password@postgresserver/db" - engine = sa.create_engine(conn_str, echo=False) + engine = sa.create_engine( + SQLALCHEMY_DATABASE_URL, + echo=False, + connect_args={"check_same_thread": check_thread}, + ) - __factory = orm.sessionmaker(bind=engine) + SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) import db.sql._all_models SqlAlchemyBase.metadata.create_all(engine) - -def create_session() -> Session: - global __factory - return __factory() + return SessionLocal diff --git a/mealie/db/sql/recipe_models.py b/mealie/db/sql/recipe_models.py index f8b7cf65b..561d07c23 100644 --- a/mealie/db/sql/recipe_models.py +++ b/mealie/db/sql/recipe_models.py @@ -50,6 +50,10 @@ class Note(SqlAlchemyBase): title = sa.Column(sa.String) text = sa.Column(sa.String) + def __init__(self, title, text) -> None: + self.title = title + self.text = text + def dict(self): return {"title": self.title, "text": self.text} @@ -169,7 +173,7 @@ class RecipeModel(SqlAlchemyBase, BaseMixins): self.categories = [Category(name=cat) for cat in categories] self.tags = [Tag(name=tag) for tag in tags] self.dateAdded = dateAdded - self.notes = [Note(note) for note in notes] + self.notes = [Note(**note) for note in notes] self.rating = rating self.orgURL = orgURL self.extras = [ApiExtras(key=key, value=value) for key, value in extras.items()] diff --git a/mealie/routes/backup_routes.py b/mealie/routes/backup_routes.py index f29c560ce..01b945393 100644 --- a/mealie/routes/backup_routes.py +++ b/mealie/routes/backup_routes.py @@ -1,10 +1,12 @@ import operator from app_config import BACKUP_DIR, TEMPLATE_DIR -from fastapi import APIRouter, HTTPException +from db.db_setup import generate_session +from fastapi import APIRouter, Depends, HTTPException from models.backup_models import BackupJob, ImportJob, Imports, LocalBackup from services.backups.exports import backup_all from services.backups.imports import ImportDatabase +from sqlalchemy.orm.session import Session from utils.snackbar import SnackResponse router = APIRouter(tags=["Import / Export"]) @@ -28,9 +30,10 @@ def available_imports(): @router.post("/api/backups/export/database/", status_code=201) -def export_database(data: BackupJob): +def export_database(data: BackupJob, db: Session = Depends(generate_session)): """Generates a backup of the recipe database in json format.""" export_path = backup_all( + session=db, tag=data.tag, templates=data.templates, export_recipes=data.options.recipes, @@ -47,10 +50,13 @@ def export_database(data: BackupJob): @router.post("/api/backups/{file_name}/import/", status_code=200) -def import_database(file_name: str, import_data: ImportJob): +def import_database( + file_name: str, import_data: ImportJob, db: Session = Depends(generate_session) +): """ Import a database backup file generated from Mealie. """ import_db = ImportDatabase( + session=db, zip_archive=import_data.name, import_recipes=import_data.recipes, force_import=import_data.force, diff --git a/mealie/routes/meal_routes.py b/mealie/routes/meal_routes.py index 8e9663043..0bf536b6d 100644 --- a/mealie/routes/meal_routes.py +++ b/mealie/routes/meal_routes.py @@ -1,24 +1,26 @@ from typing import List -from fastapi import APIRouter, HTTPException +from db.db_setup import generate_session +from fastapi import APIRouter, Depends, HTTPException from services.meal_services import MealPlan +from sqlalchemy.orm.session import Session from utils.snackbar import SnackResponse router = APIRouter(tags=["Meal Plan"]) @router.get("/api/meal-plan/all/", response_model=List[MealPlan]) -def get_all_meals(): +def get_all_meals(db: Session = Depends(generate_session)): """ Returns a list of all available Meal Plan """ - return MealPlan.get_all() + return MealPlan.get_all(db) @router.post("/api/meal-plan/create/") -def set_meal_plan(data: MealPlan): +def set_meal_plan(data: MealPlan, db: Session = Depends(generate_session)): """ Creates a meal plan database entry """ data.process_meals() - data.save_to_db() + data.save_to_db(db) # raise HTTPException( # status_code=404, @@ -29,10 +31,12 @@ def set_meal_plan(data: MealPlan): @router.post("/api/meal-plan/{plan_id}/update/") -def update_meal_plan(plan_id: str, meal_plan: MealPlan): +def update_meal_plan( + plan_id: str, meal_plan: MealPlan, db: Session = Depends(generate_session) +): """ Updates a meal plan based off ID """ meal_plan.process_meals() - meal_plan.update(plan_id) + meal_plan.update(db, plan_id) # try: # meal_plan.process_meals() # meal_plan.update(plan_id) @@ -46,10 +50,10 @@ def update_meal_plan(plan_id: str, meal_plan: MealPlan): @router.delete("/api/meal-plan/{plan_id}/delete/") -def delete_meal_plan(plan_id): +def delete_meal_plan(plan_id, db: Session = Depends(generate_session)): """ Removes a meal plan from the database """ - MealPlan.delete(plan_id) + MealPlan.delete(db, plan_id) return SnackResponse.success("Mealplan Deleted") @@ -58,17 +62,17 @@ def delete_meal_plan(plan_id): "/api/meal-plan/today/", tags=["Meal Plan"], ) -def get_today(): +def get_today(db: Session = Depends(generate_session)): """ Returns the recipe slug for the meal scheduled for today. If no meal is scheduled nothing is returned """ - return MealPlan.today() + return MealPlan.today(db) @router.get("/api/meal-plan/this-week/", response_model=MealPlan) -def get_this_week(): +def get_this_week(db: Session = Depends(generate_session)): """ Returns the meal plan data for this week """ - return MealPlan.this_week() + return MealPlan.this_week(db) diff --git a/mealie/routes/migration_routes.py b/mealie/routes/migration_routes.py index 49eccc2c8..5fb101f37 100644 --- a/mealie/routes/migration_routes.py +++ b/mealie/routes/migration_routes.py @@ -1,10 +1,12 @@ import shutil -from fastapi import APIRouter, File, HTTPException, UploadFile +from app_config import MIGRATION_DIR +from db.db_setup import create_session +from fastapi import APIRouter, Depends, File, HTTPException, UploadFile from models.migration_models import ChowdownURL from services.migrations.chowdown import chowdown_migrate as chowdow_migrate from services.migrations.nextcloud import migrate as nextcloud_migrate -from app_config import MIGRATION_DIR +from sqlalchemy.orm.session import Session from utils.snackbar import SnackResponse router = APIRouter(tags=["Migration"]) @@ -12,10 +14,10 @@ router = APIRouter(tags=["Migration"]) # Chowdown @router.post("/api/migration/chowdown/repo/") -def import_chowdown_recipes(repo: ChowdownURL): +def import_chowdown_recipes(repo: ChowdownURL, db: Session = Depends(create_session)): """ Import Chowsdown Recipes from Repo URL """ try: - report = chowdow_migrate(repo.url) + report = chowdow_migrate(db, repo.url) return SnackResponse.success( "Recipes Imported from Git Repo, see report for failures.", additional_data=report, @@ -44,10 +46,10 @@ def get_avaiable_nextcloud_imports(): @router.post("/api/migration/nextcloud/{selection}/import/") -def import_nextcloud_directory(selection: str): +def import_nextcloud_directory(selection: str, db: Session = Depends(create_session)): """ Imports all the recipes in a given directory """ - return nextcloud_migrate(selection) + return nextcloud_migrate(db, selection) @router.delete("/api/migration/{file_folder_name}/delete/") diff --git a/mealie/routes/recipe_routes.py b/mealie/routes/recipe_routes.py index eadf7046f..d6e5d0372 100644 --- a/mealie/routes/recipe_routes.py +++ b/mealie/routes/recipe_routes.py @@ -1,18 +1,24 @@ from typing import List, Optional -from fastapi import APIRouter, File, Form, HTTPException, Query +from db.db_setup import generate_session +from fastapi import APIRouter, Depends, File, Form, HTTPException, Query from fastapi.responses import FileResponse from models.recipe_models import AllRecipeRequest, RecipeURLIn from services.image_services import read_image, write_image from services.recipe_services import Recipe, read_requested_values from services.scrape_services import create_from_url +from sqlalchemy.orm.session import Session from utils.snackbar import SnackResponse router = APIRouter(tags=["Recipes"]) @router.get("/api/all-recipes/", response_model=List[dict]) -def get_all_recipes(keys: Optional[List[str]] = Query(...), num: Optional[int] = 100): +def get_all_recipes( + keys: Optional[List[str]] = Query(...), + num: Optional[int] = 100, + db: Session = Depends(generate_session), +): """ Returns key data for all recipes based off the query paramters provided. For example, if slug, image, and name are provided you will recieve a list of @@ -24,12 +30,14 @@ def get_all_recipes(keys: Optional[List[str]] = Query(...), num: Optional[int] = See the *Post* method for more details. """ - all_recipes = read_requested_values(keys, num) + all_recipes = read_requested_values(db, keys, num) return all_recipes @router.post("/api/all-recipes/", response_model=List[dict]) -def get_all_recipes_post(body: AllRecipeRequest): +def get_all_recipes_post( + body: AllRecipeRequest, db: Session = Depends(generate_session) +): """ Returns key data for all recipes based off the body data provided. For example, if slug, image, and name are provided you will recieve a list of @@ -39,15 +47,18 @@ def get_all_recipes_post(body: AllRecipeRequest): """ - all_recipes = read_requested_values(body.properties, body.limit) + all_recipes = read_requested_values(db, body.properties, body.limit) return all_recipes -@router.get("/api/recipe/{recipe_slug}/", response_model=Recipe) -def get_recipe(recipe_slug: str): +@router.get( + "/api/recipe/{recipe_slug}/", + response_model=Recipe, +) +def get_recipe(recipe_slug: str, db: Session = Depends(generate_session)): """ Takes in a recipe slug, returns all data for a recipe """ - recipe = Recipe.get_by_slug(recipe_slug) + recipe = Recipe.get_by_slug(db, recipe_slug) return recipe @@ -63,22 +74,22 @@ def get_recipe_img(recipe_slug: str): # Recipe Creations @router.post( "/api/recipe/create-url/", - tags=["Recipes"], status_code=201, response_model=str, ) -def parse_recipe_url(url: RecipeURLIn): +def parse_recipe_url(url: RecipeURLIn, db: Session = Depends(generate_session)): """ Takes in a URL and attempts to scrape data and load it into the database """ - slug = create_from_url(url.url) + recipe = create_from_url(url.url) + recipe.save_to_db(db) - return slug + return recipe.slug @router.post("/api/recipe/create/") -def create_from_json(data: Recipe) -> str: +def create_from_json(data: Recipe, db: Session = Depends(generate_session)) -> str: """ Takes in a JSON string and loads data into the database as a new entry""" - created_recipe = data.save_to_db() + created_recipe = data.save_to_db(db) return created_recipe @@ -95,20 +106,22 @@ def update_recipe_image( @router.post("/api/recipe/{recipe_slug}/update/") -def update_recipe(recipe_slug: str, data: Recipe): +def update_recipe( + recipe_slug: str, data: Recipe, db: Session = Depends(generate_session) +): """ Updates a recipe by existing slug and data. """ - new_slug = data.update(recipe_slug) + new_slug = data.update(db, recipe_slug) return new_slug @router.delete("/api/recipe/{recipe_slug}/delete/") -def delete_recipe(recipe_slug: str): +def delete_recipe(recipe_slug: str, db: Session = Depends(generate_session)): """ Deletes a recipe by slug """ try: - Recipe.delete(recipe_slug) + Recipe.delete(db, recipe_slug) except: raise HTTPException( status_code=404, detail=SnackResponse.error("Unable to Delete Recipe") diff --git a/mealie/routes/setting_routes.py b/mealie/routes/setting_routes.py index a30875c50..293d84f12 100644 --- a/mealie/routes/setting_routes.py +++ b/mealie/routes/setting_routes.py @@ -1,6 +1,8 @@ -from fastapi import APIRouter, HTTPException +from db.db_setup import generate_session +from fastapi import APIRouter, Depends, HTTPException from services.scheduler_services import post_webhooks from services.settings_services import SiteSettings, SiteTheme +from sqlalchemy.orm.session import Session from utils.global_scheduler import scheduler from utils.snackbar import SnackResponse @@ -8,10 +10,10 @@ router = APIRouter(tags=["Settings"]) @router.get("/api/site-settings/") -def get_main_settings(): +def get_main_settings(db: Session = Depends(generate_session)): """ Returns basic site settings """ - return SiteSettings.get_site_settings() + return SiteSettings.get_site_settings(db) @router.post("/api/site-settings/webhooks/test/") @@ -37,22 +39,22 @@ def update_settings(data: SiteSettings): @router.get("/api/site-settings/themes/", tags=["Themes"]) -def get_all_themes(): +def get_all_themes(db: Session = Depends(generate_session)): """ Returns all site themes """ - return SiteTheme.get_all() + return SiteTheme.get_all(db) @router.get("/api/site-settings/themes/{theme_name}/", tags=["Themes"]) -def get_single_theme(theme_name: str): +def get_single_theme(theme_name: str, db: Session = Depends(generate_session)): """ Returns a named theme """ - return SiteTheme.get_by_name(theme_name) + return SiteTheme.get_by_name(db, theme_name) @router.post("/api/site-settings/themes/create/", tags=["Themes"]) -def create_theme(data: SiteTheme): +def create_theme(data: SiteTheme, db: Session = Depends(generate_session)): """ Creates a site color theme database entry """ - data.save_to_db() + data.save_to_db(db) # try: # data.save_to_db() # except: @@ -64,9 +66,11 @@ def create_theme(data: SiteTheme): @router.post("/api/site-settings/themes/{theme_name}/update/", tags=["Themes"]) -def update_theme(theme_name: str, data: SiteTheme): +def update_theme( + theme_name: str, data: SiteTheme, db: Session = Depends(generate_session) +): """ Update a theme database entry """ - data.update_document() + data.update_document(db) # try: # except: @@ -78,9 +82,9 @@ def update_theme(theme_name: str, data: SiteTheme): @router.delete("/api/site-settings/themes/{theme_name}/delete/", tags=["Themes"]) -def delete_theme(theme_name: str): +def delete_theme(theme_name: str, db: Session = Depends(generate_session)): """ Deletes theme from the database """ - SiteTheme.delete_theme(theme_name) + SiteTheme.delete_theme(db, theme_name) # try: # SiteTheme.delete_theme(theme_name) # except: diff --git a/mealie/services/backups/exports.py b/mealie/services/backups/exports.py index 28ed19baf..f8e946c09 100644 --- a/mealie/services/backups/exports.py +++ b/mealie/services/backups/exports.py @@ -4,6 +4,7 @@ from datetime import datetime from pathlib import Path from app_config import BACKUP_DIR, IMG_DIR, TEMP_DIR, TEMPLATE_DIR +from db.db_setup import create_session from jinja2 import Template from services.meal_services import MealPlan from services.recipe_services import Recipe @@ -12,7 +13,7 @@ from utils.logger import logger class ExportDatabase: - def __init__(self, tag=None, templates=None) -> None: + def __init__(self, session, tag=None, templates=None) -> None: """Export a Mealie database. Export interacts directly with class objects and can be used with any supported backend database platform. By default tags are timestands, and no Jinja2 templates are rendered @@ -26,6 +27,7 @@ class ExportDatabase: else: export_tag = datetime.now().strftime("%Y-%b-%d") + self.session = session self.main_dir = TEMP_DIR.joinpath(export_tag) self.img_dir = self.main_dir.joinpath("images") self.recipe_dir = self.main_dir.joinpath("recipes") @@ -54,7 +56,7 @@ class ExportDatabase: dir.mkdir(parents=True, exist_ok=True) def export_recipes(self): - all_recipes = Recipe.get_all() + all_recipes = Recipe.get_all(self.session) for recipe in all_recipes: logger.info(f"Backing Up Recipes: {recipe}") @@ -86,12 +88,12 @@ class ExportDatabase: shutil.copy(file, self.img_dir.joinpath(file.name)) def export_settings(self): - all_settings = SiteSettings.get_site_settings() + all_settings = SiteSettings.get_site_settings(self.session) out_file = self.settings_dir.joinpath("settings.json") ExportDatabase._write_json_file(all_settings.dict(), out_file) def export_themes(self): - all_themes = SiteTheme.get_all() + all_themes = SiteTheme.get_all(self.session) if all_themes: all_themes = [x.dict() for x in all_themes] out_file = self.themes_dir.joinpath("themes.json") @@ -100,7 +102,7 @@ class ExportDatabase: def export_meals( self, ): #! Problem Parseing Datetime Objects... May come back to this - meal_plans = MealPlan.get_all() + meal_plans = MealPlan.get_all(self.session) if meal_plans: meal_plans = [x.dict() for x in meal_plans] @@ -124,13 +126,14 @@ class ExportDatabase: def backup_all( + session, tag=None, templates=None, export_recipes=True, export_settings=True, export_themes=True, ): - db_export = ExportDatabase(tag=tag, templates=templates) + db_export = ExportDatabase(session=session, tag=tag, templates=templates) if export_recipes: db_export.export_recipes() @@ -138,7 +141,7 @@ def backup_all( if export_settings: db_export.export_settings() - + if export_themes: db_export.export_themes() # db_export.export_meals() @@ -154,5 +157,6 @@ def auto_backup_job(): for template in TEMPLATE_DIR.iterdir(): templates.append(template) - backup_all(tag="Auto", templates=templates) + session = create_session() + backup_all(session=session, tag="Auto", templates=templates) logger.info("Auto Backup Called") diff --git a/mealie/services/backups/imports.py b/mealie/services/backups/imports.py index 940e3a92b..f47979b22 100644 --- a/mealie/services/backups/imports.py +++ b/mealie/services/backups/imports.py @@ -7,12 +7,14 @@ from typing import List from app_config import BACKUP_DIR, IMG_DIR, TEMP_DIR from services.recipe_services import Recipe from services.settings_services import SiteSettings, SiteTheme +from sqlalchemy.orm.session import Session from utils.logger import logger class ImportDatabase: def __init__( self, + session: Session, zip_archive: str, import_recipes: bool = True, import_settings: bool = True, @@ -33,7 +35,7 @@ class ImportDatabase: Raises: Exception: If the zip file does not exists an exception raise. """ - + self.session = session self.archive = BACKUP_DIR.joinpath(zip_archive) self.imp_recipes = import_recipes self.imp_settings = import_settings @@ -75,7 +77,7 @@ class ImportDatabase: recipe_dict = ImportDatabase._recipe_migration(recipe_dict) try: recipe_obj = Recipe(**recipe_dict) - recipe_obj.save_to_db() + recipe_obj.save_to_db(self.session) successful_imports.append(recipe.stem) logger.info(f"Imported: {recipe.stem}") except: @@ -114,7 +116,7 @@ class ImportDatabase: for theme in themes: new_theme = SiteTheme(**theme) try: - new_theme.save_to_db() + new_theme.save_to_db(self.session) except: logger.info(f"Unable Import Theme {new_theme.name}") @@ -126,7 +128,7 @@ class ImportDatabase: settings = SiteSettings(**settings) - settings.update() + settings.update(self.session) def clean_up(self): shutil.rmtree(TEMP_DIR) diff --git a/mealie/services/meal_services.py b/mealie/services/meal_services.py index 9e6b318b0..0f138b9c3 100644 --- a/mealie/services/meal_services.py +++ b/mealie/services/meal_services.py @@ -4,6 +4,7 @@ from typing import List, Optional from db.database import db from pydantic import BaseModel +from sqlalchemy.orm.session import Session from services.recipe_services import Recipe @@ -78,27 +79,29 @@ class MealPlan(BaseModel): self.meals = meals - def save_to_db(self): - db.meals.save_new(self.dict()) + def save_to_db(self, session: Session): + db.meals.save_new(session, self.dict()) @staticmethod - def get_all() -> List: + def get_all(session: Session) -> List: - all_meals = [MealPlan(**x) for x in db.meals.get_all(order_by="startDate")] + all_meals = [ + MealPlan(**x) for x in db.meals.get_all(session, order_by="startDate") + ] return all_meals - def update(self, uid): - db.meals.update(uid, self.dict()) + def update(self, session, uid): + db.meals.update(session, uid, self.dict()) @staticmethod - def delete(uid): - db.meals.delete(uid) + def delete(session, uid): + db.meals.delete(session, uid) @staticmethod - def today() -> str: + def today(session: Session) -> str: """ Returns the meal slug for Today """ - meal_plan = db.meals.get_all(limit=1, order_by="startDate") + meal_plan = db.meals.get_all(session, limit=1, order_by="startDate") meal_docs = [Meal(**meal) for meal in meal_plan["meals"]] @@ -109,7 +112,7 @@ class MealPlan(BaseModel): return "No Meal Today" @staticmethod - def this_week(): - meal_plan = db.meals.get_all(limit=1, order_by="startDate") + def this_week(session: Session): + meal_plan = db.meals.get_all(session, limit=1, order_by="startDate") return meal_plan diff --git a/mealie/services/migrations/chowdown.py b/mealie/services/migrations/chowdown.py index 24a3f85a5..65aca0f53 100644 --- a/mealie/services/migrations/chowdown.py +++ b/mealie/services/migrations/chowdown.py @@ -3,8 +3,9 @@ from pathlib import Path import git import yaml -from services.recipe_services import Recipe from app_config import IMG_DIR +from services.recipe_services import Recipe +from sqlalchemy.orm.session import Session try: from yaml import CLoader as Loader @@ -75,7 +76,7 @@ def read_chowdown_file(recipe_file: Path) -> Recipe: return new_recipe -def chowdown_migrate(repo): +def chowdown_migrate(session: Session, repo): recipe_dir, image_dir = pull_repo(repo) failed_images = [] @@ -89,7 +90,7 @@ def chowdown_migrate(repo): for recipe in recipe_dir.glob("*.md"): try: new_recipe = read_chowdown_file(recipe) - new_recipe.save_to_db() + new_recipe.save_to_db(session) except: failed_recipes.append(recipe.name) diff --git a/mealie/services/migrations/nextcloud.py b/mealie/services/migrations/nextcloud.py index 8d124b004..ae3efc83b 100644 --- a/mealie/services/migrations/nextcloud.py +++ b/mealie/services/migrations/nextcloud.py @@ -4,9 +4,9 @@ import shutil import zipfile from pathlib import Path +from app_config import IMG_DIR, TEMP_DIR from services.recipe_services import Recipe from services.scrape_services import normalize_data, process_recipe_data -from app_config import IMG_DIR, TEMP_DIR CWD = Path(__file__).parent MIGRTAION_DIR = CWD.parent.parent.joinpath("data", "migration") @@ -65,7 +65,7 @@ def cleanup(): shutil.rmtree(TEMP_DIR) -def migrate(selection: str): +def migrate(session, selection: str): prep() MIGRTAION_DIR.mkdir(exist_ok=True) selection = MIGRTAION_DIR.joinpath(selection) @@ -78,7 +78,7 @@ def migrate(selection: str): if dir.is_dir(): try: recipe = import_recipes(dir) - recipe.save_to_db() + recipe.save_to_db(session) successful_imports.append(recipe.name) except: logging.error(f"Failed Nextcloud Import: {dir.name}") diff --git a/mealie/services/recipe_services.py b/mealie/services/recipe_services.py index ed2242b48..61fa67eca 100644 --- a/mealie/services/recipe_services.py +++ b/mealie/services/recipe_services.py @@ -6,6 +6,7 @@ from typing import Any, List, Optional from db.database import db from pydantic import BaseModel, validator from slugify import slugify +from sqlalchemy.orm.session import Session from services.image_services import delete_image @@ -91,14 +92,14 @@ class Recipe(BaseModel): return cls(**document) @classmethod - def get_by_slug(cls, slug: str): + def get_by_slug(cls, session, slug: str): """ Returns a Recipe Object by Slug """ - document = db.recipes.get(slug, "slug") + document = db.recipes.get(session, slug, "slug") return cls(**document) - def save_to_db(self) -> str: + def save_to_db(self, session) -> str: recipe_dict = self.dict() try: @@ -113,21 +114,21 @@ class Recipe(BaseModel): # except: # pass - recipe_doc = db.recipes.save_new(recipe_dict) + recipe_doc = db.recipes.save_new(session, recipe_dict) recipe = Recipe(**recipe_doc) return recipe.slug @staticmethod - def delete(recipe_slug: str) -> str: + def delete(session: Session, recipe_slug: str) -> str: """ Removes the recipe from the database by slug """ delete_image(recipe_slug) - db.recipes.delete(recipe_slug) + db.recipes.delete(session, recipe_slug) return "Document Deleted" - def update(self, recipe_slug: str): + def update(self, session: Session, recipe_slug: str): """ Updates the recipe from the database by slug""" - updated_slug = db.recipes.update(recipe_slug, self.dict()) + updated_slug = db.recipes.update(session, recipe_slug, self.dict()) return updated_slug.get("slug") @staticmethod @@ -135,11 +136,13 @@ class Recipe(BaseModel): db.recipes.update_image(slug, extension) @staticmethod - def get_all(): - return db.recipes.get_all() + def get_all(session: Session): + return db.recipes.get_all(session) -def read_requested_values(keys: list, max_results: int = 0) -> List[dict]: +def read_requested_values( + session: Session, keys: list, max_results: int = 0 +) -> List[dict]: """ Pass in a list of key values to be run against the database. If a match is found it is then added to a dictionary inside of a list. If a key does not exist the @@ -152,7 +155,9 @@ def read_requested_values(keys: list, max_results: int = 0) -> List[dict]: """ recipe_list = [] - for recipe in db.recipes.get_all(limit=max_results, order_by="dateAdded"): + for recipe in db.recipes.get_all( + session=session, limit=max_results, order_by="dateAdded" + ): recipe_details = {} for key in keys: try: diff --git a/mealie/services/scheduler_services.py b/mealie/services/scheduler_services.py index be3fe04e0..96aa49d42 100644 --- a/mealie/services/scheduler_services.py +++ b/mealie/services/scheduler_services.py @@ -3,6 +3,7 @@ import json import requests from apscheduler.schedulers.background import BackgroundScheduler +from db.db_setup import create_session from utils.logger import logger from services.backups.exports import auto_backup_job @@ -40,7 +41,7 @@ class Scheduler: self.scheduler.add_job( auto_backup_job, trigger="cron", hour="3", max_instances=1 ) - settings = SiteSettings.get_site_settings() + settings = SiteSettings.get_site_settings(create_session()) time = cron_parser(settings.webhooks.webhookTime) self.webhook = self.scheduler.add_job( diff --git a/mealie/services/scrape_services.py b/mealie/services/scrape_services.py index 6d72c59ba..1980384e2 100644 --- a/mealie/services/scrape_services.py +++ b/mealie/services/scrape_services.py @@ -5,6 +5,7 @@ from typing import List, Tuple import extruct import requests import scrape_schema_recipe +from app_config import DEBUG_DIR from slugify import slugify from utils.logger import logger from w3lib.html import get_base_url @@ -13,7 +14,7 @@ from services.image_services import scrape_image from services.recipe_services import Recipe CWD = Path(__file__).parent -TEMP_FILE = CWD.parent.joinpath("data", "debug", "last_recipe.json") +TEMP_FILE = DEBUG_DIR.joinpath("last_recipe.json") def normalize_image_url(image) -> str: @@ -165,7 +166,7 @@ def process_recipe_url(url: str) -> dict: return new_recipe -def create_from_url(url: str) -> dict: +def create_from_url(url: str) -> Recipe: recipe_data = process_recipe_url(url) with open(TEMP_FILE, "w") as f: @@ -173,4 +174,4 @@ def create_from_url(url: str) -> dict: recipe = Recipe(**recipe_data) - return recipe.save_to_db() + return recipe diff --git a/mealie/services/settings_services.py b/mealie/services/settings_services.py index 25b5911d9..1f261f5d7 100644 --- a/mealie/services/settings_services.py +++ b/mealie/services/settings_services.py @@ -2,7 +2,9 @@ from typing import List, Optional from db.database import db from db.db_setup import sql_exists +from db.db_setup import create_session, generate_session from pydantic import BaseModel +from sqlalchemy.orm.session import Session from utils.logger import logger @@ -29,17 +31,19 @@ class SiteSettings(BaseModel): } @staticmethod - def get_all(): - db.settings.get_all() + def get_all(session: Session): + db.settings.get_all(session) @classmethod - def get_site_settings(cls): + def get_site_settings(cls, session: Session): try: - document = db.settings.get("main") + document = db.settings.get(session=session, match_value="main") except: webhooks = Webhooks() default_entry = SiteSettings(name="main", webhooks=webhooks) - document = db.settings.save_new(default_entry.dict(), webhooks.dict()) + document = db.settings.save_new( + session, default_entry.dict(), webhooks.dict() + ) return cls(**document) @@ -78,16 +82,16 @@ class SiteTheme(BaseModel): } @classmethod - def get_by_name(cls, theme_name): - db_entry = db.themes.get(theme_name) + def get_by_name(cls, session: Session, theme_name): + db_entry = db.themes.get(session, theme_name) name = db_entry.get("name") colors = Colors(**db_entry.get("colors")) return cls(name=name, colors=colors) @staticmethod - def get_all(): - all_themes = db.themes.get_all() + def get_all(session: Session): + all_themes = db.themes.get_all(session) for index, theme in enumerate(all_themes): name = theme.get("name") colors = Colors(**theme.get("colors")) @@ -96,16 +100,16 @@ class SiteTheme(BaseModel): return all_themes - def save_to_db(self): - db.themes.save_new(self.dict()) + def save_to_db(self, session: Session): + db.themes.save_new(session, self.dict()) - def update_document(self): - db.themes.update(self.dict()) + def update_document(self, session: Session): + db.themes.update(session, self.dict()) @staticmethod - def delete_theme(theme_name: str) -> str: + def delete_theme(session: Session, theme_name: str) -> str: """ Removes the theme by name """ - db.themes.delete(theme_name) + db.themes.delete(session, theme_name) def default_theme_init(): @@ -118,24 +122,25 @@ def default_theme_init(): "warning": "#FF4081", "error": "#EF5350", } - + session = create_session() try: - SiteTheme.get_by_name("default") + SiteTheme.get_by_name(session, "default") logger.info("Default theme exists... skipping generation") except: logger.info("Generating Default Theme") colors = Colors(**default_colors) default_theme = SiteTheme(name="default", colors=colors) - default_theme.save_to_db() + default_theme.save_to_db(session) def default_settings_init(): + session = create_session() try: - document = db.settings.get("main") + document = db.settings.get(session, "main") except: webhooks = Webhooks() default_entry = SiteSettings(name="main", webhooks=webhooks) - document = db.settings.save_new(default_entry.dict(), webhooks.dict()) + document = db.settings.save_new(session, default_entry.dict(), webhooks.dict()) if not sql_exists: diff --git a/mealie/tests/conftest.py b/mealie/tests/conftest.py index ec087fa03..33e9232dc 100644 --- a/mealie/tests/conftest.py +++ b/mealie/tests/conftest.py @@ -1,2 +1,25 @@ -import db.db_setup +from app import app +from app_config import SQLITE_DIR +from db.db_setup import generate_session, sql_global_init +from fastapi.testclient import TestClient from pytest import fixture + +SQLITE_FILE = SQLITE_DIR.joinpath("test.db") +SQLITE_FILE.unlink(missing_ok=True) + + +TestSessionLocal = sql_global_init(SQLITE_FILE, check_thread=False) + + +def override_get_db(): + try: + db = TestSessionLocal() + yield db + finally: + db.close() + + +@fixture +def api_client(): + app.dependency_overrides[generate_session] = override_get_db + return TestClient(app) diff --git a/mealie/tests/test_routes/test_recipe_routes.py b/mealie/tests/test_routes/test_recipe_routes.py new file mode 100644 index 000000000..dfaf93f9e --- /dev/null +++ b/mealie/tests/test_routes/test_recipe_routes.py @@ -0,0 +1,59 @@ +import json + +import pytest + + +class RecipeTestData: + def __init__(self, url, expected_slug) -> None: + self.url: str = url + self.expected_slug: str = expected_slug + + +test_data = [ + RecipeTestData( + url="https://www.bonappetit.com/recipe/rustic-shrimp-toasts", + expected_slug="rustic-shrimp-toasts", + ), + RecipeTestData( + url="https://www.allrecipes.com/recipe/282905/honey-garlic-shrimp/", + expected_slug="honey-garlic-shrimp", + ), +] + + +@pytest.mark.parametrize("recipe_data", test_data) +def test_create(api_client, recipe_data: RecipeTestData): + payload = json.dumps({"url": recipe_data.url}) + response = api_client.post("/api/recipe/create-url/", payload) + assert response.status_code == 201 + assert json.loads(response.text) == recipe_data.expected_slug + + +@pytest.mark.parametrize("recipe_data", test_data) +def test_read_update(api_client, recipe_data): + response = api_client.get(f"/api/recipe/{recipe_data.expected_slug}/") + assert response.status_code == 200 + + recipe = json.loads(response.content) + + recipe["notes"] = [ + {"title": "My Test Title1", "text": "My Test Text1"}, + {"title": "My Test Title2", "text": "My Test Text2"}, + ] + + recipe["categories"] = ["one", "two", "three"] + + payload = json.dumps(recipe) + + response = api_client.post( + f"/api/recipe/{recipe_data.expected_slug}/update/", payload + ) + + assert response.status_code == 200 + assert json.loads(response.text) == recipe_data.expected_slug + + +@pytest.mark.parametrize("recipe_data", test_data) +def test_delete(api_client, recipe_data): + response = api_client.delete(f"/api/recipe/{recipe_data.expected_slug}/delete/") + assert response.status_code == 200 diff --git a/mealie/utils/global_scheduler.py b/mealie/utils/global_scheduler.py index 77d574dd0..9b30ebb2e 100644 --- a/mealie/utils/global_scheduler.py +++ b/mealie/utils/global_scheduler.py @@ -8,4 +8,4 @@ def start_scheduler(): return scheduler -scheduler = start_scheduler() +scheduler = start_scheduler