Browse Source

Enable to import relation extraction dataset

pull/1823/head
Hironsan 2 years ago
parent
commit
56440c180e
11 changed files with 144 additions and 25 deletions
  1. 10
      backend/data_import/celery_tasks.py
  2. 8
      backend/data_import/pipeline/cleaners.py
  3. 1
      backend/data_import/pipeline/examples/relation_extraction/example.jsonl
  4. 19
      backend/data_import/pipeline/factories.py
  5. 55
      backend/data_import/pipeline/labeled_examples.py
  6. 17
      backend/data_import/pipeline/labels.py
  7. 10
      backend/data_import/pipeline/readers.py
  8. 4
      backend/data_import/pipeline/writers.py
  9. 1
      backend/data_import/tests/data/relation_extraction/example.jsonl
  10. 8
      backend/data_import/tests/test_builder.py
  11. 36
      backend/data_import/tests/test_tasks.py

10
backend/data_import/celery_tasks.py

@ -10,7 +10,12 @@ from django_drf_filepond.models import TemporaryUpload
from .pipeline.catalog import AudioFile, ImageFile
from .pipeline.exceptions import FileTypeException, MaximumFileSizeException
from .pipeline.factories import create_builder, create_cleaner, create_parser
from .pipeline.factories import (
create_builder,
create_cleaner,
create_parser,
select_examples,
)
from .pipeline.readers import FileName, Reader
from .pipeline.writers import Writer
from projects.models import Project
@ -64,7 +69,8 @@ def import_dataset(user_id, project_id, file_format: str, upload_ids: List[str],
cleaner = create_cleaner(project)
reader = Reader(filenames=filenames, parser=parser, builder=builder, cleaner=cleaner)
writer = Writer(batch_size=settings.IMPORT_BATCH_SIZE)
writer.save(reader, project, user)
examples = select_examples(project)
writer.save(reader, project, user, examples)
upload_to_store(temporary_uploads)
return {"error": reader.errors + errors}

8
backend/data_import/pipeline/cleaners.py

@ -25,14 +25,16 @@ class SpanCleaner(Cleaner):
if self.allow_overlapping:
return labels
labels.sort(key=lambda label: label.start_offset)
span_labels = [label for label in labels if isinstance(label, SpanLabel)]
other_labels = [label for label in labels if not isinstance(label, SpanLabel)]
span_labels.sort(key=lambda label: label.start_offset)
last_offset = -1
new_labels = []
for label in labels:
for label in span_labels:
if label.start_offset >= last_offset:
last_offset = label.end_offset
new_labels.append(label)
return new_labels
return new_labels + other_labels
@property
def message(self) -> str:

1
backend/data_import/pipeline/examples/relation_extraction/example.jsonl

@ -1,4 +1,3 @@
{
"text": "Google was founded on September 4, 1998, by Larry Page and Sergey Brin.",
"entities": [

19
backend/data_import/pipeline/factories.py

@ -1,4 +1,13 @@
from . import builders, catalog, cleaners, data, labels, parsers, readers
from . import (
builders,
catalog,
cleaners,
data,
labeled_examples,
labels,
parsers,
readers,
)
from projects.models import (
DOCUMENT_CLASSIFICATION,
IMAGE_CLASSIFICATION,
@ -60,6 +69,14 @@ def create_cleaner(project):
return cleaner_class(project)
def select_examples(project):
use_relation = getattr(project, "use_relation", False)
if project.project_type == SEQUENCE_LABELING and use_relation:
return labeled_examples.RelationExamples
else:
return labeled_examples.LabeledExamples
def create_builder(project, **kwargs):
if not project.is_text_project:
return builders.PlainBuilder(data_class=get_data_class(project.project_type))

55
backend/data_import/pipeline/labeled_examples.py

@ -1,14 +1,15 @@
import itertools
from collections import defaultdict
from typing import Any, Dict, List, Type
from typing import Any, Dict, List, Optional, Type
from .cleaners import Cleaner
from .data import BaseData
from .exceptions import FileParseException
from .labels import Label
from .labels import Label, RelationLabel, SpanLabel
from examples.models import Example
from label_types.models import CategoryType, LabelType, SpanType
from label_types.models import CategoryType, LabelType, RelationType, SpanType
from labels.models import Label as LabelModel
from labels.models import Relation, Span
from projects.models import Project
@ -53,8 +54,20 @@ class Record:
labels = [label.create_type(project) for label in self._label]
return list(filter(None, labels))
def create_label(self, user, example, mapping) -> List[LabelModel]:
return [label.create(user, example, mapping) for label in self._label]
def create_label(
self, user, example, mapping, label_class: Optional[Type[Label]] = None, **kwargs
) -> List[LabelModel]:
if label_class is None:
return [label.create(user, example, mapping) for label in self._label]
else:
return [
label.create(user, example, mapping, **kwargs)
for label in self._label
if isinstance(label, label_class)
]
def select_label(self, label_class: Type[Label]) -> List[Label]:
return [label for label in self._label if isinstance(label, label_class)]
@property
def label(self):
@ -88,3 +101,35 @@ class LabeledExamples:
)
for label_class, instances in group_by_class(labels).items():
label_class.objects.bulk_create(instances)
class RelationExamples(LabeledExamples):
def create_label(self, project: Project, user, examples: List[Example]):
mapping = {}
label_types: List[Type[LabelType]] = [RelationType, SpanType]
for model in label_types:
for label in model.objects.filter(project=project):
mapping[label.text] = label
labels = list(
itertools.chain.from_iterable(
[data.create_label(user, example, mapping, SpanLabel) for data, example in zip(self.records, examples)]
)
)
Span.objects.bulk_create(labels)
# filter spans by uuid
original_spans = list(
itertools.chain.from_iterable(example.select_label(SpanLabel) for example in self.records)
)
spans = Span.objects.filter(uuid__in=[span.uuid for span in original_spans])
# create mapping from id to span
# this is needed to create the relation
span_mapping = {original_span.id: saved_span for saved_span, original_span in zip(spans, original_spans)}
# then, replace from_id and to_id with the span id
relations = itertools.chain.from_iterable(
[
data.create_label(user, example, mapping, RelationLabel, span_mapping=span_mapping)
for data, example in zip(self.records, examples)
]
)
Relation.objects.bulk_create(relations)

17
backend/data_import/pipeline/labels.py

@ -1,7 +1,8 @@
import abc
import uuid
from typing import Any, Dict, Optional
from pydantic import BaseModel, validator
from pydantic import UUID4, BaseModel, validator
from label_types.models import CategoryType, LabelType, RelationType, SpanType
from labels.models import Category
@ -12,6 +13,13 @@ from projects.models import Project
class Label(BaseModel, abc.ABC):
id: int = -1
uuid: UUID4
def __init__(self, **data):
data["uuid"] = uuid.uuid4()
super().__init__(**data)
@abc.abstractmethod
def has_name(self) -> bool:
raise NotImplementedError()
@ -67,11 +75,10 @@ class CategoryLabel(Label):
return CategoryType(text=self.label, project=project)
def create(self, user, example, mapping: Dict[str, LabelType], **kwargs):
return Category(user=user, example=example, label=mapping[self.label])
return Category(uuid=self.uuid, user=user, example=example, label=mapping[self.label])
class SpanLabel(Label):
id: int = -1
label: str
start_offset: int
end_offset: int
@ -99,6 +106,7 @@ class SpanLabel(Label):
def create(self, user, example, mapping: Dict[str, LabelType], **kwargs):
return Span(
uuid=self.uuid,
user=user,
example=example,
start_offset=self.start_offset,
@ -128,7 +136,7 @@ class TextLabel(Label):
return None
def create(self, user, example, mapping, **kwargs):
return TL(user=user, example=example, text=self.text)
return TL(uuid=self.uuid, user=user, example=example, text=self.text)
class RelationLabel(Label):
@ -155,6 +163,7 @@ class RelationLabel(Label):
def create(self, user, example, mapping: Dict[str, LabelType], **kwargs):
return Relation(
uuid=self.uuid,
user=user,
example=example,
type=mapping[self.type],

10
backend/data_import/pipeline/readers.py

@ -1,7 +1,7 @@
import abc
import collections.abc
import dataclasses
from typing import Any, Dict, Iterator, List
from typing import Any, Dict, Iterator, List, Type
from .cleaners import Cleaner
from .exceptions import FileParseException
@ -29,7 +29,7 @@ class BaseReader(collections.abc.Iterable):
raise NotImplementedError("Please implement this method in the subclass.")
@abc.abstractmethod
def batch(self, batch_size: int) -> Iterator[LabeledExamples]:
def batch(self, batch_size: int, labeled_examples: Type[LabeledExamples]) -> Iterator[LabeledExamples]:
raise NotImplementedError("Please implement this method in the subclass.")
@ -84,15 +84,15 @@ class Reader(BaseReader):
except FileParseException as e:
self._errors.append(e)
def batch(self, batch_size: int) -> Iterator[LabeledExamples]:
def batch(self, batch_size: int, labeled_examples: Type[LabeledExamples]) -> Iterator[LabeledExamples]:
batch = []
for record in self:
batch.append(record)
if len(batch) == batch_size:
yield LabeledExamples(batch)
yield labeled_examples(batch)
batch = []
if batch:
yield LabeledExamples(batch)
yield labeled_examples(batch)
@property
def errors(self) -> List[FileParseException]:

4
backend/data_import/pipeline/writers.py

@ -6,8 +6,8 @@ class Writer:
def __init__(self, batch_size: int):
self.batch_size = batch_size
def save(self, reader: BaseReader, project: Project, user):
for batch in reader.batch(self.batch_size):
def save(self, reader: BaseReader, project: Project, user, labeled_examples):
for batch in reader.batch(self.batch_size, labeled_examples):
examples = batch.create_data(project)
batch.create_label_type(project)
batch.create_label(project, user, examples)

1
backend/data_import/tests/data/relation_extraction/example.jsonl

@ -0,0 +1 @@
{"text": "Google was founded on September 4, 1998, by Larry Page and Sergey Brin.", "entities": [{"id": 0, "start_offset": 0, "end_offset": 6, "label": "ORG"}, {"id": 1, "start_offset": 22, "end_offset": 39, "label": "DATE"}, {"id": 2, "start_offset": 44, "end_offset": 54, "label": "PERSON"}, {"id": 3, "start_offset": 59, "end_offset": 70, "label": "PERSON"}], "relations": [{"from_id": 0, "to_id": 1, "type": "foundedAt"}, {"from_id": 0, "to_id": 2, "type": "foundedBy"}, {"from_id": 0, "to_id": 3, "type": "foundedBy"}]}

8
backend/data_import/tests/test_builder.py

@ -10,8 +10,12 @@ from data_import.pipeline.readers import FileName
class TestColumnBuilder(unittest.TestCase):
def assert_record(self, actual, expected):
labels = actual.label
for label in labels:
label.pop("id")
label.pop("uuid")
self.assertEqual(actual.data.text, expected["data"])
self.assertEqual(actual.label, expected["label"])
self.assertEqual(labels, expected["label"])
def create_record(self, row, data_column: builders.DataColumn, label_columns: Optional[List[builders.Column]]):
builder = builders.ColumnBuilder(data_column=data_column, label_columns=label_columns)
@ -71,6 +75,6 @@ class TestColumnBuilder(unittest.TestCase):
actual = self.create_record(row, data_column, label_columns)
expected = {
"data": "Text",
"label": [{"label": "Label"}, {"id": -1, "label": "LOC", "start_offset": 0, "end_offset": 1}],
"label": [{"label": "Label"}, {"label": "LOC", "start_offset": 0, "end_offset": 1}],
}
self.assert_record(actual, expected)

36
backend/data_import/tests/test_tasks.py

@ -208,6 +208,42 @@ class TestImportSequenceLabelingData(TestImportData):
self.assertEqual(len(response["error"]), 1)
class TestImportRelationExtractionData(TestImportData):
task = SEQUENCE_LABELING
def setUp(self):
self.project = prepare_project(self.task, use_relation=True)
self.user = self.project.admin
self.data_path = pathlib.Path(__file__).parent / "data"
self.upload_id = _get_file_id()
def assert_examples(self, dataset):
self.assertEqual(Example.objects.count(), len(dataset))
for text, expected_spans in dataset:
example = Example.objects.get(text=text)
spans = [[span.start_offset, span.end_offset, span.label.text] for span in example.spans.all()]
self.assertEqual(spans, expected_spans)
self.assertEqual(example.relations.count(), 3)
def assert_parse_error(self, response):
self.assertGreaterEqual(len(response["error"]), 1)
self.assertEqual(Example.objects.count(), 0)
self.assertEqual(SpanType.objects.count(), 0)
self.assertEqual(Span.objects.count(), 0)
def test_jsonl(self):
filename = "relation_extraction/example.jsonl"
file_format = "JSONL"
dataset = [
(
"Google was founded on September 4, 1998, by Larry Page and Sergey Brin.",
[[0, 6, "ORG"], [22, 39, "DATE"], [44, 54, "PERSON"], [59, 70, "PERSON"]],
),
]
self.import_dataset(filename, file_format)
self.assert_examples(dataset)
class TestImportSeq2seqData(TestImportData):
task = SEQ2SEQ

Loading…
Cancel
Save