diff --git a/backend/api/models.py b/backend/api/models.py index f5490562..a6ea556e 100644 --- a/backend/api/models.py +++ b/backend/api/models.py @@ -282,18 +282,29 @@ class Span(Annotation): start_offset = models.IntegerField() end_offset = models.IntegerField() - def clean(self): - if self.start_offset >= self.end_offset: - raise ValidationError('start_offset > end_offset') + def validate_unique(self, exclude=None): + allow_overlapping = getattr(self.example.project, 'allow_overlapping', False) + if allow_overlapping: + super().validate_unique(exclude=exclude) + else: + if Span.objects.filter(example=self.example).filter( + models.Q(start_offset__gte=self.start_offset, start_offset__lt=self.end_offset) | + models.Q(end_offset__gt=self.start_offset, end_offset__lte=self.end_offset) | + models.Q(start_offset__lte=self.start_offset, end_offset__gte=self.end_offset) + ).exists(): + raise ValidationError('This overlapping is not allowed in this project.') + + def save(self, force_insert=False, force_update=False, using=None, + update_fields=None): + self.full_clean() + super().save(force_insert, force_update, using, update_fields) class Meta: - unique_together = ( - 'example', - 'user', - 'label', - 'start_offset', - 'end_offset' - ) + constraints = [ + models.CheckConstraint(check=models.Q(start_offset__gte=0), name='startOffset >= 0'), + models.CheckConstraint(check=models.Q(end_offset__gte=0), name='endOffset >= 0'), + models.CheckConstraint(check=models.Q(start_offset__lt=models.F('end_offset')), name='start < end') + ] class TextLabel(Annotation): diff --git a/backend/api/tasks.py b/backend/api/tasks.py index 09c35d71..daf668c0 100644 --- a/backend/api/tasks.py +++ b/backend/api/tasks.py @@ -11,8 +11,8 @@ from .models import Example, Label, Project from .views.download.factory import create_repository, create_writer from .views.download.service import ExportApplicationService from .views.upload.exception import FileParseException, FileParseExceptions -from .views.upload.factory import (get_data_class, get_dataset_class, - get_label_class) +from .views.upload.factory import (create_cleaner, get_data_class, + get_dataset_class, get_label_class) from .views.upload.utils import append_field logger = get_task_logger(__name__) @@ -109,6 +109,7 @@ def ingest_data(user_id, project_id, filenames, format: str, **kwargs): label_class=Label, annotation_class=project.get_annotation_class() ) + cleaner = create_cleaner(project) while True: try: example = next(it) @@ -120,6 +121,10 @@ def ingest_data(user_id, project_id, filenames, format: str, **kwargs): except FileParseExceptions as err: response['error'].extend(list(err)) continue + try: + example.clean(cleaner) + except FileParseException as err: + response['error'].append(err.dict()) buffer.add(example) if buffer.is_full(): diff --git a/backend/api/tests/data/sequence_labeling/example_overlapping.jsonl b/backend/api/tests/data/sequence_labeling/example_overlapping.jsonl new file mode 100644 index 00000000..f9076195 --- /dev/null +++ b/backend/api/tests/data/sequence_labeling/example_overlapping.jsonl @@ -0,0 +1 @@ +{"text": "exampleA", "label": [[0, 1, "LOC"], [0, 1, "LOC"]], "meta": {"wikiPageID": 1}} diff --git a/backend/api/tests/test_models.py b/backend/api/tests/test_models.py index c7ac049a..be29e4eb 100644 --- a/backend/api/tests/test_models.py +++ b/backend/api/tests/test_models.py @@ -119,19 +119,36 @@ class TestCategory(TestCase): class TestSequenceAnnotation(TestCase): - def test_uniqueness(self): - a = mommy.make('Span') + def test_start_offset_is_not_negative(self): + with self.assertRaises(IntegrityError): + mommy.make('Span', start_offset=-1, end_offset=0) + + def test_end_offset_is_not_negative(self): + with self.assertRaises(IntegrityError): + mommy.make('Span', start_offset=-2, end_offset=-1) + + def test_start_offset_is_less_than_end_offset(self): with self.assertRaises(IntegrityError): - Span(example=a.example, - user=a.user, - label=a.label, - start_offset=a.start_offset, - end_offset=a.end_offset).save() + mommy.make('Span', start_offset=0, end_offset=0) - def test_position_constraint(self): + def test_overlapping(self): + project = mommy.make('SequenceLabelingProject', allow_overlapping=False) + example = mommy.make('Example', project=project) + mommy.make('Span', example=example, start_offset=5, end_offset=10) + with self.assertRaises(ValidationError): + mommy.make('Span', example=example, start_offset=5, end_offset=10) + with self.assertRaises(ValidationError): + mommy.make('Span', example=example, start_offset=5, end_offset=11) + with self.assertRaises(ValidationError): + mommy.make('Span', example=example, start_offset=4, end_offset=10) + with self.assertRaises(ValidationError): + mommy.make('Span', example=example, start_offset=6, end_offset=9) + with self.assertRaises(ValidationError): + mommy.make('Span', example=example, start_offset=9, end_offset=15) with self.assertRaises(ValidationError): - mommy.make('Span', - start_offset=1, end_offset=0).clean() + mommy.make('Span', example=example, start_offset=0, end_offset=6) + mommy.make('Span', example=example, start_offset=0, end_offset=5) + mommy.make('Span', example=example, start_offset=10, end_offset=15) class TestSeq2seqAnnotation(TestCase): diff --git a/backend/api/tests/test_tasks.py b/backend/api/tests/test_tasks.py index 0e08928d..f8672491 100644 --- a/backend/api/tests/test_tasks.py +++ b/backend/api/tests/test_tasks.py @@ -190,6 +190,12 @@ class TestIngestSequenceLabelingData(TestIngestData): response = self.ingest_data(filename, file_format) self.assert_parse_error(response) + def test_jsonl_with_overlapping(self): + filename = 'sequence_labeling/example_overlapping.jsonl' + file_format = 'JSONL' + response = self.ingest_data(filename, file_format) + self.assertEqual(len(response['error']), 1) + class TestIngestSeq2seqData(TestIngestData): task = SEQ2SEQ diff --git a/backend/api/views/upload/cleaners.py b/backend/api/views/upload/cleaners.py new file mode 100644 index 00000000..e19fa981 --- /dev/null +++ b/backend/api/views/upload/cleaners.py @@ -0,0 +1,46 @@ +from typing import List + +from ...models import Project +from .label import CategoryLabel, Label, OffsetLabel + + +class Cleaner: + + def __init__(self, project: Project): + pass + + def clean(self, labels: List[Label]) -> List[Label]: + return labels + + +class SpanCleaner(Cleaner): + + def __init__(self, project: Project): + super().__init__(project) + self.allow_overlapping = getattr(project, 'allow_overlapping', False) + + def clean(self, labels: List[OffsetLabel]) -> List[OffsetLabel]: + if self.allow_overlapping: + return labels + + labels.sort(key=lambda label: label.start_offset) + last_offset = -1 + new_labels = [] + for label in labels: + if label.start_offset >= last_offset: + last_offset = label.end_offset + new_labels.append(label) + return new_labels + + +class CategoryCleaner(Cleaner): + + def __init__(self, project: Project): + super().__init__(project) + self.exclusive = getattr(project, 'single_class_classification', False) + + def clean(self, labels: List[CategoryLabel]) -> List[CategoryLabel]: + if self.exclusive: + return labels[:1] + else: + return labels diff --git a/backend/api/views/upload/dataset.py b/backend/api/views/upload/dataset.py index bd5e33fd..364ffaaf 100644 --- a/backend/api/views/upload/dataset.py +++ b/backend/api/views/upload/dataset.py @@ -11,6 +11,7 @@ from chardet.universaldetector import UniversalDetector from pydantic import ValidationError from seqeval.scheme import BILOU, IOB2, IOBES, IOE2, Tokens +from .cleaners import Cleaner from .data import BaseData from .exception import FileParseException, FileParseExceptions from .label import Label @@ -30,6 +31,18 @@ class Record: def __str__(self): return f'{self._data}\t{self._label}' + def clean(self, cleaner: Cleaner): + label = cleaner.clean(self._label) + changed = len(label) != len(self.label) + self._label = label + if changed: + message = 'There are invalid labels. It\'s cleaned.' + raise FileParseException( + filename=self._data.filename, + line_num=-1, + message=message + ) + @property def data(self): return self._data.dict() diff --git a/backend/api/views/upload/factory.py b/backend/api/views/upload/factory.py index 9453943d..b09bd716 100644 --- a/backend/api/views/upload/factory.py +++ b/backend/api/views/upload/factory.py @@ -1,6 +1,6 @@ from ...models import (DOCUMENT_CLASSIFICATION, IMAGE_CLASSIFICATION, SEQ2SEQ, SEQUENCE_LABELING, SPEECH2TEXT) -from . import catalog, data, dataset, label +from . import catalog, cleaners, data, dataset, label def get_data_class(project_type: str): @@ -40,3 +40,15 @@ def get_label_class(project_type: str): if project_type not in mapping: ValueError(f'Invalid project type: {project_type}') return mapping[project_type] + + +def create_cleaner(project): + mapping = { + DOCUMENT_CLASSIFICATION: cleaners.CategoryCleaner, + SEQUENCE_LABELING: cleaners.SpanCleaner, + IMAGE_CLASSIFICATION: cleaners.CategoryCleaner + } + if project.project_type not in mapping: + ValueError(f'Invalid project type: {project.project_type}') + cleaner_class = mapping.get(project.project_type, cleaners.Cleaner) + return cleaner_class(project)