Browse Source

Update Span model

pull/1568/head
Hironsan 2 years ago
parent
commit
2a325904ec
8 changed files with 134 additions and 23 deletions
  1. 31
      backend/api/models.py
  2. 9
      backend/api/tasks.py
  3. 1
      backend/api/tests/data/sequence_labeling/example_overlapping.jsonl
  4. 37
      backend/api/tests/test_models.py
  5. 6
      backend/api/tests/test_tasks.py
  6. 46
      backend/api/views/upload/cleaners.py
  7. 13
      backend/api/views/upload/dataset.py
  8. 14
      backend/api/views/upload/factory.py

31
backend/api/models.py

@ -282,18 +282,29 @@ class Span(Annotation):
start_offset = models.IntegerField() start_offset = models.IntegerField()
end_offset = models.IntegerField() end_offset = models.IntegerField()
def clean(self):
if self.start_offset >= self.end_offset:
raise ValidationError('start_offset > end_offset')
def validate_unique(self, exclude=None):
allow_overlapping = getattr(self.example.project, 'allow_overlapping', False)
if allow_overlapping:
super().validate_unique(exclude=exclude)
else:
if Span.objects.filter(example=self.example).filter(
models.Q(start_offset__gte=self.start_offset, start_offset__lt=self.end_offset) |
models.Q(end_offset__gt=self.start_offset, end_offset__lte=self.end_offset) |
models.Q(start_offset__lte=self.start_offset, end_offset__gte=self.end_offset)
).exists():
raise ValidationError('This overlapping is not allowed in this project.')
def save(self, force_insert=False, force_update=False, using=None,
update_fields=None):
self.full_clean()
super().save(force_insert, force_update, using, update_fields)
class Meta: class Meta:
unique_together = (
'example',
'user',
'label',
'start_offset',
'end_offset'
)
constraints = [
models.CheckConstraint(check=models.Q(start_offset__gte=0), name='startOffset >= 0'),
models.CheckConstraint(check=models.Q(end_offset__gte=0), name='endOffset >= 0'),
models.CheckConstraint(check=models.Q(start_offset__lt=models.F('end_offset')), name='start < end')
]
class TextLabel(Annotation): class TextLabel(Annotation):

9
backend/api/tasks.py

@ -11,8 +11,8 @@ from .models import Example, Label, Project
from .views.download.factory import create_repository, create_writer from .views.download.factory import create_repository, create_writer
from .views.download.service import ExportApplicationService from .views.download.service import ExportApplicationService
from .views.upload.exception import FileParseException, FileParseExceptions from .views.upload.exception import FileParseException, FileParseExceptions
from .views.upload.factory import (get_data_class, get_dataset_class,
get_label_class)
from .views.upload.factory import (create_cleaner, get_data_class,
get_dataset_class, get_label_class)
from .views.upload.utils import append_field from .views.upload.utils import append_field
logger = get_task_logger(__name__) logger = get_task_logger(__name__)
@ -109,6 +109,7 @@ def ingest_data(user_id, project_id, filenames, format: str, **kwargs):
label_class=Label, label_class=Label,
annotation_class=project.get_annotation_class() annotation_class=project.get_annotation_class()
) )
cleaner = create_cleaner(project)
while True: while True:
try: try:
example = next(it) example = next(it)
@ -120,6 +121,10 @@ def ingest_data(user_id, project_id, filenames, format: str, **kwargs):
except FileParseExceptions as err: except FileParseExceptions as err:
response['error'].extend(list(err)) response['error'].extend(list(err))
continue continue
try:
example.clean(cleaner)
except FileParseException as err:
response['error'].append(err.dict())
buffer.add(example) buffer.add(example)
if buffer.is_full(): if buffer.is_full():

1
backend/api/tests/data/sequence_labeling/example_overlapping.jsonl

@ -0,0 +1 @@
{"text": "exampleA", "label": [[0, 1, "LOC"], [0, 1, "LOC"]], "meta": {"wikiPageID": 1}}

37
backend/api/tests/test_models.py

@ -119,19 +119,36 @@ class TestCategory(TestCase):
class TestSequenceAnnotation(TestCase): class TestSequenceAnnotation(TestCase):
def test_uniqueness(self):
a = mommy.make('Span')
def test_start_offset_is_not_negative(self):
with self.assertRaises(IntegrityError):
mommy.make('Span', start_offset=-1, end_offset=0)
def test_end_offset_is_not_negative(self):
with self.assertRaises(IntegrityError):
mommy.make('Span', start_offset=-2, end_offset=-1)
def test_start_offset_is_less_than_end_offset(self):
with self.assertRaises(IntegrityError): with self.assertRaises(IntegrityError):
Span(example=a.example,
user=a.user,
label=a.label,
start_offset=a.start_offset,
end_offset=a.end_offset).save()
mommy.make('Span', start_offset=0, end_offset=0)
def test_position_constraint(self):
def test_overlapping(self):
project = mommy.make('SequenceLabelingProject', allow_overlapping=False)
example = mommy.make('Example', project=project)
mommy.make('Span', example=example, start_offset=5, end_offset=10)
with self.assertRaises(ValidationError):
mommy.make('Span', example=example, start_offset=5, end_offset=10)
with self.assertRaises(ValidationError):
mommy.make('Span', example=example, start_offset=5, end_offset=11)
with self.assertRaises(ValidationError):
mommy.make('Span', example=example, start_offset=4, end_offset=10)
with self.assertRaises(ValidationError):
mommy.make('Span', example=example, start_offset=6, end_offset=9)
with self.assertRaises(ValidationError):
mommy.make('Span', example=example, start_offset=9, end_offset=15)
with self.assertRaises(ValidationError): with self.assertRaises(ValidationError):
mommy.make('Span',
start_offset=1, end_offset=0).clean()
mommy.make('Span', example=example, start_offset=0, end_offset=6)
mommy.make('Span', example=example, start_offset=0, end_offset=5)
mommy.make('Span', example=example, start_offset=10, end_offset=15)
class TestSeq2seqAnnotation(TestCase): class TestSeq2seqAnnotation(TestCase):

6
backend/api/tests/test_tasks.py

@ -190,6 +190,12 @@ class TestIngestSequenceLabelingData(TestIngestData):
response = self.ingest_data(filename, file_format) response = self.ingest_data(filename, file_format)
self.assert_parse_error(response) self.assert_parse_error(response)
def test_jsonl_with_overlapping(self):
filename = 'sequence_labeling/example_overlapping.jsonl'
file_format = 'JSONL'
response = self.ingest_data(filename, file_format)
self.assertEqual(len(response['error']), 1)
class TestIngestSeq2seqData(TestIngestData): class TestIngestSeq2seqData(TestIngestData):
task = SEQ2SEQ task = SEQ2SEQ

46
backend/api/views/upload/cleaners.py

@ -0,0 +1,46 @@
from typing import List
from ...models import Project
from .label import CategoryLabel, Label, OffsetLabel
class Cleaner:
def __init__(self, project: Project):
pass
def clean(self, labels: List[Label]) -> List[Label]:
return labels
class SpanCleaner(Cleaner):
def __init__(self, project: Project):
super().__init__(project)
self.allow_overlapping = getattr(project, 'allow_overlapping', False)
def clean(self, labels: List[OffsetLabel]) -> List[OffsetLabel]:
if self.allow_overlapping:
return labels
labels.sort(key=lambda label: label.start_offset)
last_offset = -1
new_labels = []
for label in labels:
if label.start_offset >= last_offset:
last_offset = label.end_offset
new_labels.append(label)
return new_labels
class CategoryCleaner(Cleaner):
def __init__(self, project: Project):
super().__init__(project)
self.exclusive = getattr(project, 'single_class_classification', False)
def clean(self, labels: List[CategoryLabel]) -> List[CategoryLabel]:
if self.exclusive:
return labels[:1]
else:
return labels

13
backend/api/views/upload/dataset.py

@ -11,6 +11,7 @@ from chardet.universaldetector import UniversalDetector
from pydantic import ValidationError from pydantic import ValidationError
from seqeval.scheme import BILOU, IOB2, IOBES, IOE2, Tokens from seqeval.scheme import BILOU, IOB2, IOBES, IOE2, Tokens
from .cleaners import Cleaner
from .data import BaseData from .data import BaseData
from .exception import FileParseException, FileParseExceptions from .exception import FileParseException, FileParseExceptions
from .label import Label from .label import Label
@ -30,6 +31,18 @@ class Record:
def __str__(self): def __str__(self):
return f'{self._data}\t{self._label}' return f'{self._data}\t{self._label}'
def clean(self, cleaner: Cleaner):
label = cleaner.clean(self._label)
changed = len(label) != len(self.label)
self._label = label
if changed:
message = 'There are invalid labels. It\'s cleaned.'
raise FileParseException(
filename=self._data.filename,
line_num=-1,
message=message
)
@property @property
def data(self): def data(self):
return self._data.dict() return self._data.dict()

14
backend/api/views/upload/factory.py

@ -1,6 +1,6 @@
from ...models import (DOCUMENT_CLASSIFICATION, IMAGE_CLASSIFICATION, SEQ2SEQ, from ...models import (DOCUMENT_CLASSIFICATION, IMAGE_CLASSIFICATION, SEQ2SEQ,
SEQUENCE_LABELING, SPEECH2TEXT) SEQUENCE_LABELING, SPEECH2TEXT)
from . import catalog, data, dataset, label
from . import catalog, cleaners, data, dataset, label
def get_data_class(project_type: str): def get_data_class(project_type: str):
@ -40,3 +40,15 @@ def get_label_class(project_type: str):
if project_type not in mapping: if project_type not in mapping:
ValueError(f'Invalid project type: {project_type}') ValueError(f'Invalid project type: {project_type}')
return mapping[project_type] return mapping[project_type]
def create_cleaner(project):
mapping = {
DOCUMENT_CLASSIFICATION: cleaners.CategoryCleaner,
SEQUENCE_LABELING: cleaners.SpanCleaner,
IMAGE_CLASSIFICATION: cleaners.CategoryCleaner
}
if project.project_type not in mapping:
ValueError(f'Invalid project type: {project.project_type}')
cleaner_class = mapping.get(project.project_type, cleaners.Cleaner)
return cleaner_class(project)
Loading…
Cancel
Save