import abc from typing import List, Type from django.contrib.auth.models import User from .models import DummyLabelType from .pipeline.catalog import RELATION_EXTRACTION, Format from .pipeline.data import BaseData, BinaryData, TextData from .pipeline.examples import Examples from .pipeline.exceptions import FileParseException from .pipeline.factories import create_parser from .pipeline.label import CategoryLabel, Label, RelationLabel, SpanLabel, TextLabel from .pipeline.label_types import LabelTypes from .pipeline.labels import Categories, Labels, Relations, Spans, Texts from .pipeline.makers import BinaryExampleMaker, ExampleMaker, LabelMaker from .pipeline.readers import ( DEFAULT_LABEL_COLUMN, DEFAULT_TEXT_COLUMN, FileName, Reader, ) from label_types.models import CategoryType, LabelType, RelationType, SpanType from projects.models import Project, ProjectType class Dataset(abc.ABC): def __init__(self, reader: Reader, project: Project, **kwargs): self.reader = reader self.project = project self.kwargs = kwargs def save(self, user: User, batch_size: int = 1000): raise NotImplementedError() @property def errors(self) -> List[FileParseException]: raise NotImplementedError() class PlainDataset(Dataset): def __init__(self, reader: Reader, project: Project, **kwargs): super().__init__(reader, project, **kwargs) self.example_maker = ExampleMaker(project=project, data_class=TextData) def save(self, user: User, batch_size: int = 1000): for records in self.reader.batch(batch_size): examples = Examples(self.example_maker.make(records)) examples.save() @property def errors(self) -> List[FileParseException]: return self.reader.errors + self.example_maker.errors class DatasetWithSingleLabelType(Dataset): data_class: Type[BaseData] label_class: Type[Label] label_type = LabelType labels_class = Labels def __init__(self, reader: Reader, project: Project, **kwargs): super().__init__(reader, project, **kwargs) self.types = LabelTypes(self.label_type) self.example_maker = ExampleMaker( project=project, data_class=self.data_class, column_data=kwargs.get("column_data") or DEFAULT_TEXT_COLUMN, exclude_columns=[kwargs.get("column_label") or DEFAULT_LABEL_COLUMN], ) self.label_maker = LabelMaker( column=kwargs.get("column_label") or DEFAULT_LABEL_COLUMN, label_class=self.label_class ) def save(self, user: User, batch_size: int = 1000): for records in self.reader.batch(batch_size): # create examples examples = Examples(self.example_maker.make(records)) examples.save() # create label types labels = self.labels_class(self.label_maker.make(records), self.types) labels.clean(self.project) labels.save_types(self.project) # create Labels labels.save(user, examples) @property def errors(self) -> List[FileParseException]: return self.reader.errors + self.example_maker.errors + self.label_maker.errors class BinaryDataset(Dataset): def __init__(self, reader: Reader, project: Project, **kwargs): super().__init__(reader, project, **kwargs) self.example_maker = BinaryExampleMaker(project=project, data_class=BinaryData) def save(self, user: User, batch_size: int = 1000): for records in self.reader.batch(batch_size): examples = Examples(self.example_maker.make(records)) examples.save() @property def errors(self) -> List[FileParseException]: return self.reader.errors + self.example_maker.errors class TextClassificationDataset(DatasetWithSingleLabelType): data_class = TextData label_class = CategoryLabel label_type = CategoryType labels_class = Categories class SequenceLabelingDataset(DatasetWithSingleLabelType): data_class = TextData label_class = SpanLabel label_type = SpanType labels_class = Spans class Seq2seqDataset(DatasetWithSingleLabelType): data_class = TextData label_class = TextLabel label_type = DummyLabelType labels_class = Texts class RelationExtractionDataset(Dataset): def __init__(self, reader: Reader, project: Project, **kwargs): super().__init__(reader, project, **kwargs) self.span_types = LabelTypes(SpanType) self.relation_types = LabelTypes(RelationType) self.example_maker = ExampleMaker( project=project, data_class=TextData, column_data=kwargs.get("column_data") or DEFAULT_TEXT_COLUMN, exclude_columns=["entities", "relations"], ) self.span_maker = LabelMaker(column="entities", label_class=SpanLabel) self.relation_maker = LabelMaker(column="relations", label_class=RelationLabel) def save(self, user: User, batch_size: int = 1000): for records in self.reader.batch(batch_size): # create examples examples = Examples(self.example_maker.make(records)) examples.save() # create label types spans = Spans(self.span_maker.make(records), self.span_types) spans.clean(self.project) spans.save_types(self.project) relations = Relations(self.relation_maker.make(records), self.relation_types) relations.clean(self.project) relations.save_types(self.project) # create Labels spans.save(user, examples) relations.save(user, examples, spans=spans) @property def errors(self) -> List[FileParseException]: return self.reader.errors + self.example_maker.errors + self.span_maker.errors + self.relation_maker.errors class CategoryAndSpanDataset(Dataset): def __init__(self, reader: Reader, project: Project, **kwargs): super().__init__(reader, project, **kwargs) self.category_types = LabelTypes(CategoryType) self.span_types = LabelTypes(SpanType) self.example_maker = ExampleMaker( project=project, data_class=TextData, column_data=kwargs.get("column_data") or DEFAULT_TEXT_COLUMN, exclude_columns=["cats", "entities"], ) self.category_maker = LabelMaker(column="cats", label_class=CategoryLabel) self.span_maker = LabelMaker(column="entities", label_class=SpanLabel) def save(self, user: User, batch_size: int = 1000): for records in self.reader.batch(batch_size): # create examples examples = Examples(self.example_maker.make(records)) examples.save() # create label types categories = Categories(self.category_maker.make(records), self.category_types) categories.clean(self.project) categories.save_types(self.project) spans = Spans(self.span_maker.make(records), self.span_types) spans.clean(self.project) spans.save_types(self.project) # create Labels categories.save(user, examples) spans.save(user, examples) @property def errors(self) -> List[FileParseException]: return self.reader.errors + self.example_maker.errors + self.category_maker.errors + self.span_maker.errors def select_dataset(project: Project, task: str, file_format: Format) -> Type[Dataset]: mapping = { ProjectType.DOCUMENT_CLASSIFICATION: TextClassificationDataset, ProjectType.SEQUENCE_LABELING: SequenceLabelingDataset, RELATION_EXTRACTION: RelationExtractionDataset, ProjectType.SEQ2SEQ: Seq2seqDataset, ProjectType.INTENT_DETECTION_AND_SLOT_FILLING: CategoryAndSpanDataset, ProjectType.IMAGE_CLASSIFICATION: BinaryDataset, ProjectType.IMAGE_CAPTIONING: BinaryDataset, ProjectType.BOUNDING_BOX: BinaryDataset, ProjectType.SEGMENTATION: BinaryDataset, ProjectType.SPEECH2TEXT: BinaryDataset, } if task not in mapping: task = project.project_type if project.is_text_project and file_format.is_plain_text(): return PlainDataset return mapping[task] def load_dataset(task: str, file_format: Format, data_files: List[FileName], project: Project, **kwargs) -> Dataset: parser = create_parser(file_format, **kwargs) reader = Reader(data_files, parser) dataset_class = select_dataset(project, task, file_format) return dataset_class(reader, project, **kwargs)