You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

85 lines
3.0 KiB

2 years ago
2 years ago
2 years ago
2 years ago
3 years ago
2 years ago
3 years ago
2 years ago
3 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
  1. from . import builders, catalog, cleaners, data, labels, parsers, readers
  2. from projects.models import (
  3. DOCUMENT_CLASSIFICATION,
  4. IMAGE_CLASSIFICATION,
  5. INTENT_DETECTION_AND_SLOT_FILLING,
  6. SEQ2SEQ,
  7. SEQUENCE_LABELING,
  8. SPEECH2TEXT,
  9. )
  10. def get_data_class(project_type: str):
  11. text_projects = [DOCUMENT_CLASSIFICATION, SEQUENCE_LABELING, SEQ2SEQ, INTENT_DETECTION_AND_SLOT_FILLING]
  12. if project_type in text_projects:
  13. return data.TextData
  14. else:
  15. return data.FileData
  16. def create_parser(file_format: str, **kwargs):
  17. mapping = {
  18. catalog.TextFile.name: parsers.TextFileParser,
  19. catalog.TextLine.name: parsers.LineParser,
  20. catalog.CSV.name: parsers.CSVParser,
  21. catalog.JSONL.name: parsers.JSONLParser,
  22. catalog.JSON.name: parsers.JSONParser,
  23. catalog.FastText.name: parsers.FastTextParser,
  24. catalog.Excel.name: parsers.ExcelParser,
  25. catalog.CoNLL.name: parsers.CoNLLParser,
  26. catalog.ImageFile.name: parsers.PlainParser,
  27. catalog.AudioFile.name: parsers.PlainParser,
  28. }
  29. if file_format not in mapping:
  30. raise ValueError(f"Invalid format: {file_format}")
  31. return mapping[file_format](**kwargs)
  32. def get_label_class(project_type: str):
  33. mapping = {
  34. DOCUMENT_CLASSIFICATION: labels.CategoryLabel,
  35. SEQUENCE_LABELING: labels.SpanLabel,
  36. SEQ2SEQ: labels.TextLabel,
  37. IMAGE_CLASSIFICATION: labels.CategoryLabel,
  38. SPEECH2TEXT: labels.TextLabel,
  39. }
  40. if project_type not in mapping:
  41. ValueError(f"Invalid project type: {project_type}")
  42. return mapping[project_type]
  43. def create_cleaner(project):
  44. mapping = {
  45. DOCUMENT_CLASSIFICATION: cleaners.CategoryCleaner,
  46. SEQUENCE_LABELING: cleaners.SpanCleaner,
  47. IMAGE_CLASSIFICATION: cleaners.CategoryCleaner,
  48. }
  49. if project.project_type not in mapping:
  50. return cleaners.Cleaner(project)
  51. cleaner_class = mapping.get(project.project_type, cleaners.Cleaner)
  52. return cleaner_class(project)
  53. def create_builder(project, **kwargs):
  54. if not project.is_text_project:
  55. return builders.PlainBuilder(data_class=get_data_class(project.project_type))
  56. data_column = builders.DataColumn(
  57. name=kwargs.get("column_data") or readers.DEFAULT_TEXT_COLUMN, value_class=get_data_class(project.project_type)
  58. )
  59. # If project is intent detection and slot filling,
  60. # column names are fixed: entities, cats
  61. if project.project_type == INTENT_DETECTION_AND_SLOT_FILLING:
  62. label_columns = [
  63. builders.LabelColumn(name="cats", value_class=labels.CategoryLabel),
  64. builders.LabelColumn(name="entities", value_class=labels.SpanLabel),
  65. ]
  66. else:
  67. label_columns = [
  68. builders.LabelColumn(
  69. name=kwargs.get("column_label") or readers.DEFAULT_LABEL_COLUMN,
  70. value_class=get_label_class(project.project_type),
  71. )
  72. ]
  73. builder = builders.ColumnBuilder(data_column=data_column, label_columns=label_columns)
  74. return builder