You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

37 lines
1.2 KiB

3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
  1. from ...models import DOCUMENT_CLASSIFICATION, SEQ2SEQ, SEQUENCE_LABELING
  2. from . import catalog, data, dataset, label
  3. def get_data_class(project_type: str):
  4. text_projects = [DOCUMENT_CLASSIFICATION, SEQUENCE_LABELING, SEQ2SEQ]
  5. if project_type in text_projects:
  6. return data.TextData
  7. else:
  8. return data.FileData
  9. def get_dataset_class(format: str):
  10. mapping = {
  11. catalog.TextFile.name: dataset.TextFileDataset,
  12. catalog.TextLine.name: dataset.TextLineDataset,
  13. catalog.CSV.name: dataset.CsvDataset,
  14. catalog.JSONL.name: dataset.JSONLDataset,
  15. catalog.JSON.name: dataset.JSONDataset,
  16. catalog.FastText.name: dataset.FastTextDataset,
  17. catalog.Excel.name: dataset.ExcelDataset,
  18. catalog.CoNLL.name: dataset.CoNLLDataset
  19. }
  20. if format not in mapping:
  21. ValueError(f'Invalid format: {format}')
  22. return mapping[format]
  23. def get_label_class(project_type: str):
  24. mapping = {
  25. DOCUMENT_CLASSIFICATION: label.CategoryLabel,
  26. SEQUENCE_LABELING: label.OffsetLabel,
  27. SEQ2SEQ: label.TextLabel
  28. }
  29. if project_type not in mapping:
  30. ValueError(f'Invalid project type: {project_type}')
  31. return mapping[project_type]