Browse Source

Merge pull request #1764 from mkmark/master

fix empty export in entity-relationship-labeling
pull/1777/head
Hiroki Nakayama 2 years ago
committed by GitHub
parent
commit
afcfd2ea0e
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 350 additions and 63 deletions
  1. 2
      backend/data_export/celery_tasks.py
  2. 4
      backend/data_export/pipeline/factories.py
  3. 32
      backend/data_export/pipeline/repositories.py
  4. 375
      backend/data_export/tests/test_repositories.py

2
backend/data_export/celery_tasks.py

@ -13,7 +13,7 @@ logger = get_task_logger(__name__)
@shared_task
def export_dataset(project_id, file_format: str, export_approved=False):
project = get_object_or_404(Project, pk=project_id)
repository = create_repository(project)
repository = create_repository(project, file_format)
writer = create_writer(file_format)(settings.MEDIA_ROOT)
service = ExportApplicationService(repository, writer)
filepath = service.export(export_approved)

4
backend/data_export/pipeline/factories.py

@ -11,8 +11,8 @@ from projects.models import (
)
def create_repository(project):
if getattr(project, "use_relation", False):
def create_repository(project, file_format: str):
if getattr(project, "use_relation", False) and file_format == catalog.JSONLRelation.name:
return repositories.RelationExtractionRepository(project)
mapping = {
DOCUMENT_CLASSIFICATION: repositories.TextClassificationRepository,

32
backend/data_export/pipeline/repositories.py

@ -1,7 +1,7 @@
import abc
import itertools
from collections import defaultdict
from typing import Dict, Iterator, List, Tuple, Union
from typing import Any, Dict, Iterator, List, Tuple
from .data import Record
from examples.models import Example
@ -10,13 +10,12 @@ from projects.models import Project
SpanType = Tuple[int, int, str]
class BaseRepository(abc.ABC):
class BaseRepository:
def __init__(self, project: Project):
self.project = project
@abc.abstractmethod
def list(self, export_approved=False) -> Iterator[Record]:
pass
raisen> <span class="ne">NotImplementedError()
class FileRepository(BaseRepository):
@ -54,7 +53,7 @@ class FileRepository(BaseRepository):
label_per_user[a.user.username].append(a.label.text)
return label_per_user
def reduce_user(self, label_per_user: Dict[str, List]):
def reduce_user(self, label_per_user: Dict[str, Any]):
value = list(itertools.chain(*label_per_user.values()))
return {"all": value}
@ -96,7 +95,7 @@ class TextRepository(BaseRepository):
def label_per_user(self, doc) -> Dict:
raise NotImplementedError()
def reduce_user(self, label_per_user: Dict[str, List]):
def reduce_user(self, label_per_user: Dict[str, Any]):
value = list(itertools.chain(*label_per_user.values()))
return {"all": value}
@ -161,6 +160,14 @@ class RelationExtractionRepository(TextRepository):
label_per_user[user]["entities"] = span
return label_per_user
def reduce_user(self, label_per_user: Dict[str, Any]):
entities = []
relations = []
for user, label in label_per_user.items():
entities.extend(label.get("entities", []))
relations.extend(label.get("relations", []))
return {"all": {"entities": entities, "relations": relations}}
class Seq2seqRepository(TextRepository):
@property
@ -184,7 +191,7 @@ class IntentDetectionSlotFillingRepository(TextRepository):
def label_per_user(self, doc) -> Dict:
category_per_user: Dict[str, List[str]] = defaultdict(list)
span_per_user: Dict[str, List[SpanType]] = defaultdict(list)
label_per_user: Dict[str, Dict[str, Union[List[str], List[SpanType]]]] = defaultdict(dict)
label_per_user: Dict[str, Dict[str, List]] = defaultdict(dict)
for a in doc.categories.all():
category_per_user[a.user.username].append(a.label.text)
for a in doc.spans.all():
@ -193,4 +200,15 @@ class IntentDetectionSlotFillingRepository(TextRepository):
label_per_user[user]["cats"] = cats
for user, span in span_per_user.items():
label_per_user[user]["entities"] = span
for label in label_per_user.values():
label.setdefault("cats", [])
label.setdefault("entities", [])
return label_per_user
def reduce_user(self, label_per_user: Dict[str, Any]):
cats = []
entities = []
for user, label in label_per_user.items():
cats.extend(label.get("cats", []))
entities.extend(label.get("entities", []))
return {"all": {"entities": entities, "cats": cats}}

375
backend/data_export/tests/test_repositories.py

@ -3,72 +3,341 @@ import unittest
from model_mommy import mommy
from ..pipeline.repositories import (
FileRepository,
IntentDetectionSlotFillingRepository,
RelationExtractionRepository,
Seq2seqRepository,
SequenceLabelingRepository,
Speech2TextRepository,
TextClassificationRepository,
)
from projects.models import (
DOCUMENT_CLASSIFICATION,
IMAGE_CLASSIFICATION,
INTENT_DETECTION_AND_SLOT_FILLING,
SEQ2SEQ,
SEQUENCE_LABELING,
SPEECH2TEXT,
)
from projects.models import INTENT_DETECTION_AND_SLOT_FILLING, SEQUENCE_LABELING
from projects.tests.utils import prepare_project
class TestCSVWriter(unittest.TestCase):
def setUp(self):
self.project = prepare_project(INTENT_DETECTION_AND_SLOT_FILLING)
class TestRepository(unittest.TestCase):
def assert_records(self, repository, expected):
records = list(repository.list())
self.assertEqual(len(records), len(expected))
for record, expect in zip(records, expected):
self.assertEqual(record.data, expect["data"])
self.assertEqual(record.label, expect["label"])
self.assertEqual(record.user, expect["user"])
class TestTextClassificationRepository(TestRepository):
def prepare_data(self, project):
self.example = mommy.make("Example", project=project.item, text="example")
self.category1 = mommy.make("Category", example=self.example, user=project.admin)
self.category2 = mommy.make("Category", example=self.example, user=project.annotator)
def test_list(self):
example = mommy.make("Example", project=self.project.item, text="example")
category = mommy.make("Category", example=example, user=self.project.admin)
span = mommy.make("Span", example=example, user=self.project.admin, start_offset=0, end_offset=1)
repository = IntentDetectionSlotFillingRepository(self.project.item)
project = prepare_project(DOCUMENT_CLASSIFICATION)
repository = TextClassificationRepository(project.item)
self.prepare_data(project)
expected = [
{
"data": example.text,
"data": self.example.text,
"label": [self.category1.label.text],
"user": project.admin.username,
},
{
"data": self.example.text,
"label": [self.category2.label.text],
"user": project.annotator.username,
},
]
self.assert_records(repository, expected)
def test_list_on_collaborative_annotation(self):
project = prepare_project(DOCUMENT_CLASSIFICATION, collaborative_annotation=True)
repository = TextClassificationRepository(project.item)
self.prepare_data(project)
expected = [
{
"data": self.example.text,
"label": [self.category1.label.text, self.category2.label.text],
"user": "all",
}
]
self.assert_records(repository, expected)
class TestSeq2seqRepository(TestRepository):
def prepare_data(self, project):
self.example = mommy.make("Example", project=project.item, text="example")
self.text1 = mommy.make("TextLabel", example=self.example, user=project.admin)
self.text2 = mommy.make("TextLabel", example=self.example, user=project.annotator)
def test_list(self):
project = prepare_project(SEQ2SEQ)
repository = Seq2seqRepository(project.item)
self.prepare_data(project)
expected = [
{
"data": self.example.text,
"label": [self.text1.text],
"user": project.admin.username,
},
{
"data": self.example.text,
"label": [self.text2.text],
"user": project.annotator.username,
},
]
self.assert_records(repository, expected)
def test_list_on_collaborative_annotation(self):
project = prepare_project(SEQ2SEQ, collaborative_annotation=True)
repository = Seq2seqRepository(project.item)
self.prepare_data(project)
expected = [
{
"data": self.example.text,
"label": [self.text1.text, self.text2.text],
"user": "all",
}
]
self.assert_records(repository, expected)
class TestIntentDetectionSlotFillingRepository(TestRepository):
def prepare_data(self, project):
self.example = mommy.make("Example", project=project.item, text="example")
self.category1 = mommy.make("Category", example=self.example, user=project.admin)
self.category2 = mommy.make("Category", example=self.example, user=project.annotator)
self.span = mommy.make("Span", example=self.example, user=project.admin, start_offset=0, end_offset=1)
def test_list(self):
project = prepare_project(INTENT_DETECTION_AND_SLOT_FILLING)
repository = IntentDetectionSlotFillingRepository(project.item)
self.prepare_data(project)
expected = [
{
"data": self.example.text,
"label": {
"cats": [self.category1.label.text],
"entities": [(self.span.start_offset, self.span.end_offset, self.span.label.text)],
},
"user": project.admin.username,
},
{
"data": self.example.text,
"label": {
"cats": [self.category2.label.text],
"entities": [],
},
"user": project.annotator.username,
},
]
self.assert_records(repository, expected)
def test_list_on_collaborative_annotation(self):
project = prepare_project(INTENT_DETECTION_AND_SLOT_FILLING, collaborative_annotation=True)
repository = IntentDetectionSlotFillingRepository(project.item)
self.prepare_data(project)
expected = [
{
"data": self.example.text,
"label": {
"cats": [category.label.text],
"entities": [(span.start_offset, span.end_offset, span.label.text)],
"cats": [self.category1.label.text, self.category2.label.text],
"entities": [(self.span.start_offset, self.span.end_offset, self.span.label.text)],
},
"user": "all",
}
]
records = list(repository.list())
self.assertEqual(len(records), len(expected))
for record, expect in zip(records, expected):
self.assertEqual(record.data, expect["data"])
self.assertEqual(record.label["cats"], expect["label"]["cats"])
self.assertEqual(record.label["entities"], expect["label"]["entities"])
class TestRelationExtractionRepository(unittest.TestCase):
def setUp(self):
self.project = prepare_project(SEQUENCE_LABELING, use_relation=True)
def test_label_per_user(self):
from_entity = mommy.make("Span", start_offset=0, end_offset=1, user=self.project.admin)
to_entity = mommy.make(
"Span", start_offset=1, end_offset=2, example=from_entity.example, user=self.project.admin
)
relation = mommy.make(
"Relation", from_id=from_entity, to_id=to_entity, example=from_entity.example, user=self.project.admin
)
repository = RelationExtractionRepository(self.project.item)
expected = {
"admin": {
"entities": [
{
"id": from_entity.id,
"start_offset": from_entity.start_offset,
"end_offset": from_entity.end_offset,
"label": from_entity.label.text,
},
{
"id": to_entity.id,
"start_offset": to_entity.start_offset,
"end_offset": to_entity.end_offset,
"label": to_entity.label.text,
},
],
"relations": [
{"id": relation.id, "from_id": from_entity.id, "to_id": to_entity.id, "type": relation.type.text}
self.assert_records(repository, expected)
class TestSequenceLabelingRepository(TestRepository):
def prepare_data(self, project):
self.example = mommy.make("Example", project=project.item, text="example")
self.span1 = mommy.make("Span", example=self.example, user=project.admin, start_offset=0, end_offset=1)
self.span2 = mommy.make("Span", example=self.example, user=project.annotator, start_offset=1, end_offset=2)
def test_list(self):
project = prepare_project(SEQUENCE_LABELING)
repository = SequenceLabelingRepository(project.item)
self.prepare_data(project)
expected = [
{
"data": self.example.text,
"label": [(self.span1.start_offset, self.span1.end_offset, self.span1.label.text)],
"user": project.admin.username,
},
{
"data": self.example.text,
"label": [(self.span2.start_offset, self.span2.end_offset, self.span2.label.text)],
"user": project.annotator.username,
},
]
self.assert_records(repository, expected)
def test_list_on_collaborative_annotation(self):
project = prepare_project(SEQUENCE_LABELING, collaborative_annotation=True)
repository = SequenceLabelingRepository(project.item)
self.prepare_data(project)
expected = [
{
"data": self.example.text,
"label": [
(self.span1.start_offset, self.span1.end_offset, self.span1.label.text),
(self.span2.start_offset, self.span2.end_offset, self.span2.label.text),
],
"user": "all",
}
]
self.assert_records(repository, expected)
class TestRelationExtractionRepository(TestRepository):
def test_list(self):
project = prepare_project(SEQUENCE_LABELING, use_relation=True)
example = mommy.make("Example", project=project.item, text="example")
span1 = mommy.make("Span", example=example, user=project.admin, start_offset=0, end_offset=1)
span2 = mommy.make("Span", example=example, user=project.admin, start_offset=1, end_offset=2)
relation = mommy.make("Relation", from_id=span1, to_id=span2, example=example, user=project.admin)
repository = RelationExtractionRepository(project.item)
expected = [
{
"data": example.text,
"label": {
"entities": [
{
"id": span1.id,
"start_offset": span1.start_offset,
"end_offset": span1.end_offset,
"label": span1.label.text,
},
{
"id": span2.id,
"start_offset": span2.start_offset,
"end_offset": span2.end_offset,
"label": span2.label.text,
},
],
"relations": [
{"id": relation.id, "from_id": span1.id, "to_id": span2.id, "type": relation.type.text}
],
},
"user": project.admin.username,
}
]
self.assert_records(repository, expected)
def test_list_on_collaborative_annotation(self):
project = prepare_project(SEQUENCE_LABELING, collaborative_annotation=True, use_relation=True)
example = mommy.make("Example", project=project.item, text="example")
span1 = mommy.make("Span", example=example, user=project.admin, start_offset=0, end_offset=1)
span2 = mommy.make("Span", example=example, user=project.annotator, start_offset=1, end_offset=2)
relation = mommy.make("Relation", from_id=span1, to_id=span2, example=example, user=project.admin)
repository = RelationExtractionRepository(project.item)
expected = [
{
"data": example.text,
"label": {
"entities": [
{
"id": span1.id,
"start_offset": span1.start_offset,
"end_offset": span1.end_offset,
"label": span1.label.text,
},
{
"id": span2.id,
"start_offset": span2.start_offset,
"end_offset": span2.end_offset,
"label": span2.label.text,
},
],
"relations": [
{"id": relation.id, "from_id": span1.id, "to_id": span2.id, "type": relation.type.text}
],
},
"user": "all",
}
]
self.assert_records(repository, expected)
class TestSpeech2TextRepository(TestRepository):
def prepare_data(self, project):
self.example = mommy.make("Example", project=project.item, text="example")
self.text1 = mommy.make("TextLabel", example=self.example, user=project.admin)
self.text2 = mommy.make("TextLabel", example=self.example, user=project.annotator)
def test_list(self):
project = prepare_project(SPEECH2TEXT)
repository = Speech2TextRepository(project.item)
self.prepare_data(project)
expected = [
{
"data": self.example.filename,
"label": [self.text1.text],
"user": project.admin.username,
},
{
"data": self.example.filename,
"label": [self.text2.text],
"user": project.annotator.username,
},
]
self.assert_records(repository, expected)
def test_list_on_collaborative_annotation(self):
project = prepare_project(SPEECH2TEXT, collaborative_annotation=True)
repository = Speech2TextRepository(project.item)
self.prepare_data(project)
expected = [
{
"data": self.example.filename,
"label": [self.text1.text, self.text2.text],
"user": "all",
}
]
self.assert_records(repository, expected)
class TestFileRepository(TestRepository):
def prepare_data(self, project):
self.example = mommy.make("Example", project=project.item, text="example")
self.category1 = mommy.make("Category", example=self.example, user=project.admin)
self.category2 = mommy.make("Category", example=self.example, user=project.annotator)
def test_list(self):
project = prepare_project(IMAGE_CLASSIFICATION)
repository = FileRepository(project.item)
self.prepare_data(project)
expected = [
{
"data": self.example.filename,
"label": [self.category1.label.text],
"user": project.admin.username,
},
{
"data": self.example.filename,
"label": [self.category2.label.text],
"user": project.annotator.username,
},
]
self.assert_records(repository, expected)
def test_list_on_collaborative_annotation(self):
project = prepare_project(IMAGE_CLASSIFICATION, collaborative_annotation=True)
repository = FileRepository(project.item)
self.prepare_data(project)
expected = [
{
"data": self.example.filename,
"label": [self.category1.label.text, self.category2.label.text],
"user": "all",
}
}
actual = repository.label_per_user(from_entity.example)
self.assertDictEqual(actual, expected)
]
self.assert_records(repository, expected)
Loading…
Cancel
Save