test: refactor database init to support tests

This commit is contained in:
Hayden 2021-01-18 15:11:01 -09:00
commit 548e652a82
23 changed files with 304 additions and 178 deletions

View file

@ -3,7 +3,7 @@ from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
# import utils.startup as startup # import utils.startup as startup
from app_config import PORT, PRODUCTION, WEB_PATH, docs_url, redoc_url from app_config import PORT, PRODUCTION, SQLITE_FILE, WEB_PATH, docs_url, redoc_url
from routes import ( from routes import (
backup_routes, backup_routes,
meal_routes, meal_routes,
@ -13,6 +13,7 @@ from routes import (
static_routes, static_routes,
user_routes, user_routes,
) )
# from utils.api_docs import generate_api_docs # from utils.api_docs import generate_api_docs
from utils.logger import logger from utils.logger import logger
@ -25,6 +26,7 @@ app = FastAPI(
) )
def mount_static_files(): def mount_static_files():
app.mount("/static", StaticFiles(directory=WEB_PATH, html=True)) app.mount("/static", StaticFiles(directory=WEB_PATH, html=True))

View file

@ -1,5 +1,4 @@
from db.db_base import BaseDocument from db.db_base import BaseDocument
from db.sql.db_session import create_session
from db.sql.meal_models import MealPlanModel from db.sql.meal_models import MealPlanModel
from db.sql.recipe_models import RecipeModel from db.sql.recipe_models import RecipeModel
from db.sql.settings_models import SiteSettingsModel from db.sql.settings_models import SiteSettingsModel
@ -10,11 +9,12 @@ from db.sql.theme_models import SiteThemeModel
- [ ] Abstract Classes to use save_new, and update from base models - [ ] Abstract Classes to use save_new, and update from base models
- [ ] Create Category and Tags Table with Many to Many relationship - [ ] Create Category and Tags Table with Many to Many relationship
""" """
class _Recipes(BaseDocument): class _Recipes(BaseDocument):
def __init__(self) -> None: def __init__(self) -> None:
self.primary_key = "slug" self.primary_key = "slug"
self.sql_model = RecipeModel self.sql_model = RecipeModel
self.create_session = create_session
def update_image(self, slug: str, extension: str) -> None: def update_image(self, slug: str, extension: str) -> None:
pass pass
@ -24,17 +24,14 @@ class _Meals(BaseDocument):
def __init__(self) -> None: def __init__(self) -> None:
self.primary_key = "uid" self.primary_key = "uid"
self.sql_model = MealPlanModel self.sql_model = MealPlanModel
self.create_session = create_session
class _Settings(BaseDocument): class _Settings(BaseDocument):
def __init__(self) -> None: def __init__(self) -> None:
self.primary_key = "name" self.primary_key = "name"
self.sql_model = SiteSettingsModel self.sql_model = SiteSettingsModel
self.create_session = create_session
def save_new(self, main: dict, webhooks: dict) -> str: def save_new(self, session, main: dict, webhooks: dict) -> str:
session = create_session()
new_settings = self.sql_model(main.get("name"), webhooks) new_settings = self.sql_model(main.get("name"), webhooks)
session.add(new_settings) session.add(new_settings)
@ -47,16 +44,14 @@ class _Themes(BaseDocument):
def __init__(self) -> None: def __init__(self) -> None:
self.primary_key = "name" self.primary_key = "name"
self.sql_model = SiteThemeModel self.sql_model = SiteThemeModel
self.create_session = create_session
def update(self, data: dict) -> dict: def update(self, session, data: dict) -> dict:
session, theme_model = self._query_one( theme_model = self._query_one(
match_value=data["name"], match_key="name" session=session, match_value=data["name"], match_key="name"
) )
theme_model.update(**data) theme_model.update(**data)
session.commit() session.commit()
session.close()
class Database: class Database:

View file

@ -2,7 +2,6 @@ from typing import Union
from sqlalchemy.orm.session import Session from sqlalchemy.orm.session import Session
from db.sql.db_session import create_session
from db.sql.model_base import SqlAlchemyBase from db.sql.model_base import SqlAlchemyBase
@ -11,12 +10,9 @@ class BaseDocument:
self.primary_key: str self.primary_key: str
self.store: str self.store: str
self.sql_model: SqlAlchemyBase self.sql_model: SqlAlchemyBase
self.create_session = create_session
def get_all(self, limit: int = None, order_by: str = None): def get_all(self, session: Session, limit: int = None, order_by: str = None):
session = create_session()
list = [x.dict() for x in session.query(self.sql_model).all()] list = [x.dict() for x in session.query(self.sql_model).all()]
session.close()
if limit == 1: if limit == 1:
return list[0] return list[0]
@ -24,7 +20,7 @@ class BaseDocument:
return list return list
def _query_one( def _query_one(
self, match_value: str, match_key: str = None self, session: Session, match_value: str, match_key: str = None
) -> Union[Session, SqlAlchemyBase]: ) -> Union[Session, SqlAlchemyBase]:
"""Query the sql database for one item an return the sql alchemy model """Query the sql database for one item an return the sql alchemy model
object. If no match key is provided the primary_key attribute will be used. object. If no match key is provided the primary_key attribute will be used.
@ -36,8 +32,6 @@ class BaseDocument:
Returns: Returns:
Union[Session, SqlAlchemyBase]: Will return both the session and found model Union[Session, SqlAlchemyBase]: Will return both the session and found model
""" """
session = self.create_session()
if match_key == None: if match_key == None:
match_key = self.primary_key match_key = self.primary_key
@ -45,10 +39,10 @@ class BaseDocument:
session.query(self.sql_model).filter_by(**{match_key: match_value}).one() session.query(self.sql_model).filter_by(**{match_key: match_value}).one()
) )
return session, result return result
def get( def get(
self, match_value: str, match_key: str = None, limit=1 self, session: Session, match_value: str, match_key: str = None, limit=1
) -> dict or list[dict]: ) -> dict or list[dict]:
"""Retrieves an entry from the database by matching a key/value pair. If no """Retrieves an entry from the database by matching a key/value pair. If no
key is provided the class objects primary key will be used to match against. key is provided the class objects primary key will be used to match against.
@ -65,39 +59,30 @@ class BaseDocument:
if match_key == None: if match_key == None:
match_key = self.primary_key match_key = self.primary_key
session = self.create_session()
result = ( result = (
session.query(self.sql_model).filter_by(**{match_key: match_value}).one() session.query(self.sql_model).filter_by(**{match_key: match_value}).one()
) )
db_entry = result.dict() db_entry = result.dict()
session.close()
return db_entry return db_entry
def save_new(self, document: dict) -> dict: def save_new(self, session: Session, document: dict) -> dict:
session = self.create_session()
new_document = self.sql_model(**document) new_document = self.sql_model(**document)
session.add(new_document) session.add(new_document)
return_data = new_document.dict() return_data = new_document.dict()
session.commit() session.commit()
session.close()
return return_data return return_data
def update(self, match_value, new_data) -> dict: def update(self, session: Session, match_value, new_data) -> dict:
session, entry = self._query_one(match_value=match_value) entry = self._query_one(session=session, match_value=match_value)
entry.update(session=session, **new_data) entry.update(session=session, **new_data)
return_data = entry.dict() return_data = entry.dict()
session.commit() session.commit()
session.close()
return return_data return return_data
def delete(self, primary_key_value) -> dict: def delete(self, session: Session, primary_key_value) -> dict:
session = create_session()
result = ( result = (
session.query(self.sql_model) session.query(self.sql_model)
.filter_by(**{self.primary_key: primary_key_value}) .filter_by(**{self.primary_key: primary_key_value})
@ -107,4 +92,3 @@ class BaseDocument:
session.delete(result) session.delete(result)
session.commit() session.commit()
session.close()

View file

@ -1,14 +1,26 @@
from app_config import SQLITE_FILE, USE_SQL from app_config import SQLITE_FILE, USE_SQL
from sqlalchemy.orm.session import Session
from db.sql.db_session import globa_init as sql_global_init from db.sql.db_session import sql_global_init
sql_exists = True sql_exists = True
if USE_SQL: if USE_SQL:
sql_exists = SQLITE_FILE.is_file() sql_exists = SQLITE_FILE.is_file()
sql_global_init(SQLITE_FILE) SessionLocal = sql_global_init(SQLITE_FILE)
pass
else: else:
raise Exception("Cannot identify database type") raise Exception("Cannot identify database type")
def create_session() -> Session:
global SessionLocal
return SessionLocal()
def generate_session() -> Session:
global SessionLocal
db = SessionLocal()
try:
yield db
finally:
db.close()

View file

@ -1,29 +1,25 @@
from pathlib import Path from pathlib import Path
import sqlalchemy as sa import sqlalchemy as sa
import sqlalchemy.orm as orm
from db.sql.model_base import SqlAlchemyBase from db.sql.model_base import SqlAlchemyBase
from sqlalchemy.orm.session import Session from sqlalchemy.orm import sessionmaker
__factory = None
def globa_init(db_file: Path): def sql_global_init(db_file: Path, check_thread=False):
global __factory
if __factory: SQLALCHEMY_DATABASE_URL = "sqlite:///" + str(db_file.absolute())
return # SQLALCHEMY_DATABASE_URL = "postgresql://user:password@postgresserver/db"
conn_str = "sqlite:///" + str(db_file.absolute())
engine = sa.create_engine(conn_str, echo=False) engine = sa.create_engine(
SQLALCHEMY_DATABASE_URL,
echo=False,
connect_args={"check_same_thread": check_thread},
)
__factory = orm.sessionmaker(bind=engine) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
import db.sql._all_models import db.sql._all_models
SqlAlchemyBase.metadata.create_all(engine) SqlAlchemyBase.metadata.create_all(engine)
return SessionLocal
def create_session() -> Session:
global __factory
return __factory()

View file

@ -50,6 +50,10 @@ class Note(SqlAlchemyBase):
title = sa.Column(sa.String) title = sa.Column(sa.String)
text = sa.Column(sa.String) text = sa.Column(sa.String)
def __init__(self, title, text) -> None:
self.title = title
self.text = text
def dict(self): def dict(self):
return {"title": self.title, "text": self.text} return {"title": self.title, "text": self.text}
@ -169,7 +173,7 @@ class RecipeModel(SqlAlchemyBase, BaseMixins):
self.categories = [Category(name=cat) for cat in categories] self.categories = [Category(name=cat) for cat in categories]
self.tags = [Tag(name=tag) for tag in tags] self.tags = [Tag(name=tag) for tag in tags]
self.dateAdded = dateAdded self.dateAdded = dateAdded
self.notes = [Note(note) for note in notes] self.notes = [Note(**note) for note in notes]
self.rating = rating self.rating = rating
self.orgURL = orgURL self.orgURL = orgURL
self.extras = [ApiExtras(key=key, value=value) for key, value in extras.items()] self.extras = [ApiExtras(key=key, value=value) for key, value in extras.items()]

View file

@ -1,10 +1,12 @@
import operator import operator
from app_config import BACKUP_DIR, TEMPLATE_DIR from app_config import BACKUP_DIR, TEMPLATE_DIR
from fastapi import APIRouter, HTTPException from db.db_setup import generate_session
from fastapi import APIRouter, Depends, HTTPException
from models.backup_models import BackupJob, ImportJob, Imports, LocalBackup from models.backup_models import BackupJob, ImportJob, Imports, LocalBackup
from services.backups.exports import backup_all from services.backups.exports import backup_all
from services.backups.imports import ImportDatabase from services.backups.imports import ImportDatabase
from sqlalchemy.orm.session import Session
from utils.snackbar import SnackResponse from utils.snackbar import SnackResponse
router = APIRouter(tags=["Import / Export"]) router = APIRouter(tags=["Import / Export"])
@ -28,9 +30,10 @@ def available_imports():
@router.post("/api/backups/export/database/", status_code=201) @router.post("/api/backups/export/database/", status_code=201)
def export_database(data: BackupJob): def export_database(data: BackupJob, db: Session = Depends(generate_session)):
"""Generates a backup of the recipe database in json format.""" """Generates a backup of the recipe database in json format."""
export_path = backup_all( export_path = backup_all(
session=db,
tag=data.tag, tag=data.tag,
templates=data.templates, templates=data.templates,
export_recipes=data.options.recipes, export_recipes=data.options.recipes,
@ -47,10 +50,13 @@ def export_database(data: BackupJob):
@router.post("/api/backups/{file_name}/import/", status_code=200) @router.post("/api/backups/{file_name}/import/", status_code=200)
def import_database(file_name: str, import_data: ImportJob): def import_database(
file_name: str, import_data: ImportJob, db: Session = Depends(generate_session)
):
""" Import a database backup file generated from Mealie. """ """ Import a database backup file generated from Mealie. """
import_db = ImportDatabase( import_db = ImportDatabase(
session=db,
zip_archive=import_data.name, zip_archive=import_data.name,
import_recipes=import_data.recipes, import_recipes=import_data.recipes,
force_import=import_data.force, force_import=import_data.force,

View file

@ -1,24 +1,26 @@
from typing import List from typing import List
from fastapi import APIRouter, HTTPException from db.db_setup import generate_session
from fastapi import APIRouter, Depends, HTTPException
from services.meal_services import MealPlan from services.meal_services import MealPlan
from sqlalchemy.orm.session import Session
from utils.snackbar import SnackResponse from utils.snackbar import SnackResponse
router = APIRouter(tags=["Meal Plan"]) router = APIRouter(tags=["Meal Plan"])
@router.get("/api/meal-plan/all/", response_model=List[MealPlan]) @router.get("/api/meal-plan/all/", response_model=List[MealPlan])
def get_all_meals(): def get_all_meals(db: Session = Depends(generate_session)):
""" Returns a list of all available Meal Plan """ """ Returns a list of all available Meal Plan """
return MealPlan.get_all() return MealPlan.get_all(db)
@router.post("/api/meal-plan/create/") @router.post("/api/meal-plan/create/")
def set_meal_plan(data: MealPlan): def set_meal_plan(data: MealPlan, db: Session = Depends(generate_session)):
""" Creates a meal plan database entry """ """ Creates a meal plan database entry """
data.process_meals() data.process_meals()
data.save_to_db() data.save_to_db(db)
# raise HTTPException( # raise HTTPException(
# status_code=404, # status_code=404,
@ -29,10 +31,12 @@ def set_meal_plan(data: MealPlan):
@router.post("/api/meal-plan/{plan_id}/update/") @router.post("/api/meal-plan/{plan_id}/update/")
def update_meal_plan(plan_id: str, meal_plan: MealPlan): def update_meal_plan(
plan_id: str, meal_plan: MealPlan, db: Session = Depends(generate_session)
):
""" Updates a meal plan based off ID """ """ Updates a meal plan based off ID """
meal_plan.process_meals() meal_plan.process_meals()
meal_plan.update(plan_id) meal_plan.update(db, plan_id)
# try: # try:
# meal_plan.process_meals() # meal_plan.process_meals()
# meal_plan.update(plan_id) # meal_plan.update(plan_id)
@ -46,10 +50,10 @@ def update_meal_plan(plan_id: str, meal_plan: MealPlan):
@router.delete("/api/meal-plan/{plan_id}/delete/") @router.delete("/api/meal-plan/{plan_id}/delete/")
def delete_meal_plan(plan_id): def delete_meal_plan(plan_id, db: Session = Depends(generate_session)):
""" Removes a meal plan from the database """ """ Removes a meal plan from the database """
MealPlan.delete(plan_id) MealPlan.delete(db, plan_id)
return SnackResponse.success("Mealplan Deleted") return SnackResponse.success("Mealplan Deleted")
@ -58,17 +62,17 @@ def delete_meal_plan(plan_id):
"/api/meal-plan/today/", "/api/meal-plan/today/",
tags=["Meal Plan"], tags=["Meal Plan"],
) )
def get_today(): def get_today(db: Session = Depends(generate_session)):
""" """
Returns the recipe slug for the meal scheduled for today. Returns the recipe slug for the meal scheduled for today.
If no meal is scheduled nothing is returned If no meal is scheduled nothing is returned
""" """
return MealPlan.today() return MealPlan.today(db)
@router.get("/api/meal-plan/this-week/", response_model=MealPlan) @router.get("/api/meal-plan/this-week/", response_model=MealPlan)
def get_this_week(): def get_this_week(db: Session = Depends(generate_session)):
""" Returns the meal plan data for this week """ """ Returns the meal plan data for this week """
return MealPlan.this_week() return MealPlan.this_week(db)

View file

@ -1,10 +1,12 @@
import shutil import shutil
from fastapi import APIRouter, File, HTTPException, UploadFile from app_config import MIGRATION_DIR
from db.db_setup import create_session
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile
from models.migration_models import ChowdownURL from models.migration_models import ChowdownURL
from services.migrations.chowdown import chowdown_migrate as chowdow_migrate from services.migrations.chowdown import chowdown_migrate as chowdow_migrate
from services.migrations.nextcloud import migrate as nextcloud_migrate from services.migrations.nextcloud import migrate as nextcloud_migrate
from app_config import MIGRATION_DIR from sqlalchemy.orm.session import Session
from utils.snackbar import SnackResponse from utils.snackbar import SnackResponse
router = APIRouter(tags=["Migration"]) router = APIRouter(tags=["Migration"])
@ -12,10 +14,10 @@ router = APIRouter(tags=["Migration"])
# Chowdown # Chowdown
@router.post("/api/migration/chowdown/repo/") @router.post("/api/migration/chowdown/repo/")
def import_chowdown_recipes(repo: ChowdownURL): def import_chowdown_recipes(repo: ChowdownURL, db: Session = Depends(create_session)):
""" Import Chowsdown Recipes from Repo URL """ """ Import Chowsdown Recipes from Repo URL """
try: try:
report = chowdow_migrate(repo.url) report = chowdow_migrate(db, repo.url)
return SnackResponse.success( return SnackResponse.success(
"Recipes Imported from Git Repo, see report for failures.", "Recipes Imported from Git Repo, see report for failures.",
additional_data=report, additional_data=report,
@ -44,10 +46,10 @@ def get_avaiable_nextcloud_imports():
@router.post("/api/migration/nextcloud/{selection}/import/") @router.post("/api/migration/nextcloud/{selection}/import/")
def import_nextcloud_directory(selection: str): def import_nextcloud_directory(selection: str, db: Session = Depends(create_session)):
""" Imports all the recipes in a given directory """ """ Imports all the recipes in a given directory """
return nextcloud_migrate(selection) return nextcloud_migrate(db, selection)
@router.delete("/api/migration/{file_folder_name}/delete/") @router.delete("/api/migration/{file_folder_name}/delete/")

View file

@ -1,18 +1,24 @@
from typing import List, Optional from typing import List, Optional
from fastapi import APIRouter, File, Form, HTTPException, Query from db.db_setup import generate_session
from fastapi import APIRouter, Depends, File, Form, HTTPException, Query
from fastapi.responses import FileResponse from fastapi.responses import FileResponse
from models.recipe_models import AllRecipeRequest, RecipeURLIn from models.recipe_models import AllRecipeRequest, RecipeURLIn
from services.image_services import read_image, write_image from services.image_services import read_image, write_image
from services.recipe_services import Recipe, read_requested_values from services.recipe_services import Recipe, read_requested_values
from services.scrape_services import create_from_url from services.scrape_services import create_from_url
from sqlalchemy.orm.session import Session
from utils.snackbar import SnackResponse from utils.snackbar import SnackResponse
router = APIRouter(tags=["Recipes"]) router = APIRouter(tags=["Recipes"])
@router.get("/api/all-recipes/", response_model=List[dict]) @router.get("/api/all-recipes/", response_model=List[dict])
def get_all_recipes(keys: Optional[List[str]] = Query(...), num: Optional[int] = 100): def get_all_recipes(
keys: Optional[List[str]] = Query(...),
num: Optional[int] = 100,
db: Session = Depends(generate_session),
):
""" """
Returns key data for all recipes based off the query paramters provided. Returns key data for all recipes based off the query paramters provided.
For example, if slug, image, and name are provided you will recieve a list of For example, if slug, image, and name are provided you will recieve a list of
@ -24,12 +30,14 @@ def get_all_recipes(keys: Optional[List[str]] = Query(...), num: Optional[int] =
See the *Post* method for more details. See the *Post* method for more details.
""" """
all_recipes = read_requested_values(keys, num) all_recipes = read_requested_values(db, keys, num)
return all_recipes return all_recipes
@router.post("/api/all-recipes/", response_model=List[dict]) @router.post("/api/all-recipes/", response_model=List[dict])
def get_all_recipes_post(body: AllRecipeRequest): def get_all_recipes_post(
body: AllRecipeRequest, db: Session = Depends(generate_session)
):
""" """
Returns key data for all recipes based off the body data provided. Returns key data for all recipes based off the body data provided.
For example, if slug, image, and name are provided you will recieve a list of For example, if slug, image, and name are provided you will recieve a list of
@ -39,15 +47,18 @@ def get_all_recipes_post(body: AllRecipeRequest):
""" """
all_recipes = read_requested_values(body.properties, body.limit) all_recipes = read_requested_values(db, body.properties, body.limit)
return all_recipes return all_recipes
@router.get("/api/recipe/{recipe_slug}/", response_model=Recipe) @router.get(
def get_recipe(recipe_slug: str): "/api/recipe/{recipe_slug}/",
response_model=Recipe,
)
def get_recipe(recipe_slug: str, db: Session = Depends(generate_session)):
""" Takes in a recipe slug, returns all data for a recipe """ """ Takes in a recipe slug, returns all data for a recipe """
recipe = Recipe.get_by_slug(recipe_slug) recipe = Recipe.get_by_slug(db, recipe_slug)
return recipe return recipe
@ -63,22 +74,22 @@ def get_recipe_img(recipe_slug: str):
# Recipe Creations # Recipe Creations
@router.post( @router.post(
"/api/recipe/create-url/", "/api/recipe/create-url/",
tags=["Recipes"],
status_code=201, status_code=201,
response_model=str, response_model=str,
) )
def parse_recipe_url(url: RecipeURLIn): def parse_recipe_url(url: RecipeURLIn, db: Session = Depends(generate_session)):
""" Takes in a URL and attempts to scrape data and load it into the database """ """ Takes in a URL and attempts to scrape data and load it into the database """
slug = create_from_url(url.url) recipe = create_from_url(url.url)
recipe.save_to_db(db)
return slug return recipe.slug
@router.post("/api/recipe/create/") @router.post("/api/recipe/create/")
def create_from_json(data: Recipe) -> str: def create_from_json(data: Recipe, db: Session = Depends(generate_session)) -> str:
""" Takes in a JSON string and loads data into the database as a new entry""" """ Takes in a JSON string and loads data into the database as a new entry"""
created_recipe = data.save_to_db() created_recipe = data.save_to_db(db)
return created_recipe return created_recipe
@ -95,20 +106,22 @@ def update_recipe_image(
@router.post("/api/recipe/{recipe_slug}/update/") @router.post("/api/recipe/{recipe_slug}/update/")
def update_recipe(recipe_slug: str, data: Recipe): def update_recipe(
recipe_slug: str, data: Recipe, db: Session = Depends(generate_session)
):
""" Updates a recipe by existing slug and data. """ """ Updates a recipe by existing slug and data. """
new_slug = data.update(recipe_slug) new_slug = data.update(db, recipe_slug)
return new_slug return new_slug
@router.delete("/api/recipe/{recipe_slug}/delete/") @router.delete("/api/recipe/{recipe_slug}/delete/")
def delete_recipe(recipe_slug: str): def delete_recipe(recipe_slug: str, db: Session = Depends(generate_session)):
""" Deletes a recipe by slug """ """ Deletes a recipe by slug """
try: try:
Recipe.delete(recipe_slug) Recipe.delete(db, recipe_slug)
except: except:
raise HTTPException( raise HTTPException(
status_code=404, detail=SnackResponse.error("Unable to Delete Recipe") status_code=404, detail=SnackResponse.error("Unable to Delete Recipe")

View file

@ -1,6 +1,8 @@
from fastapi import APIRouter, HTTPException from db.db_setup import generate_session
from fastapi import APIRouter, Depends, HTTPException
from services.scheduler_services import post_webhooks from services.scheduler_services import post_webhooks
from services.settings_services import SiteSettings, SiteTheme from services.settings_services import SiteSettings, SiteTheme
from sqlalchemy.orm.session import Session
from utils.global_scheduler import scheduler from utils.global_scheduler import scheduler
from utils.snackbar import SnackResponse from utils.snackbar import SnackResponse
@ -8,10 +10,10 @@ router = APIRouter(tags=["Settings"])
@router.get("/api/site-settings/") @router.get("/api/site-settings/")
def get_main_settings(): def get_main_settings(db: Session = Depends(generate_session)):
""" Returns basic site settings """ """ Returns basic site settings """
return SiteSettings.get_site_settings() return SiteSettings.get_site_settings(db)
@router.post("/api/site-settings/webhooks/test/") @router.post("/api/site-settings/webhooks/test/")
@ -37,22 +39,22 @@ def update_settings(data: SiteSettings):
@router.get("/api/site-settings/themes/", tags=["Themes"]) @router.get("/api/site-settings/themes/", tags=["Themes"])
def get_all_themes(): def get_all_themes(db: Session = Depends(generate_session)):
""" Returns all site themes """ """ Returns all site themes """
return SiteTheme.get_all() return SiteTheme.get_all(db)
@router.get("/api/site-settings/themes/{theme_name}/", tags=["Themes"]) @router.get("/api/site-settings/themes/{theme_name}/", tags=["Themes"])
def get_single_theme(theme_name: str): def get_single_theme(theme_name: str, db: Session = Depends(generate_session)):
""" Returns a named theme """ """ Returns a named theme """
return SiteTheme.get_by_name(theme_name) return SiteTheme.get_by_name(db, theme_name)
@router.post("/api/site-settings/themes/create/", tags=["Themes"]) @router.post("/api/site-settings/themes/create/", tags=["Themes"])
def create_theme(data: SiteTheme): def create_theme(data: SiteTheme, db: Session = Depends(generate_session)):
""" Creates a site color theme database entry """ """ Creates a site color theme database entry """
data.save_to_db() data.save_to_db(db)
# try: # try:
# data.save_to_db() # data.save_to_db()
# except: # except:
@ -64,9 +66,11 @@ def create_theme(data: SiteTheme):
@router.post("/api/site-settings/themes/{theme_name}/update/", tags=["Themes"]) @router.post("/api/site-settings/themes/{theme_name}/update/", tags=["Themes"])
def update_theme(theme_name: str, data: SiteTheme): def update_theme(
theme_name: str, data: SiteTheme, db: Session = Depends(generate_session)
):
""" Update a theme database entry """ """ Update a theme database entry """
data.update_document() data.update_document(db)
# try: # try:
# except: # except:
@ -78,9 +82,9 @@ def update_theme(theme_name: str, data: SiteTheme):
@router.delete("/api/site-settings/themes/{theme_name}/delete/", tags=["Themes"]) @router.delete("/api/site-settings/themes/{theme_name}/delete/", tags=["Themes"])
def delete_theme(theme_name: str): def delete_theme(theme_name: str, db: Session = Depends(generate_session)):
""" Deletes theme from the database """ """ Deletes theme from the database """
SiteTheme.delete_theme(theme_name) SiteTheme.delete_theme(db, theme_name)
# try: # try:
# SiteTheme.delete_theme(theme_name) # SiteTheme.delete_theme(theme_name)
# except: # except:

View file

@ -4,6 +4,7 @@ from datetime import datetime
from pathlib import Path from pathlib import Path
from app_config import BACKUP_DIR, IMG_DIR, TEMP_DIR, TEMPLATE_DIR from app_config import BACKUP_DIR, IMG_DIR, TEMP_DIR, TEMPLATE_DIR
from db.db_setup import create_session
from jinja2 import Template from jinja2 import Template
from services.meal_services import MealPlan from services.meal_services import MealPlan
from services.recipe_services import Recipe from services.recipe_services import Recipe
@ -12,7 +13,7 @@ from utils.logger import logger
class ExportDatabase: class ExportDatabase:
def __init__(self, tag=None, templates=None) -> None: def __init__(self, session, tag=None, templates=None) -> None:
"""Export a Mealie database. Export interacts directly with class objects and can be used """Export a Mealie database. Export interacts directly with class objects and can be used
with any supported backend database platform. By default tags are timestands, and no Jinja2 templates are rendered with any supported backend database platform. By default tags are timestands, and no Jinja2 templates are rendered
@ -26,6 +27,7 @@ class ExportDatabase:
else: else:
export_tag = datetime.now().strftime("%Y-%b-%d") export_tag = datetime.now().strftime("%Y-%b-%d")
self.session = session
self.main_dir = TEMP_DIR.joinpath(export_tag) self.main_dir = TEMP_DIR.joinpath(export_tag)
self.img_dir = self.main_dir.joinpath("images") self.img_dir = self.main_dir.joinpath("images")
self.recipe_dir = self.main_dir.joinpath("recipes") self.recipe_dir = self.main_dir.joinpath("recipes")
@ -54,7 +56,7 @@ class ExportDatabase:
dir.mkdir(parents=True, exist_ok=True) dir.mkdir(parents=True, exist_ok=True)
def export_recipes(self): def export_recipes(self):
all_recipes = Recipe.get_all() all_recipes = Recipe.get_all(self.session)
for recipe in all_recipes: for recipe in all_recipes:
logger.info(f"Backing Up Recipes: {recipe}") logger.info(f"Backing Up Recipes: {recipe}")
@ -86,12 +88,12 @@ class ExportDatabase:
shutil.copy(file, self.img_dir.joinpath(file.name)) shutil.copy(file, self.img_dir.joinpath(file.name))
def export_settings(self): def export_settings(self):
all_settings = SiteSettings.get_site_settings() all_settings = SiteSettings.get_site_settings(self.session)
out_file = self.settings_dir.joinpath("settings.json") out_file = self.settings_dir.joinpath("settings.json")
ExportDatabase._write_json_file(all_settings.dict(), out_file) ExportDatabase._write_json_file(all_settings.dict(), out_file)
def export_themes(self): def export_themes(self):
all_themes = SiteTheme.get_all() all_themes = SiteTheme.get_all(self.session)
if all_themes: if all_themes:
all_themes = [x.dict() for x in all_themes] all_themes = [x.dict() for x in all_themes]
out_file = self.themes_dir.joinpath("themes.json") out_file = self.themes_dir.joinpath("themes.json")
@ -100,7 +102,7 @@ class ExportDatabase:
def export_meals( def export_meals(
self, self,
): #! Problem Parseing Datetime Objects... May come back to this ): #! Problem Parseing Datetime Objects... May come back to this
meal_plans = MealPlan.get_all() meal_plans = MealPlan.get_all(self.session)
if meal_plans: if meal_plans:
meal_plans = [x.dict() for x in meal_plans] meal_plans = [x.dict() for x in meal_plans]
@ -124,13 +126,14 @@ class ExportDatabase:
def backup_all( def backup_all(
session,
tag=None, tag=None,
templates=None, templates=None,
export_recipes=True, export_recipes=True,
export_settings=True, export_settings=True,
export_themes=True, export_themes=True,
): ):
db_export = ExportDatabase(tag=tag, templates=templates) db_export = ExportDatabase(session=session, tag=tag, templates=templates)
if export_recipes: if export_recipes:
db_export.export_recipes() db_export.export_recipes()
@ -154,5 +157,6 @@ def auto_backup_job():
for template in TEMPLATE_DIR.iterdir(): for template in TEMPLATE_DIR.iterdir():
templates.append(template) templates.append(template)
backup_all(tag="Auto", templates=templates) session = create_session()
backup_all(session=session, tag="Auto", templates=templates)
logger.info("Auto Backup Called") logger.info("Auto Backup Called")

View file

@ -7,12 +7,14 @@ from typing import List
from app_config import BACKUP_DIR, IMG_DIR, TEMP_DIR from app_config import BACKUP_DIR, IMG_DIR, TEMP_DIR
from services.recipe_services import Recipe from services.recipe_services import Recipe
from services.settings_services import SiteSettings, SiteTheme from services.settings_services import SiteSettings, SiteTheme
from sqlalchemy.orm.session import Session
from utils.logger import logger from utils.logger import logger
class ImportDatabase: class ImportDatabase:
def __init__( def __init__(
self, self,
session: Session,
zip_archive: str, zip_archive: str,
import_recipes: bool = True, import_recipes: bool = True,
import_settings: bool = True, import_settings: bool = True,
@ -33,7 +35,7 @@ class ImportDatabase:
Raises: Raises:
Exception: If the zip file does not exists an exception raise. Exception: If the zip file does not exists an exception raise.
""" """
self.session = session
self.archive = BACKUP_DIR.joinpath(zip_archive) self.archive = BACKUP_DIR.joinpath(zip_archive)
self.imp_recipes = import_recipes self.imp_recipes = import_recipes
self.imp_settings = import_settings self.imp_settings = import_settings
@ -75,7 +77,7 @@ class ImportDatabase:
recipe_dict = ImportDatabase._recipe_migration(recipe_dict) recipe_dict = ImportDatabase._recipe_migration(recipe_dict)
try: try:
recipe_obj = Recipe(**recipe_dict) recipe_obj = Recipe(**recipe_dict)
recipe_obj.save_to_db() recipe_obj.save_to_db(self.session)
successful_imports.append(recipe.stem) successful_imports.append(recipe.stem)
logger.info(f"Imported: {recipe.stem}") logger.info(f"Imported: {recipe.stem}")
except: except:
@ -114,7 +116,7 @@ class ImportDatabase:
for theme in themes: for theme in themes:
new_theme = SiteTheme(**theme) new_theme = SiteTheme(**theme)
try: try:
new_theme.save_to_db() new_theme.save_to_db(self.session)
except: except:
logger.info(f"Unable Import Theme {new_theme.name}") logger.info(f"Unable Import Theme {new_theme.name}")
@ -126,7 +128,7 @@ class ImportDatabase:
settings = SiteSettings(**settings) settings = SiteSettings(**settings)
settings.update() settings.update(self.session)
def clean_up(self): def clean_up(self):
shutil.rmtree(TEMP_DIR) shutil.rmtree(TEMP_DIR)

View file

@ -4,6 +4,7 @@ from typing import List, Optional
from db.database import db from db.database import db
from pydantic import BaseModel from pydantic import BaseModel
from sqlalchemy.orm.session import Session
from services.recipe_services import Recipe from services.recipe_services import Recipe
@ -78,27 +79,29 @@ class MealPlan(BaseModel):
self.meals = meals self.meals = meals
def save_to_db(self): def save_to_db(self, session: Session):
db.meals.save_new(self.dict()) db.meals.save_new(session, self.dict())
@staticmethod @staticmethod
def get_all() -> List: def get_all(session: Session) -> List:
all_meals = [MealPlan(**x) for x in db.meals.get_all(order_by="startDate")] all_meals = [
MealPlan(**x) for x in db.meals.get_all(session, order_by="startDate")
]
return all_meals return all_meals
def update(self, uid): def update(self, session, uid):
db.meals.update(uid, self.dict()) db.meals.update(session, uid, self.dict())
@staticmethod @staticmethod
def delete(uid): def delete(session, uid):
db.meals.delete(uid) db.meals.delete(session, uid)
@staticmethod @staticmethod
def today() -> str: def today(session: Session) -> str:
""" Returns the meal slug for Today """ """ Returns the meal slug for Today """
meal_plan = db.meals.get_all(limit=1, order_by="startDate") meal_plan = db.meals.get_all(session, limit=1, order_by="startDate")
meal_docs = [Meal(**meal) for meal in meal_plan["meals"]] meal_docs = [Meal(**meal) for meal in meal_plan["meals"]]
@ -109,7 +112,7 @@ class MealPlan(BaseModel):
return "No Meal Today" return "No Meal Today"
@staticmethod @staticmethod
def this_week(): def this_week(session: Session):
meal_plan = db.meals.get_all(limit=1, order_by="startDate") meal_plan = db.meals.get_all(session, limit=1, order_by="startDate")
return meal_plan return meal_plan

View file

@ -3,8 +3,9 @@ from pathlib import Path
import git import git
import yaml import yaml
from services.recipe_services import Recipe
from app_config import IMG_DIR from app_config import IMG_DIR
from services.recipe_services import Recipe
from sqlalchemy.orm.session import Session
try: try:
from yaml import CLoader as Loader from yaml import CLoader as Loader
@ -75,7 +76,7 @@ def read_chowdown_file(recipe_file: Path) -> Recipe:
return new_recipe return new_recipe
def chowdown_migrate(repo): def chowdown_migrate(session: Session, repo):
recipe_dir, image_dir = pull_repo(repo) recipe_dir, image_dir = pull_repo(repo)
failed_images = [] failed_images = []
@ -89,7 +90,7 @@ def chowdown_migrate(repo):
for recipe in recipe_dir.glob("*.md"): for recipe in recipe_dir.glob("*.md"):
try: try:
new_recipe = read_chowdown_file(recipe) new_recipe = read_chowdown_file(recipe)
new_recipe.save_to_db() new_recipe.save_to_db(session)
except: except:
failed_recipes.append(recipe.name) failed_recipes.append(recipe.name)

View file

@ -4,9 +4,9 @@ import shutil
import zipfile import zipfile
from pathlib import Path from pathlib import Path
from app_config import IMG_DIR, TEMP_DIR
from services.recipe_services import Recipe from services.recipe_services import Recipe
from services.scrape_services import normalize_data, process_recipe_data from services.scrape_services import normalize_data, process_recipe_data
from app_config import IMG_DIR, TEMP_DIR
CWD = Path(__file__).parent CWD = Path(__file__).parent
MIGRTAION_DIR = CWD.parent.parent.joinpath("data", "migration") MIGRTAION_DIR = CWD.parent.parent.joinpath("data", "migration")
@ -65,7 +65,7 @@ def cleanup():
shutil.rmtree(TEMP_DIR) shutil.rmtree(TEMP_DIR)
def migrate(selection: str): def migrate(session, selection: str):
prep() prep()
MIGRTAION_DIR.mkdir(exist_ok=True) MIGRTAION_DIR.mkdir(exist_ok=True)
selection = MIGRTAION_DIR.joinpath(selection) selection = MIGRTAION_DIR.joinpath(selection)
@ -78,7 +78,7 @@ def migrate(selection: str):
if dir.is_dir(): if dir.is_dir():
try: try:
recipe = import_recipes(dir) recipe = import_recipes(dir)
recipe.save_to_db() recipe.save_to_db(session)
successful_imports.append(recipe.name) successful_imports.append(recipe.name)
except: except:
logging.error(f"Failed Nextcloud Import: {dir.name}") logging.error(f"Failed Nextcloud Import: {dir.name}")

View file

@ -6,6 +6,7 @@ from typing import Any, List, Optional
from db.database import db from db.database import db
from pydantic import BaseModel, validator from pydantic import BaseModel, validator
from slugify import slugify from slugify import slugify
from sqlalchemy.orm.session import Session
from services.image_services import delete_image from services.image_services import delete_image
@ -91,14 +92,14 @@ class Recipe(BaseModel):
return cls(**document) return cls(**document)
@classmethod @classmethod
def get_by_slug(cls, slug: str): def get_by_slug(cls, session, slug: str):
""" Returns a Recipe Object by Slug """ """ Returns a Recipe Object by Slug """
document = db.recipes.get(slug, "slug") document = db.recipes.get(session, slug, "slug")
return cls(**document) return cls(**document)
def save_to_db(self) -> str: def save_to_db(self, session) -> str:
recipe_dict = self.dict() recipe_dict = self.dict()
try: try:
@ -113,21 +114,21 @@ class Recipe(BaseModel):
# except: # except:
# pass # pass
recipe_doc = db.recipes.save_new(recipe_dict) recipe_doc = db.recipes.save_new(session, recipe_dict)
recipe = Recipe(**recipe_doc) recipe = Recipe(**recipe_doc)
return recipe.slug return recipe.slug
@staticmethod @staticmethod
def delete(recipe_slug: str) -> str: def delete(session: Session, recipe_slug: str) -> str:
""" Removes the recipe from the database by slug """ """ Removes the recipe from the database by slug """
delete_image(recipe_slug) delete_image(recipe_slug)
db.recipes.delete(recipe_slug) db.recipes.delete(session, recipe_slug)
return "Document Deleted" return "Document Deleted"
def update(self, recipe_slug: str): def update(self, session: Session, recipe_slug: str):
""" Updates the recipe from the database by slug""" """ Updates the recipe from the database by slug"""
updated_slug = db.recipes.update(recipe_slug, self.dict()) updated_slug = db.recipes.update(session, recipe_slug, self.dict())
return updated_slug.get("slug") return updated_slug.get("slug")
@staticmethod @staticmethod
@ -135,11 +136,13 @@ class Recipe(BaseModel):
db.recipes.update_image(slug, extension) db.recipes.update_image(slug, extension)
@staticmethod @staticmethod
def get_all(): def get_all(session: Session):
return db.recipes.get_all() return db.recipes.get_all(session)
def read_requested_values(keys: list, max_results: int = 0) -> List[dict]: def read_requested_values(
session: Session, keys: list, max_results: int = 0
) -> List[dict]:
""" """
Pass in a list of key values to be run against the database. If a match is found Pass in a list of key values to be run against the database. If a match is found
it is then added to a dictionary inside of a list. If a key does not exist the it is then added to a dictionary inside of a list. If a key does not exist the
@ -152,7 +155,9 @@ def read_requested_values(keys: list, max_results: int = 0) -> List[dict]:
""" """
recipe_list = [] recipe_list = []
for recipe in db.recipes.get_all(limit=max_results, order_by="dateAdded"): for recipe in db.recipes.get_all(
session=session, limit=max_results, order_by="dateAdded"
):
recipe_details = {} recipe_details = {}
for key in keys: for key in keys:
try: try:

View file

@ -3,6 +3,7 @@ import json
import requests import requests
from apscheduler.schedulers.background import BackgroundScheduler from apscheduler.schedulers.background import BackgroundScheduler
from db.db_setup import create_session
from utils.logger import logger from utils.logger import logger
from services.backups.exports import auto_backup_job from services.backups.exports import auto_backup_job
@ -40,7 +41,7 @@ class Scheduler:
self.scheduler.add_job( self.scheduler.add_job(
auto_backup_job, trigger="cron", hour="3", max_instances=1 auto_backup_job, trigger="cron", hour="3", max_instances=1
) )
settings = SiteSettings.get_site_settings() settings = SiteSettings.get_site_settings(create_session())
time = cron_parser(settings.webhooks.webhookTime) time = cron_parser(settings.webhooks.webhookTime)
self.webhook = self.scheduler.add_job( self.webhook = self.scheduler.add_job(

View file

@ -5,6 +5,7 @@ from typing import List, Tuple
import extruct import extruct
import requests import requests
import scrape_schema_recipe import scrape_schema_recipe
from app_config import DEBUG_DIR
from slugify import slugify from slugify import slugify
from utils.logger import logger from utils.logger import logger
from w3lib.html import get_base_url from w3lib.html import get_base_url
@ -13,7 +14,7 @@ from services.image_services import scrape_image
from services.recipe_services import Recipe from services.recipe_services import Recipe
CWD = Path(__file__).parent CWD = Path(__file__).parent
TEMP_FILE = CWD.parent.joinpath("data", "debug", "last_recipe.json") TEMP_FILE = DEBUG_DIR.joinpath("last_recipe.json")
def normalize_image_url(image) -> str: def normalize_image_url(image) -> str:
@ -165,7 +166,7 @@ def process_recipe_url(url: str) -> dict:
return new_recipe return new_recipe
def create_from_url(url: str) -> dict: def create_from_url(url: str) -> Recipe:
recipe_data = process_recipe_url(url) recipe_data = process_recipe_url(url)
with open(TEMP_FILE, "w") as f: with open(TEMP_FILE, "w") as f:
@ -173,4 +174,4 @@ def create_from_url(url: str) -> dict:
recipe = Recipe(**recipe_data) recipe = Recipe(**recipe_data)
return recipe.save_to_db() return recipe

View file

@ -2,7 +2,9 @@ from typing import List, Optional
from db.database import db from db.database import db
from db.db_setup import sql_exists from db.db_setup import sql_exists
from db.db_setup import create_session, generate_session
from pydantic import BaseModel from pydantic import BaseModel
from sqlalchemy.orm.session import Session
from utils.logger import logger from utils.logger import logger
@ -29,17 +31,19 @@ class SiteSettings(BaseModel):
} }
@staticmethod @staticmethod
def get_all(): def get_all(session: Session):
db.settings.get_all() db.settings.get_all(session)
@classmethod @classmethod
def get_site_settings(cls): def get_site_settings(cls, session: Session):
try: try:
document = db.settings.get("main") document = db.settings.get(session=session, match_value="main")
except: except:
webhooks = Webhooks() webhooks = Webhooks()
default_entry = SiteSettings(name="main", webhooks=webhooks) default_entry = SiteSettings(name="main", webhooks=webhooks)
document = db.settings.save_new(default_entry.dict(), webhooks.dict()) document = db.settings.save_new(
session, default_entry.dict(), webhooks.dict()
)
return cls(**document) return cls(**document)
@ -78,16 +82,16 @@ class SiteTheme(BaseModel):
} }
@classmethod @classmethod
def get_by_name(cls, theme_name): def get_by_name(cls, session: Session, theme_name):
db_entry = db.themes.get(theme_name) db_entry = db.themes.get(session, theme_name)
name = db_entry.get("name") name = db_entry.get("name")
colors = Colors(**db_entry.get("colors")) colors = Colors(**db_entry.get("colors"))
return cls(name=name, colors=colors) return cls(name=name, colors=colors)
@staticmethod @staticmethod
def get_all(): def get_all(session: Session):
all_themes = db.themes.get_all() all_themes = db.themes.get_all(session)
for index, theme in enumerate(all_themes): for index, theme in enumerate(all_themes):
name = theme.get("name") name = theme.get("name")
colors = Colors(**theme.get("colors")) colors = Colors(**theme.get("colors"))
@ -96,16 +100,16 @@ class SiteTheme(BaseModel):
return all_themes return all_themes
def save_to_db(self): def save_to_db(self, session: Session):
db.themes.save_new(self.dict()) db.themes.save_new(session, self.dict())
def update_document(self): def update_document(self, session: Session):
db.themes.update(self.dict()) db.themes.update(session, self.dict())
@staticmethod @staticmethod
def delete_theme(theme_name: str) -> str: def delete_theme(session: Session, theme_name: str) -> str:
""" Removes the theme by name """ """ Removes the theme by name """
db.themes.delete(theme_name) db.themes.delete(session, theme_name)
def default_theme_init(): def default_theme_init():
@ -118,24 +122,25 @@ def default_theme_init():
"warning": "#FF4081", "warning": "#FF4081",
"error": "#EF5350", "error": "#EF5350",
} }
session = create_session()
try: try:
SiteTheme.get_by_name("default") SiteTheme.get_by_name(session, "default")
logger.info("Default theme exists... skipping generation") logger.info("Default theme exists... skipping generation")
except: except:
logger.info("Generating Default Theme") logger.info("Generating Default Theme")
colors = Colors(**default_colors) colors = Colors(**default_colors)
default_theme = SiteTheme(name="default", colors=colors) default_theme = SiteTheme(name="default", colors=colors)
default_theme.save_to_db() default_theme.save_to_db(session)
def default_settings_init(): def default_settings_init():
session = create_session()
try: try:
document = db.settings.get("main") document = db.settings.get(session, "main")
except: except:
webhooks = Webhooks() webhooks = Webhooks()
default_entry = SiteSettings(name="main", webhooks=webhooks) default_entry = SiteSettings(name="main", webhooks=webhooks)
document = db.settings.save_new(default_entry.dict(), webhooks.dict()) document = db.settings.save_new(session, default_entry.dict(), webhooks.dict())
if not sql_exists: if not sql_exists:

View file

@ -1,2 +1,25 @@
import db.db_setup from app import app
from app_config import SQLITE_DIR
from db.db_setup import generate_session, sql_global_init
from fastapi.testclient import TestClient
from pytest import fixture from pytest import fixture
SQLITE_FILE = SQLITE_DIR.joinpath("test.db")
SQLITE_FILE.unlink(missing_ok=True)
TestSessionLocal = sql_global_init(SQLITE_FILE, check_thread=False)
def override_get_db():
try:
db = TestSessionLocal()
yield db
finally:
db.close()
@fixture
def api_client():
app.dependency_overrides[generate_session] = override_get_db
return TestClient(app)

View file

@ -0,0 +1,59 @@
import json
import pytest
class RecipeTestData:
def __init__(self, url, expected_slug) -> None:
self.url: str = url
self.expected_slug: str = expected_slug
test_data = [
RecipeTestData(
url="https://www.bonappetit.com/recipe/rustic-shrimp-toasts",
expected_slug="rustic-shrimp-toasts",
),
RecipeTestData(
url="https://www.allrecipes.com/recipe/282905/honey-garlic-shrimp/",
expected_slug="honey-garlic-shrimp",
),
]
@pytest.mark.parametrize("recipe_data", test_data)
def test_create(api_client, recipe_data: RecipeTestData):
payload = json.dumps({"url": recipe_data.url})
response = api_client.post("/api/recipe/create-url/", payload)
assert response.status_code == 201
assert json.loads(response.text) == recipe_data.expected_slug
@pytest.mark.parametrize("recipe_data", test_data)
def test_read_update(api_client, recipe_data):
response = api_client.get(f"/api/recipe/{recipe_data.expected_slug}/")
assert response.status_code == 200
recipe = json.loads(response.content)
recipe["notes"] = [
{"title": "My Test Title1", "text": "My Test Text1"},
{"title": "My Test Title2", "text": "My Test Text2"},
]
recipe["categories"] = ["one", "two", "three"]
payload = json.dumps(recipe)
response = api_client.post(
f"/api/recipe/{recipe_data.expected_slug}/update/", payload
)
assert response.status_code == 200
assert json.loads(response.text) == recipe_data.expected_slug
@pytest.mark.parametrize("recipe_data", test_data)
def test_delete(api_client, recipe_data):
response = api_client.delete(f"/api/recipe/{recipe_data.expected_slug}/delete/")
assert response.status_code == 200

View file

@ -8,4 +8,4 @@ def start_scheduler():
return scheduler return scheduler
scheduler = start_scheduler() scheduler = start_scheduler