mirror of https://github.com/doccano/doccano.git
pythonannotation-tooldatasetsactive-learningtext-annotationdatasetnatural-language-processingdata-labelingmachine-learning
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
119 lines
3.8 KiB
119 lines
3.8 KiB
import abc
|
|
import itertools
|
|
from collections import defaultdict
|
|
from typing import Any, Dict, List, Type
|
|
|
|
from django.conf import settings
|
|
|
|
from .exceptions import FileParseException
|
|
from .readers import BaseReader, Record
|
|
from examples.models import Example
|
|
from label_types.models import CategoryType, LabelType, SpanType
|
|
from projects.models import Project
|
|
|
|
|
|
class Writer(abc.ABC):
|
|
@abc.abstractmethod
|
|
def save(self, reader: BaseReader, project: Project, user, cleaner):
|
|
"""Save the read contents to DB."""
|
|
raise NotImplementedError("Please implement this method in the subclass.")
|
|
|
|
def errors(self) -> List[Dict[Any, Any]]:
|
|
"""Return errors."""
|
|
raise NotImplementedError("Please implement this method in the subclass.")
|
|
|
|
|
|
def group_by_class(instances):
|
|
groups = defaultdict(list)
|
|
for instance in instances:
|
|
groups[instance.__class__].append(instance)
|
|
return groups
|
|
|
|
|
|
class Examples:
|
|
def __init__(self, buffer_size: int = settings.IMPORT_BATCH_SIZE):
|
|
self.buffer_size = buffer_size
|
|
self.buffer: List[Record] = []
|
|
|
|
def __len__(self):
|
|
return len(self.buffer)
|
|
|
|
@property
|
|
def data(self):
|
|
return self.buffer
|
|
|
|
def add(self, data):
|
|
self.buffer.append(data)
|
|
|
|
def clear(self):
|
|
self.buffer = []
|
|
|
|
def is_full(self):
|
|
return len(self) >= self.buffer_size
|
|
|
|
def is_empty(self):
|
|
return len(self) == 0
|
|
|
|
def save_label(self, project: Project):
|
|
labels = list(itertools.chain.from_iterable([example.create_label(project) for example in self.buffer]))
|
|
labels = list(filter(None, labels))
|
|
groups = group_by_class(labels)
|
|
for klass, instances in groups.items():
|
|
klass.objects.bulk_create(instances, ignore_conflicts=True)
|
|
|
|
def save_data(self, project: Project) -> List[Example]:
|
|
examples = [example.create_data(project) for example in self.buffer]
|
|
return Example.objects.bulk_create(examples)
|
|
|
|
def save_annotation(self, project: Project, user, examples):
|
|
# Todo: move annotation class
|
|
mapping = {}
|
|
label_types: List[Type[LabelType]] = [CategoryType, SpanType]
|
|
for model in label_types:
|
|
for label in model.objects.all():
|
|
mapping[label.text] = label
|
|
annotations = list(
|
|
itertools.chain.from_iterable(
|
|
[data.create_annotation(user, example, mapping) for data, example in zip(self.buffer, examples)]
|
|
)
|
|
)
|
|
groups = group_by_class(annotations)
|
|
for klass, instances in groups.items():
|
|
klass.objects.bulk_create(instances)
|
|
|
|
|
|
class BulkWriter(Writer):
|
|
def __init__(self, batch_size: int):
|
|
self.examples = Examples(batch_size)
|
|
self._errors: List[FileParseException] = []
|
|
|
|
def save(self, reader: BaseReader, project: Project, user, cleaner):
|
|
it = iter(reader)
|
|
while True:
|
|
try:
|
|
example = next(it)
|
|
except StopIteration:
|
|
break
|
|
try:
|
|
example.clean(cleaner)
|
|
except FileParseException as err:
|
|
self._errors.append(err)
|
|
|
|
self.examples.add(example)
|
|
if self.examples.is_full():
|
|
self.create(project, user)
|
|
self.examples.clear()
|
|
if not self.examples.is_empty():
|
|
self.create(project, user)
|
|
self.examples.clear()
|
|
self._errors.extend(reader.errors)
|
|
|
|
@property
|
|
def errors(self) -> List[Dict[Any, Any]]:
|
|
self._errors.sort(key=lambda e: e.line_num)
|
|
return [error.dict() for error in self._errors]
|
|
|
|
def create(self, project: Project, user):
|
|
self.examples.save_label(project)
|
|
ids = self.examples.save_data(project)
|
|
self.examples.save_annotation(project, user, ids)
|