From 481ce92d848e7239bddb2c81d7f4c9793181b944 Mon Sep 17 00:00:00 2001 From: Michael Genson <71845777+michael-genson@users.noreply.github.com> Date: Thu, 14 Aug 2025 07:21:40 -0500 Subject: [PATCH] fix: CONTAINS ALL doesn't contain all (#5900) --- mealie/schema/response/query_filter.py | 37 ++++++++++--------- .../repository_tests/test_pagination.py | 13 +++++++ 2 files changed, 32 insertions(+), 18 deletions(-) diff --git a/mealie/schema/response/query_filter.py b/mealie/schema/response/query_filter.py index 9e41490f3..ef093a89d 100644 --- a/mealie/schema/response/query_filter.py +++ b/mealie/schema/response/query_filter.py @@ -351,45 +351,46 @@ class QueryFilterBuilder: ) -> sa.ColumnElement: original_model_attr = model_attr model_attr = cls._transform_model_attr(model_attr, model_attr_type) + value = component.validate(model_attr_type) # Keywords if component.relationship is RelationalKeyword.IS: - element = model_attr.is_(component.validate(model_attr_type)) + element = model_attr.is_(value) elif component.relationship is RelationalKeyword.IS_NOT: - element = model_attr.is_not(component.validate(model_attr_type)) + element = model_attr.is_not(value) elif component.relationship is RelationalKeyword.IN: - element = model_attr.in_(component.validate(model_attr_type)) + element = model_attr.in_(value) elif component.relationship is RelationalKeyword.NOT_IN: - vals = component.validate(model_attr_type) if original_model_attr.parent.entity != model: - subq = query.with_only_columns(model.id).where(model_attr.in_(vals)) + subq = query.with_only_columns(model.id).where(model_attr.in_(value)) element = sa.not_(model.id.in_(subq)) else: - element = sa.not_(model_attr.in_(vals)) + element = sa.not_(model_attr.in_(value)) elif component.relationship is RelationalKeyword.CONTAINS_ALL: - primary_model_attr: InstrumentedAttribute = getattr(model, component.attribute_name.split(".")[0]) - element = sa.and_() - for v in component.validate(model_attr_type): - element = sa.and_(element, primary_model_attr.any(model_attr == v)) + if len(value) == 1: + element = model_attr.in_(value) + else: + primary_model_attr: InstrumentedAttribute = getattr(model, component.attribute_name.split(".")[0]) + element = sa.and_(*(primary_model_attr.any(model_attr == v) for v in value)) elif component.relationship is RelationalKeyword.LIKE: - element = model_attr.ilike(component.validate(model_attr_type)) + element = model_attr.ilike(value) elif component.relationship is RelationalKeyword.NOT_LIKE: - element = model_attr.not_ilike(component.validate(model_attr_type)) + element = model_attr.not_ilike(value) # Operators elif component.relationship is RelationalOperator.EQ: - element = model_attr == component.validate(model_attr_type) + element = model_attr == value elif component.relationship is RelationalOperator.NOTEQ: - element = model_attr != component.validate(model_attr_type) + element = model_attr != value elif component.relationship is RelationalOperator.GT: - element = model_attr > component.validate(model_attr_type) + element = model_attr > value elif component.relationship is RelationalOperator.LT: - element = model_attr < component.validate(model_attr_type) + element = model_attr < value elif component.relationship is RelationalOperator.GTE: - element = model_attr >= component.validate(model_attr_type) + element = model_attr >= value elif component.relationship is RelationalOperator.LTE: - element = model_attr <= component.validate(model_attr_type) + element = model_attr <= value else: raise ValueError(f"invalid relationship {component.relationship}") diff --git a/tests/unit_tests/repository_tests/test_pagination.py b/tests/unit_tests/repository_tests/test_pagination.py index 4d82af9e1..64c987a10 100644 --- a/tests/unit_tests/repository_tests/test_pagination.py +++ b/tests/unit_tests/repository_tests/test_pagination.py @@ -488,6 +488,19 @@ def test_pagination_filter_in_advanced(unique_user: TestUser): assert recipe_2.id not in recipe_ids assert recipe_1_2.id in recipe_ids + query = PaginationQuery( + page=1, + per_page=-1, + query_filter=f"tags.name CONTAINS ALL [{tag_1.name}]", + ) + recipe_results = database.recipes.page_all(query).items + assert len(recipe_results) == 2 + recipe_ids = {recipe.id for recipe in recipe_results} + assert recipe_0.id not in recipe_ids + assert recipe_1.id in recipe_ids + assert recipe_2.id not in recipe_ids + assert recipe_1_2.id in recipe_ids + def test_pagination_filter_like(query_units: tuple[RepositoryUnit, IngredientUnit, IngredientUnit, IngredientUnit]): units_repo, unit_1, unit_2, unit_3 = query_units