Browse Source

Enable to import relation extraction dataset

pull/1823/head
Hironsan 2 years ago
parent
commit
56440c180e
11 changed files with 144 additions and 25 deletions
  1. 10
      backend/data_import/celery_tasks.py
  2. 8
      backend/data_import/pipeline/cleaners.py
  3. 1
      backend/data_import/pipeline/examples/relation_extraction/example.jsonl
  4. 19
      backend/data_import/pipeline/factories.py
  5. 55
      backend/data_import/pipeline/labeled_examples.py
  6. 17
      backend/data_import/pipeline/labels.py
  7. 10
      backend/data_import/pipeline/readers.py
  8. 4
      backend/data_import/pipeline/writers.py
  9. 1
      backend/data_import/tests/data/relation_extraction/example.jsonl
  10. 8
      backend/data_import/tests/test_builder.py
  11. 36
      backend/data_import/tests/test_tasks.py

10
backend/data_import/celery_tasks.py

@ -10,7 +10,12 @@ from django_drf_filepond.models import TemporaryUpload
from .pipeline.catalog import AudioFile, ImageFile from .pipeline.catalog import AudioFile, ImageFile
from .pipeline.exceptions import FileTypeException, MaximumFileSizeException from .pipeline.exceptions import FileTypeException, MaximumFileSizeException
from .pipeline.factories import create_builder, create_cleaner, create_parser
from .pipeline.factories import (
create_builder,
create_cleaner,
create_parser,
select_examples,
)
from .pipeline.readers import FileName, Reader from .pipeline.readers import FileName, Reader
from .pipeline.writers import Writer from .pipeline.writers import Writer
from projects.models import Project from projects.models import Project
@ -64,7 +69,8 @@ def import_dataset(user_id, project_id, file_format: str, upload_ids: List[str],
cleaner = create_cleaner(project) cleaner = create_cleaner(project)
reader = Reader(filenames=filenames, parser=parser, builder=builder, cleaner=cleaner) reader = Reader(filenames=filenames, parser=parser, builder=builder, cleaner=cleaner)
writer = Writer(batch_size=settings.IMPORT_BATCH_SIZE) writer = Writer(batch_size=settings.IMPORT_BATCH_SIZE)
writer.save(reader, project, user)
examples = select_examples(project)
writer.save(reader, project, user, examples)
upload_to_store(temporary_uploads) upload_to_store(temporary_uploads)
return {"error": reader.errors + errors} return {"error": reader.errors + errors}

8
backend/data_import/pipeline/cleaners.py

@ -25,14 +25,16 @@ class SpanCleaner(Cleaner):
if self.allow_overlapping: if self.allow_overlapping:
return labels return labels
labels.sort(key=lambda label: label.start_offset)
span_labels = [label for label in labels if isinstance(label, SpanLabel)]
other_labels = [label for label in labels if not isinstance(label, SpanLabel)]
span_labels.sort(key=lambda label: label.start_offset)
last_offset = -1 last_offset = -1
new_labels = [] new_labels = []
for label in labels:
for label in span_labels:
if label.start_offset >= last_offset: if label.start_offset >= last_offset:
last_offset = label.end_offset last_offset = label.end_offset
new_labels.append(label) new_labels.append(label)
return new_labels
return new_labels + other_labels
@property @property
def message(self) -> str: def message(self) -> str:

1
backend/data_import/pipeline/examples/relation_extraction/example.jsonl

@ -1,4 +1,3 @@
{ {
"text": "Google was founded on September 4, 1998, by Larry Page and Sergey Brin.", "text": "Google was founded on September 4, 1998, by Larry Page and Sergey Brin.",
"entities": [ "entities": [

19
backend/data_import/pipeline/factories.py

@ -1,4 +1,13 @@
from . import builders, catalog, cleaners, data, labels, parsers, readers
from . import (
builders,
catalog,
cleaners,
data,
labeled_examples,
labels,
parsers,
readers,
)
from projects.models import ( from projects.models import (
DOCUMENT_CLASSIFICATION, DOCUMENT_CLASSIFICATION,
IMAGE_CLASSIFICATION, IMAGE_CLASSIFICATION,
@ -60,6 +69,14 @@ def create_cleaner(project):
return cleaner_class(project) return cleaner_class(project)
def select_examples(project):
use_relation = getattr(project, "use_relation", False)
if project.project_type == SEQUENCE_LABELING and use_relation:
return labeled_examples.RelationExamples
else:
return labeled_examples.LabeledExamples
def create_builder(project, **kwargs): def create_builder(project, **kwargs):
if not project.is_text_project: if not project.is_text_project:
return builders.PlainBuilder(data_class=get_data_class(project.project_type)) return builders.PlainBuilder(data_class=get_data_class(project.project_type))

55
backend/data_import/pipeline/labeled_examples.py

@ -1,14 +1,15 @@
import itertools import itertools
from collections import defaultdict from collections import defaultdict
from typing import Any, Dict, List, Type
from typing import Any, Dict, List, Optional, Type
from .cleaners import Cleaner from .cleaners import Cleaner
from .data import BaseData from .data import BaseData
from .exceptions import FileParseException from .exceptions import FileParseException
from .labels import Label
from .labels import Label, RelationLabel, SpanLabel
from examples.models import Example from examples.models import Example
from label_types.models import CategoryType, LabelType, SpanType
from label_types.models import CategoryType, LabelType, RelationType, SpanType
from labels.models import Label as LabelModel from labels.models import Label as LabelModel
from labels.models import Relation, Span
from projects.models import Project from projects.models import Project
@ -53,8 +54,20 @@ class Record:
labels = [label.create_type(project) for label in self._label] labels = [label.create_type(project) for label in self._label]
return list(filter(None, labels)) return list(filter(None, labels))
def create_label(self, user, example, mapping) -> List[LabelModel]:
return [label.create(user, example, mapping) for label in self._label]
def create_label(
self, user, example, mapping, label_class: Optional[Type[Label]] = None, **kwargs
) -> List[LabelModel]:
if label_class is None:
return [label.create(user, example, mapping) for label in self._label]
else:
return [
label.create(user, example, mapping, **kwargs)
for label in self._label
if isinstance(label, label_class)
]
def select_label(self, label_class: Type[Label]) -> List[Label]:
return [label for label in self._label if isinstance(label, label_class)]
@property @property
def label(self): def label(self):
@ -88,3 +101,35 @@ class LabeledExamples:
) )
for label_class, instances in group_by_class(labels).items(): for label_class, instances in group_by_class(labels).items():
label_class.objects.bulk_create(instances) label_class.objects.bulk_create(instances)
class RelationExamples(LabeledExamples):
def create_label(self, project: Project, user, examples: List[Example]):
mapping = {}
label_types: List[Type[LabelType]] = [RelationType, SpanType]
for model in label_types:
for label in model.objects.filter(project=project):
mapping[label.text] = label
labels = list(
itertools.chain.from_iterable(
[data.create_label(user, example, mapping, SpanLabel) for data, example in zip(self.records, examples)]
)
)
Span.objects.bulk_create(labels)
# filter spans by uuid
original_spans = list(
itertools.chain.from_iterable(example.select_label(SpanLabel) for example in self.records)
)
spans = Span.objects.filter(uuid__in=[span.uuid for span in original_spans])
# create mapping from id to span
# this is needed to create the relation
span_mapping = {original_span.id: saved_span for saved_span, original_span in zip(spans, original_spans)}
# then, replace from_id and to_id with the span id
relations = itertools.chain.from_iterable(
[
data.create_label(user, example, mapping, RelationLabel, span_mapping=span_mapping)
for data, example in zip(self.records, examples)
]
)
Relation.objects.bulk_create(relations)

17
backend/data_import/pipeline/labels.py

@ -1,7 +1,8 @@
import abc import abc
import uuid
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from pydantic import BaseModel, validator
from pydantic import UUID4, BaseModel, validator
from label_types.models import CategoryType, LabelType, RelationType, SpanType from label_types.models import CategoryType, LabelType, RelationType, SpanType
from labels.models import Category from labels.models import Category
@ -12,6 +13,13 @@ from projects.models import Project
class Label(BaseModel, abc.ABC): class Label(BaseModel, abc.ABC):
id: int = -1
uuid: UUID4
def __init__(self, **data):
data["uuid"] = uuid.uuid4()
super().__init__(**data)
@abc.abstractmethod @abc.abstractmethod
def has_name(self) -> bool: def has_name(self) -> bool:
raise NotImplementedError() raise NotImplementedError()
@ -67,11 +75,10 @@ class CategoryLabel(Label):
return CategoryType(text=self.label, project=project) return CategoryType(text=self.label, project=project)
def create(self, user, example, mapping: Dict[str, LabelType], **kwargs): def create(self, user, example, mapping: Dict[str, LabelType], **kwargs):
return Category(user=user, example=example, label=mapping[self.label])
return Category(uuid=self.uuid, user=user, example=example, label=mapping[self.label])
class SpanLabel(Label): class SpanLabel(Label):
id: int = -1
label: str label: str
start_offset: int start_offset: int
end_offset: int end_offset: int
@ -99,6 +106,7 @@ class SpanLabel(Label):
def create(self, user, example, mapping: Dict[str, LabelType], **kwargs): def create(self, user, example, mapping: Dict[str, LabelType], **kwargs):
return Span( return Span(
uuid=self.uuid,
user=user, user=user,
example=example, example=example,
start_offset=self.start_offset, start_offset=self.start_offset,
@ -128,7 +136,7 @@ class TextLabel(Label):
return None return None
def create(self, user, example, mapping, **kwargs): def create(self, user, example, mapping, **kwargs):
return TL(user=user, example=example, text=self.text)
return TL(uuid=self.uuid, user=user, example=example, text=self.text)
class RelationLabel(Label): class RelationLabel(Label):
@ -155,6 +163,7 @@ class RelationLabel(Label):
def create(self, user, example, mapping: Dict[str, LabelType], **kwargs): def create(self, user, example, mapping: Dict[str, LabelType], **kwargs):
return Relation( return Relation(
uuid=self.uuid,
user=user, user=user,
example=example, example=example,
type=mapping[self.type], type=mapping[self.type],

10
backend/data_import/pipeline/readers.py

@ -1,7 +1,7 @@
import abc import abc
import collections.abc import collections.abc
import dataclasses import dataclasses
from typing import Any, Dict, Iterator, List
from typing import Any, Dict, Iterator, List, Type
from .cleaners import Cleaner from .cleaners import Cleaner
from .exceptions import FileParseException from .exceptions import FileParseException
@ -29,7 +29,7 @@ class BaseReader(collections.abc.Iterable):
raise NotImplementedError("Please implement this method in the subclass.") raise NotImplementedError("Please implement this method in the subclass.")
@abc.abstractmethod @abc.abstractmethod
def batch(self, batch_size: int) -> Iterator[LabeledExamples]:
def batch(self, batch_size: int, labeled_examples: Type[LabeledExamples]) -> Iterator[LabeledExamples]:
raise NotImplementedError("Please implement this method in the subclass.") raise NotImplementedError("Please implement this method in the subclass.")
@ -84,15 +84,15 @@ class Reader(BaseReader):
except FileParseException as e: except FileParseException as e:
self._errors.append(e) self._errors.append(e)
def batch(self, batch_size: int) -> Iterator[LabeledExamples]:
def batch(self, batch_size: int, labeled_examples: Type[LabeledExamples]) -> Iterator[LabeledExamples]:
batch = [] batch = []
for record in self: for record in self:
batch.append(record) batch.append(record)
if len(batch) == batch_size: if len(batch) == batch_size:
yield LabeledExamples(batch)
yield labeled_examples(batch)
batch = [] batch = []
if batch: if batch:
yield LabeledExamples(batch)
yield labeled_examples(batch)
@property @property
def errors(self) -> List[FileParseException]: def errors(self) -> List[FileParseException]:

4
backend/data_import/pipeline/writers.py

@ -6,8 +6,8 @@ class Writer:
def __init__(self, batch_size: int): def __init__(self, batch_size: int):
self.batch_size = batch_size self.batch_size = batch_size
def save(self, reader: BaseReader, project: Project, user):
for batch in reader.batch(self.batch_size):
def save(self, reader: BaseReader, project: Project, user, labeled_examples):
for batch in reader.batch(self.batch_size, labeled_examples):
examples = batch.create_data(project) examples = batch.create_data(project)
batch.create_label_type(project) batch.create_label_type(project)
batch.create_label(project, user, examples) batch.create_label(project, user, examples)

1
backend/data_import/tests/data/relation_extraction/example.jsonl

@ -0,0 +1 @@
{"text": "Google was founded on September 4, 1998, by Larry Page and Sergey Brin.", "entities": [{"id": 0, "start_offset": 0, "end_offset": 6, "label": "ORG"}, {"id": 1, "start_offset": 22, "end_offset": 39, "label": "DATE"}, {"id": 2, "start_offset": 44, "end_offset": 54, "label": "PERSON"}, {"id": 3, "start_offset": 59, "end_offset": 70, "label": "PERSON"}], "relations": [{"from_id": 0, "to_id": 1, "type": "foundedAt"}, {"from_id": 0, "to_id": 2, "type": "foundedBy"}, {"from_id": 0, "to_id": 3, "type": "foundedBy"}]}

8
backend/data_import/tests/test_builder.py

@ -10,8 +10,12 @@ from data_import.pipeline.readers import FileName
class TestColumnBuilder(unittest.TestCase): class TestColumnBuilder(unittest.TestCase):
def assert_record(self, actual, expected): def assert_record(self, actual, expected):
labels = actual.label
for label in labels:
label.pop("id")
label.pop("uuid")
self.assertEqual(actual.data.text, expected["data"]) self.assertEqual(actual.data.text, expected["data"])
self.assertEqual(actual.label, expected["label"])
self.assertEqual(labels, expected["label"])
def create_record(self, row, data_column: builders.DataColumn, label_columns: Optional[List[builders.Column]]): def create_record(self, row, data_column: builders.DataColumn, label_columns: Optional[List[builders.Column]]):
builder = builders.ColumnBuilder(data_column=data_column, label_columns=label_columns) builder = builders.ColumnBuilder(data_column=data_column, label_columns=label_columns)
@ -71,6 +75,6 @@ class TestColumnBuilder(unittest.TestCase):
actual = self.create_record(row, data_column, label_columns) actual = self.create_record(row, data_column, label_columns)
expected = { expected = {
"data": "Text", "data": "Text",
"label": [{"label": "Label"}, {"id": -1, "label": "LOC", "start_offset": 0, "end_offset": 1}],
"label": [{"label": "Label"}, {"label": "LOC", "start_offset": 0, "end_offset": 1}],
} }
self.assert_record(actual, expected) self.assert_record(actual, expected)

36
backend/data_import/tests/test_tasks.py

@ -208,6 +208,42 @@ class TestImportSequenceLabelingData(TestImportData):
self.assertEqual(len(response["error"]), 1) self.assertEqual(len(response["error"]), 1)
class TestImportRelationExtractionData(TestImportData):
task = SEQUENCE_LABELING
def setUp(self):
self.project = prepare_project(self.task, use_relation=True)
self.user = self.project.admin
self.data_path = pathlib.Path(__file__).parent / "data"
self.upload_id = _get_file_id()
def assert_examples(self, dataset):
self.assertEqual(Example.objects.count(), len(dataset))
for text, expected_spans in dataset:
example = Example.objects.get(text=text)
spans = [[span.start_offset, span.end_offset, span.label.text] for span in example.spans.all()]
self.assertEqual(spans, expected_spans)
self.assertEqual(example.relations.count(), 3)
def assert_parse_error(self, response):
self.assertGreaterEqual(len(response["error"]), 1)
self.assertEqual(Example.objects.count(), 0)
self.assertEqual(SpanType.objects.count(), 0)
self.assertEqual(Span.objects.count(), 0)
def test_jsonl(self):
filename = "relation_extraction/example.jsonl"
file_format = "JSONL"
dataset = [
(
"Google was founded on September 4, 1998, by Larry Page and Sergey Brin.",
[[0, 6, "ORG"], [22, 39, "DATE"], [44, 54, "PERSON"], [59, 70, "PERSON"]],
),
]
self.import_dataset(filename, file_format)
self.assert_examples(dataset)
class TestImportSeq2seqData(TestImportData): class TestImportSeq2seqData(TestImportData):
task = SEQ2SEQ task = SEQ2SEQ

Loading…
Cancel
Save