From 295775dbd5c21a1b39bf8ad07c33f0df76664d8d Mon Sep 17 00:00:00 2001 From: hay-kot Date: Wed, 5 May 2021 13:56:59 -0800 Subject: [PATCH] remvoe old typing --- mealie/db/db_base.py | 32 +++++++++++++++++++++++++------- 1 file changed, 25 insertions(+), 7 deletions(-) diff --git a/mealie/db/db_base.py b/mealie/db/db_base.py index 9d178adf4..5ad9750a0 100644 --- a/mealie/db/db_base.py +++ b/mealie/db/db_base.py @@ -1,4 +1,4 @@ -from typing import List +from typing import Union from mealie.core.root_logger import get_logger from mealie.db.models.model_base import SqlAlchemyBase @@ -20,7 +20,7 @@ class BaseDocument: # TODO: Improve Get All Query Functionality def get_all( self, session: Session, limit: int = None, order_by: str = None, start=0, end=9999, override_schema=None - ) -> List[dict]: + ) -> list[dict]: eff_schema = override_schema or self.schema if order_by: @@ -33,13 +33,13 @@ class BaseDocument: return [eff_schema.from_orm(x) for x in session.query(self.sql_model).offset(start).limit(limit).all()] - def get_all_limit_columns(self, session: Session, fields: List[str], limit: int = None) -> List[SqlAlchemyBase]: + def get_all_limit_columns(self, session: Session, fields: list[str], limit: int = None) -> list[SqlAlchemyBase]: """Queries the database for the selected model. Restricts return responses to the keys specified under "fields" Args: \n session (Session): Database Session Object - fields (List[str]): List of column names to query + fields (list[str]): list of column names to query limit (int): A limit of values to return Returns: @@ -47,7 +47,7 @@ class BaseDocument: """ return session.query(self.sql_model).options(load_only(*fields)).limit(limit).all() - def get_all_primary_keys(self, session: Session) -> List[str]: + def get_all_primary_keys(self, session: Session) -> list[str]: """Queries the database of the selected model and returns a list of all primary_key values @@ -79,7 +79,7 @@ class BaseDocument: def get( self, session: Session, match_value: str, match_key: str = None, limit=1, any_case=False - ) -> BaseModel or List[BaseModel]: + ) -> Union[BaseModel, list[BaseModel]]: """Retrieves an entry from the database by matching a key/value pair. If no key is provided the class objects primary key will be used to match against. @@ -120,6 +120,8 @@ class BaseDocument: Returns: dict: A dictionary representation of the database entry """ + document = document if isinstance(document, dict) else document.dict() + new_document = self.sql_model(session=session, **document) session.add(new_document) session.commit() @@ -136,6 +138,7 @@ class BaseDocument: Returns: dict: Returns a dictionary representation of the database entry """ + new_data = new_data if isinstance(new_data, dict) else new_data.dict() entry = self._query_one(session=session, match_value=match_value) entry.update(session=session, **new_data) @@ -144,6 +147,8 @@ class BaseDocument: return self.schema.from_orm(entry) def patch(self, session: Session, match_value: str, new_data: dict) -> BaseModel: + new_data = new_data if isinstance(new_data, dict) else new_data.dict() + entry = self._query_one(session=session, match_value=match_value) if not entry: @@ -168,8 +173,21 @@ class BaseDocument: session.commit() def count_all(self, session: Session, match_key=None, match_value=None) -> int: - if None in [match_key, match_value]: return session.query(self.sql_model).count() else: return session.query(self.sql_model).filter_by(**{match_key: match_value}).count() + + def _countr_attribute( + self, session: Session, attribute_name: str, attr_match: str = None, count=True, override_schema=None + ) -> Union[int, BaseModel]: + eff_schema = override_schema or self.schema + # attr_filter = getattr(self.sql_model, attribute_name) + + if count: + return session.query(self.sql_model).filter(attribute_name == attr_match).count() # noqa: 711 + else: + return [ + eff_schema.from_orm(x) + for x in session.query(self.sql_model).filter(attribute_name == attr_match).all() # noqa: 711 + ]