You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

234 lines
8.4 KiB

  1. import abc
  2. from typing import List, Type
  3. from django.contrib.auth.models import User
  4. from .models import DummyLabelType
  5. from .pipeline.catalog import RELATION_EXTRACTION, Format
  6. from .pipeline.data import BaseData, BinaryData, TextData
  7. from .pipeline.exceptions import FileParseException
  8. from .pipeline.factories import create_parser
  9. from .pipeline.label import CategoryLabel, Label, RelationLabel, SpanLabel, TextLabel
  10. from .pipeline.label_types import LabelTypes
  11. from .pipeline.labels import Categories, Labels, Relations, Spans, Texts
  12. from .pipeline.makers import BinaryExampleMaker, ExampleMaker, LabelMaker
  13. from .pipeline.readers import (
  14. DEFAULT_LABEL_COLUMN,
  15. DEFAULT_TEXT_COLUMN,
  16. FileName,
  17. Reader,
  18. )
  19. from examples.models import Example
  20. from label_types.models import CategoryType, LabelType, RelationType, SpanType
  21. from projects.models import (
  22. DOCUMENT_CLASSIFICATION,
  23. IMAGE_CLASSIFICATION,
  24. INTENT_DETECTION_AND_SLOT_FILLING,
  25. SEQ2SEQ,
  26. SEQUENCE_LABELING,
  27. SPEECH2TEXT,
  28. Project,
  29. )
  30. class Dataset(abc.ABC):
  31. def __init__(self, reader: Reader, project: Project, **kwargs):
  32. self.reader = reader
  33. self.project = project
  34. self.kwargs = kwargs
  35. def save(self, user: User, batch_size: int = 1000):
  36. raise NotImplementedError()
  37. @property
  38. def errors(self) -> List[FileParseException]:
  39. raise NotImplementedError()
  40. class PlainDataset(Dataset):
  41. def __init__(self, reader: Reader, project: Project, **kwargs):
  42. super().__init__(reader, project, **kwargs)
  43. self.example_maker = ExampleMaker(project=project, data_class=TextData)
  44. def save(self, user: User, batch_size: int = 1000):
  45. for records in self.reader.batch(batch_size):
  46. examples = self.example_maker.make(records)
  47. Example.objects.bulk_create(examples)
  48. @property
  49. def errors(self) -> List[FileParseException]:
  50. return self.reader.errors + self.example_maker.errors
  51. class DatasetWithSingleLabelType(Dataset):
  52. data_class: Type[BaseData]
  53. label_class: Type[Label]
  54. label_type = LabelType
  55. labels_class = Labels
  56. def __init__(self, reader: Reader, project: Project, **kwargs):
  57. super().__init__(reader, project, **kwargs)
  58. self.types = LabelTypes(self.label_type)
  59. self.example_maker = ExampleMaker(
  60. project=project,
  61. data_class=self.data_class,
  62. column_data=kwargs.get("column_data") or DEFAULT_TEXT_COLUMN,
  63. exclude_columns=[kwargs.get("column_label") or DEFAULT_LABEL_COLUMN],
  64. )
  65. self.label_maker = LabelMaker(
  66. column=kwargs.get("column_label") or DEFAULT_LABEL_COLUMN, label_class=self.label_class
  67. )
  68. def save(self, user: User, batch_size: int = 1000):
  69. for records in self.reader.batch(batch_size):
  70. # create examples
  71. examples = self.example_maker.make(records)
  72. Example.objects.bulk_create(examples)
  73. # create label types
  74. labels = self.labels_class(self.label_maker.make(records), self.types)
  75. labels.clean(self.project)
  76. labels.save_types(self.project)
  77. # create Labels
  78. labels.save(user)
  79. @property
  80. def errors(self) -> List[FileParseException]:
  81. return self.reader.errors + self.example_maker.errors + self.label_maker.errors
  82. class BinaryDataset(Dataset):
  83. def __init__(self, reader: Reader, project: Project, **kwargs):
  84. super().__init__(reader, project, **kwargs)
  85. self.example_maker = BinaryExampleMaker(project=project, data_class=BinaryData)
  86. def save(self, user: User, batch_size: int = 1000):
  87. for records in self.reader.batch(batch_size):
  88. examples = self.example_maker.make(records)
  89. Example.objects.bulk_create(examples)
  90. @property
  91. def errors(self) -> List[FileParseException]:
  92. return self.reader.errors + self.example_maker.errors
  93. class TextClassificationDataset(DatasetWithSingleLabelType):
  94. data_class = TextData
  95. label_class = CategoryLabel
  96. label_type = CategoryType
  97. labels_class = Categories
  98. class SequenceLabelingDataset(DatasetWithSingleLabelType):
  99. data_class = TextData
  100. label_class = SpanLabel
  101. label_type = SpanType
  102. labels_class = Spans
  103. class Seq2seqDataset(DatasetWithSingleLabelType):
  104. data_class = TextData
  105. label_class = TextLabel
  106. label_type = DummyLabelType
  107. labels_class = Texts
  108. class RelationExtractionDataset(Dataset):
  109. def __init__(self, reader: Reader, project: Project, **kwargs):
  110. super().__init__(reader, project, **kwargs)
  111. self.span_types = LabelTypes(SpanType)
  112. self.relation_types = LabelTypes(RelationType)
  113. self.example_maker = ExampleMaker(
  114. project=project,
  115. data_class=TextData,
  116. column_data=kwargs.get("column_data") or DEFAULT_TEXT_COLUMN,
  117. exclude_columns=["entities", "relations"],
  118. )
  119. self.span_maker = LabelMaker(column="entities", label_class=SpanLabel)
  120. self.relation_maker = LabelMaker(column="relations", label_class=RelationLabel)
  121. def save(self, user: User, batch_size: int = 1000):
  122. for records in self.reader.batch(batch_size):
  123. # create examples
  124. examples = self.example_maker.make(records)
  125. Example.objects.bulk_create(examples)
  126. # create label types
  127. spans = Spans(self.span_maker.make(records), self.span_types)
  128. spans.clean(self.project)
  129. spans.save_types(self.project)
  130. relations = Relations(self.relation_maker.make(records), self.relation_types)
  131. relations.clean(self.project)
  132. relations.save_types(self.project)
  133. # create Labels
  134. spans.save(user)
  135. relations.save(user, spans=spans)
  136. @property
  137. def errors(self) -> List[FileParseException]:
  138. return self.reader.errors + self.example_maker.errors + self.span_maker.errors + self.relation_maker.errors
  139. class CategoryAndSpanDataset(Dataset):
  140. def __init__(self, reader: Reader, project: Project, **kwargs):
  141. super().__init__(reader, project, **kwargs)
  142. self.category_types = LabelTypes(CategoryType)
  143. self.span_types = LabelTypes(SpanType)
  144. self.example_maker = ExampleMaker(
  145. project=project,
  146. data_class=TextData,
  147. column_data=kwargs.get("column_data") or DEFAULT_TEXT_COLUMN,
  148. exclude_columns=["cats", "entities"],
  149. )
  150. self.category_maker = LabelMaker(column="cats", label_class=CategoryLabel)
  151. self.span_maker = LabelMaker(column="entities", label_class=SpanLabel)
  152. def save(self, user: User, batch_size: int = 1000):
  153. for records in self.reader.batch(batch_size):
  154. # create examples
  155. examples = self.example_maker.make(records)
  156. Example.objects.bulk_create(examples)
  157. # create label types
  158. categories = Categories(self.category_maker.make(records), self.category_types)
  159. categories.clean(self.project)
  160. categories.save_types(self.project)
  161. spans = Spans(self.span_maker.make(records), self.span_types)
  162. spans.clean(self.project)
  163. spans.save_types(self.project)
  164. # create Labels
  165. categories.save(user)
  166. spans.save(user)
  167. @property
  168. def errors(self) -> List[FileParseException]:
  169. return self.reader.errors + self.example_maker.errors + self.category_maker.errors + self.span_maker.errors
  170. def select_dataset(project: Project, task: str, file_format: Format) -> Type[Dataset]:
  171. mapping = {
  172. DOCUMENT_CLASSIFICATION: TextClassificationDataset,
  173. SEQUENCE_LABELING: SequenceLabelingDataset,
  174. RELATION_EXTRACTION: RelationExtractionDataset,
  175. SEQ2SEQ: Seq2seqDataset,
  176. INTENT_DETECTION_AND_SLOT_FILLING: CategoryAndSpanDataset,
  177. IMAGE_CLASSIFICATION: BinaryDataset,
  178. SPEECH2TEXT: BinaryDataset,
  179. }
  180. if task not in mapping:
  181. task = project.project_type
  182. if project.is_text_project and file_format.is_plain_text():
  183. return PlainDataset
  184. return mapping[task]
  185. def load_dataset(task: str, file_format: Format, data_files: List[FileName], project: Project, **kwargs) -> Dataset:
  186. parser = create_parser(file_format, **kwargs)
  187. reader = Reader(data_files, parser)
  188. dataset_class = select_dataset(project, task, file_format)
  189. return dataset_class(reader, project, **kwargs)