From a4e57f65b6e8d11a62dc3db4e6947c6de8a6b600 Mon Sep 17 00:00:00 2001 From: Hironsan Date: Fri, 21 Jan 2022 16:57:06 +0900 Subject: [PATCH] Add data property to Example --- backend/api/models.py | 7 +++++++ backend/api/tests/test_models.py | 17 +++++++++++++++-- backend/auto_labeling/views.py | 11 ++--------- 3 files changed, 24 insertions(+), 11 deletions(-) diff --git a/backend/api/models.py b/backend/api/models.py index 8ecae2b0..9d878f40 100644 --- a/backend/api/models.py +++ b/backend/api/models.py @@ -257,6 +257,13 @@ class Example(models.Model): def comment_count(self): return Comment.objects.filter(example=self.id).count() + @property + def data(self): + if self.project.is_text_project: + return self.text + else: + return str(self.filename) + def is_labeled(self, is_collaborative, user): if is_collaborative: for model in Annotation.__subclasses__(): diff --git a/backend/api/tests/test_models.py b/backend/api/tests/test_models.py index 8cda0bc8..8b73f27c 100644 --- a/backend/api/tests/test_models.py +++ b/backend/api/tests/test_models.py @@ -3,8 +3,8 @@ from django.db.utils import IntegrityError from django.test import TestCase from model_mommy import mommy -from api.models import (SEQUENCE_LABELING, Category, CategoryType, - ExampleState, Span, SpanType, TextLabel, +from api.models import (IMAGE_CLASSIFICATION, SEQUENCE_LABELING, Category, + CategoryType, ExampleState, Span, SpanType, TextLabel, generate_random_hex_color) from .api.utils import prepare_project @@ -240,3 +240,16 @@ class TestLabelDistribution(TestCase): expected[self.user.username][label_a.text] = 1 expected[self.user.username][label_b.text] = 1 self.assertEqual(distribution, expected) + + +class TestExample(TestCase): + + def test_text_project_returns_text_as_data_property(self): + project = prepare_project(SEQUENCE_LABELING) + example = mommy.make('Example', project=project.item) + self.assertEqual(example.text, example.data) + + def test_image_project_returns_filename_as_data_property(self): + project = prepare_project(IMAGE_CLASSIFICATION) + example = mommy.make('Example', project=project.item) + self.assertEqual(str(example.filename), example.data) diff --git a/backend/auto_labeling/views.py b/backend/auto_labeling/views.py index 96d27633..d1f79755 100644 --- a/backend/auto_labeling/views.py +++ b/backend/auto_labeling/views.py @@ -215,21 +215,14 @@ class AutomatedDataLabeling(generics.CreateAPIView): def perform_create(self, serializer): serializer.save(user=self.request.user) - def get_example(self, project): - example = get_object_or_404(Example, pk=self.kwargs['example_id']) - if project.is_text_project: - return example.text - else: - return str(example.filename) - def extract(self): project = get_object_or_404(Project, pk=self.kwargs['project_id']) - example = self.get_example(project) + example = get_object_or_404(Example, pk=self.kwargs['example_id']) config = project.auto_labeling_config.first() if not config: raise AutoLabelingPermissionDenied() return execute_pipeline( - text=example, + text=example.data, project_type=project.project_type, model_name=config.model_name, model_attrs=config.model_attrs,