|
|
@ -12,56 +12,63 @@ from projects.models import ( |
|
|
|
SEQ2SEQ, |
|
|
|
SEQUENCE_LABELING, |
|
|
|
SPEECH2TEXT, |
|
|
|
Project, |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def select_writer(file_format: str) -> Type[writers.Writer]: |
|
|
|
def create_writer(file_format: str) -> writers.Writer: |
|
|
|
mapping = { |
|
|
|
catalog.CSV.name: writers.CsvWriter, |
|
|
|
catalog.JSON.name: writers.JsonWriter, |
|
|
|
catalog.JSONL.name: writers.JsonlWriter, |
|
|
|
catalog.FastText.name: writers.FastTextWriter, |
|
|
|
catalog.CSV.name: writers.CsvWriter(), |
|
|
|
catalog.JSON.name: writers.JsonWriter(), |
|
|
|
catalog.JSONL.name: writers.JsonlWriter(), |
|
|
|
catalog.FastText.name: writers.FastTextWriter(), |
|
|
|
} |
|
|
|
if file_format not in mapping: |
|
|
|
ValueError(f"Invalid format: {file_format}") |
|
|
|
return mapping[file_format] |
|
|
|
|
|
|
|
|
|
|
|
def select_formatter(project, file_format: str) -> List[Type[formatters.Formatter]]: |
|
|
|
def create_formatter(project: Project, file_format: str) -> List[formatters.Formatter]: |
|
|
|
use_relation = getattr(project, "use_relation", False) |
|
|
|
mapping: Dict[str, Dict[str, List[Type[formatters.Formatter]]]] = { |
|
|
|
mapping: Dict[str, Dict[str, List[formatters.Formatter]]] = { |
|
|
|
DOCUMENT_CLASSIFICATION: { |
|
|
|
catalog.CSV.name: [formatters.JoinedCategoryFormatter], |
|
|
|
catalog.JSON.name: [formatters.ListedCategoryFormatter], |
|
|
|
catalog.JSONL.name: [formatters.ListedCategoryFormatter], |
|
|
|
catalog.FastText.name: [formatters.FastTextCategoryFormatter], |
|
|
|
catalog.CSV.name: [formatters.JoinedCategoryFormatter(labels.Categories.column)], |
|
|
|
catalog.JSON.name: [formatters.ListedCategoryFormatter(labels.Categories.column)], |
|
|
|
catalog.JSONL.name: [formatters.ListedCategoryFormatter(labels.Categories.column)], |
|
|
|
catalog.FastText.name: [formatters.FastTextCategoryFormatter(labels.Categories.column)], |
|
|
|
}, |
|
|
|
SEQUENCE_LABELING: { |
|
|
|
catalog.JSONL.name: [formatters.DictFormatter, formatters.DictFormatter] |
|
|
|
catalog.JSONL.name: [ |
|
|
|
formatters.DictFormatter(labels.Spans.column), |
|
|
|
formatters.DictFormatter(labels.Relations.column), |
|
|
|
] |
|
|
|
if use_relation |
|
|
|
else [formatters.TupledSpanFormatter] |
|
|
|
else [formatters.TupledSpanFormatter(labels.Spans.column)] |
|
|
|
}, |
|
|
|
SEQ2SEQ: { |
|
|
|
catalog.CSV.name: [formatters.JoinedCategoryFormatter], |
|
|
|
catalog.JSON.name: [formatters.ListedCategoryFormatter], |
|
|
|
catalog.JSONL.name: [formatters.ListedCategoryFormatter], |
|
|
|
catalog.CSV.name: [formatters.JoinedCategoryFormatter(labels.Texts.column)], |
|
|
|
catalog.JSON.name: [formatters.ListedCategoryFormatter(labels.Texts.column)], |
|
|
|
catalog.JSONL.name: [formatters.ListedCategoryFormatter(labels.Texts.column)], |
|
|
|
}, |
|
|
|
IMAGE_CLASSIFICATION: { |
|
|
|
catalog.JSONL.name: [formatters.ListedCategoryFormatter], |
|
|
|
catalog.JSONL.name: [formatters.ListedCategoryFormatter(labels.Categories.column)], |
|
|
|
}, |
|
|
|
SPEECH2TEXT: { |
|
|
|
catalog.JSONL.name: [formatters.ListedCategoryFormatter], |
|
|
|
catalog.JSONL.name: [formatters.ListedCategoryFormatter(labels.Texts.column)], |
|
|
|
}, |
|
|
|
INTENT_DETECTION_AND_SLOT_FILLING: { |
|
|
|
catalog.JSONL.name: [formatters.ListedCategoryFormatter, formatters.TupledSpanFormatter] |
|
|
|
catalog.JSONL.name: [ |
|
|
|
formatters.ListedCategoryFormatter(labels.Categories.column), |
|
|
|
formatters.TupledSpanFormatter(labels.Spans.column), |
|
|
|
] |
|
|
|
}, |
|
|
|
} |
|
|
|
return mapping[project.project_type][file_format] |
|
|
|
|
|
|
|
|
|
|
|
def select_label_collection(project): |
|
|
|
def select_label_collection(project: Project) -> List[Type[Labels]]: |
|
|
|
use_relation = getattr(project, "use_relation", False) |
|
|
|
mapping = { |
|
|
|
mapping: Dict[str, List[Type[Labels]]] = { |
|
|
|
DOCUMENT_CLASSIFICATION: [labels.Categories], |
|
|
|
SEQUENCE_LABELING: [labels.Spans, labels.Relations] if use_relation else [labels.Spans], |
|
|
|
SEQ2SEQ: [labels.Texts], |
|
|
@ -72,5 +79,7 @@ def select_label_collection(project): |
|
|
|
return mapping[project.project_type] |
|
|
|
|
|
|
|
|
|
|
|
def create_labels(label_collection_class: Type[Labels], examples: QuerySet[ExportedExample], user=None): |
|
|
|
return label_collection_class(examples, user) |
|
|
|
def create_labels(project: Project, examples: QuerySet[ExportedExample], user=None) -> List[Labels]: |
|
|
|
label_collections = select_label_collection(project) |
|
|
|
labels = [label_collection(examples=examples, user=user) for label_collection in label_collections] |
|
|
|
return labels |