diff --git a/backend/examples/filters.py b/backend/examples/filters.py index fceaba12..120361ff 100644 --- a/backend/examples/filters.py +++ b/backend/examples/filters.py @@ -1,11 +1,12 @@ -from django.db.models import Count, Q -from django_filters.rest_framework import BooleanFilter, FilterSet +from django.db.models import Count, Q, QuerySet +from django_filters.rest_framework import BooleanFilter, CharFilter, FilterSet from .models import Example class ExampleFilter(FilterSet): confirmed = BooleanFilter(field_name="states", method="filter_by_state") + label = CharFilter(method="filter_by_label") def filter_by_state(self, queryset, field_name, is_confirmed: bool): queryset = queryset.annotate( @@ -21,6 +22,35 @@ class ExampleFilter(FilterSet): queryset = queryset.filter(num_confirm__lte=0) return queryset + def filter_by_label(self, queryset: QuerySet, field_name: str, label: str) -> QuerySet: + """Filter examples by a given label name. + + This performs filtering on all of the following labels at once: + - categories + - spans + - relations + - bboxes + - segmentations + + Todo: Consider project type to make filtering more efficient. + + Args: + queryset (QuerySet): QuerySet to filter. + field_name (str): This equals to `label`. + label (str): The label name to filter. + + Returns: + QuerySet: Filtered examples. + """ + queryset = queryset.filter( + Q(categories__label__text=label) + | Q(spans__label__text=label) + | Q(relations__type__text=label) + | Q(bboxes__label__text=label) + | Q(segmentations__label__text=label) + ) + return queryset + class Meta: model = Example - fields = ("project", "text", "created_at", "updated_at") + fields = ("project", "text", "created_at", "updated_at", "label") diff --git a/backend/examples/tests/test_filters.py b/backend/examples/tests/test_filters.py index 2d0a6df0..a6ae2e97 100644 --- a/backend/examples/tests/test_filters.py +++ b/backend/examples/tests/test_filters.py @@ -1,10 +1,12 @@ from unittest.mock import MagicMock from django.test import TestCase +from model_mommy import mommy from .utils import make_doc, make_example_state from examples.filters import ExampleFilter from examples.models import Example +from projects.models import ProjectType from projects.tests.utils import prepare_project @@ -48,6 +50,17 @@ class TestExampleFilter(TestFilterMixin): self.assert_filter(data={"confirmed": ""}, expected=1) +class TestLabelFilter(TestFilterMixin): + def setUp(self): + self.project = prepare_project(task=ProjectType.DOCUMENT_CLASSIFICATION) + self.prepare(project=self.project) + self.label_type = mommy.make("CategoryType", project=self.project.item, text="positive") + mommy.make("Category", example=self.example, label=self.label_type) + + def test_returns_example_with_positive_label(self): + self.assert_filter(data={"label": self.label_type.text}, expected=1) + + class TestExampleFilterOnCollaborative(TestFilterMixin): def setUp(self): self.project = prepare_project(task="DocumentClassification", collaborative_annotation=True)