You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

92 lines
3.1 KiB

3 years ago
3 years ago
3 years ago
3 years ago
  1. from api.models import (DOCUMENT_CLASSIFICATION, IMAGE_CLASSIFICATION,
  2. INTENT_DETECTION_AND_SLOT_FILLING, SEQ2SEQ,
  3. SEQUENCE_LABELING, SPEECH2TEXT)
  4. from . import builders, catalog, cleaners, data, labels, parsers, readers
  5. def get_data_class(project_type: str):
  6. text_projects = [
  7. DOCUMENT_CLASSIFICATION,
  8. SEQUENCE_LABELING,
  9. SEQ2SEQ,
  10. INTENT_DETECTION_AND_SLOT_FILLING
  11. ]
  12. if project_type in text_projects:
  13. return data.TextData
  14. else:
  15. return data.FileData
  16. def create_parser(file_format: str, **kwargs):
  17. mapping = {
  18. catalog.TextFile.name: parsers.TextFileParser,
  19. catalog.TextLine.name: parsers.LineParser,
  20. catalog.CSV.name: parsers.CSVParser,
  21. catalog.JSONL.name: parsers.JSONLParser,
  22. catalog.JSON.name: parsers.JSONParser,
  23. catalog.FastText.name: parsers.FastTextParser,
  24. catalog.Excel.name: parsers.ExcelParser,
  25. catalog.CoNLL.name: parsers.CoNLLParser,
  26. catalog.ImageFile.name: parsers.PlainParser,
  27. catalog.AudioFile.name: parsers.PlainParser,
  28. }
  29. if file_format not in mapping:
  30. raise ValueError(f'Invalid format: {file_format}')
  31. return mapping[file_format](**kwargs)
  32. def get_label_class(project_type: str):
  33. mapping = {
  34. DOCUMENT_CLASSIFICATION: labels.CategoryLabel,
  35. SEQUENCE_LABELING: labels.SpanLabel,
  36. SEQ2SEQ: labels.TextLabel,
  37. IMAGE_CLASSIFICATION: labels.CategoryLabel,
  38. SPEECH2TEXT: labels.TextLabel,
  39. }
  40. if project_type not in mapping:
  41. ValueError(f'Invalid project type: {project_type}')
  42. return mapping[project_type]
  43. def create_cleaner(project):
  44. mapping = {
  45. DOCUMENT_CLASSIFICATION: cleaners.CategoryCleaner,
  46. SEQUENCE_LABELING: cleaners.SpanCleaner,
  47. IMAGE_CLASSIFICATION: cleaners.CategoryCleaner
  48. }
  49. if project.project_type not in mapping:
  50. ValueError(f'Invalid project type: {project.project_type}')
  51. cleaner_class = mapping.get(project.project_type, cleaners.Cleaner)
  52. return cleaner_class(project)
  53. def create_bulder(project, **kwargs):
  54. data_column = builders.DataColumn(
  55. name=kwargs.get('column_data') or readers.DEFAULT_TEXT_COLUMN,
  56. value_class=get_data_class(project.project_type)
  57. )
  58. # If project is intent detection and slot filling,
  59. # column names are fixed: entities, cats
  60. if project.project_type == INTENT_DETECTION_AND_SLOT_FILLING:
  61. label_columns = [
  62. builders.LabelColumn(
  63. name='cats',
  64. value_class=labels.CategoryLabel
  65. ),
  66. builders.LabelColumn(
  67. name='entities',
  68. value_class=labels.SpanLabel
  69. )
  70. ]
  71. else:
  72. label_columns = [
  73. builders.LabelColumn(
  74. name=kwargs.get('column_label') or readers.DEFAULT_LABEL_COLUMN,
  75. value_class=get_label_class(project.project_type)
  76. )
  77. ]
  78. builder = builders.ColumnBuilder(
  79. data_column=data_column,
  80. label_columns=label_columns
  81. )
  82. return builder