You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

229 lines
8.5 KiB

import abc
from typing import List, Type
from django.contrib.auth.models import User
from .models import DummyLabelType
from .pipeline.catalog import RELATION_EXTRACTION, Format
from .pipeline.data import BaseData, BinaryData, TextData
from .pipeline.examples import Examples
from .pipeline.exceptions import FileParseException
from .pipeline.factories import create_parser
from .pipeline.label import CategoryLabel, Label, RelationLabel, SpanLabel, TextLabel
from .pipeline.label_types import LabelTypes
from .pipeline.labels import Categories, Labels, Relations, Spans, Texts
from .pipeline.makers import BinaryExampleMaker, ExampleMaker, LabelMaker
from .pipeline.readers import (
DEFAULT_LABEL_COLUMN,
DEFAULT_TEXT_COLUMN,
FileName,
Reader,
)
from label_types.models import CategoryType, LabelType, RelationType, SpanType
from projects.models import Project, ProjectType
class Dataset(abc.ABC):
def __init__(self, reader: Reader, project: Project, **kwargs):
self.reader = reader
self.project = project
self.kwargs = kwargs
def save(self, user: User, batch_size: int = 1000):
raise NotImplementedError()
@property
def errors(self) -> List[FileParseException]:
raise NotImplementedError()
class PlainDataset(Dataset):
def __init__(self, reader: Reader, project: Project, **kwargs):
super().__init__(reader, project, **kwargs)
self.example_maker = ExampleMaker(project=project, data_class=TextData)
def save(self, user: User, batch_size: int = 1000):
for records in self.reader.batch(batch_size):
examples = Examples(self.example_maker.make(records))
examples.save()
@property
def errors(self) -> List[FileParseException]:
return self.reader.errors + self.example_maker.errors
class DatasetWithSingleLabelType(Dataset):
data_class: Type[BaseData]
label_class: Type[Label]
label_type = LabelType
labels_class = Labels
def __init__(self, reader: Reader, project: Project, **kwargs):
super().__init__(reader, project, **kwargs)
self.types = LabelTypes(self.label_type)
self.example_maker = ExampleMaker(
project=project,
data_class=self.data_class,
column_data=kwargs.get("column_data") or DEFAULT_TEXT_COLUMN,
exclude_columns=[kwargs.get("column_label") or DEFAULT_LABEL_COLUMN],
)
self.label_maker = LabelMaker(
column=kwargs.get("column_label") or DEFAULT_LABEL_COLUMN, label_class=self.label_class
)
def save(self, user: User, batch_size: int = 1000):
for records in self.reader.batch(batch_size):
# create examples
examples = Examples(self.example_maker.make(records))
examples.save()
# create label types
labels = self.labels_class(self.label_maker.make(records), self.types)
labels.clean(self.project)
labels.save_types(self.project)
# create Labels
labels.save(user, examples)
@property
def errors(self) -> List[FileParseException]:
return self.reader.errors + self.example_maker.errors + self.label_maker.errors
class BinaryDataset(Dataset):
def __init__(self, reader: Reader, project: Project, **kwargs):
super().__init__(reader, project, **kwargs)
self.example_maker = BinaryExampleMaker(project=project, data_class=BinaryData)
def save(self, user: User, batch_size: int = 1000):
for records in self.reader.batch(batch_size):
examples = Examples(self.example_maker.make(records))
examples.save()
@property
def errors(self) -> List[FileParseException]:
return self.reader.errors + self.example_maker.errors
class TextClassificationDataset(DatasetWithSingleLabelType):
data_class = TextData
label_class = CategoryLabel
label_type = CategoryType
labels_class = Categories
class SequenceLabelingDataset(DatasetWithSingleLabelType):
data_class = TextData
label_class = SpanLabel
label_type = SpanType
labels_class = Spans
class Seq2seqDataset(DatasetWithSingleLabelType):
data_class = TextData
label_class = TextLabel
label_type = DummyLabelType
labels_class = Texts
class RelationExtractionDataset(Dataset):
def __init__(self, reader: Reader, project: Project, **kwargs):
super().__init__(reader, project, **kwargs)
self.span_types = LabelTypes(SpanType)
self.relation_types = LabelTypes(RelationType)
self.example_maker = ExampleMaker(
project=project,
data_class=TextData,
column_data=kwargs.get("column_data") or DEFAULT_TEXT_COLUMN,
exclude_columns=["entities", "relations"],
)
self.span_maker = LabelMaker(column="entities", label_class=SpanLabel)
self.relation_maker = LabelMaker(column="relations", label_class=RelationLabel)
def save(self, user: User, batch_size: int = 1000):
for records in self.reader.batch(batch_size):
# create examples
examples = Examples(self.example_maker.make(records))
examples.save()
# create label types
spans = Spans(self.span_maker.make(records), self.span_types)
spans.clean(self.project)
spans.save_types(self.project)
relations = Relations(self.relation_maker.make(records), self.relation_types)
relations.clean(self.project)
relations.save_types(self.project)
# create Labels
spans.save(user, examples)
relations.save(user, examples, spans=spans)
@property
def errors(self) -> List[FileParseException]:
return self.reader.errors + self.example_maker.errors + self.span_maker.errors + self.relation_maker.errors
class CategoryAndSpanDataset(Dataset):
def __init__(self, reader: Reader, project: Project, **kwargs):
super().__init__(reader, project, **kwargs)
self.category_types = LabelTypes(CategoryType)
self.span_types = LabelTypes(SpanType)
self.example_maker = ExampleMaker(
project=project,
data_class=TextData,
column_data=kwargs.get("column_data") or DEFAULT_TEXT_COLUMN,
exclude_columns=["cats", "entities"],
)
self.category_maker = LabelMaker(column="cats", label_class=CategoryLabel)
self.span_maker = LabelMaker(column="entities", label_class=SpanLabel)
def save(self, user: User, batch_size: int = 1000):
for records in self.reader.batch(batch_size):
# create examples
examples = Examples(self.example_maker.make(records))
examples.save()
# create label types
categories = Categories(self.category_maker.make(records), self.category_types)
categories.clean(self.project)
categories.save_types(self.project)
spans = Spans(self.span_maker.make(records), self.span_types)
spans.clean(self.project)
spans.save_types(self.project)
# create Labels
categories.save(user, examples)
spans.save(user, examples)
@property
def errors(self) -> List[FileParseException]:
return self.reader.errors + self.example_maker.errors + self.category_maker.errors + self.span_maker.errors
def select_dataset(project: Project, task: str, file_format: Format) -> Type[Dataset]:
mapping = {
ProjectType.DOCUMENT_CLASSIFICATION: TextClassificationDataset,
ProjectType.SEQUENCE_LABELING: SequenceLabelingDataset,
RELATION_EXTRACTION: RelationExtractionDataset,
ProjectType.SEQ2SEQ: Seq2seqDataset,
ProjectType.INTENT_DETECTION_AND_SLOT_FILLING: CategoryAndSpanDataset,
ProjectType.IMAGE_CLASSIFICATION: BinaryDataset,
ProjectType.IMAGE_CAPTIONING: BinaryDataset,
ProjectType.BOUNDING_BOX: BinaryDataset,
ProjectType.SEGMENTATION: BinaryDataset,
ProjectType.SPEECH2TEXT: BinaryDataset,
}
if task not in mapping:
task = project.project_type
if project.is_text_project and file_format.is_plain_text():
return PlainDataset
return mapping[task]
def load_dataset(task: str, file_format: Format, data_files: List[FileName], project: Project, **kwargs) -> Dataset:
parser = create_parser(file_format, **kwargs)
reader = Reader(data_files, parser)
dataset_class = select_dataset(project, task, file_format)
return dataset_class(reader, project, **kwargs)