Browse Source

Add data property to Example

pull/1650/head
Hironsan 2 years ago
parent
commit
a4e57f65b6
3 changed files with 24 additions and 11 deletions
  1. 7
      backend/api/models.py
  2. 17
      backend/api/tests/test_models.py
  3. 11
      backend/auto_labeling/views.py

7
backend/api/models.py

@ -257,6 +257,13 @@ class Example(models.Model):
def comment_count(self):
return Comment.objects.filter(example=self.id).count()
@property
def data(self):
if self.project.is_text_project:
return self.text
else:
return str(self.filename)
def is_labeled(self, is_collaborative, user):
if is_collaborative:
for model in Annotation.__subclasses__():

17
backend/api/tests/test_models.py

@ -3,8 +3,8 @@ from django.db.utils import IntegrityError
from django.test import TestCase
from model_mommy import mommy
from api.models import (SEQUENCE_LABELING, Category, CategoryType,
ExampleState, Span, SpanType, TextLabel,
from api.models import (IMAGE_CLASSIFICATION, SEQUENCE_LABELING, Category,
CategoryType, ExampleState, Span, SpanType, TextLabel,
generate_random_hex_color)
from .api.utils import prepare_project
@ -240,3 +240,16 @@ class TestLabelDistribution(TestCase):
expected[self.user.username][label_a.text] = 1
expected[self.user.username][label_b.text] = 1
self.assertEqual(distribution, expected)
class TestExample(TestCase):
def test_text_project_returns_text_as_data_property(self):
project = prepare_project(SEQUENCE_LABELING)
example = mommy.make('Example', project=project.item)
self.assertEqual(example.text, example.data)
def test_image_project_returns_filename_as_data_property(self):
project = prepare_project(IMAGE_CLASSIFICATION)
example = mommy.make('Example', project=project.item)
self.assertEqual(str(example.filename), example.data)

11
backend/auto_labeling/views.py

@ -215,21 +215,14 @@ class AutomatedDataLabeling(generics.CreateAPIView):
def perform_create(self, serializer):
serializer.save(user=self.request.user)
def get_example(self, project):
example = get_object_or_404(Example, pk=self.kwargs['example_id'])
if project.is_text_project:
return example.text
else:
return str(example.filename)
def extract(self):
project = get_object_or_404(Project, pk=self.kwargs['project_id'])
example = self.get_example(project)
example = get_object_or_404(Example, pk=self.kwargs['example_id'])
config = project.auto_labeling_config.first()
if not config:
raise AutoLabelingPermissionDenied()
return execute_pipeline(
text=example,
text=example.data,
project_type=project.project_type,
model_name=config.model_name,
model_attrs=config.model_attrs,

Loading…
Cancel
Save