From 56440c180e4b738a3cd911024ac25a3ed1388686 Mon Sep 17 00:00:00 2001 From: Hironsan Date: Thu, 12 May 2022 12:31:11 +0900 Subject: [PATCH] Enable to import relation extraction dataset --- backend/data_import/celery_tasks.py | 10 +++- backend/data_import/pipeline/cleaners.py | 8 ++- .../relation_extraction/example.jsonl | 1 - backend/data_import/pipeline/factories.py | 19 ++++++- .../data_import/pipeline/labeled_examples.py | 55 +++++++++++++++++-- backend/data_import/pipeline/labels.py | 17 ++++-- backend/data_import/pipeline/readers.py | 10 ++-- backend/data_import/pipeline/writers.py | 4 +- .../data/relation_extraction/example.jsonl | 1 + backend/data_import/tests/test_builder.py | 8 ++- backend/data_import/tests/test_tasks.py | 36 ++++++++++++ 11 files changed, 144 insertions(+), 25 deletions(-) create mode 100644 backend/data_import/tests/data/relation_extraction/example.jsonl diff --git a/backend/data_import/celery_tasks.py b/backend/data_import/celery_tasks.py index e5109164..e7f8fb25 100644 --- a/backend/data_import/celery_tasks.py +++ b/backend/data_import/celery_tasks.py @@ -10,7 +10,12 @@ from django_drf_filepond.models import TemporaryUpload from .pipeline.catalog import AudioFile, ImageFile from .pipeline.exceptions import FileTypeException, MaximumFileSizeException -from .pipeline.factories import create_builder, create_cleaner, create_parser +from .pipeline.factories import ( + create_builder, + create_cleaner, + create_parser, + select_examples, +) from .pipeline.readers import FileName, Reader from .pipeline.writers import Writer from projects.models import Project @@ -64,7 +69,8 @@ def import_dataset(user_id, project_id, file_format: str, upload_ids: List[str], cleaner = create_cleaner(project) reader = Reader(filenames=filenames, parser=parser, builder=builder, cleaner=cleaner) writer = Writer(batch_size=settings.IMPORT_BATCH_SIZE) - writer.save(reader, project, user) + examples = select_examples(project) + writer.save(reader, project, user, examples) upload_to_store(temporary_uploads) return {"error": reader.errors + errors} diff --git a/backend/data_import/pipeline/cleaners.py b/backend/data_import/pipeline/cleaners.py index b6d31d96..bf013a95 100644 --- a/backend/data_import/pipeline/cleaners.py +++ b/backend/data_import/pipeline/cleaners.py @@ -25,14 +25,16 @@ class SpanCleaner(Cleaner): if self.allow_overlapping: return labels - labels.sort(key=lambda label: label.start_offset) + span_labels = [label for label in labels if isinstance(label, SpanLabel)] + other_labels = [label for label in labels if not isinstance(label, SpanLabel)] + span_labels.sort(key=lambda label: label.start_offset) last_offset = -1 new_labels = [] - for label in labels: + for label in span_labels: if label.start_offset >= last_offset: last_offset = label.end_offset new_labels.append(label) - return new_labels + return new_labels + other_labels @property def message(self) -> str: diff --git a/backend/data_import/pipeline/examples/relation_extraction/example.jsonl b/backend/data_import/pipeline/examples/relation_extraction/example.jsonl index d3186d0e..6ba958ff 100644 --- a/backend/data_import/pipeline/examples/relation_extraction/example.jsonl +++ b/backend/data_import/pipeline/examples/relation_extraction/example.jsonl @@ -1,4 +1,3 @@ - { "text": "Google was founded on September 4, 1998, by Larry Page and Sergey Brin.", "entities": [ diff --git a/backend/data_import/pipeline/factories.py b/backend/data_import/pipeline/factories.py index 61758a37..e723ad1d 100644 --- a/backend/data_import/pipeline/factories.py +++ b/backend/data_import/pipeline/factories.py @@ -1,4 +1,13 @@ -from . import builders, catalog, cleaners, data, labels, parsers, readers +from . import ( + builders, + catalog, + cleaners, + data, + labeled_examples, + labels, + parsers, + readers, +) from projects.models import ( DOCUMENT_CLASSIFICATION, IMAGE_CLASSIFICATION, @@ -60,6 +69,14 @@ def create_cleaner(project): return cleaner_class(project) +def select_examples(project): + use_relation = getattr(project, "use_relation", False) + if project.project_type == SEQUENCE_LABELING and use_relation: + return labeled_examples.RelationExamples + else: + return labeled_examples.LabeledExamples + + def create_builder(project, **kwargs): if not project.is_text_project: return builders.PlainBuilder(data_class=get_data_class(project.project_type)) diff --git a/backend/data_import/pipeline/labeled_examples.py b/backend/data_import/pipeline/labeled_examples.py index 58c75a68..ac53755d 100644 --- a/backend/data_import/pipeline/labeled_examples.py +++ b/backend/data_import/pipeline/labeled_examples.py @@ -1,14 +1,15 @@ import itertools from collections import defaultdict -from typing import Any, Dict, List, Type +from typing import Any, Dict, List, Optional, Type from .cleaners import Cleaner from .data import BaseData from .exceptions import FileParseException -from .labels import Label +from .labels import Label, RelationLabel, SpanLabel from examples.models import Example -from label_types.models import CategoryType, LabelType, SpanType +from label_types.models import CategoryType, LabelType, RelationType, SpanType from labels.models import Label as LabelModel +from labels.models import Relation, Span from projects.models import Project @@ -53,8 +54,20 @@ class Record: labels = [label.create_type(project) for label in self._label] return list(filter(None, labels)) - def create_label(self, user, example, mapping) -> List[LabelModel]: - return [label.create(user, example, mapping) for label in self._label] + def create_label( + self, user, example, mapping, label_class: Optional[Type[Label]] = None, **kwargs + ) -> List[LabelModel]: + if label_class is None: + return [label.create(user, example, mapping) for label in self._label] + else: + return [ + label.create(user, example, mapping, **kwargs) + for label in self._label + if isinstance(label, label_class) + ] + + def select_label(self, label_class: Type[Label]) -> List[Label]: + return [label for label in self._label if isinstance(label, label_class)] @property def label(self): @@ -88,3 +101,35 @@ class LabeledExamples: ) for label_class, instances in group_by_class(labels).items(): label_class.objects.bulk_create(instances) + + +class RelationExamples(LabeledExamples): + def create_label(self, project: Project, user, examples: List[Example]): + mapping = {} + label_types: List[Type[LabelType]] = [RelationType, SpanType] + for model in label_types: + for label in model.objects.filter(project=project): + mapping[label.text] = label + + labels = list( + itertools.chain.from_iterable( + [data.create_label(user, example, mapping, SpanLabel) for data, example in zip(self.records, examples)] + ) + ) + Span.objects.bulk_create(labels) + # filter spans by uuid + original_spans = list( + itertools.chain.from_iterable(example.select_label(SpanLabel) for example in self.records) + ) + spans = Span.objects.filter(uuid__in=[span.uuid for span in original_spans]) + # create mapping from id to span + # this is needed to create the relation + span_mapping = {original_span.id: saved_span for saved_span, original_span in zip(spans, original_spans)} + # then, replace from_id and to_id with the span id + relations = itertools.chain.from_iterable( + [ + data.create_label(user, example, mapping, RelationLabel, span_mapping=span_mapping) + for data, example in zip(self.records, examples) + ] + ) + Relation.objects.bulk_create(relations) diff --git a/backend/data_import/pipeline/labels.py b/backend/data_import/pipeline/labels.py index 5ffab4df..b490dabb 100644 --- a/backend/data_import/pipeline/labels.py +++ b/backend/data_import/pipeline/labels.py @@ -1,7 +1,8 @@ import abc +import uuid from typing import Any, Dict, Optional -from pydantic import BaseModel, validator +from pydantic import UUID4, BaseModel, validator from label_types.models import CategoryType, LabelType, RelationType, SpanType from labels.models import Category @@ -12,6 +13,13 @@ from projects.models import Project class Label(BaseModel, abc.ABC): + id: int = -1 + uuid: UUID4 + + def __init__(self, **data): + data["uuid"] = uuid.uuid4() + super().__init__(**data) + @abc.abstractmethod def has_name(self) -> bool: raise NotImplementedError() @@ -67,11 +75,10 @@ class CategoryLabel(Label): return CategoryType(text=self.label, project=project) def create(self, user, example, mapping: Dict[str, LabelType], **kwargs): - return Category(user=user, example=example, label=mapping[self.label]) + return Category(uuid=self.uuid, user=user, example=example, label=mapping[self.label]) class SpanLabel(Label): - id: int = -1 label: str start_offset: int end_offset: int @@ -99,6 +106,7 @@ class SpanLabel(Label): def create(self, user, example, mapping: Dict[str, LabelType], **kwargs): return Span( + uuid=self.uuid, user=user, example=example, start_offset=self.start_offset, @@ -128,7 +136,7 @@ class TextLabel(Label): return None def create(self, user, example, mapping, **kwargs): - return TL(user=user, example=example, text=self.text) + return TL(uuid=self.uuid, user=user, example=example, text=self.text) class RelationLabel(Label): @@ -155,6 +163,7 @@ class RelationLabel(Label): def create(self, user, example, mapping: Dict[str, LabelType], **kwargs): return Relation( + uuid=self.uuid, user=user, example=example, type=mapping[self.type], diff --git a/backend/data_import/pipeline/readers.py b/backend/data_import/pipeline/readers.py index 7a296616..86cf5680 100644 --- a/backend/data_import/pipeline/readers.py +++ b/backend/data_import/pipeline/readers.py @@ -1,7 +1,7 @@ import abc import collections.abc import dataclasses -from typing import Any, Dict, Iterator, List +from typing import Any, Dict, Iterator, List, Type from .cleaners import Cleaner from .exceptions import FileParseException @@ -29,7 +29,7 @@ class BaseReader(collections.abc.Iterable): raise NotImplementedError("Please implement this method in the subclass.") @abc.abstractmethod - def batch(self, batch_size: int) -> Iterator[LabeledExamples]: + def batch(self, batch_size: int, labeled_examples: Type[LabeledExamples]) -> Iterator[LabeledExamples]: raise NotImplementedError("Please implement this method in the subclass.") @@ -84,15 +84,15 @@ class Reader(BaseReader): except FileParseException as e: self._errors.append(e) - def batch(self, batch_size: int) -> Iterator[LabeledExamples]: + def batch(self, batch_size: int, labeled_examples: Type[LabeledExamples]) -> Iterator[LabeledExamples]: batch = [] for record in self: batch.append(record) if len(batch) == batch_size: - yield LabeledExamples(batch) + yield labeled_examples(batch) batch = [] if batch: - yield LabeledExamples(batch) + yield labeled_examples(batch) @property def errors(self) -> List[FileParseException]: diff --git a/backend/data_import/pipeline/writers.py b/backend/data_import/pipeline/writers.py index 77b54b97..0b89761c 100644 --- a/backend/data_import/pipeline/writers.py +++ b/backend/data_import/pipeline/writers.py @@ -6,8 +6,8 @@ class Writer: def __init__(self, batch_size: int): self.batch_size = batch_size - def save(self, reader: BaseReader, project: Project, user): - for batch in reader.batch(self.batch_size): + def save(self, reader: BaseReader, project: Project, user, labeled_examples): + for batch in reader.batch(self.batch_size, labeled_examples): examples = batch.create_data(project) batch.create_label_type(project) batch.create_label(project, user, examples) diff --git a/backend/data_import/tests/data/relation_extraction/example.jsonl b/backend/data_import/tests/data/relation_extraction/example.jsonl new file mode 100644 index 00000000..fbb39fc9 --- /dev/null +++ b/backend/data_import/tests/data/relation_extraction/example.jsonl @@ -0,0 +1 @@ +{"text": "Google was founded on September 4, 1998, by Larry Page and Sergey Brin.", "entities": [{"id": 0, "start_offset": 0, "end_offset": 6, "label": "ORG"}, {"id": 1, "start_offset": 22, "end_offset": 39, "label": "DATE"}, {"id": 2, "start_offset": 44, "end_offset": 54, "label": "PERSON"}, {"id": 3, "start_offset": 59, "end_offset": 70, "label": "PERSON"}], "relations": [{"from_id": 0, "to_id": 1, "type": "foundedAt"}, {"from_id": 0, "to_id": 2, "type": "foundedBy"}, {"from_id": 0, "to_id": 3, "type": "foundedBy"}]} diff --git a/backend/data_import/tests/test_builder.py b/backend/data_import/tests/test_builder.py index acba65ab..dbecedb2 100644 --- a/backend/data_import/tests/test_builder.py +++ b/backend/data_import/tests/test_builder.py @@ -10,8 +10,12 @@ from data_import.pipeline.readers import FileName class TestColumnBuilder(unittest.TestCase): def assert_record(self, actual, expected): + labels = actual.label + for label in labels: + label.pop("id") + label.pop("uuid") self.assertEqual(actual.data.text, expected["data"]) - self.assertEqual(actual.label, expected["label"]) + self.assertEqual(labels, expected["label"]) def create_record(self, row, data_column: builders.DataColumn, label_columns: Optional[List[builders.Column]]): builder = builders.ColumnBuilder(data_column=data_column, label_columns=label_columns) @@ -71,6 +75,6 @@ class TestColumnBuilder(unittest.TestCase): actual = self.create_record(row, data_column, label_columns) expected = { "data": "Text", - "label": [{"label": "Label"}, {"id": -1, "label": "LOC", "start_offset": 0, "end_offset": 1}], + "label": [{"label": "Label"}, {"label": "LOC", "start_offset": 0, "end_offset": 1}], } self.assert_record(actual, expected) diff --git a/backend/data_import/tests/test_tasks.py b/backend/data_import/tests/test_tasks.py index 7e957702..be1275a4 100644 --- a/backend/data_import/tests/test_tasks.py +++ b/backend/data_import/tests/test_tasks.py @@ -208,6 +208,42 @@ class TestImportSequenceLabelingData(TestImportData): self.assertEqual(len(response["error"]), 1) +class TestImportRelationExtractionData(TestImportData): + task = SEQUENCE_LABELING + + def setUp(self): + self.project = prepare_project(self.task, use_relation=True) + self.user = self.project.admin + self.data_path = pathlib.Path(__file__).parent / "data" + self.upload_id = _get_file_id() + + def assert_examples(self, dataset): + self.assertEqual(Example.objects.count(), len(dataset)) + for text, expected_spans in dataset: + example = Example.objects.get(text=text) + spans = [[span.start_offset, span.end_offset, span.label.text] for span in example.spans.all()] + self.assertEqual(spans, expected_spans) + self.assertEqual(example.relations.count(), 3) + + def assert_parse_error(self, response): + self.assertGreaterEqual(len(response["error"]), 1) + self.assertEqual(Example.objects.count(), 0) + self.assertEqual(SpanType.objects.count(), 0) + self.assertEqual(Span.objects.count(), 0) + + def test_jsonl(self): + filename = "relation_extraction/example.jsonl" + file_format = "JSONL" + dataset = [ + ( + "Google was founded on September 4, 1998, by Larry Page and Sergey Brin.", + [[0, 6, "ORG"], [22, 39, "DATE"], [44, 54, "PERSON"], [59, 70, "PERSON"]], + ), + ] + self.import_dataset(filename, file_format) + self.assert_examples(dataset) + + class TestImportSeq2seqData(TestImportData): task = SEQ2SEQ