You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

136 lines
5.2 KiB

3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
  1. from typing import Dict, List, Type
  2. from django.db.models import QuerySet
  3. from . import writers
  4. from .catalog import CSV, JSON, JSONL, FastText
  5. from .formatters import (
  6. DictFormatter,
  7. FastTextCategoryFormatter,
  8. Formatter,
  9. JoinedCategoryFormatter,
  10. ListedCategoryFormatter,
  11. RenameFormatter,
  12. TupledSpanFormatter,
  13. )
  14. from .labels import BoundingBoxes, Categories, Labels, Relations, Segments, Spans, Texts
  15. from data_export.models import DATA, ExportedExample
  16. from projects.models import (
  17. BOUNDING_BOX,
  18. DOCUMENT_CLASSIFICATION,
  19. IMAGE_CAPTIONING,
  20. IMAGE_CLASSIFICATION,
  21. INTENT_DETECTION_AND_SLOT_FILLING,
  22. SEGMENTATION,
  23. SEQ2SEQ,
  24. SEQUENCE_LABELING,
  25. SPEECH2TEXT,
  26. Project,
  27. )
  28. def create_writer(file_format: str) -> writers.Writer:
  29. mapping = {
  30. CSV.name: writers.CsvWriter(),
  31. JSON.name: writers.JsonWriter(),
  32. JSONL.name: writers.JsonlWriter(),
  33. FastText.name: writers.FastTextWriter(),
  34. }
  35. if file_format not in mapping:
  36. ValueError(f"Invalid format: {file_format}")
  37. return mapping[file_format]
  38. def create_formatter(project: Project, file_format: str) -> List[Formatter]:
  39. use_relation = getattr(project, "use_relation", False)
  40. # text tasks
  41. mapper_text_classification = {DATA: "text", Categories.column: "label"}
  42. mapper_sequence_labeling = {DATA: "text", Spans.column: "label"}
  43. mapper_seq2seq = {DATA: "text", Texts.column: "label"}
  44. mapper_intent_detection = {DATA: "text", Categories.column: "cats"}
  45. mapper_relation_extraction = {DATA: "text"}
  46. # image tasks
  47. mapper_image_classification = {DATA: "filename", Categories.column: "label"}
  48. mapper_bounding_box = {DATA: "filename", BoundingBoxes.column: "bbox"}
  49. mapper_segmentation = {DATA: "filename", BoundingBoxes.column: "segmentation"}
  50. mapper_image_captioning = {DATA: "filename", Texts.column: "label"}
  51. # audio tasks
  52. mapper_speech2text = {DATA: "filename", Texts.column: "label"}
  53. mapping: Dict[str, Dict[str, List[Formatter]]] = {
  54. DOCUMENT_CLASSIFICATION: {
  55. CSV.name: [
  56. JoinedCategoryFormatter(Categories.column),
  57. RenameFormatter(**mapper_text_classification),
  58. ],
  59. JSON.name: [
  60. ListedCategoryFormatter(Categories.column),
  61. RenameFormatter(**mapper_text_classification),
  62. ],
  63. JSONL.name: [
  64. ListedCategoryFormatter(Categories.column),
  65. RenameFormatter(**mapper_text_classification),
  66. ],
  67. FastText.name: [FastTextCategoryFormatter(Categories.column)],
  68. },
  69. SEQUENCE_LABELING: {
  70. JSONL.name: [
  71. DictFormatter(Spans.column),
  72. DictFormatter(Relations.column),
  73. RenameFormatter(**mapper_relation_extraction),
  74. ]
  75. if use_relation
  76. else [TupledSpanFormatter(Spans.column), RenameFormatter(**mapper_sequence_labeling)]
  77. },
  78. SEQ2SEQ: {
  79. CSV.name: [JoinedCategoryFormatter(Texts.column), RenameFormatter(**mapper_seq2seq)],
  80. JSON.name: [ListedCategoryFormatter(Texts.column), RenameFormatter(**mapper_seq2seq)],
  81. JSONL.name: [ListedCategoryFormatter(Texts.column), RenameFormatter(**mapper_seq2seq)],
  82. },
  83. IMAGE_CLASSIFICATION: {
  84. JSONL.name: [
  85. ListedCategoryFormatter(Categories.column),
  86. RenameFormatter(**mapper_image_classification),
  87. ],
  88. },
  89. SPEECH2TEXT: {
  90. JSONL.name: [ListedCategoryFormatter(Texts.column), RenameFormatter(**mapper_speech2text)],
  91. },
  92. INTENT_DETECTION_AND_SLOT_FILLING: {
  93. JSONL.name: [
  94. ListedCategoryFormatter(Categories.column),
  95. TupledSpanFormatter(Spans.column),
  96. RenameFormatter(**mapper_intent_detection),
  97. ]
  98. },
  99. BOUNDING_BOX: {JSONL.name: [DictFormatter(BoundingBoxes.column), RenameFormatter(**mapper_bounding_box)]},
  100. SEGMENTATION: {JSONL.name: [DictFormatter(Segments.column), RenameFormatter(**mapper_segmentation)]},
  101. IMAGE_CAPTIONING: {
  102. JSONL.name: [ListedCategoryFormatter(Texts.column), RenameFormatter(**mapper_image_captioning)]
  103. },
  104. }
  105. return mapping[project.project_type][file_format]
  106. def select_label_collection(project: Project) -> List[Type[Labels]]:
  107. use_relation = getattr(project, "use_relation", False)
  108. mapping: Dict[str, List[Type[Labels]]] = {
  109. DOCUMENT_CLASSIFICATION: [Categories],
  110. SEQUENCE_LABELING: [Spans, Relations] if use_relation else [Spans],
  111. SEQ2SEQ: [Texts],
  112. IMAGE_CLASSIFICATION: [Categories],
  113. SPEECH2TEXT: [Texts],
  114. INTENT_DETECTION_AND_SLOT_FILLING: [Categories, Spans],
  115. BOUNDING_BOX: [BoundingBoxes],
  116. SEGMENTATION: [Segments],
  117. IMAGE_CAPTIONING: [Texts],
  118. }
  119. return mapping[project.project_type]
  120. def create_labels(project: Project, examples: QuerySet[ExportedExample], user=None) -> List[Labels]:
  121. label_collections = select_label_collection(project)
  122. labels = [label_collection(examples=examples, user=user) for label_collection in label_collections]
  123. return labels