From 12573ec4bb129d02f795a7ef3d4c569ff297c9a7 Mon Sep 17 00:00:00 2001 From: Hironsan Date: Mon, 25 Apr 2022 07:00:53 +0900 Subject: [PATCH] Update factories of data export --- backend/data_export/celery_tasks.py | 32 +++---------- backend/data_export/pipeline/factories.py | 55 +++++++++++++---------- 2 files changed, 39 insertions(+), 48 deletions(-) diff --git a/backend/data_export/celery_tasks.py b/backend/data_export/celery_tasks.py index f11e72c9..44117379 100644 --- a/backend/data_export/celery_tasks.py +++ b/backend/data_export/celery_tasks.py @@ -6,12 +6,7 @@ from django.conf import settings from django.shortcuts import get_object_or_404 from .pipeline.dataset import Dataset -from .pipeline.factories import ( - create_labels, - select_formatter, - select_label_collection, - select_writer, -) +from .pipeline.factories import create_formatter, create_labels, create_writer from .pipeline.services import ExportApplicationService from .pipeline.writers import remove_files, zip_files from data_export.models import ExportedExample @@ -25,16 +20,11 @@ def create_collaborative_dataset(project: Project, file_format: str, confirmed_o examples = ExportedExample.objects.confirmed(is_collaborative=project.collaborative_annotation) else: examples = ExportedExample.objects.all() - writer = select_writer(file_format)() - label_collections = select_label_collection(project) - formatter_classes = select_formatter(project, file_format) - formatters = [ - formatter(target_column=label_collection.column) - for formatter, label_collection in zip(formatter_classes, label_collections) - ] - labels = [create_labels(label_collection, examples=examples) for label_collection in label_collections] + labels = create_labels(project, examples) dataset = Dataset(examples, labels) + formatters = create_formatter(project, file_format) + writer = create_writer(file_format) service = ExportApplicationService(dataset, formatters, writer) filepath = os.path.join(settings.MEDIA_URL, f"all.{writer.extension}") service.export(filepath) @@ -51,19 +41,11 @@ def create_individual_dataset(project: Project, file_format: str, confirmed_only ) else: examples = ExportedExample.objects.all() - writer = select_writer(file_format)() - label_collections = select_label_collection(project) - formatter_classes = select_formatter(project, file_format) - formatters = [ - formatter(target_column=label_collection.column) - for formatter, label_collection in zip(formatter_classes, label_collections) - ] - labels = [ - create_labels(label_collection, examples=examples, user=member.user) - for label_collection in label_collections - ] + labels = create_labels(project, examples, member.user) dataset = Dataset(examples, labels) + formatters = create_formatter(project, file_format) + writer = create_writer(file_format) service = ExportApplicationService(dataset, formatters, writer) filepath = os.path.join(settings.MEDIA_URL, f"{member.username}.{writer.extension}") service.export(filepath) diff --git a/backend/data_export/pipeline/factories.py b/backend/data_export/pipeline/factories.py index 134733f9..583b4adb 100644 --- a/backend/data_export/pipeline/factories.py +++ b/backend/data_export/pipeline/factories.py @@ -12,56 +12,63 @@ from projects.models import ( SEQ2SEQ, SEQUENCE_LABELING, SPEECH2TEXT, + Project, ) -def select_writer(file_format: str) -> Type[writers.Writer]: +def create_writer(file_format: str) -> writers.Writer: mapping = { - catalog.CSV.name: writers.CsvWriter, - catalog.JSON.name: writers.JsonWriter, - catalog.JSONL.name: writers.JsonlWriter, - catalog.FastText.name: writers.FastTextWriter, + catalog.CSV.name: writers.CsvWriter(), + catalog.JSON.name: writers.JsonWriter(), + catalog.JSONL.name: writers.JsonlWriter(), + catalog.FastText.name: writers.FastTextWriter(), } if file_format not in mapping: ValueError(f"Invalid format: {file_format}") return mapping[file_format] -def select_formatter(project, file_format: str) -> List[Type[formatters.Formatter]]: +def create_formatter(project: Project, file_format: str) -> List[formatters.Formatter]: use_relation = getattr(project, "use_relation", False) - mapping: Dict[str, Dict[str, List[Type[formatters.Formatter]]]] = { + mapping: Dict[str, Dict[str, List[formatters.Formatter]]] = { DOCUMENT_CLASSIFICATION: { - catalog.CSV.name: [formatters.JoinedCategoryFormatter], - catalog.JSON.name: [formatters.ListedCategoryFormatter], - catalog.JSONL.name: [formatters.ListedCategoryFormatter], - catalog.FastText.name: [formatters.FastTextCategoryFormatter], + catalog.CSV.name: [formatters.JoinedCategoryFormatter(labels.Categories.column)], + catalog.JSON.name: [formatters.ListedCategoryFormatter(labels.Categories.column)], + catalog.JSONL.name: [formatters.ListedCategoryFormatter(labels.Categories.column)], + catalog.FastText.name: [formatters.FastTextCategoryFormatter(labels.Categories.column)], }, SEQUENCE_LABELING: { - catalog.JSONL.name: [formatters.DictFormatter, formatters.DictFormatter] + catalog.JSONL.name: [ + formatters.DictFormatter(labels.Spans.column), + formatters.DictFormatter(labels.Relations.column), + ] if use_relation - else [formatters.TupledSpanFormatter] + else [formatters.TupledSpanFormatter(labels.Spans.column)] }, SEQ2SEQ: { - catalog.CSV.name: [formatters.JoinedCategoryFormatter], - catalog.JSON.name: [formatters.ListedCategoryFormatter], - catalog.JSONL.name: [formatters.ListedCategoryFormatter], + catalog.CSV.name: [formatters.JoinedCategoryFormatter(labels.Texts.column)], + catalog.JSON.name: [formatters.ListedCategoryFormatter(labels.Texts.column)], + catalog.JSONL.name: [formatters.ListedCategoryFormatter(labels.Texts.column)], }, IMAGE_CLASSIFICATION: { - catalog.JSONL.name: [formatters.ListedCategoryFormatter], + catalog.JSONL.name: [formatters.ListedCategoryFormatter(labels.Categories.column)], }, SPEECH2TEXT: { - catalog.JSONL.name: [formatters.ListedCategoryFormatter], + catalog.JSONL.name: [formatters.ListedCategoryFormatter(labels.Texts.column)], }, INTENT_DETECTION_AND_SLOT_FILLING: { - catalog.JSONL.name: [formatters.ListedCategoryFormatter, formatters.TupledSpanFormatter] + catalog.JSONL.name: [ + formatters.ListedCategoryFormatter(labels.Categories.column), + formatters.TupledSpanFormatter(labels.Spans.column), + ] }, } return mapping[project.project_type][file_format] -def select_label_collection(project): +def select_label_collection(project: Project) -> List[Type[Labels]]: use_relation = getattr(project, "use_relation", False) - mapping = { + mapping: Dict[str, List[Type[Labels]]] = { DOCUMENT_CLASSIFICATION: [labels.Categories], SEQUENCE_LABELING: [labels.Spans, labels.Relations] if use_relation else [labels.Spans], SEQ2SEQ: [labels.Texts], @@ -72,5 +79,7 @@ def select_label_collection(project): return mapping[project.project_type] -def create_labels(label_collection_class: Type[Labels], examples: QuerySet[ExportedExample], user=None): - return label_collection_class(examples, user) +def create_labels(project: Project, examples: QuerySet[ExportedExample], user=None) -> List[Labels]: + label_collections = select_label_collection(project) + labels = [label_collection(examples=examples, user=user) for label_collection in label_collections] + return labels