Browse Source

Update factories of data export

pull/1799/head
Hironsan 3 years ago
parent
commit
12573ec4bb
2 changed files with 39 additions and 48 deletions
  1. 32
      backend/data_export/celery_tasks.py
  2. 55
      backend/data_export/pipeline/factories.py

32
backend/data_export/celery_tasks.py

@ -6,12 +6,7 @@ from django.conf import settings
from django.shortcuts import get_object_or_404
from .pipeline.dataset import Dataset
from .pipeline.factories import (
create_labels,
select_formatter,
select_label_collection,
select_writer,
)
from .pipeline.factories import create_formatter, create_labels, create_writer
from .pipeline.services import ExportApplicationService
from .pipeline.writers import remove_files, zip_files
from data_export.models import ExportedExample
@ -25,16 +20,11 @@ def create_collaborative_dataset(project: Project, file_format: str, confirmed_o
examples = ExportedExample.objects.confirmed(is_collaborative=project.collaborative_annotation)
else:
examples = ExportedExample.objects.all()
writer = select_writer(file_format)()
label_collections = select_label_collection(project)
formatter_classes = select_formatter(project, file_format)
formatters = [
formatter(target_column=label_collection.column)
for formatter, label_collection in zip(formatter_classes, label_collections)
]
labels = [create_labels(label_collection, examples=examples) for label_collection in label_collections]
labels = create_labels(project, examples)
dataset = Dataset(examples, labels)
formatters = create_formatter(project, file_format)
writer = create_writer(file_format)
service = ExportApplicationService(dataset, formatters, writer)
filepath = os.path.join(settings.MEDIA_URL, f"all.{writer.extension}")
service.export(filepath)
@ -51,19 +41,11 @@ def create_individual_dataset(project: Project, file_format: str, confirmed_only
)
else:
examples = ExportedExample.objects.all()
writer = select_writer(file_format)()
label_collections = select_label_collection(project)
formatter_classes = select_formatter(project, file_format)
formatters = [
formatter(target_column=label_collection.column)
for formatter, label_collection in zip(formatter_classes, label_collections)
]
labels = [
create_labels(label_collection, examples=examples, user=member.user)
for label_collection in label_collections
]
labels = create_labels(project, examples, member.user)
dataset = Dataset(examples, labels)
formatters = create_formatter(project, file_format)
writer = create_writer(file_format)
service = ExportApplicationService(dataset, formatters, writer)
filepath = os.path.join(settings.MEDIA_URL, f"{member.username}.{writer.extension}")
service.export(filepath)

55
backend/data_export/pipeline/factories.py

@ -12,56 +12,63 @@ from projects.models import (
SEQ2SEQ,
SEQUENCE_LABELING,
SPEECH2TEXT,
Project,
)
def select_writer(file_format: str) -> Type[writers.Writer]:
def create_writer(file_format: str) -> writers.Writer:
mapping = {
catalog.CSV.name: writers.CsvWriter,
catalog.JSON.name: writers.JsonWriter,
catalog.JSONL.name: writers.JsonlWriter,
catalog.FastText.name: writers.FastTextWriter,
catalog.CSV.name: writers.CsvWriter(),
catalog.JSON.name: writers.JsonWriter(),
catalog.JSONL.name: writers.JsonlWriter(),
catalog.FastText.name: writers.FastTextWriter(),
}
if file_format not in mapping:
ValueError(f"Invalid format: {file_format}")
return mapping[file_format]
def select_formatter(project, file_format: str) -> List[Type[formatters.Formatter]]:
def create_formatter(project: Project, file_format: str) -> List[formatters.Formatter]:
use_relation = getattr(project, "use_relation", False)
mapping: Dict[str, Dict[str, List[Type[formatters.Formatter]]]] = {
mapping: Dict[str, Dict[str, List[formatters.Formatter]]] = {
DOCUMENT_CLASSIFICATION: {
catalog.CSV.name: [formatters.JoinedCategoryFormatter],
catalog.JSON.name: [formatters.ListedCategoryFormatter],
catalog.JSONL.name: [formatters.ListedCategoryFormatter],
catalog.FastText.name: [formatters.FastTextCategoryFormatter],
catalog.CSV.name: [formatters.JoinedCategoryFormatter(labels.Categories.column)],
catalog.JSON.name: [formatters.ListedCategoryFormatter(labels.Categories.column)],
catalog.JSONL.name: [formatters.ListedCategoryFormatter(labels.Categories.column)],
catalog.FastText.name: [formatters.FastTextCategoryFormatter(labels.Categories.column)],
},
SEQUENCE_LABELING: {
catalog.JSONL.name: [formatters.DictFormatter, formatters.DictFormatter]
catalog.JSONL.name: [
formatters.DictFormatter(labels.Spans.column),
formatters.DictFormatter(labels.Relations.column),
]
if use_relation
else [formatters.TupledSpanFormatter]
else [formatters.TupledSpanFormatter(labels.Spans.column)]
},
SEQ2SEQ: {
catalog.CSV.name: [formatters.JoinedCategoryFormatter],
catalog.JSON.name: [formatters.ListedCategoryFormatter],
catalog.JSONL.name: [formatters.ListedCategoryFormatter],
catalog.CSV.name: [formatters.JoinedCategoryFormatter(labels.Texts.column)],
catalog.JSON.name: [formatters.ListedCategoryFormatter(labels.Texts.column)],
catalog.JSONL.name: [formatters.ListedCategoryFormatter(labels.Texts.column)],
},
IMAGE_CLASSIFICATION: {
catalog.JSONL.name: [formatters.ListedCategoryFormatter],
catalog.JSONL.name: [formatters.ListedCategoryFormatter(labels.Categories.column)],
},
SPEECH2TEXT: {
catalog.JSONL.name: [formatters.ListedCategoryFormatter],
catalog.JSONL.name: [formatters.ListedCategoryFormatter(labels.Texts.column)],
},
INTENT_DETECTION_AND_SLOT_FILLING: {
catalog.JSONL.name: [formatters.ListedCategoryFormatter, formatters.TupledSpanFormatter]
catalog.JSONL.name: [
formatters.ListedCategoryFormatter(labels.Categories.column),
formatters.TupledSpanFormatter(labels.Spans.column),
]
},
}
return mapping[project.project_type][file_format]
def select_label_collection(project):
def select_label_collection(project: Project) -> List[Type[Labels]]:
use_relation = getattr(project, "use_relation", False)
mapping = {
mapping: Dict[str, List[Type[Labels]]] = {
DOCUMENT_CLASSIFICATION: [labels.Categories],
SEQUENCE_LABELING: [labels.Spans, labels.Relations] if use_relation else [labels.Spans],
SEQ2SEQ: [labels.Texts],
@ -72,5 +79,7 @@ def select_label_collection(project):
return mapping[project.project_type]
def create_labels(label_collection_class: Type[Labels], examples: QuerySet[ExportedExample], user=None):
return label_collection_class(examples, user)
def create_labels(project: Project, examples: QuerySet[ExportedExample], user=None) -> List[Labels]:
label_collections = select_label_collection(project)
labels = [label_collection(examples=examples, user=user) for label_collection in label_collections]
return labels
Loading…
Cancel
Save