diff --git a/app/server/api.py b/app/server/api.py index 34e8b742..d798cc71 100644 --- a/app/server/api.py +++ b/app/server/api.py @@ -16,6 +16,7 @@ from rest_framework.views import APIView from rest_framework.parsers import MultiPartParser from .exceptions import FileParseException +from .filters import DocumentFilter from .models import Project, Label, Document from .models import SequenceAnnotation from .permissions import IsAdminUserAndWriteOnly, IsProjectUser, IsMyEntity, IsOwnAnnotation @@ -98,8 +99,8 @@ class DocumentList(generics.ListCreateAPIView): filter_backends = (DjangoFilterBackend, filters.SearchFilter, filters.OrderingFilter) search_fields = ('text', ) ordering_fields = ('created_at', 'updated_at', 'doc_annotations__updated_at', - 'seq_annotations__updated_at') - filter_fields = ('doc_annotations__label__id', 'seq_annotations__label__id') + 'seq_annotations__updated_at', 'seq2seq_annotations__updated_at') + filter_class = DocumentFilter permission_classes = (IsAuthenticated, IsProjectUser, IsAdminUserAndWriteOnly) def get_queryset(self): diff --git a/app/server/filters.py b/app/server/filters.py new file mode 100644 index 00000000..2b41169c --- /dev/null +++ b/app/server/filters.py @@ -0,0 +1,14 @@ +from django_filters.rest_framework import FilterSet, BooleanFilter +from .models import Document + + +class DocumentFilter(FilterSet): + seq_annotations__isnull = BooleanFilter(field_name='seq_annotations', lookup_expr='isnull') + doc_annotations__isnull = BooleanFilter(field_name='doc_annotations', lookup_expr='isnull') + seq2seq_annotations__isnull = BooleanFilter(field_name='seq2seq_annotations', lookup_expr='isnull') + + class Meta: + model = Document + fields = ('project', 'text', 'metadata', 'created_at', 'updated_at', + 'doc_annotations__label__id', 'seq_annotations__label__id', + 'doc_annotations__isnull', 'seq_annotations__isnull', 'seq2seq_annotations__isnull') diff --git a/app/server/tests/test_api.py b/app/server/tests/test_api.py index bd7da7b2..b2fe952a 100644 --- a/app/server/tests/test_api.py +++ b/app/server/tests/test_api.py @@ -758,11 +758,12 @@ class TestFilter(APITestCase): cls.project_member_pass = 'project_member_pass' project_member = User.objects.create_user(username=cls.project_member_name, password=cls.project_member_pass) - cls.main_project = mommy.make('server.TextClassificationProject', users=[project_member]) + cls.main_project = mommy.make('server.SequenceLabelingProject', users=[project_member]) cls.label1 = mommy.make('server.Label', project=cls.main_project) cls.label2 = mommy.make('server.Label', project=cls.main_project) doc1 = mommy.make('server.Document', project=cls.main_project) doc2 = mommy.make('server.Document', project=cls.main_project) + doc3 = mommy.make('server.Document', project=cls.main_project) mommy.make('server.SequenceAnnotation', document=doc1, user=project_member, label=cls.label1) mommy.make('server.SequenceAnnotation', document=doc2, user=project_member, label=cls.label2) cls.url = reverse(viewname='doc_list', args=[cls.main_project.id]) @@ -777,6 +778,26 @@ class TestFilter(APITestCase): for d1, d2 in zip(response.data['results'], docs): self.assertEqual(d1['id'], d2['id']) + def test_can_filter_doc_with_annotation(self): + params = {'seq_annotations__isnull': False} + self.client.login(username=self.project_member_name, + password=self.project_member_pass) + response = self.client.get(self.url, format='json', data=params) + docs = Document.objects.filter(project=self.main_project, seq_annotations__isnull=False).values() + self.assertEqual(response.data['count'], docs.count()) + for d1, d2 in zip(response.data['results'], docs): + self.assertEqual(d1['id'], d2['id']) + + def test_can_filter_doc_without_anotation(self): + params = {'seq_annotations__isnull': True} + self.client.login(username=self.project_member_name, + password=self.project_member_pass) + response = self.client.get(self.url, format='json', data=params) + docs = Document.objects.filter(project=self.main_project, seq_annotations__isnull=True).values() + self.assertEqual(response.data['count'], docs.count()) + for d1, d2 in zip(response.data['results'], docs): + self.assertEqual(d1['id'], d2['id']) + class TestUploader(APITestCase):