You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

183 lines
6.5 KiB

2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
  1. from typing import Dict, List, Type
  2. from django.db.models import QuerySet
  3. from . import writers
  4. from .catalog import CSV, JSON, JSONL, FastText
  5. from .comments import Comments
  6. from .formatters import (
  7. DictFormatter,
  8. FastTextCategoryFormatter,
  9. Formatter,
  10. JoinedCategoryFormatter,
  11. ListedCategoryFormatter,
  12. RenameFormatter,
  13. TupledSpanFormatter,
  14. )
  15. from .labels import BoundingBoxes, Categories, Labels, Relations, Segments, Spans, Texts
  16. from data_export.models import DATA, ExportedExample
  17. from projects.models import (
  18. BOUNDING_BOX,
  19. DOCUMENT_CLASSIFICATION,
  20. IMAGE_CAPTIONING,
  21. IMAGE_CLASSIFICATION,
  22. INTENT_DETECTION_AND_SLOT_FILLING,
  23. SEGMENTATION,
  24. SEQ2SEQ,
  25. SEQUENCE_LABELING,
  26. SPEECH2TEXT,
  27. Project,
  28. )
  29. def create_writer(file_format: str) -> writers.Writer:
  30. mapping = {
  31. CSV.name: writers.CsvWriter(),
  32. JSON.name: writers.JsonWriter(),
  33. JSONL.name: writers.JsonlWriter(),
  34. FastText.name: writers.FastTextWriter(),
  35. }
  36. if file_format not in mapping:
  37. ValueError(f"Invalid format: {file_format}")
  38. return mapping[file_format]
  39. def create_formatter(project: Project, file_format: str) -> List[Formatter]:
  40. use_relation = getattr(project, "use_relation", False)
  41. # text tasks
  42. mapper_text_classification = {DATA: "text", Categories.column: "label"}
  43. mapper_sequence_labeling = {DATA: "text", Spans.column: "label"}
  44. mapper_seq2seq = {DATA: "text", Texts.column: "label"}
  45. mapper_intent_detection = {DATA: "text", Categories.column: "cats"}
  46. mapper_relation_extraction = {DATA: "text"}
  47. # image tasks
  48. mapper_image_classification = {DATA: "filename", Categories.column: "label"}
  49. mapper_bounding_box = {DATA: "filename", BoundingBoxes.column: "bbox"}
  50. mapper_segmentation = {DATA: "filename", BoundingBoxes.column: "segmentation"}
  51. mapper_image_captioning = {DATA: "filename", Texts.column: "label"}
  52. # audio tasks
  53. mapper_speech2text = {DATA: "filename", Texts.column: "label"}
  54. mapping: Dict[str, Dict[str, List[Formatter]]] = {
  55. DOCUMENT_CLASSIFICATION: {
  56. CSV.name: [
  57. JoinedCategoryFormatter(Categories.column),
  58. JoinedCategoryFormatter(Comments.column),
  59. RenameFormatter(**mapper_text_classification),
  60. ],
  61. JSON.name: [
  62. ListedCategoryFormatter(Categories.column),
  63. ListedCategoryFormatter(Comments.column),
  64. RenameFormatter(**mapper_text_classification),
  65. ],
  66. JSONL.name: [
  67. ListedCategoryFormatter(Categories.column),
  68. ListedCategoryFormatter(Comments.column),
  69. RenameFormatter(**mapper_text_classification),
  70. ],
  71. FastText.name: [FastTextCategoryFormatter(Categories.column)],
  72. },
  73. SEQUENCE_LABELING: {
  74. JSONL.name: [
  75. DictFormatter(Spans.column),
  76. DictFormatter(Relations.column),
  77. DictFormatter(Comments.column),
  78. RenameFormatter(**mapper_relation_extraction),
  79. ]
  80. if use_relation
  81. else [
  82. TupledSpanFormatter(Spans.column),
  83. ListedCategoryFormatter(Comments.column),
  84. RenameFormatter(**mapper_sequence_labeling),
  85. ]
  86. },
  87. SEQ2SEQ: {
  88. CSV.name: [
  89. JoinedCategoryFormatter(Texts.column),
  90. JoinedCategoryFormatter(Comments.column),
  91. RenameFormatter(**mapper_seq2seq),
  92. ],
  93. JSON.name: [
  94. ListedCategoryFormatter(Texts.column),
  95. ListedCategoryFormatter(Comments.column),
  96. RenameFormatter(**mapper_seq2seq),
  97. ],
  98. JSONL.name: [
  99. ListedCategoryFormatter(Texts.column),
  100. ListedCategoryFormatter(Comments.column),
  101. RenameFormatter(**mapper_seq2seq),
  102. ],
  103. },
  104. IMAGE_CLASSIFICATION: {
  105. JSONL.name: [
  106. ListedCategoryFormatter(Categories.column),
  107. ListedCategoryFormatter(Comments.column),
  108. RenameFormatter(**mapper_image_classification),
  109. ],
  110. },
  111. SPEECH2TEXT: {
  112. JSONL.name: [
  113. ListedCategoryFormatter(Texts.column),
  114. ListedCategoryFormatter(Comments.column),
  115. RenameFormatter(**mapper_speech2text),
  116. ],
  117. },
  118. INTENT_DETECTION_AND_SLOT_FILLING: {
  119. JSONL.name: [
  120. ListedCategoryFormatter(Categories.column),
  121. TupledSpanFormatter(Spans.column),
  122. ListedCategoryFormatter(Comments.column),
  123. RenameFormatter(**mapper_intent_detection),
  124. ]
  125. },
  126. BOUNDING_BOX: {
  127. JSONL.name: [
  128. DictFormatter(BoundingBoxes.column),
  129. DictFormatter(Comments.column),
  130. RenameFormatter(**mapper_bounding_box),
  131. ]
  132. },
  133. SEGMENTATION: {
  134. JSONL.name: [
  135. DictFormatter(Segments.column),
  136. DictFormatter(Comments.column),
  137. RenameFormatter(**mapper_segmentation),
  138. ]
  139. },
  140. IMAGE_CAPTIONING: {
  141. JSONL.name: [
  142. ListedCategoryFormatter(Texts.column),
  143. ListedCategoryFormatter(Comments.column),
  144. RenameFormatter(**mapper_image_captioning),
  145. ]
  146. },
  147. }
  148. return mapping[project.project_type][file_format]
  149. def select_label_collection(project: Project) -> List[Type[Labels]]:
  150. use_relation = getattr(project, "use_relation", False)
  151. mapping: Dict[str, List[Type[Labels]]] = {
  152. DOCUMENT_CLASSIFICATION: [Categories],
  153. SEQUENCE_LABELING: [Spans, Relations] if use_relation else [Spans],
  154. SEQ2SEQ: [Texts],
  155. IMAGE_CLASSIFICATION: [Categories],
  156. SPEECH2TEXT: [Texts],
  157. INTENT_DETECTION_AND_SLOT_FILLING: [Categories, Spans],
  158. BOUNDING_BOX: [BoundingBoxes],
  159. SEGMENTATION: [Segments],
  160. IMAGE_CAPTIONING: [Texts],
  161. }
  162. return mapping[project.project_type]
  163. def create_labels(project: Project, examples: QuerySet[ExportedExample], user=None) -> List[Labels]:
  164. label_collections = select_label_collection(project)
  165. labels = [label_collection(examples=examples, user=user) for label_collection in label_collections]
  166. return labels
  167. def create_comment(examples: QuerySet[ExportedExample], user=None) -> List[Comments]:
  168. return [Comments(examples=examples, user=user)]