You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

229 lines
8.5 KiB

  1. import abc
  2. from typing import List, Type
  3. from django.contrib.auth.models import User
  4. from .models import DummyLabelType
  5. from .pipeline.catalog import RELATION_EXTRACTION, Format
  6. from .pipeline.data import BaseData, BinaryData, TextData
  7. from .pipeline.examples import Examples
  8. from .pipeline.exceptions import FileParseException
  9. from .pipeline.factories import create_parser
  10. from .pipeline.label import CategoryLabel, Label, RelationLabel, SpanLabel, TextLabel
  11. from .pipeline.label_types import LabelTypes
  12. from .pipeline.labels import Categories, Labels, Relations, Spans, Texts
  13. from .pipeline.makers import BinaryExampleMaker, ExampleMaker, LabelMaker
  14. from .pipeline.readers import (
  15. DEFAULT_LABEL_COLUMN,
  16. DEFAULT_TEXT_COLUMN,
  17. FileName,
  18. Reader,
  19. )
  20. from label_types.models import CategoryType, LabelType, RelationType, SpanType
  21. from projects.models import Project, ProjectType
  22. class Dataset(abc.ABC):
  23. def __init__(self, reader: Reader, project: Project, **kwargs):
  24. self.reader = reader
  25. self.project = project
  26. self.kwargs = kwargs
  27. def save(self, user: User, batch_size: int = 1000):
  28. raise NotImplementedError()
  29. @property
  30. def errors(self) -> List[FileParseException]:
  31. raise NotImplementedError()
  32. class PlainDataset(Dataset):
  33. def __init__(self, reader: Reader, project: Project, **kwargs):
  34. super().__init__(reader, project, **kwargs)
  35. self.example_maker = ExampleMaker(project=project, data_class=TextData)
  36. def save(self, user: User, batch_size: int = 1000):
  37. for records in self.reader.batch(batch_size):
  38. examples = Examples(self.example_maker.make(records))
  39. examples.save()
  40. @property
  41. def errors(self) -> List[FileParseException]:
  42. return self.reader.errors + self.example_maker.errors
  43. class DatasetWithSingleLabelType(Dataset):
  44. data_class: Type[BaseData]
  45. label_class: Type[Label]
  46. label_type = LabelType
  47. labels_class = Labels
  48. def __init__(self, reader: Reader, project: Project, **kwargs):
  49. super().__init__(reader, project, **kwargs)
  50. self.types = LabelTypes(self.label_type)
  51. self.example_maker = ExampleMaker(
  52. project=project,
  53. data_class=self.data_class,
  54. column_data=kwargs.get("column_data") or DEFAULT_TEXT_COLUMN,
  55. exclude_columns=[kwargs.get("column_label") or DEFAULT_LABEL_COLUMN],
  56. )
  57. self.label_maker = LabelMaker(
  58. column=kwargs.get("column_label") or DEFAULT_LABEL_COLUMN, label_class=self.label_class
  59. )
  60. def save(self, user: User, batch_size: int = 1000):
  61. for records in self.reader.batch(batch_size):
  62. # create examples
  63. examples = Examples(self.example_maker.make(records))
  64. examples.save()
  65. # create label types
  66. labels = self.labels_class(self.label_maker.make(records), self.types)
  67. labels.clean(self.project)
  68. labels.save_types(self.project)
  69. # create Labels
  70. labels.save(user, examples)
  71. @property
  72. def errors(self) -> List[FileParseException]:
  73. return self.reader.errors + self.example_maker.errors + self.label_maker.errors
  74. class BinaryDataset(Dataset):
  75. def __init__(self, reader: Reader, project: Project, **kwargs):
  76. super().__init__(reader, project, **kwargs)
  77. self.example_maker = BinaryExampleMaker(project=project, data_class=BinaryData)
  78. def save(self, user: User, batch_size: int = 1000):
  79. for records in self.reader.batch(batch_size):
  80. examples = Examples(self.example_maker.make(records))
  81. examples.save()
  82. @property
  83. def errors(self) -> List[FileParseException]:
  84. return self.reader.errors + self.example_maker.errors
  85. class TextClassificationDataset(DatasetWithSingleLabelType):
  86. data_class = TextData
  87. label_class = CategoryLabel
  88. label_type = CategoryType
  89. labels_class = Categories
  90. class SequenceLabelingDataset(DatasetWithSingleLabelType):
  91. data_class = TextData
  92. label_class = SpanLabel
  93. label_type = SpanType
  94. labels_class = Spans
  95. class Seq2seqDataset(DatasetWithSingleLabelType):
  96. data_class = TextData
  97. label_class = TextLabel
  98. label_type = DummyLabelType
  99. labels_class = Texts
  100. class RelationExtractionDataset(Dataset):
  101. def __init__(self, reader: Reader, project: Project, **kwargs):
  102. super().__init__(reader, project, **kwargs)
  103. self.span_types = LabelTypes(SpanType)
  104. self.relation_types = LabelTypes(RelationType)
  105. self.example_maker = ExampleMaker(
  106. project=project,
  107. data_class=TextData,
  108. column_data=kwargs.get("column_data") or DEFAULT_TEXT_COLUMN,
  109. exclude_columns=["entities", "relations"],
  110. )
  111. self.span_maker = LabelMaker(column="entities", label_class=SpanLabel)
  112. self.relation_maker = LabelMaker(column="relations", label_class=RelationLabel)
  113. def save(self, user: User, batch_size: int = 1000):
  114. for records in self.reader.batch(batch_size):
  115. # create examples
  116. examples = Examples(self.example_maker.make(records))
  117. examples.save()
  118. # create label types
  119. spans = Spans(self.span_maker.make(records), self.span_types)
  120. spans.clean(self.project)
  121. spans.save_types(self.project)
  122. relations = Relations(self.relation_maker.make(records), self.relation_types)
  123. relations.clean(self.project)
  124. relations.save_types(self.project)
  125. # create Labels
  126. spans.save(user, examples)
  127. relations.save(user, examples, spans=spans)
  128. @property
  129. def errors(self) -> List[FileParseException]:
  130. return self.reader.errors + self.example_maker.errors + self.span_maker.errors + self.relation_maker.errors
  131. class CategoryAndSpanDataset(Dataset):
  132. def __init__(self, reader: Reader, project: Project, **kwargs):
  133. super().__init__(reader, project, **kwargs)
  134. self.category_types = LabelTypes(CategoryType)
  135. self.span_types = LabelTypes(SpanType)
  136. self.example_maker = ExampleMaker(
  137. project=project,
  138. data_class=TextData,
  139. column_data=kwargs.get("column_data") or DEFAULT_TEXT_COLUMN,
  140. exclude_columns=["cats", "entities"],
  141. )
  142. self.category_maker = LabelMaker(column="cats", label_class=CategoryLabel)
  143. self.span_maker = LabelMaker(column="entities", label_class=SpanLabel)
  144. def save(self, user: User, batch_size: int = 1000):
  145. for records in self.reader.batch(batch_size):
  146. # create examples
  147. examples = Examples(self.example_maker.make(records))
  148. examples.save()
  149. # create label types
  150. categories = Categories(self.category_maker.make(records), self.category_types)
  151. categories.clean(self.project)
  152. categories.save_types(self.project)
  153. spans = Spans(self.span_maker.make(records), self.span_types)
  154. spans.clean(self.project)
  155. spans.save_types(self.project)
  156. # create Labels
  157. categories.save(user, examples)
  158. spans.save(user, examples)
  159. @property
  160. def errors(self) -> List[FileParseException]:
  161. return self.reader.errors + self.example_maker.errors + self.category_maker.errors + self.span_maker.errors
  162. def select_dataset(project: Project, task: str, file_format: Format) -> Type[Dataset]:
  163. mapping = {
  164. ProjectType.DOCUMENT_CLASSIFICATION: TextClassificationDataset,
  165. ProjectType.SEQUENCE_LABELING: SequenceLabelingDataset,
  166. RELATION_EXTRACTION: RelationExtractionDataset,
  167. ProjectType.SEQ2SEQ: Seq2seqDataset,
  168. ProjectType.INTENT_DETECTION_AND_SLOT_FILLING: CategoryAndSpanDataset,
  169. ProjectType.IMAGE_CLASSIFICATION: BinaryDataset,
  170. ProjectType.IMAGE_CAPTIONING: BinaryDataset,
  171. ProjectType.BOUNDING_BOX: BinaryDataset,
  172. ProjectType.SEGMENTATION: BinaryDataset,
  173. ProjectType.SPEECH2TEXT: BinaryDataset,
  174. }
  175. if task not in mapping:
  176. task = project.project_type
  177. if project.is_text_project and file_format.is_plain_text():
  178. return PlainDataset
  179. return mapping[task]
  180. def load_dataset(task: str, file_format: Format, data_files: List[FileName], project: Project, **kwargs) -> Dataset:
  181. parser = create_parser(file_format, **kwargs)
  182. reader = Reader(data_files, parser)
  183. dataset_class = select_dataset(project, task, file_format)
  184. return dataset_class(reader, project, **kwargs)