From aa13d5c97f5c372740a2268750f9830d435a3937 Mon Sep 17 00:00:00 2001 From: Hironsan Date: Fri, 22 Apr 2022 09:25:55 +0900 Subject: [PATCH] Enable to export dataset --- backend/data_export/celery_tasks.py | 64 +++++++++-- backend/data_export/pipeline/catalog.py | 4 +- backend/data_export/pipeline/dataset.py | 20 ++-- backend/data_export/pipeline/factories.py | 43 ++++++-- backend/data_export/pipeline/formatters.py | 9 ++ backend/data_export/pipeline/labels.py | 8 +- backend/data_export/pipeline/services.py | 19 ++-- backend/data_export/pipeline/writers.py | 12 ++- backend/data_export/tests/test_task.py | 117 +++++++++++++++++++++ backend/projects/tests/utils.py | 4 + 10 files changed, 264 insertions(+), 36 deletions(-) create mode 100644 backend/data_export/tests/test_task.py diff --git a/backend/data_export/celery_tasks.py b/backend/data_export/celery_tasks.py index 63a40283..eb6164f1 100644 --- a/backend/data_export/celery_tasks.py +++ b/backend/data_export/celery_tasks.py @@ -1,20 +1,70 @@ +import os + from celery import shared_task from celery.utils.log import get_task_logger from django.conf import settings from django.shortcuts import get_object_or_404 -from .pipeline.factories import create_repository, create_writer +from .pipeline.dataset import Dataset, filter_examples +from .pipeline.factories import ( + create_formatter, + create_labels, + create_writer, + select_label_collection, +) from .pipeline.services import ExportApplicationService -from projects.models import Project +from .pipeline.writers import zip_files +from projects.models import Member, Project logger = get_task_logger(__name__) +def create_collaborative_dataset(project: Project, file_format: str, confirmed_only: bool): + examples = filter_examples( + examples=project.examples.all(), + is_collaborative=project.collaborative_annotation, + confirmed_only=confirmed_only, + ) + label_collection = select_label_collection(project) + labels = create_labels(label_collection, examples=examples) + dataset = Dataset(examples, labels) + formatter = create_formatter(project, file_format)(target_column=label_collection.field_name) + writer = create_writer(file_format) + service = ExportApplicationService(dataset, formatter, writer) + filepath = os.path.join(settings.MEDIA_URL, f"all.{writer.extension}") + service.export(filepath) + return filepath + + +def create_individual_dataset(project: Project, file_format: str, confirmed_only: bool): + files = [] + members = Member.objects.filter(project=project) + for member in members: + examples = filter_examples( + examples=project.examples.all(), + is_collaborative=project.collaborative_annotation, + confirmed_only=confirmed_only, + user=member.user, + ) + label_collection = select_label_collection(project) + labels = create_labels(label_collection, examples=examples, user=member.user) + dataset = Dataset(examples, labels) + formatter = create_formatter(project, file_format)(target_column=label_collection.field_name) + writer = create_writer(file_format) + service = ExportApplicationService(dataset, formatter, writer) + filepath = os.path.join(settings.MEDIA_URL, f"{member.username}.{writer.extension}") + service.export(filepath) + files.append(filepath) + zip_file = zip_files(files, settings.MEDIA_URL) + for file in files: + os.remove(file) + return zip_file + + @shared_task def export_dataset(project_id, file_format: str, export_approved=False): project = get_object_or_404(Project, pk=project_id) - repository = create_repository(project, file_format) - writer = create_writer(file_format)(settings.MEDIA_ROOT) - service = ExportApplicationService(repository, writer) - filepath = service.export(export_approved) - return filepath + if project.collaborative_annotation: + return create_collaborative_dataset(project, file_format, export_approved) + else: + return create_individual_dataset(project, file_format, export_approved) diff --git a/backend/data_export/pipeline/catalog.py b/backend/data_export/pipeline/catalog.py index de384081..7a4cfdd2 100644 --- a/backend/data_export/pipeline/catalog.py +++ b/backend/data_export/pipeline/catalog.py @@ -86,7 +86,7 @@ Options.register(DOCUMENT_CLASSIFICATION, JSONL, OptionNone, examples.CATEGORY_J # Sequence Labeling Options.register(SEQUENCE_LABELING, JSONL, OptionNone, examples.SPAN_JSONL) -Options.register(SEQUENCE_LABELING, JSONLRelation, OptionNone, examples.ENTITY_AND_RELATION_JSONL) +Options.register(SEQUENCE_LABELING, JSONL, OptionNone, examples.ENTITY_AND_RELATION_JSONL) # Sequence to sequence Options.register(SEQ2SEQ, CSV, OptionDelimiter, examples.TEXT_CSV) @@ -94,7 +94,7 @@ Options.register(SEQ2SEQ, JSON, OptionNone, examples.TEXT_JSON) Options.register(SEQ2SEQ, JSONL, OptionNone, examples.TEXT_JSONL) # Intent detection and slot filling -Options.register(INTENT_DETECTION_AND_SLOT_FILLING, IntentAndSlot, OptionNone, examples.INTENT_JSONL) +Options.register(INTENT_DETECTION_AND_SLOT_FILLING, JSONL, OptionNone, examples.INTENT_JSONL) # Image Classification Options.register(IMAGE_CLASSIFICATION, JSONL, OptionNone, examples.CATEGORY_IMAGE_CLASSIFICATION) diff --git a/backend/data_export/pipeline/dataset.py b/backend/data_export/pipeline/dataset.py index 1bfe8817..50e038a0 100644 --- a/backend/data_export/pipeline/dataset.py +++ b/backend/data_export/pipeline/dataset.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Iterator, Type +from typing import Any, Dict, Iterator import pandas as pd from django.db.models.query import QuerySet @@ -7,16 +7,24 @@ from .labels import Labels from examples.models import Example +def filter_examples(examples: QuerySet[Example], is_collaborative=False, confirmed_only=False, user=None): + if is_collaborative and confirmed_only: + return examples.exclude(states=None) + elif not is_collaborative and confirmed_only: + assert user is not None + return examples.filter(states__confirmed_by=user) + else: + return examples + + class Dataset: - def __init__(self, examples: QuerySet[Example], user, label_collection_class: Type[Labels], confirmed_only=False): - if confirmed_only: - examples = examples.filter(states__confirmed_by=user) + def __init__(self, examples: QuerySet[Example], labels: Labels): self.examples = examples - self.labels = label_collection_class(examples, user) + self.labels = labels def __iter__(self) -> Iterator[Dict[str, Any]]: for example in self.examples: yield {"id": example.id, "data": example.text, **example.meta, **self.labels.find_by(example.id)} - def to_pandas(self) -> pd.DataFrame: + def to_dataframe(self) -> pd.DataFrame: return pd.DataFrame(self) diff --git a/backend/data_export/pipeline/factories.py b/backend/data_export/pipeline/factories.py index 9096a442..6bd43984 100644 --- a/backend/data_export/pipeline/factories.py +++ b/backend/data_export/pipeline/factories.py @@ -1,6 +1,10 @@ from typing import Type -from . import catalog, repositories, writers +from django.db.models import QuerySet + +from . import catalog, formatters, labels, repositories, writers +from .labels import Labels +from examples.models import Example from projects.models import ( DOCUMENT_CLASSIFICATION, IMAGE_CLASSIFICATION, @@ -28,15 +32,38 @@ def create_repository(project, file_format: str): return repository -def create_writer(file_format: str) -> Type[writers.BaseWriter]: +def create_writer(file_format: str) -> writers.Writer: mapping = { catalog.CSV.name: writers.CsvWriter, - catalog.JSON.name: writers.JSONWriter, - catalog.JSONL.name: writers.JSONLWriter, - catalog.FastText.name: writers.FastTextWriter, - catalog.IntentAndSlot.name: writers.IntentAndSlotWriter, - catalog.JSONLRelation.name: writers.EntityAndRelationWriter, + catalog.JSON.name: writers.JsonWriter, + catalog.JSONL.name: writers.JsonlWriter, + # catalog.FastText.name: writers.FastTextWriter, } if file_format not in mapping: ValueError(f"Invalid format: {file_format}") - return mapping[file_format] + return mapping[file_format]() + + +def create_formatter(project, file_format: str): + mapping = { + DOCUMENT_CLASSIFICATION: { + catalog.CSV.name: formatters.JoinedCategoryFormatter, + catalog.JSON.name: formatters.ListedCategoryFormatter, + catalog.JSONL.name: formatters.ListedCategoryFormatter, + }, + SEQUENCE_LABELING: {}, + SEQ2SEQ: {}, + IMAGE_CLASSIFICATION: {}, + SPEECH2TEXT: {}, + INTENT_DETECTION_AND_SLOT_FILLING: {}, + } + return mapping[project.project_type][file_format] + + +def select_label_collection(project): + mapping = {DOCUMENT_CLASSIFICATION: labels.Categories} + return mapping[project.project_type] + + +def create_labels(label_collection_class: Type[Labels], examples: QuerySet[Example], user=None): + return label_collection_class(examples, user) diff --git a/backend/data_export/pipeline/formatters.py b/backend/data_export/pipeline/formatters.py index 7945f681..3cdb4ed7 100644 --- a/backend/data_export/pipeline/formatters.py +++ b/backend/data_export/pipeline/formatters.py @@ -18,6 +18,9 @@ class Formatter(abc.ABC): class JoinedCategoryFormatter(Formatter): def format(self, dataset: pd.DataFrame) -> pd.DataFrame: """Format the label column to `LabelA#LabelB` format.""" + if self.target_column not in dataset.columns: + return dataset + dataset[self.target_column] = dataset[self.target_column].apply( lambda labels: "#".join(sorted(label.to_string() for label in labels)) ) @@ -27,6 +30,9 @@ class JoinedCategoryFormatter(Formatter): class ListedCategoryFormatter(Formatter): def format(self, dataset: pd.DataFrame) -> pd.DataFrame: """Format the label column to `['LabelA', 'LabelB']` format.""" + if self.target_column not in dataset.columns: + return dataset + dataset[self.target_column] = dataset[self.target_column].apply( lambda labels: sorted([label.to_string() for label in labels]) ) @@ -38,6 +44,9 @@ class FastTextCategoryFormatter(Formatter): """Format the label column to `__label__LabelA __label__LabelB` format. Also, drop the columns except for `data` and `self.target_column`. """ + if self.target_column not in dataset.columns: + return dataset + dataset = dataset[["data", self.target_column]] dataset[self.target_column] = dataset[self.target_column].apply( lambda labels: sorted(f"__label__{label.to_string()}" for label in labels) diff --git a/backend/data_export/pipeline/labels.py b/backend/data_export/pipeline/labels.py index 8b9983df..8c30735d 100644 --- a/backend/data_export/pipeline/labels.py +++ b/backend/data_export/pipeline/labels.py @@ -16,10 +16,12 @@ class Labels(abc.ABC): field_name = "labels" fields = ("example", "label") - def __init__(self, examples: QuerySet[Example], user): + def __init__(self, examples: QuerySet[Example], user=None): self.label_groups = defaultdict(list) - labels = self.label_class.objects.filter(example__in=examples, user=user).select_related(*self.fields) - for label in labels: + labels = self.label_class.objects.filter(example__in=examples) + if user: + labels = labels.filter(user=user) + for label in labels.select_related(*self.fields): self.label_groups[label.example.id].append(label) def find_by(self, example_id: int) -> Dict[str, List[ExportedLabel]]: diff --git a/backend/data_export/pipeline/services.py b/backend/data_export/pipeline/services.py index 3f94b7e7..0916956f 100644 --- a/backend/data_export/pipeline/services.py +++ b/backend/data_export/pipeline/services.py @@ -1,13 +1,16 @@ -from .repositories import BaseRepository -from .writers import BaseWriter +from .dataset import Dataset +from .formatters import Formatter +from .writers import Writer class ExportApplicationService: - def __init__(self, repository: BaseRepository, writer: BaseWriter): - self.repository = repository + def __init__(self, dataset: Dataset, formatter: Formatter, writer: Writer): + self.dataset = dataset + self.formatter = formatter self.writer = writer - def export(self, export_approved=False) -> str: - records = self.repository.list(export_approved=export_approved) - filepath = self.writer.write(records) - return filepath + def export(self, file): + dataset = self.dataset.to_dataframe() + dataset = self.formatter.format(dataset) + self.writer.write(file, dataset) + return file diff --git a/backend/data_export/pipeline/writers.py b/backend/data_export/pipeline/writers.py index 013eb0b6..dbe6be54 100644 --- a/backend/data_export/pipeline/writers.py +++ b/backend/data_export/pipeline/writers.py @@ -6,8 +6,8 @@ import zipfile import pandas as pd -def zip_files(files): - save_file = f"{uuid.uuid4()}.zip" +def zip_files(files, dirname): + save_file = os.path.join(dirname, f"{uuid.uuid4()}.zip") with zipfile.ZipFile(save_file, "w", compression=zipfile.ZIP_DEFLATED) as zf: for file in files: zf.write(filename=file, arcname=os.path.basename(file)) @@ -15,6 +15,8 @@ def zip_files(files): class Writer(abc.ABC): + extension = "" + @staticmethod @abc.abstractmethod def write(file, dataset: pd.DataFrame): @@ -22,18 +24,24 @@ class Writer(abc.ABC): class CsvWriter(Writer): + extension = "csv" + @staticmethod def write(file, dataset: pd.DataFrame): dataset.to_csv(file, index=False, encoding="utf-8") class JsonWriter(Writer): + extension = "json" + @staticmethod def write(file, dataset: pd.DataFrame): dataset.to_json(file, orient="records", force_ascii=False) class JsonlWriter(Writer): + extension = "jsonl" + @staticmethod def write(file, dataset: pd.DataFrame): dataset.to_json(file, orient="records", force_ascii=False, lines=True) diff --git a/backend/data_export/tests/test_task.py b/backend/data_export/tests/test_task.py new file mode 100644 index 00000000..5dac4295 --- /dev/null +++ b/backend/data_export/tests/test_task.py @@ -0,0 +1,117 @@ +import os +import zipfile + +import numpy as np +import pandas as pd +from django.test import TestCase, override_settings +from model_mommy import mommy +from pandas.testing import assert_frame_equal + +from ..celery_tasks import export_dataset +from projects.models import DOCUMENT_CLASSIFICATION +from projects.tests.utils import prepare_project + + +def read_zip_content(file): + datasets = {} + with zipfile.ZipFile(file) as z: + for file in z.filelist: + username = file.filename.split(".")[0] + with z.open(file) as f: + try: + df = pd.read_csv(f) + except pd.errors.EmptyDataError: + continue + datasets[username] = df + return datasets + + +@override_settings(MEDIA_URL=os.path.dirname(__file__)) +class TestExportTask(TestCase): + def prepare_data(self, collaborative=False): + self.project = prepare_project(DOCUMENT_CLASSIFICATION, collaborative_annotation=collaborative) + self.example1 = mommy.make("Example", project=self.project.item, text="confirmed") + self.example2 = mommy.make("Example", project=self.project.item, text="unconfirmed") + self.category1 = mommy.make("Category", example=self.example1, user=self.project.admin) + self.category2 = mommy.make("Category", example=self.example1, user=self.project.annotator) + mommy.make("ExampleState", example=self.example1, confirmed_by=self.project.admin) + + def test_unconfirmed_and_non_collaborative(self): + self.prepare_data() + file = export_dataset(self.project.id, "CSV", False) + datasets = read_zip_content(file) + os.remove(file) + expected_datasets = { + self.project.admin.username: pd.DataFrame( + [ + {"id": self.example1.id, "data": self.example1.text, "categories": self.category1.label.text}, + {"id": self.example2.id, "data": self.example2.text, "categories": np.nan}, + ] + ), + self.project.approver.username: pd.DataFrame( + [ + {"id": self.example1.id, "data": self.example1.text, "categories": np.nan}, + {"id": self.example2.id, "data": self.example2.text, "categories": np.nan}, + ] + ), + self.project.annotator.username: pd.DataFrame( + [ + {"id": self.example1.id, "data": self.example1.text, "categories": self.category2.label.text}, + {"id": self.example2.id, "data": self.example2.text, "categories": np.nan}, + ] + ), + } + for username, dataset in expected_datasets.items(): + assert_frame_equal(dataset, datasets[username]) + + def test_unconfirmed_and_collaborative(self): + self.prepare_data(collaborative=True) + file = export_dataset(self.project.id, "CSV", False) + dataset = pd.read_csv(file) + os.remove(file) + expected_dataset = pd.DataFrame( + [ + { + "id": self.example1.id, + "data": self.example1.text, + "categories": "#".join(sorted([self.category1.label.text, self.category2.label.text])), + }, + {"id": self.example2.id, "data": self.example2.text, "categories": np.nan}, + ] + ) + assert_frame_equal(dataset, expected_dataset) + + def test_confirmed_and_non_collaborative(self): + self.prepare_data() + file = export_dataset(self.project.id, "CSV", True) + datasets = read_zip_content(file) + os.remove(file) + expected_datasets = { + self.project.admin.username: pd.DataFrame( + [ + { + "id": self.example1.id, + "data": self.example1.text, + "categories": self.category1.label.text, + } + ] + ) + } + for username, dataset in expected_datasets.items(): + assert_frame_equal(dataset, datasets[username]) + + def test_confirmed_and_collaborative(self): + self.prepare_data(collaborative=True) + file = export_dataset(self.project.id, "CSV", True) + dataset = pd.read_csv(file) + os.remove(file) + expected_dataset = pd.DataFrame( + [ + { + "id": self.example1.id, + "data": self.example1.text, + "categories": "#".join(sorted([self.category1.label.text, self.category2.label.text])), + } + ] + ) + assert_frame_equal(dataset, expected_dataset) diff --git a/backend/projects/tests/utils.py b/backend/projects/tests/utils.py index ccd201c7..8ddf074f 100644 --- a/backend/projects/tests/utils.py +++ b/backend/projects/tests/utils.py @@ -22,6 +22,10 @@ class ProjectData: self.item = item self.members = members + @property + def id(self): + return self.item.id + @property def admin(self): return self.members[0]