|
|
@ -1,10 +1,9 @@ |
|
|
|
import itertools |
|
|
|
from collections import defaultdict |
|
|
|
from typing import Any, Dict, List, Type |
|
|
|
from typing import List, Type |
|
|
|
|
|
|
|
from django.conf import settings |
|
|
|
|
|
|
|
from .exceptions import FileParseException |
|
|
|
from .readers import BaseReader, Record |
|
|
|
from examples.models import Example |
|
|
|
from label_types.models import CategoryType, LabelType, SpanType |
|
|
@ -45,19 +44,14 @@ class Examples: |
|
|
|
class Writer: |
|
|
|
def __init__(self, batch_size: int): |
|
|
|
self.examples = Examples(batch_size) |
|
|
|
self._errors: List[FileParseException] = [] |
|
|
|
|
|
|
|
def save(self, reader: BaseReader, project: Project, user, cleaner): |
|
|
|
def save(self, reader: BaseReader, project: Project, user): |
|
|
|
it = iter(reader) |
|
|
|
while True: |
|
|
|
try: |
|
|
|
example = next(it) |
|
|
|
except StopIteration: |
|
|
|
break |
|
|
|
try: |
|
|
|
example.clean(cleaner) |
|
|
|
except FileParseException as err: |
|
|
|
self._errors.append(err) |
|
|
|
|
|
|
|
self.examples.add(example) |
|
|
|
if self.examples.is_full(): |
|
|
@ -66,36 +60,30 @@ class Writer: |
|
|
|
if not self.examples.is_empty(): |
|
|
|
self.create(project, user) |
|
|
|
self.examples.clear() |
|
|
|
self._errors.extend(reader.errors) |
|
|
|
|
|
|
|
@property |
|
|
|
def errors(self) -> List[Dict[Any, Any]]: |
|
|
|
self._errors.sort(key=lambda e: e.line_num) |
|
|
|
return [error.dict() for error in self._errors] |
|
|
|
|
|
|
|
def create(self, project: Project, user): |
|
|
|
self.save_label(project) |
|
|
|
ids = self.save_data(project) |
|
|
|
self.save_annotation(project, user, ids) |
|
|
|
self.save_label_type(project) |
|
|
|
ids = self.save_example(project) |
|
|
|
self.save_label(project, user, ids) |
|
|
|
|
|
|
|
def save_label(self, project: Project): |
|
|
|
def save_label_type(self, project: Project): |
|
|
|
labels = list(itertools.chain.from_iterable([example.create_label(project) for example in self.examples])) |
|
|
|
labels = list(filter(None, labels)) |
|
|
|
groups = group_by_class(labels) |
|
|
|
for klass, instances in groups.items(): |
|
|
|
klass.objects.bulk_create(instances, ignore_conflicts=True) |
|
|
|
|
|
|
|
def save_data(self, project: Project) -> List[Example]: |
|
|
|
def save_example(self, project: Project) -> List[Example]: |
|
|
|
examples = [example.create_data(project) for example in self.examples] |
|
|
|
return Example.objects.bulk_create(examples) |
|
|
|
|
|
|
|
def save_annotation(self, project: Project, user, examples): |
|
|
|
# Todo: move annotation class |
|
|
|
def save_label(self, project: Project, user, examples): |
|
|
|
mapping = {} |
|
|
|
label_types: List[Type[LabelType]] = [CategoryType, SpanType] |
|
|
|
for model in label_types: |
|
|
|
for label in model.objects.filter(project=project): |
|
|
|
mapping[label.text] = label |
|
|
|
|
|
|
|
annotations = list( |
|
|
|
itertools.chain.from_iterable( |
|
|
|
[data.create_annotation(user, example, mapping) for data, example in zip(self.examples, examples)] |
|
|
|