You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

54 lines
1.9 KiB

3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
  1. from ...models import (DOCUMENT_CLASSIFICATION, IMAGE_CLASSIFICATION, SEQ2SEQ,
  2. SEQUENCE_LABELING, SPEECH2TEXT)
  3. from . import catalog, cleaners, data, dataset, label
  4. def get_data_class(project_type: str):
  5. text_projects = [DOCUMENT_CLASSIFICATION, SEQUENCE_LABELING, SEQ2SEQ]
  6. if project_type in text_projects:
  7. return data.TextData
  8. else:
  9. return data.FileData
  10. def get_dataset_class(format: str):
  11. mapping = {
  12. catalog.TextFile.name: dataset.TextFileDataset,
  13. catalog.TextLine.name: dataset.TextLineDataset,
  14. catalog.CSV.name: dataset.CsvDataset,
  15. catalog.JSONL.name: dataset.JSONLDataset,
  16. catalog.JSON.name: dataset.JSONDataset,
  17. catalog.FastText.name: dataset.FastTextDataset,
  18. catalog.Excel.name: dataset.ExcelDataset,
  19. catalog.CoNLL.name: dataset.CoNLLDataset,
  20. catalog.ImageFile.name: dataset.FileBaseDataset,
  21. catalog.AudioFile.name: dataset.FileBaseDataset,
  22. }
  23. if format not in mapping:
  24. ValueError(f'Invalid format: {format}')
  25. return mapping[format]
  26. def get_label_class(project_type: str):
  27. mapping = {
  28. DOCUMENT_CLASSIFICATION: label.CategoryLabel,
  29. SEQUENCE_LABELING: label.OffsetLabel,
  30. SEQ2SEQ: label.TextLabel,
  31. IMAGE_CLASSIFICATION: label.CategoryLabel,
  32. SPEECH2TEXT: label.TextLabel,
  33. }
  34. if project_type not in mapping:
  35. ValueError(f'Invalid project type: {project_type}')
  36. return mapping[project_type]
  37. def create_cleaner(project):
  38. mapping = {
  39. DOCUMENT_CLASSIFICATION: cleaners.CategoryCleaner,
  40. SEQUENCE_LABELING: cleaners.SpanCleaner,
  41. IMAGE_CLASSIFICATION: cleaners.CategoryCleaner
  42. }
  43. if project.project_type not in mapping:
  44. ValueError(f'Invalid project type: {project.project_type}')
  45. cleaner_class = mapping.get(project.project_type, cleaners.Cleaner)
  46. return cleaner_class(project)