remvoe old typing

This commit is contained in:
hay-kot 2021-05-05 13:56:59 -08:00
commit 295775dbd5

View file

@ -1,4 +1,4 @@
from typing import List from typing import Union
from mealie.core.root_logger import get_logger from mealie.core.root_logger import get_logger
from mealie.db.models.model_base import SqlAlchemyBase from mealie.db.models.model_base import SqlAlchemyBase
@ -20,7 +20,7 @@ class BaseDocument:
# TODO: Improve Get All Query Functionality # TODO: Improve Get All Query Functionality
def get_all( def get_all(
self, session: Session, limit: int = None, order_by: str = None, start=0, end=9999, override_schema=None self, session: Session, limit: int = None, order_by: str = None, start=0, end=9999, override_schema=None
) -> List[dict]: ) -> list[dict]:
eff_schema = override_schema or self.schema eff_schema = override_schema or self.schema
if order_by: if order_by:
@ -33,13 +33,13 @@ class BaseDocument:
return [eff_schema.from_orm(x) for x in session.query(self.sql_model).offset(start).limit(limit).all()] return [eff_schema.from_orm(x) for x in session.query(self.sql_model).offset(start).limit(limit).all()]
def get_all_limit_columns(self, session: Session, fields: List[str], limit: int = None) -> List[SqlAlchemyBase]: def get_all_limit_columns(self, session: Session, fields: list[str], limit: int = None) -> list[SqlAlchemyBase]:
"""Queries the database for the selected model. Restricts return responses to the """Queries the database for the selected model. Restricts return responses to the
keys specified under "fields" keys specified under "fields"
Args: \n Args: \n
session (Session): Database Session Object session (Session): Database Session Object
fields (List[str]): List of column names to query fields (list[str]): list of column names to query
limit (int): A limit of values to return limit (int): A limit of values to return
Returns: Returns:
@ -47,7 +47,7 @@ class BaseDocument:
""" """
return session.query(self.sql_model).options(load_only(*fields)).limit(limit).all() return session.query(self.sql_model).options(load_only(*fields)).limit(limit).all()
def get_all_primary_keys(self, session: Session) -> List[str]: def get_all_primary_keys(self, session: Session) -> list[str]:
"""Queries the database of the selected model and returns a list """Queries the database of the selected model and returns a list
of all primary_key values of all primary_key values
@ -79,7 +79,7 @@ class BaseDocument:
def get( def get(
self, session: Session, match_value: str, match_key: str = None, limit=1, any_case=False self, session: Session, match_value: str, match_key: str = None, limit=1, any_case=False
) -> BaseModel or List[BaseModel]: ) -> Union[BaseModel, list[BaseModel]]:
"""Retrieves an entry from the database by matching a key/value pair. If no """Retrieves an entry from the database by matching a key/value pair. If no
key is provided the class objects primary key will be used to match against. key is provided the class objects primary key will be used to match against.
@ -120,6 +120,8 @@ class BaseDocument:
Returns: Returns:
dict: A dictionary representation of the database entry dict: A dictionary representation of the database entry
""" """
document = document if isinstance(document, dict) else document.dict()
new_document = self.sql_model(session=session, **document) new_document = self.sql_model(session=session, **document)
session.add(new_document) session.add(new_document)
session.commit() session.commit()
@ -136,6 +138,7 @@ class BaseDocument:
Returns: Returns:
dict: Returns a dictionary representation of the database entry dict: Returns a dictionary representation of the database entry
""" """
new_data = new_data if isinstance(new_data, dict) else new_data.dict()
entry = self._query_one(session=session, match_value=match_value) entry = self._query_one(session=session, match_value=match_value)
entry.update(session=session, **new_data) entry.update(session=session, **new_data)
@ -144,6 +147,8 @@ class BaseDocument:
return self.schema.from_orm(entry) return self.schema.from_orm(entry)
def patch(self, session: Session, match_value: str, new_data: dict) -> BaseModel: def patch(self, session: Session, match_value: str, new_data: dict) -> BaseModel:
new_data = new_data if isinstance(new_data, dict) else new_data.dict()
entry = self._query_one(session=session, match_value=match_value) entry = self._query_one(session=session, match_value=match_value)
if not entry: if not entry:
@ -168,8 +173,21 @@ class BaseDocument:
session.commit() session.commit()
def count_all(self, session: Session, match_key=None, match_value=None) -> int: def count_all(self, session: Session, match_key=None, match_value=None) -> int:
if None in [match_key, match_value]: if None in [match_key, match_value]:
return session.query(self.sql_model).count() return session.query(self.sql_model).count()
else: else:
return session.query(self.sql_model).filter_by(**{match_key: match_value}).count() return session.query(self.sql_model).filter_by(**{match_key: match_value}).count()
def _countr_attribute(
self, session: Session, attribute_name: str, attr_match: str = None, count=True, override_schema=None
) -> Union[int, BaseModel]:
eff_schema = override_schema or self.schema
# attr_filter = getattr(self.sql_model, attribute_name)
if count:
return session.query(self.sql_model).filter(attribute_name == attr_match).count() # noqa: 711
else:
return [
eff_schema.from_orm(x)
for x in session.query(self.sql_model).filter(attribute_name == attr_match).all() # noqa: 711
]