You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

240 lines
8.6 KiB

  1. import abc
  2. from typing import List, Type
  3. from django.contrib.auth.models import User
  4. from .models import DummyLabelType
  5. from .pipeline.catalog import RELATION_EXTRACTION, Format
  6. from .pipeline.data import BaseData, BinaryData, TextData
  7. from .pipeline.examples import Examples
  8. from .pipeline.exceptions import FileParseException
  9. from .pipeline.factories import create_parser
  10. from .pipeline.label import CategoryLabel, Label, RelationLabel, SpanLabel, TextLabel
  11. from .pipeline.label_types import LabelTypes
  12. from .pipeline.labels import Categories, Labels, Relations, Spans, Texts
  13. from .pipeline.makers import BinaryExampleMaker, ExampleMaker, LabelMaker
  14. from .pipeline.readers import (
  15. DEFAULT_LABEL_COLUMN,
  16. DEFAULT_TEXT_COLUMN,
  17. FileName,
  18. Reader,
  19. )
  20. from label_types.models import CategoryType, LabelType, RelationType, SpanType
  21. from projects.models import (
  22. BOUNDING_BOX,
  23. DOCUMENT_CLASSIFICATION,
  24. IMAGE_CAPTIONING,
  25. IMAGE_CLASSIFICATION,
  26. INTENT_DETECTION_AND_SLOT_FILLING,
  27. SEGMENTATION,
  28. SEQ2SEQ,
  29. SEQUENCE_LABELING,
  30. SPEECH2TEXT,
  31. Project,
  32. )
  33. class Dataset(abc.ABC):
  34. def __init__(self, reader: Reader, project: Project, **kwargs):
  35. self.reader = reader
  36. self.project = project
  37. self.kwargs = kwargs
  38. def save(self, user: User, batch_size: int = 1000):
  39. raise NotImplementedError()
  40. @property
  41. def errors(self) -> List[FileParseException]:
  42. raise NotImplementedError()
  43. class PlainDataset(Dataset):
  44. def __init__(self, reader: Reader, project: Project, **kwargs):
  45. super().__init__(reader, project, **kwargs)
  46. self.example_maker = ExampleMaker(project=project, data_class=TextData)
  47. def save(self, user: User, batch_size: int = 1000):
  48. for records in self.reader.batch(batch_size):
  49. examples = Examples(self.example_maker.make(records))
  50. examples.save()
  51. @property
  52. def errors(self) -> List[FileParseException]:
  53. return self.reader.errors + self.example_maker.errors
  54. class DatasetWithSingleLabelType(Dataset):
  55. data_class: Type[BaseData]
  56. label_class: Type[Label]
  57. label_type = LabelType
  58. labels_class = Labels
  59. def __init__(self, reader: Reader, project: Project, **kwargs):
  60. super().__init__(reader, project, **kwargs)
  61. self.types = LabelTypes(self.label_type)
  62. self.example_maker = ExampleMaker(
  63. project=project,
  64. data_class=self.data_class,
  65. column_data=kwargs.get("column_data") or DEFAULT_TEXT_COLUMN,
  66. exclude_columns=[kwargs.get("column_label") or DEFAULT_LABEL_COLUMN],
  67. )
  68. self.label_maker = LabelMaker(
  69. column=kwargs.get("column_label") or DEFAULT_LABEL_COLUMN, label_class=self.label_class
  70. )
  71. def save(self, user: User, batch_size: int = 1000):
  72. for records in self.reader.batch(batch_size):
  73. # create examples
  74. examples = Examples(self.example_maker.make(records))
  75. examples.save()
  76. # create label types
  77. labels = self.labels_class(self.label_maker.make(records), self.types)
  78. labels.clean(self.project)
  79. labels.save_types(self.project)
  80. # create Labels
  81. labels.save(user, examples)
  82. @property
  83. def errors(self) -> List[FileParseException]:
  84. return self.reader.errors + self.example_maker.errors + self.label_maker.errors
  85. class BinaryDataset(Dataset):
  86. def __init__(self, reader: Reader, project: Project, **kwargs):
  87. super().__init__(reader, project, **kwargs)
  88. self.example_maker = BinaryExampleMaker(project=project, data_class=BinaryData)
  89. def save(self, user: User, batch_size: int = 1000):
  90. for records in self.reader.batch(batch_size):
  91. examples = Examples(self.example_maker.make(records))
  92. examples.save()
  93. @property
  94. def errors(self) -> List[FileParseException]:
  95. return self.reader.errors + self.example_maker.errors
  96. class TextClassificationDataset(DatasetWithSingleLabelType):
  97. data_class = TextData
  98. label_class = CategoryLabel
  99. label_type = CategoryType
  100. labels_class = Categories
  101. class SequenceLabelingDataset(DatasetWithSingleLabelType):
  102. data_class = TextData
  103. label_class = SpanLabel
  104. label_type = SpanType
  105. labels_class = Spans
  106. class Seq2seqDataset(DatasetWithSingleLabelType):
  107. data_class = TextData
  108. label_class = TextLabel
  109. label_type = DummyLabelType
  110. labels_class = Texts
  111. class RelationExtractionDataset(Dataset):
  112. def __init__(self, reader: Reader, project: Project, **kwargs):
  113. super().__init__(reader, project, **kwargs)
  114. self.span_types = LabelTypes(SpanType)
  115. self.relation_types = LabelTypes(RelationType)
  116. self.example_maker = ExampleMaker(
  117. project=project,
  118. data_class=TextData,
  119. column_data=kwargs.get("column_data") or DEFAULT_TEXT_COLUMN,
  120. exclude_columns=["entities", "relations"],
  121. )
  122. self.span_maker = LabelMaker(column="entities", label_class=SpanLabel)
  123. self.relation_maker = LabelMaker(column="relations", label_class=RelationLabel)
  124. def save(self, user: User, batch_size: int = 1000):
  125. for records in self.reader.batch(batch_size):
  126. # create examples
  127. examples = Examples(self.example_maker.make(records))
  128. examples.save()
  129. # create label types
  130. spans = Spans(self.span_maker.make(records), self.span_types)
  131. spans.clean(self.project)
  132. spans.save_types(self.project)
  133. relations = Relations(self.relation_maker.make(records), self.relation_types)
  134. relations.clean(self.project)
  135. relations.save_types(self.project)
  136. # create Labels
  137. spans.save(user, examples)
  138. relations.save(user, examples, spans=spans)
  139. @property
  140. def errors(self) -> List[FileParseException]:
  141. return self.reader.errors + self.example_maker.errors + self.span_maker.errors + self.relation_maker.errors
  142. class CategoryAndSpanDataset(Dataset):
  143. def __init__(self, reader: Reader, project: Project, **kwargs):
  144. super().__init__(reader, project, **kwargs)
  145. self.category_types = LabelTypes(CategoryType)
  146. self.span_types = LabelTypes(SpanType)
  147. self.example_maker = ExampleMaker(
  148. project=project,
  149. data_class=TextData,
  150. column_data=kwargs.get("column_data") or DEFAULT_TEXT_COLUMN,
  151. exclude_columns=["cats", "entities"],
  152. )
  153. self.category_maker = LabelMaker(column="cats", label_class=CategoryLabel)
  154. self.span_maker = LabelMaker(column="entities", label_class=SpanLabel)
  155. def save(self, user: User, batch_size: int = 1000):
  156. for records in self.reader.batch(batch_size):
  157. # create examples
  158. examples = Examples(self.example_maker.make(records))
  159. examples.save()
  160. # create label types
  161. categories = Categories(self.category_maker.make(records), self.category_types)
  162. categories.clean(self.project)
  163. categories.save_types(self.project)
  164. spans = Spans(self.span_maker.make(records), self.span_types)
  165. spans.clean(self.project)
  166. spans.save_types(self.project)
  167. # create Labels
  168. categories.save(user, examples)
  169. spans.save(user, examples)
  170. @property
  171. def errors(self) -> List[FileParseException]:
  172. return self.reader.errors + self.example_maker.errors + self.category_maker.errors + self.span_maker.errors
  173. def select_dataset(project: Project, task: str, file_format: Format) -> Type[Dataset]:
  174. mapping = {
  175. DOCUMENT_CLASSIFICATION: TextClassificationDataset,
  176. SEQUENCE_LABELING: SequenceLabelingDataset,
  177. RELATION_EXTRACTION: RelationExtractionDataset,
  178. SEQ2SEQ: Seq2seqDataset,
  179. INTENT_DETECTION_AND_SLOT_FILLING: CategoryAndSpanDataset,
  180. IMAGE_CLASSIFICATION: BinaryDataset,
  181. IMAGE_CAPTIONING: BinaryDataset,
  182. BOUNDING_BOX: BinaryDataset,
  183. SEGMENTATION: BinaryDataset,
  184. SPEECH2TEXT: BinaryDataset,
  185. }
  186. if task not in mapping:
  187. task = project.project_type
  188. if project.is_text_project and file_format.is_plain_text():
  189. return PlainDataset
  190. return mapping[task]
  191. def load_dataset(task: str, file_format: Format, data_files: List[FileName], project: Project, **kwargs) -> Dataset:
  192. parser = create_parser(file_format, **kwargs)
  193. reader = Reader(data_files, parser)
  194. dataset_class = select_dataset(project, task, file_format)
  195. return dataset_class(reader, project, **kwargs)