From 549e909070e3dbdc38d87e64841ca34a17e3d5a1 Mon Sep 17 00:00:00 2001 From: Hironsan Date: Fri, 22 Apr 2022 11:43:28 +0900 Subject: [PATCH] Enable to accept multiple label types in export dataset --- backend/data_export/celery_tasks.py | 31 +++++++++++++++------- backend/data_export/models.py | 10 ++++++- backend/data_export/pipeline/dataset.py | 9 ++++--- backend/data_export/pipeline/factories.py | 13 +++++++-- backend/data_export/pipeline/formatters.py | 4 +-- backend/data_export/pipeline/labels.py | 13 ++++++++- backend/data_export/pipeline/services.py | 9 ++++--- 7 files changed, 67 insertions(+), 22 deletions(-) diff --git a/backend/data_export/celery_tasks.py b/backend/data_export/celery_tasks.py index eb6164f1..03faa53c 100644 --- a/backend/data_export/celery_tasks.py +++ b/backend/data_export/celery_tasks.py @@ -25,12 +25,16 @@ def create_collaborative_dataset(project: Project, file_format: str, confirmed_o is_collaborative=project.collaborative_annotation, confirmed_only=confirmed_only, ) - label_collection = select_label_collection(project) - labels = create_labels(label_collection, examples=examples) - dataset = Dataset(examples, labels) - formatter = create_formatter(project, file_format)(target_column=label_collection.field_name) writer = create_writer(file_format) - service = ExportApplicationService(dataset, formatter, writer) + label_collections = select_label_collection(project) + formatters = [ + create_formatter(project, file_format)(target_column=label_collection.field_name) + for label_collection in label_collections + ] + labels = [create_labels(label_collection, examples=examples) for label_collection in label_collections] + dataset = Dataset(examples, labels) + + service = ExportApplicationService(dataset, formatters, writer) filepath = os.path.join(settings.MEDIA_URL, f"all.{writer.extension}") service.export(filepath) return filepath @@ -46,12 +50,19 @@ def create_individual_dataset(project: Project, file_format: str, confirmed_only confirmed_only=confirmed_only, user=member.user, ) - label_collection = select_label_collection(project) - labels = create_labels(label_collection, examples=examples, user=member.user) - dataset = Dataset(examples, labels) - formatter = create_formatter(project, file_format)(target_column=label_collection.field_name) writer = create_writer(file_format) - service = ExportApplicationService(dataset, formatter, writer) + label_collections = select_label_collection(project) + formatters = [ + create_formatter(project, file_format)(target_column=label_collection.field_name) + for label_collection in label_collections + ] + labels = [ + create_labels(label_collection, examples=examples, user=member.user) + for label_collection in label_collections + ] + dataset = Dataset(examples, labels) + + service = ExportApplicationService(dataset, formatters, writer) filepath = os.path.join(settings.MEDIA_URL, f"{member.username}.{writer.extension}") service.export(filepath) files.append(filepath) diff --git a/backend/data_export/models.py b/backend/data_export/models.py index c1c8ea48..34b6c78b 100644 --- a/backend/data_export/models.py +++ b/backend/data_export/models.py @@ -2,7 +2,7 @@ from typing import Any, Dict, Protocol, Tuple from django.db import models -from labels.models import Category, Span +from labels.models import Category, Relation, Span class ExportedLabelManager(models.Manager): @@ -44,3 +44,11 @@ class ExportedSpan(Span): class Meta: proxy = True + + +class ExportedRelation(Relation): + def to_dict(self): + return {"id": self.id, "from_id": self.from_id.id, "to_id": self.to_id.id, "type": self.type.text} + + class Meta: + proxy = True diff --git a/backend/data_export/pipeline/dataset.py b/backend/data_export/pipeline/dataset.py index 50e038a0..a69c7622 100644 --- a/backend/data_export/pipeline/dataset.py +++ b/backend/data_export/pipeline/dataset.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Iterator +from typing import Any, Dict, Iterator, List import pandas as pd from django.db.models.query import QuerySet @@ -18,13 +18,16 @@ def filter_examples(examples: QuerySet[Example], is_collaborative=False, confirm class Dataset: - def __init__(self, examples: QuerySet[Example], labels: Labels): + def __init__(self, examples: QuerySet[Example], labels: List[Labels]): self.examples = examples self.labels = labels def __iter__(self) -> Iterator[Dict[str, Any]]: for example in self.examples: - yield {"id": example.id, "data": example.text, **example.meta, **self.labels.find_by(example.id)} + data = {"id": example.id, "data": example.text, **example.meta} + for labels in self.labels: + data.update(**labels.find_by(example.id)) + yield data def to_dataframe(self) -> pd.DataFrame: return pd.DataFrame(self) diff --git a/backend/data_export/pipeline/factories.py b/backend/data_export/pipeline/factories.py index ed0673f6..2bd074a0 100644 --- a/backend/data_export/pipeline/factories.py +++ b/backend/data_export/pipeline/factories.py @@ -45,13 +45,18 @@ def create_writer(file_format: str) -> writers.Writer: def create_formatter(project, file_format: str): + use_relation = getattr(project, "use_relation", False) mapping = { DOCUMENT_CLASSIFICATION: { catalog.CSV.name: formatters.JoinedCategoryFormatter, catalog.JSON.name: formatters.ListedCategoryFormatter, catalog.JSONL.name: formatters.ListedCategoryFormatter, }, - SEQUENCE_LABELING: {catalog.JSONL.name: formatters.TupledSpanFormatter}, + SEQUENCE_LABELING: { + catalog.JSONL.name: [formatters.DictFormatter, formatters.DictFormatter] + if use_relation + else [formatters.TupledSpanFormatter] + }, SEQ2SEQ: {}, IMAGE_CLASSIFICATION: {}, SPEECH2TEXT: {}, @@ -61,7 +66,11 @@ def create_formatter(project, file_format: str): def select_label_collection(project): - mapping = {DOCUMENT_CLASSIFICATION: labels.Categories, SEQUENCE_LABELING: labels.Spans} + use_relation = getattr(project, "use_relation", False) + mapping = { + DOCUMENT_CLASSIFICATION: [labels.Categories], + SEQUENCE_LABELING: [labels.Spans, labels.Relations] if use_relation else [labels.Spans], + } return mapping[project.project_type] diff --git a/backend/data_export/pipeline/formatters.py b/backend/data_export/pipeline/formatters.py index 8d88884a..b72bf359 100644 --- a/backend/data_export/pipeline/formatters.py +++ b/backend/data_export/pipeline/formatters.py @@ -66,12 +66,12 @@ class TupledSpanFormatter(Formatter): return dataset -class DictSpanFormatter(Formatter): +class DictFormatter(Formatter): def format(self, dataset: pd.DataFrame) -> pd.DataFrame: if self.target_column not in dataset.columns: return dataset dataset[self.target_column] = dataset[self.target_column].apply( - lambda spans: [span.to_dict() for span in spans] + lambda labels: [label.to_dict() for label in labels] ) return dataset diff --git a/backend/data_export/pipeline/labels.py b/backend/data_export/pipeline/labels.py index 4ca890c9..394e5bf3 100644 --- a/backend/data_export/pipeline/labels.py +++ b/backend/data_export/pipeline/labels.py @@ -7,7 +7,12 @@ from typing import Dict, List from django.db.models import QuerySet -from data_export.models import ExportedCategory, ExportedLabel, ExportedSpan +from data_export.models import ( + ExportedCategory, + ExportedLabel, + ExportedRelation, + ExportedSpan, +) from examples.models import Example @@ -38,3 +43,9 @@ class Spans(Labels): label_class = ExportedSpan field_name = "entities" fields = ("example", "label") + + +class Relations(Labels): + label_class = ExportedRelation + field_name = "relations" + fields = ("example", "type") diff --git a/backend/data_export/pipeline/services.py b/backend/data_export/pipeline/services.py index 0916956f..ed1f583f 100644 --- a/backend/data_export/pipeline/services.py +++ b/backend/data_export/pipeline/services.py @@ -1,16 +1,19 @@ +from typing import List + from .dataset import Dataset from .formatters import Formatter from .writers import Writer class ExportApplicationService: - def __init__(self, dataset: Dataset, formatter: Formatter, writer: Writer): + def __init__(self, dataset: Dataset, formatters: List[Formatter], writer: Writer): self.dataset = dataset - self.formatter = formatter + self.formatters = formatters self.writer = writer def export(self, file): dataset = self.dataset.to_dataframe() - dataset = self.formatter.format(dataset) + for formatter in self.formatters: + dataset = formatter.format(dataset) self.writer.write(file, dataset) return file