You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

172 lines
6.5 KiB

2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
  1. from typing import Dict, List, Type
  2. from django.db.models import QuerySet
  3. from . import writers
  4. from .catalog import CSV, JSON, JSONL, FastText
  5. from .comments import Comments
  6. from .formatters import (
  7. DictFormatter,
  8. FastTextCategoryFormatter,
  9. Formatter,
  10. JoinedCategoryFormatter,
  11. ListedCategoryFormatter,
  12. RenameFormatter,
  13. TupledSpanFormatter,
  14. )
  15. from .labels import BoundingBoxes, Categories, Labels, Relations, Segments, Spans, Texts
  16. from data_export.models import DATA, ExportedExample
  17. from projects.models import Project, ProjectType
  18. def create_writer(file_format: str) -> writers.Writer:
  19. mapping = {
  20. CSV.name: writers.CsvWriter(),
  21. JSON.name: writers.JsonWriter(),
  22. JSONL.name: writers.JsonlWriter(),
  23. FastText.name: writers.FastTextWriter(),
  24. }
  25. if file_format not in mapping:
  26. ValueError(f"Invalid format: {file_format}")
  27. return mapping[file_format]
  28. def create_formatter(project: Project, file_format: str) -> List[Formatter]:
  29. use_relation = getattr(project, "use_relation", False)
  30. # text tasks
  31. mapper_text_classification = {DATA: "text", Categories.column: "label"}
  32. mapper_sequence_labeling = {DATA: "text", Spans.column: "label"}
  33. mapper_seq2seq = {DATA: "text", Texts.column: "label"}
  34. mapper_intent_detection = {DATA: "text", Categories.column: "cats"}
  35. mapper_relation_extraction = {DATA: "text"}
  36. # image tasks
  37. mapper_image_classification = {DATA: "filename", Categories.column: "label"}
  38. mapper_bounding_box = {DATA: "filename", BoundingBoxes.column: "bbox"}
  39. mapper_segmentation = {DATA: "filename", BoundingBoxes.column: "segmentation"}
  40. mapper_image_captioning = {DATA: "filename", Texts.column: "label"}
  41. # audio tasks
  42. mapper_speech2text = {DATA: "filename", Texts.column: "label"}
  43. mapping: Dict[str, Dict[str, List[Formatter]]] = {
  44. ProjectType.DOCUMENT_CLASSIFICATION: {
  45. CSV.name: [
  46. JoinedCategoryFormatter(Categories.column),
  47. JoinedCategoryFormatter(Comments.column),
  48. RenameFormatter(**mapper_text_classification),
  49. ],
  50. JSON.name: [
  51. ListedCategoryFormatter(Categories.column),
  52. ListedCategoryFormatter(Comments.column),
  53. RenameFormatter(**mapper_text_classification),
  54. ],
  55. JSONL.name: [
  56. ListedCategoryFormatter(Categories.column),
  57. ListedCategoryFormatter(Comments.column),
  58. RenameFormatter(**mapper_text_classification),
  59. ],
  60. FastText.name: [FastTextCategoryFormatter(Categories.column)],
  61. },
  62. ProjectType.SEQUENCE_LABELING: {
  63. JSONL.name: [
  64. DictFormatter(Spans.column),
  65. DictFormatter(Relations.column),
  66. DictFormatter(Comments.column),
  67. RenameFormatter(**mapper_relation_extraction),
  68. ]
  69. if use_relation
  70. else [
  71. TupledSpanFormatter(Spans.column),
  72. ListedCategoryFormatter(Comments.column),
  73. RenameFormatter(**mapper_sequence_labeling),
  74. ]
  75. },
  76. ProjectType.SEQ2SEQ: {
  77. CSV.name: [
  78. JoinedCategoryFormatter(Texts.column),
  79. JoinedCategoryFormatter(Comments.column),
  80. RenameFormatter(**mapper_seq2seq),
  81. ],
  82. JSON.name: [
  83. ListedCategoryFormatter(Texts.column),
  84. ListedCategoryFormatter(Comments.column),
  85. RenameFormatter(**mapper_seq2seq),
  86. ],
  87. JSONL.name: [
  88. ListedCategoryFormatter(Texts.column),
  89. ListedCategoryFormatter(Comments.column),
  90. RenameFormatter(**mapper_seq2seq),
  91. ],
  92. },
  93. ProjectType.IMAGE_CLASSIFICATION: {
  94. JSONL.name: [
  95. ListedCategoryFormatter(Categories.column),
  96. ListedCategoryFormatter(Comments.column),
  97. RenameFormatter(**mapper_image_classification),
  98. ],
  99. },
  100. ProjectType.SPEECH2TEXT: {
  101. JSONL.name: [
  102. ListedCategoryFormatter(Texts.column),
  103. ListedCategoryFormatter(Comments.column),
  104. RenameFormatter(**mapper_speech2text),
  105. ],
  106. },
  107. ProjectType.INTENT_DETECTION_AND_SLOT_FILLING: {
  108. JSONL.name: [
  109. ListedCategoryFormatter(Categories.column),
  110. TupledSpanFormatter(Spans.column),
  111. ListedCategoryFormatter(Comments.column),
  112. RenameFormatter(**mapper_intent_detection),
  113. ]
  114. },
  115. ProjectType.BOUNDING_BOX: {
  116. JSONL.name: [
  117. DictFormatter(BoundingBoxes.column),
  118. DictFormatter(Comments.column),
  119. RenameFormatter(**mapper_bounding_box),
  120. ]
  121. },
  122. ProjectType.SEGMENTATION: {
  123. JSONL.name: [
  124. DictFormatter(Segments.column),
  125. DictFormatter(Comments.column),
  126. RenameFormatter(**mapper_segmentation),
  127. ]
  128. },
  129. ProjectType.IMAGE_CAPTIONING: {
  130. JSONL.name: [
  131. ListedCategoryFormatter(Texts.column),
  132. ListedCategoryFormatter(Comments.column),
  133. RenameFormatter(**mapper_image_captioning),
  134. ]
  135. },
  136. }
  137. return mapping[project.project_type][file_format]
  138. def select_label_collection(project: Project) -> List[Type[Labels]]:
  139. use_relation = getattr(project, "use_relation", False)
  140. mapping: Dict[str, List[Type[Labels]]] = {
  141. ProjectType.DOCUMENT_CLASSIFICATION: [Categories],
  142. ProjectType.SEQUENCE_LABELING: [Spans, Relations] if use_relation else [Spans],
  143. ProjectType.SEQ2SEQ: [Texts],
  144. ProjectType.IMAGE_CLASSIFICATION: [Categories],
  145. ProjectType.SPEECH2TEXT: [Texts],
  146. ProjectType.INTENT_DETECTION_AND_SLOT_FILLING: [Categories, Spans],
  147. ProjectType.BOUNDING_BOX: [BoundingBoxes],
  148. ProjectType.SEGMENTATION: [Segments],
  149. ProjectType.IMAGE_CAPTIONING: [Texts],
  150. }
  151. return mapping[project.project_type]
  152. def create_labels(project: Project, examples: QuerySet[ExportedExample], user=None) -> List[Labels]:
  153. label_collections = select_label_collection(project)
  154. labels = [label_collection(examples=examples, user=user) for label_collection in label_collections]
  155. return labels
  156. def create_comment(examples: QuerySet[ExportedExample], user=None) -> List[Comments]:
  157. return [Comments(examples=examples, user=user)]