You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

69 lines
2.4 KiB

3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
  1. from typing import Type
  2. from django.db.models import QuerySet
  3. from . import catalog, formatters, labels, repositories, writers
  4. from .labels import Labels
  5. from examples.models import Example
  6. from projects.models import (
  7. DOCUMENT_CLASSIFICATION,
  8. IMAGE_CLASSIFICATION,
  9. INTENT_DETECTION_AND_SLOT_FILLING,
  10. SEQ2SEQ,
  11. SEQUENCE_LABELING,
  12. SPEECH2TEXT,
  13. )
  14. def create_repository(project, file_format: str):
  15. if getattr(project, "use_relation", False) and file_format == catalog.JSONLRelation.name:
  16. return repositories.RelationExtractionRepository(project)
  17. mapping = {
  18. DOCUMENT_CLASSIFICATION: repositories.TextClassificationRepository,
  19. SEQUENCE_LABELING: repositories.SequenceLabelingRepository,
  20. SEQ2SEQ: repositories.Seq2seqRepository,
  21. IMAGE_CLASSIFICATION: repositories.FileRepository,
  22. SPEECH2TEXT: repositories.Speech2TextRepository,
  23. INTENT_DETECTION_AND_SLOT_FILLING: repositories.IntentDetectionSlotFillingRepository,
  24. }
  25. if project.project_type not in mapping:
  26. ValueError(f"Invalid project type: {project.project_type}")
  27. repository = mapping[project.project_type](project)
  28. return repository
  29. def create_writer(file_format: str) -> writers.Writer:
  30. mapping = {
  31. catalog.CSV.name: writers.CsvWriter,
  32. catalog.JSON.name: writers.JsonWriter,
  33. catalog.JSONL.name: writers.JsonlWriter,
  34. # catalog.FastText.name: writers.FastTextWriter,
  35. }
  36. if file_format not in mapping:
  37. ValueError(f"Invalid format: {file_format}")
  38. return mapping[file_format]()
  39. def create_formatter(project, file_format: str):
  40. mapping = {
  41. DOCUMENT_CLASSIFICATION: {
  42. catalog.CSV.name: formatters.JoinedCategoryFormatter,
  43. catalog.JSON.name: formatters.ListedCategoryFormatter,
  44. catalog.JSONL.name: formatters.ListedCategoryFormatter,
  45. },
  46. SEQUENCE_LABELING: {},
  47. SEQ2SEQ: {},
  48. IMAGE_CLASSIFICATION: {},
  49. SPEECH2TEXT: {},
  50. INTENT_DETECTION_AND_SLOT_FILLING: {},
  51. }
  52. return mapping[project.project_type][file_format]
  53. def select_label_collection(project):
  54. mapping = {DOCUMENT_CLASSIFICATION: labels.Categories}
  55. return mapping[project.project_type]
  56. def create_labels(label_collection_class: Type[Labels], examples: QuerySet[Example], user=None):
  57. return label_collection_class(examples, user)