Browse Source

Enable to accept multiple label types in export dataset

pull/1799/head
Hironsan 2 years ago
parent
commit
549e909070
7 changed files with 67 additions and 22 deletions
  1. 31
      backend/data_export/celery_tasks.py
  2. 10
      backend/data_export/models.py
  3. 9
      backend/data_export/pipeline/dataset.py
  4. 13
      backend/data_export/pipeline/factories.py
  5. 4
      backend/data_export/pipeline/formatters.py
  6. 13
      backend/data_export/pipeline/labels.py
  7. 9
      backend/data_export/pipeline/services.py

31
backend/data_export/celery_tasks.py

@ -25,12 +25,16 @@ def create_collaborative_dataset(project: Project, file_format: str, confirmed_o
is_collaborative=project.collaborative_annotation,
confirmed_only=confirmed_only,
)
label_collection = select_label_collection(project)
labels = create_labels(label_collection, examples=examples)
dataset = Dataset(examples, labels)
formatter = create_formatter(project, file_format)(target_column=label_collection.field_name)
writer = create_writer(file_format)
service = ExportApplicationService(dataset, formatter, writer)
label_collections = select_label_collection(project)
formatters = [
create_formatter(project, file_format)(target_column=label_collection.field_name)
for label_collection in label_collections
]
labels = [create_labels(label_collection, examples=examples) for label_collection in label_collections]
dataset = Dataset(examples, labels)
service = ExportApplicationService(dataset, formatters, writer)
filepath = os.path.join(settings.MEDIA_URL, f"all.{writer.extension}")
service.export(filepath)
return filepath
@ -46,12 +50,19 @@ def create_individual_dataset(project: Project, file_format: str, confirmed_only
confirmed_only=confirmed_only,
user=member.user,
)
label_collection = select_label_collection(project)
labels = create_labels(label_collection, examples=examples, user=member.user)
dataset = Dataset(examples, labels)
formatter = create_formatter(project, file_format)(target_column=label_collection.field_name)
writer = create_writer(file_format)
service = ExportApplicationService(dataset, formatter, writer)
label_collections = select_label_collection(project)
formatters = [
create_formatter(project, file_format)(target_column=label_collection.field_name)
for label_collection in label_collections
]
labels = [
create_labels(label_collection, examples=examples, user=member.user)
for label_collection in label_collections
]
dataset = Dataset(examples, labels)
service = ExportApplicationService(dataset, formatters, writer)
filepath = os.path.join(settings.MEDIA_URL, f"{member.username}.{writer.extension}")
service.export(filepath)
files.append(filepath)

10
backend/data_export/models.py

@ -2,7 +2,7 @@ from typing import Any, Dict, Protocol, Tuple
from django.db import models
from labels.models import Category, Span
from labels.models import Category, Relation, Span
class ExportedLabelManager(models.Manager):
@ -44,3 +44,11 @@ class ExportedSpan(Span):
class Meta:
proxy = True
class ExportedRelation(Relation):
def to_dict(self):
return {"id": self.id, "from_id": self.from_id.id, "to_id": self.to_id.id, "type": self.type.text}
class Meta:
proxy = True

9
backend/data_export/pipeline/dataset.py

@ -1,4 +1,4 @@
from typing import Any, Dict, Iterator
from typing import Any, Dict, Iterator, List
import pandas as pd
from django.db.models.query import QuerySet
@ -18,13 +18,16 @@ def filter_examples(examples: QuerySet[Example], is_collaborative=False, confirm
class Dataset:
def __init__(self, examples: QuerySet[Example], labels: Labels):
def __init__(self, examples: QuerySet[Example], labels: List[Labels]):
self.examples = examples
self.labels = labels
def __iter__(self) -> Iterator[Dict[str, Any]]:
for example in self.examples:
yield {"id": example.id, "data": example.text, **example.meta, **self.labels.find_by(example.id)}
data = {"id": example.id, "data": example.text, **example.meta}
for labels in self.labels:
data.update(**labels.find_by(example.id))
yield data
def to_dataframe(self) -> pd.DataFrame:
return pd.DataFrame(self)

13
backend/data_export/pipeline/factories.py

@ -45,13 +45,18 @@ def create_writer(file_format: str) -> writers.Writer:
def create_formatter(project, file_format: str):
use_relation = getattr(project, "use_relation", False)
mapping = {
DOCUMENT_CLASSIFICATION: {
catalog.CSV.name: formatters.JoinedCategoryFormatter,
catalog.JSON.name: formatters.ListedCategoryFormatter,
catalog.JSONL.name: formatters.ListedCategoryFormatter,
},
SEQUENCE_LABELING: {catalog.JSONL.name: formatters.TupledSpanFormatter},
SEQUENCE_LABELING: {
catalog.JSONL.name: [formatters.DictFormatter, formatters.DictFormatter]
if use_relation
else [formatters.TupledSpanFormatter]
},
SEQ2SEQ: {},
IMAGE_CLASSIFICATION: {},
SPEECH2TEXT: {},
@ -61,7 +66,11 @@ def create_formatter(project, file_format: str):
def select_label_collection(project):
mapping = {DOCUMENT_CLASSIFICATION: labels.Categories, SEQUENCE_LABELING: labels.Spans}
use_relation = getattr(project, "use_relation", False)
mapping = {
DOCUMENT_CLASSIFICATION: [labels.Categories],
SEQUENCE_LABELING: [labels.Spans, labels.Relations] if use_relation else [labels.Spans],
}
return mapping[project.project_type]

4
backend/data_export/pipeline/formatters.py

@ -66,12 +66,12 @@ class TupledSpanFormatter(Formatter):
return dataset
class DictSpanFormatter(Formatter):
class DictFormatter(Formatter):
def format(self, dataset: pd.DataFrame) -> pd.DataFrame:
if self.target_column not in dataset.columns:
return dataset
dataset[self.target_column] = dataset[self.target_column].apply(
lambda spans: [span.to_dict() for span in spans]
lambda labels: [label.to_dict() for label in labels]
)
return dataset

13
backend/data_export/pipeline/labels.py

@ -7,7 +7,12 @@ from typing import Dict, List
from django.db.models import QuerySet
from data_export.models import ExportedCategory, ExportedLabel, ExportedSpan
from data_export.models import (
ExportedCategory,
ExportedLabel,
ExportedRelation,
ExportedSpan,
)
from examples.models import Example
@ -38,3 +43,9 @@ class Spans(Labels):
label_class = ExportedSpan
field_name = "entities"
fields = ("example", "label")
class Relations(Labels):
label_class = ExportedRelation
field_name = "relations"
fields = ("example", "type")

9
backend/data_export/pipeline/services.py

@ -1,16 +1,19 @@
from typing import List
from .dataset import Dataset
from .formatters import Formatter
from .writers import Writer
class ExportApplicationService:
def __init__(self, dataset: Dataset, formatter: Formatter, writer: Writer):
def __init__(self, dataset: Dataset, formatters: List[Formatter], writer: Writer):
self.dataset = dataset
self.formatter = formatter
self.formatters = formatters
self.writer = writer
def export(self, file):
dataset = self.dataset.to_dataframe()
dataset = self.formatter.format(dataset)
for formatter in self.formatters:
dataset = formatter.format(dataset)
self.writer.write(file, dataset)
return file
Loading…
Cancel
Save