From 1ba663142592dcf74220e38996cb62c9429873f2 Mon Sep 17 00:00:00 2001 From: Hironsan Date: Wed, 7 Jun 2023 10:38:43 +0900 Subject: [PATCH] Add ProjectType as a enumeration type --- .../migrations/0003_fill_task_type.py | 8 +-- backend/auto_labeling/tests/test_views.py | 18 +++--- backend/data_export/pipeline/catalog.py | 42 +++++--------- backend/data_export/pipeline/factories.py | 49 ++++++---------- backend/data_export/tests/test_catalog.py | 25 +------- backend/data_export/tests/test_labels.py | 4 +- backend/data_export/tests/test_task.py | 36 +++++------- backend/data_export/tests/test_views.py | 4 +- backend/data_import/datasets.py | 31 ++++------ backend/data_import/pipeline/catalog.py | 52 ++++++++--------- backend/data_import/tests/test_catalog.py | 19 +----- backend/data_import/tests/test_examples.py | 4 +- backend/data_import/tests/test_label.py | 10 ++-- backend/data_import/tests/test_label_types.py | 4 +- backend/data_import/tests/test_labels.py | 10 ++-- backend/data_import/tests/test_tasks.py | 26 ++++----- backend/data_import/tests/test_views.py | 4 +- backend/examples/tests/test_document.py | 10 ++-- backend/examples/tests/test_models.py | 8 +-- backend/label_types/tests/test_views.py | 12 ++-- backend/label_types/tests/utils.py | 7 ++- backend/labels/tests/test_category.py | 4 +- backend/labels/tests/test_relation.py | 4 +- backend/labels/tests/test_span.py | 16 ++--- backend/labels/tests/test_text_label.py | 4 +- backend/labels/tests/test_views.py | 58 +++++++++---------- backend/labels/tests/utils.py | 15 ++--- backend/metrics/tests.py | 10 ++-- backend/projects/models.py | 33 ++++------- backend/projects/tests/test_project.py | 7 +-- backend/projects/tests/utils.py | 32 ++++------ 31 files changed, 228 insertions(+), 338 deletions(-) diff --git a/backend/auto_labeling/migrations/0003_fill_task_type.py b/backend/auto_labeling/migrations/0003_fill_task_type.py index 3fbfa1eb..4bb36665 100644 --- a/backend/auto_labeling/migrations/0003_fill_task_type.py +++ b/backend/auto_labeling/migrations/0003_fill_task_type.py @@ -1,17 +1,17 @@ from django.db import migrations -from projects.models import DOCUMENT_CLASSIFICATION, SEQUENCE_LABELING, SEQ2SEQ, SPEECH2TEXT, IMAGE_CLASSIFICATION +from projects.models import ProjectType def fill_task_type(apps, schema_editor): AutoLabelingConfig = apps.get_model("auto_labeling", "AutoLabelingConfig") for config in AutoLabelingConfig.objects.all(): project = config.project - if project.project_type in [DOCUMENT_CLASSIFICATION, IMAGE_CLASSIFICATION]: + if project.project_type in [ProjectType.DOCUMENT_CLASSIFICATION, ProjectType.IMAGE_CLASSIFICATION]: config.task_type = "Category" - elif project.project_type in [SEQ2SEQ, SPEECH2TEXT]: + elif project.project_type in [ProjectType.SEQ2SEQ, ProjectType.SPEECH2TEXT]: config.task_type = "Text" - elif project.project_type in [SEQUENCE_LABELING]: + elif project.project_type in [ProjectType.SEQUENCE_LABELING]: config.task_type = "Span" else: config.task_type = "Category" diff --git a/backend/auto_labeling/tests/test_views.py b/backend/auto_labeling/tests/test_views.py index 969c2f15..8518a599 100644 --- a/backend/auto_labeling/tests/test_views.py +++ b/backend/auto_labeling/tests/test_views.py @@ -11,7 +11,7 @@ from api.tests.utils import CRUDMixin from auto_labeling.pipeline.labels import Categories, Spans, Texts from examples.tests.utils import make_doc from labels.models import Category, Span, TextLabel -from projects.models import DOCUMENT_CLASSIFICATION, SEQ2SEQ, SEQUENCE_LABELING +from projects.models import ProjectType from projects.tests.utils import prepare_project data_dir = pathlib.Path(__file__).parent / "data" @@ -19,7 +19,7 @@ data_dir = pathlib.Path(__file__).parent / "data" class TestTemplateList(CRUDMixin): def setUp(self): - self.project = prepare_project(task=DOCUMENT_CLASSIFICATION) + self.project = prepare_project(task=ProjectType.DOCUMENT_CLASSIFICATION) self.url = reverse(viewname="auto_labeling_templates", args=[self.project.item.id]) def test_allow_admin_to_fetch_template_list(self): @@ -47,7 +47,7 @@ class TestTemplateList(CRUDMixin): class TestConfigParameter(CRUDMixin): def setUp(self): - self.project = prepare_project(task=DOCUMENT_CLASSIFICATION) + self.project = prepare_project(task=ProjectType.DOCUMENT_CLASSIFICATION) self.data = { "model_name": "GCP Entity Analysis", "model_attrs": {"key": "hoge", "type": "PLAIN_TEXT", "language": "en"}, @@ -78,7 +78,7 @@ class TestConfigParameter(CRUDMixin): class TestTemplateMapping(CRUDMixin): def setUp(self): - self.project = prepare_project(task=DOCUMENT_CLASSIFICATION) + self.project = prepare_project(task=ProjectType.DOCUMENT_CLASSIFICATION) self.data = { "response": { "Sentiment": "NEUTRAL", @@ -106,7 +106,7 @@ class TestTemplateMapping(CRUDMixin): class TestLabelMapping(CRUDMixin): def setUp(self): - self.project = prepare_project(task=DOCUMENT_CLASSIFICATION) + self.project = prepare_project(task=ProjectType.DOCUMENT_CLASSIFICATION) self.data = { "response": [{"label": "NEGATIVE"}], "label_mapping": {"NEGATIVE": "Negative"}, @@ -122,7 +122,7 @@ class TestLabelMapping(CRUDMixin): class TestConfigCreation(CRUDMixin): def setUp(self): - self.project = prepare_project(task=DOCUMENT_CLASSIFICATION) + self.project = prepare_project(task=ProjectType.DOCUMENT_CLASSIFICATION) self.data = { "model_name": "Amazon Comprehend Sentiment Analysis", "model_attrs": { @@ -149,7 +149,7 @@ class TestConfigCreation(CRUDMixin): class TestAutomatedLabeling(CRUDMixin): def setUp(self): - self.project = prepare_project(task=DOCUMENT_CLASSIFICATION, single_class_classification=False) + self.project = prepare_project(task=ProjectType.DOCUMENT_CLASSIFICATION, single_class_classification=False) self.example = make_doc(self.project.item) self.category_pos = mommy.make("CategoryType", project=self.project.item, text="POS") self.category_neg = mommy.make("CategoryType", project=self.project.item, text="NEG") @@ -215,7 +215,7 @@ class TestAutomatedLabeling(CRUDMixin): class TestAutomatedSpanLabeling(CRUDMixin): def setUp(self): - self.project = prepare_project(task=SEQUENCE_LABELING) + self.project = prepare_project(task=ProjectType.SEQUENCE_LABELING) self.example = make_doc(self.project.item) self.loc = mommy.make("SpanType", project=self.project.item, text="LOC") self.url = reverse(viewname="auto_labeling", args=[self.project.item.id]) @@ -237,7 +237,7 @@ class TestAutomatedSpanLabeling(CRUDMixin): class TestAutomatedTextLabeling(CRUDMixin): def setUp(self): - self.project = prepare_project(task=SEQ2SEQ) + self.project = prepare_project(task=ProjectType.SEQ2SEQ) self.example = make_doc(self.project.item) self.url = reverse(viewname="auto_labeling", args=[self.project.item.id]) self.url += f"?example={self.example.id}" diff --git a/backend/data_export/pipeline/catalog.py b/backend/data_export/pipeline/catalog.py index 7b196836..93ffe7b0 100644 --- a/backend/data_export/pipeline/catalog.py +++ b/backend/data_export/pipeline/catalog.py @@ -2,17 +2,7 @@ from collections import defaultdict from pathlib import Path from typing import Dict, List, Type -from projects.models import ( - BOUNDING_BOX, - DOCUMENT_CLASSIFICATION, - IMAGE_CAPTIONING, - IMAGE_CLASSIFICATION, - INTENT_DETECTION_AND_SLOT_FILLING, - SEGMENTATION, - SEQ2SEQ, - SEQUENCE_LABELING, - SPEECH2TEXT, -) +from projects.models import ProjectType EXAMPLE_DIR = Path(__file__).parent.resolve() / "examples" @@ -68,40 +58,40 @@ class Options: # Text Classification TEXT_CLASSIFICATION_DIR = EXAMPLE_DIR / "text_classification" -Options.register(DOCUMENT_CLASSIFICATION, CSV, TEXT_CLASSIFICATION_DIR / "example.csv") -Options.register(DOCUMENT_CLASSIFICATION, FastText, TEXT_CLASSIFICATION_DIR / "example.txt") -Options.register(DOCUMENT_CLASSIFICATION, JSON, TEXT_CLASSIFICATION_DIR / "example.json") -Options.register(DOCUMENT_CLASSIFICATION, JSONL, TEXT_CLASSIFICATION_DIR / "example.jsonl") +Options.register(ProjectType.DOCUMENT_CLASSIFICATION, CSV, TEXT_CLASSIFICATION_DIR / "example.csv") +Options.register(ProjectType.DOCUMENT_CLASSIFICATION, FastText, TEXT_CLASSIFICATION_DIR / "example.txt") +Options.register(ProjectType.DOCUMENT_CLASSIFICATION, JSON, TEXT_CLASSIFICATION_DIR / "example.json") +Options.register(ProjectType.DOCUMENT_CLASSIFICATION, JSONL, TEXT_CLASSIFICATION_DIR / "example.jsonl") # Sequence Labeling SEQUENCE_LABELING_DIR = EXAMPLE_DIR / "sequence_labeling" RELATION_EXTRACTION_DIR = EXAMPLE_DIR / "relation_extraction" -Options.register(SEQUENCE_LABELING, JSONL, SEQUENCE_LABELING_DIR / "example.jsonl") -Options.register(SEQUENCE_LABELING, JSONL, RELATION_EXTRACTION_DIR / "example.jsonl", True) +Options.register(ProjectType.SEQUENCE_LABELING, JSONL, SEQUENCE_LABELING_DIR / "example.jsonl") +Options.register(ProjectType.SEQUENCE_LABELING, JSONL, RELATION_EXTRACTION_DIR / "example.jsonl", True) # Sequence to sequence SEQ2SEQ_DIR = EXAMPLE_DIR / "sequence_to_sequence" -Options.register(SEQ2SEQ, CSV, SEQ2SEQ_DIR / "example.csv") -Options.register(SEQ2SEQ, JSON, SEQ2SEQ_DIR / "example.json") -Options.register(SEQ2SEQ, JSONL, SEQ2SEQ_DIR / "example.jsonl") +Options.register(ProjectType.SEQ2SEQ, CSV, SEQ2SEQ_DIR / "example.csv") +Options.register(ProjectType.SEQ2SEQ, JSON, SEQ2SEQ_DIR / "example.json") +Options.register(ProjectType.SEQ2SEQ, JSONL, SEQ2SEQ_DIR / "example.jsonl") # Intent detection and slot filling INTENT_DETECTION_DIR = EXAMPLE_DIR / "intent_detection" -Options.register(INTENT_DETECTION_AND_SLOT_FILLING, JSONL, INTENT_DETECTION_DIR / "example.jsonl") +Options.register(ProjectType.INTENT_DETECTION_AND_SLOT_FILLING, JSONL, INTENT_DETECTION_DIR / "example.jsonl") # Image Classification IMAGE_CLASSIFICATION_DIR = EXAMPLE_DIR / "image_classification" -Options.register(IMAGE_CLASSIFICATION, JSONL, IMAGE_CLASSIFICATION_DIR / "example.jsonl") +Options.register(ProjectType.IMAGE_CLASSIFICATION, JSONL, IMAGE_CLASSIFICATION_DIR / "example.jsonl") BOUNDING_BOX_DIR = EXAMPLE_DIR / "bounding_box" -Options.register(BOUNDING_BOX, JSONL, BOUNDING_BOX_DIR / "example.jsonl") +Options.register(ProjectType.BOUNDING_BOX, JSONL, BOUNDING_BOX_DIR / "example.jsonl") SEGMENTATION_DIR = EXAMPLE_DIR / "segmentation" -Options.register(SEGMENTATION, JSONL, SEGMENTATION_DIR / "example.jsonl") +Options.register(ProjectType.SEGMENTATION, JSONL, SEGMENTATION_DIR / "example.jsonl") IMAGE_CAPTIONING_DIR = EXAMPLE_DIR / "image_captioning" -Options.register(IMAGE_CAPTIONING, JSONL, IMAGE_CAPTIONING_DIR / "example.jsonl") +Options.register(ProjectType.IMAGE_CAPTIONING, JSONL, IMAGE_CAPTIONING_DIR / "example.jsonl") # Speech to Text SPEECH2TEXT_DIR = EXAMPLE_DIR / "speech_to_text" -Options.register(SPEECH2TEXT, JSONL, SPEECH2TEXT_DIR / "example.jsonl") +Options.register(ProjectType.SPEECH2TEXT, JSONL, SPEECH2TEXT_DIR / "example.jsonl") diff --git a/backend/data_export/pipeline/factories.py b/backend/data_export/pipeline/factories.py index 65b226d8..bd99ee6b 100644 --- a/backend/data_export/pipeline/factories.py +++ b/backend/data_export/pipeline/factories.py @@ -16,18 +16,7 @@ from .formatters import ( ) from .labels import BoundingBoxes, Categories, Labels, Relations, Segments, Spans, Texts from data_export.models import DATA, ExportedExample -from projects.models import ( - BOUNDING_BOX, - DOCUMENT_CLASSIFICATION, - IMAGE_CAPTIONING, - IMAGE_CLASSIFICATION, - INTENT_DETECTION_AND_SLOT_FILLING, - SEGMENTATION, - SEQ2SEQ, - SEQUENCE_LABELING, - SPEECH2TEXT, - Project, -) +from projects.models import Project, ProjectType def create_writer(file_format: str) -> writers.Writer: @@ -61,7 +50,7 @@ def create_formatter(project: Project, file_format: str) -> List[Formatter]: mapper_speech2text = {DATA: "filename", Texts.column: "label"} mapping: Dict[str, Dict[str, List[Formatter]]] = { - DOCUMENT_CLASSIFICATION: { + ProjectType.DOCUMENT_CLASSIFICATION: { CSV.name: [ JoinedCategoryFormatter(Categories.column), JoinedCategoryFormatter(Comments.column), @@ -79,7 +68,7 @@ def create_formatter(project: Project, file_format: str) -> List[Formatter]: ], FastText.name: [FastTextCategoryFormatter(Categories.column)], }, - SEQUENCE_LABELING: { + ProjectType.SEQUENCE_LABELING: { JSONL.name: [ DictFormatter(Spans.column), DictFormatter(Relations.column), @@ -93,7 +82,7 @@ def create_formatter(project: Project, file_format: str) -> List[Formatter]: RenameFormatter(**mapper_sequence_labeling), ] }, - SEQ2SEQ: { + ProjectType.SEQ2SEQ: { CSV.name: [ JoinedCategoryFormatter(Texts.column), JoinedCategoryFormatter(Comments.column), @@ -110,21 +99,21 @@ def create_formatter(project: Project, file_format: str) -> List[Formatter]: RenameFormatter(**mapper_seq2seq), ], }, - IMAGE_CLASSIFICATION: { + ProjectType.IMAGE_CLASSIFICATION: { JSONL.name: [ ListedCategoryFormatter(Categories.column), ListedCategoryFormatter(Comments.column), RenameFormatter(**mapper_image_classification), ], }, - SPEECH2TEXT: { + ProjectType.SPEECH2TEXT: { JSONL.name: [ ListedCategoryFormatter(Texts.column), ListedCategoryFormatter(Comments.column), RenameFormatter(**mapper_speech2text), ], }, - INTENT_DETECTION_AND_SLOT_FILLING: { + ProjectType.INTENT_DETECTION_AND_SLOT_FILLING: { JSONL.name: [ ListedCategoryFormatter(Categories.column), TupledSpanFormatter(Spans.column), @@ -132,21 +121,21 @@ def create_formatter(project: Project, file_format: str) -> List[Formatter]: RenameFormatter(**mapper_intent_detection), ] }, - BOUNDING_BOX: { + ProjectType.BOUNDING_BOX: { JSONL.name: [ DictFormatter(BoundingBoxes.column), DictFormatter(Comments.column), RenameFormatter(**mapper_bounding_box), ] }, - SEGMENTATION: { + ProjectType.SEGMENTATION: { JSONL.name: [ DictFormatter(Segments.column), DictFormatter(Comments.column), RenameFormatter(**mapper_segmentation), ] }, - IMAGE_CAPTIONING: { + ProjectType.IMAGE_CAPTIONING: { JSONL.name: [ ListedCategoryFormatter(Texts.column), ListedCategoryFormatter(Comments.column), @@ -160,15 +149,15 @@ def create_formatter(project: Project, file_format: str) -> List[Formatter]: def select_label_collection(project: Project) -> List[Type[Labels]]: use_relation = getattr(project, "use_relation", False) mapping: Dict[str, List[Type[Labels]]] = { - DOCUMENT_CLASSIFICATION: [Categories], - SEQUENCE_LABELING: [Spans, Relations] if use_relation else [Spans], - SEQ2SEQ: [Texts], - IMAGE_CLASSIFICATION: [Categories], - SPEECH2TEXT: [Texts], - INTENT_DETECTION_AND_SLOT_FILLING: [Categories, Spans], - BOUNDING_BOX: [BoundingBoxes], - SEGMENTATION: [Segments], - IMAGE_CAPTIONING: [Texts], + ProjectType.DOCUMENT_CLASSIFICATION: [Categories], + ProjectType.SEQUENCE_LABELING: [Spans, Relations] if use_relation else [Spans], + ProjectType.SEQ2SEQ: [Texts], + ProjectType.IMAGE_CLASSIFICATION: [Categories], + ProjectType.SPEECH2TEXT: [Texts], + ProjectType.INTENT_DETECTION_AND_SLOT_FILLING: [Categories, Spans], + ProjectType.BOUNDING_BOX: [BoundingBoxes], + ProjectType.SEGMENTATION: [Segments], + ProjectType.IMAGE_CAPTIONING: [Texts], } return mapping[project.project_type] diff --git a/backend/data_export/tests/test_catalog.py b/backend/data_export/tests/test_catalog.py index f2c4c288..b0dcd724 100644 --- a/backend/data_export/tests/test_catalog.py +++ b/backend/data_export/tests/test_catalog.py @@ -1,33 +1,12 @@ import unittest from ..pipeline.catalog import Options -from projects.models import ( - BOUNDING_BOX, - DOCUMENT_CLASSIFICATION, - IMAGE_CAPTIONING, - IMAGE_CLASSIFICATION, - INTENT_DETECTION_AND_SLOT_FILLING, - SEGMENTATION, - SEQ2SEQ, - SEQUENCE_LABELING, - SPEECH2TEXT, -) +from projects.models import ProjectType class TestOptions(unittest.TestCase): def test_return_at_least_one_option(self): - tasks = [ - BOUNDING_BOX, - DOCUMENT_CLASSIFICATION, - IMAGE_CAPTIONING, - IMAGE_CLASSIFICATION, - INTENT_DETECTION_AND_SLOT_FILLING, - SEGMENTATION, - SEQ2SEQ, - SEQUENCE_LABELING, - SPEECH2TEXT, - ] - for task in tasks: + for task in ProjectType: with self.subTest(task=task): options = Options.filter_by_task(task) self.assertGreaterEqual(len(options), 1) diff --git a/backend/data_export/tests/test_labels.py b/backend/data_export/tests/test_labels.py index 5bbc6c44..6511dae2 100644 --- a/backend/data_export/tests/test_labels.py +++ b/backend/data_export/tests/test_labels.py @@ -3,13 +3,13 @@ from model_mommy import mommy from ..pipeline.labels import Categories from data_export.models import ExportedExample -from projects.models import DOCUMENT_CLASSIFICATION +from projects.models import ProjectType from projects.tests.utils import prepare_project class TestLabels(TestCase): def setUp(self): - self.project = prepare_project(task=DOCUMENT_CLASSIFICATION) + self.project = prepare_project(task=ProjectType.DOCUMENT_CLASSIFICATION) self.example1 = mommy.make("ExportedExample", project=self.project.item) self.example2 = mommy.make("ExportedExample", project=self.project.item) self.category1 = mommy.make("ExportedCategory", example=self.example1, user=self.project.admin) diff --git a/backend/data_export/tests/test_task.py b/backend/data_export/tests/test_task.py index 31b65687..97f9a1fd 100644 --- a/backend/data_export/tests/test_task.py +++ b/backend/data_export/tests/test_task.py @@ -7,17 +7,7 @@ from model_mommy import mommy from ..celery_tasks import export_dataset from data_export.models import DATA -from projects.models import ( - BOUNDING_BOX, - DOCUMENT_CLASSIFICATION, - IMAGE_CAPTIONING, - IMAGE_CLASSIFICATION, - INTENT_DETECTION_AND_SLOT_FILLING, - SEGMENTATION, - SEQ2SEQ, - SEQUENCE_LABELING, - SPEECH2TEXT, -) +from projects.models import ProjectType from projects.tests.utils import prepare_project @@ -59,7 +49,7 @@ class TestExport(TestCase): class TestExportCategory(TestExport): def prepare_data(self, collaborative=False): - self.project = prepare_project(DOCUMENT_CLASSIFICATION, collaborative_annotation=collaborative) + self.project = prepare_project(ProjectType.DOCUMENT_CLASSIFICATION, collaborative_annotation=collaborative) self.example1 = mommy.make("ExportedExample", project=self.project.item, text="example1") self.example2 = mommy.make("ExportedExample", project=self.project.item, text="example2") self.category1 = mommy.make("ExportedCategory", example=self.example1, user=self.project.admin) @@ -129,7 +119,7 @@ class TestExportCategory(TestExport): class TestExportSeq2seq(TestExport): def prepare_data(self, collaborative=False): - self.project = prepare_project(SEQ2SEQ, collaborative_annotation=collaborative) + self.project = prepare_project(ProjectType.SEQ2SEQ, collaborative_annotation=collaborative) self.example1 = mommy.make("ExportedExample", project=self.project.item, text="confirmed") self.example2 = mommy.make("ExportedExample", project=self.project.item, text="unconfirmed") self.text1 = mommy.make("TextLabel", example=self.example1, user=self.project.admin) @@ -201,7 +191,9 @@ class TestExportSeq2seq(TestExport): class TestExportIntentDetectionAndSlotFilling(TestExport): def prepare_data(self, collaborative=False): - self.project = prepare_project(INTENT_DETECTION_AND_SLOT_FILLING, collaborative_annotation=collaborative) + self.project = prepare_project( + ProjectType.INTENT_DETECTION_AND_SLOT_FILLING, collaborative_annotation=collaborative + ) self.example1 = mommy.make("ExportedExample", project=self.project.item, text="confirmed") self.example2 = mommy.make("ExportedExample", project=self.project.item, text="unconfirmed") self.category1 = mommy.make("ExportedCategory", example=self.example1, user=self.project.admin) @@ -293,7 +285,7 @@ class TestExportIntentDetectionAndSlotFilling(TestExport): class TestExportSequenceLabeling(TestExport): def prepare_data(self, collaborative=False): - self.project = prepare_project(SEQUENCE_LABELING, collaborative_annotation=collaborative) + self.project = prepare_project(ProjectType.SEQUENCE_LABELING, collaborative_annotation=collaborative) self.example1 = mommy.make("ExportedExample", project=self.project.item, text="confirmed") self.span1 = mommy.make( "ExportedSpan", example=self.example1, user=self.project.admin, start_offset=0, end_offset=1 @@ -369,7 +361,7 @@ class TestExportSequenceLabeling(TestExport): class TestExportSpeechToText(TestExport): def prepare_data(self, collaborative=False): - self.project = prepare_project(SPEECH2TEXT, collaborative_annotation=collaborative) + self.project = prepare_project(ProjectType.SPEECH2TEXT, collaborative_annotation=collaborative) self.example1 = mommy.make("ExportedExample", project=self.project.item, text="confirmed") self.example2 = mommy.make("ExportedExample", project=self.project.item, text="unconfirmed") self.text1 = mommy.make("TextLabel", example=self.example1, user=self.project.admin) @@ -441,7 +433,7 @@ class TestExportSpeechToText(TestExport): class TestExportImageClassification(TestExport): def prepare_data(self, collaborative=False): - self.project = prepare_project(IMAGE_CLASSIFICATION, collaborative_annotation=collaborative) + self.project = prepare_project(ProjectType.IMAGE_CLASSIFICATION, collaborative_annotation=collaborative) self.example1 = mommy.make("ExportedExample", project=self.project.item, text="confirmed") self.example2 = mommy.make("ExportedExample", project=self.project.item, text="unconfirmed") self.category1 = mommy.make("ExportedCategory", example=self.example1, user=self.project.admin) @@ -511,7 +503,7 @@ class TestExportImageClassification(TestExport): class TestExportBoundingBox(TestExport): def prepare_data(self, collaborative=False): - self.project = prepare_project(BOUNDING_BOX, collaborative_annotation=collaborative) + self.project = prepare_project(ProjectType.BOUNDING_BOX, collaborative_annotation=collaborative) self.example1 = mommy.make("ExportedExample", project=self.project.item, text="confirmed") self.example2 = mommy.make("ExportedExample", project=self.project.item, text="unconfirmed") self.comment1 = mommy.make("ExportedComment", example=self.example1, user=self.project.admin) @@ -589,7 +581,7 @@ class TestExportBoundingBox(TestExport): class TestExportSegmentation(TestExport): def prepare_data(self, collaborative=False): - self.project = prepare_project(SEGMENTATION, collaborative_annotation=collaborative) + self.project = prepare_project(ProjectType.SEGMENTATION, collaborative_annotation=collaborative) self.example1 = mommy.make("ExportedExample", project=self.project.item, text="confirmed") self.example2 = mommy.make("ExportedExample", project=self.project.item, text="unconfirmed") self.comment1 = mommy.make("ExportedComment", example=self.example1, user=self.project.admin) @@ -662,7 +654,7 @@ class TestExportSegmentation(TestExport): class TestExportImageCaptioning(TestExport): def prepare_data(self, collaborative=False): - self.project = prepare_project(IMAGE_CAPTIONING, collaborative_annotation=collaborative) + self.project = prepare_project(ProjectType.IMAGE_CAPTIONING, collaborative_annotation=collaborative) self.example1 = mommy.make("ExportedExample", project=self.project.item, text="confirmed") self.example2 = mommy.make("ExportedExample", project=self.project.item, text="unconfirmed") self.comment1 = mommy.make("ExportedComment", example=self.example1, user=self.project.admin) @@ -735,7 +727,9 @@ class TestExportImageCaptioning(TestExport): class TestExportRelation(TestExport): def prepare_data(self, collaborative=False): - self.project = prepare_project(SEQUENCE_LABELING, use_relation=True, collaborative_annotation=collaborative) + self.project = prepare_project( + ProjectType.SEQUENCE_LABELING, use_relation=True, collaborative_annotation=collaborative + ) self.example1 = mommy.make("ExportedExample", project=self.project.item, text="example") self.example2 = mommy.make("ExportedExample", project=self.project.item, text="unconfirmed") self.span1 = mommy.make( diff --git a/backend/data_export/tests/test_views.py b/backend/data_export/tests/test_views.py index ca3ad8a1..1ec814a0 100644 --- a/backend/data_export/tests/test_views.py +++ b/backend/data_export/tests/test_views.py @@ -2,13 +2,13 @@ from rest_framework import status from rest_framework.reverse import reverse from api.tests.utils import CRUDMixin -from projects.models import DOCUMENT_CLASSIFICATION +from projects.models import ProjectType from projects.tests.utils import prepare_project class TestDownloadCatalog(CRUDMixin): def setUp(self): - self.project = prepare_project(task=DOCUMENT_CLASSIFICATION) + self.project = prepare_project(task=ProjectType.DOCUMENT_CLASSIFICATION) self.url = reverse(viewname="download-format", args=[self.project.item.id]) def test_allows_project_admin_to_list_catalog(self): diff --git a/backend/data_import/datasets.py b/backend/data_import/datasets.py index 473c00e6..cf0cc0bc 100644 --- a/backend/data_import/datasets.py +++ b/backend/data_import/datasets.py @@ -20,18 +20,7 @@ from .pipeline.readers import ( Reader, ) from label_types.models import CategoryType, LabelType, RelationType, SpanType -from projects.models import ( - BOUNDING_BOX, - DOCUMENT_CLASSIFICATION, - IMAGE_CAPTIONING, - IMAGE_CLASSIFICATION, - INTENT_DETECTION_AND_SLOT_FILLING, - SEGMENTATION, - SEQ2SEQ, - SEQUENCE_LABELING, - SPEECH2TEXT, - Project, -) +from projects.models import Project, ProjectType class Dataset(abc.ABC): @@ -215,16 +204,16 @@ class CategoryAndSpanDataset(Dataset): def select_dataset(project: Project, task: str, file_format: Format) -> Type[Dataset]: mapping = { - DOCUMENT_CLASSIFICATION: TextClassificationDataset, - SEQUENCE_LABELING: SequenceLabelingDataset, + ProjectType.DOCUMENT_CLASSIFICATION: TextClassificationDataset, + ProjectType.SEQUENCE_LABELING: SequenceLabelingDataset, RELATION_EXTRACTION: RelationExtractionDataset, - SEQ2SEQ: Seq2seqDataset, - INTENT_DETECTION_AND_SLOT_FILLING: CategoryAndSpanDataset, - IMAGE_CLASSIFICATION: BinaryDataset, - IMAGE_CAPTIONING: BinaryDataset, - BOUNDING_BOX: BinaryDataset, - SEGMENTATION: BinaryDataset, - SPEECH2TEXT: BinaryDataset, + ProjectType.SEQ2SEQ: Seq2seqDataset, + ProjectType.INTENT_DETECTION_AND_SLOT_FILLING: CategoryAndSpanDataset, + ProjectType.IMAGE_CLASSIFICATION: BinaryDataset, + ProjectType.IMAGE_CAPTIONING: BinaryDataset, + ProjectType.BOUNDING_BOX: BinaryDataset, + ProjectType.SEGMENTATION: BinaryDataset, + ProjectType.SPEECH2TEXT: BinaryDataset, } if task not in mapping: task = project.project_type diff --git a/backend/data_import/pipeline/catalog.py b/backend/data_import/pipeline/catalog.py index af2e2c42..6acc794f 100644 --- a/backend/data_import/pipeline/catalog.py +++ b/backend/data_import/pipeline/catalog.py @@ -7,17 +7,7 @@ from pydantic import BaseModel from typing_extensions import Literal from .exceptions import FileFormatException -from projects.models import ( - BOUNDING_BOX, - DOCUMENT_CLASSIFICATION, - IMAGE_CAPTIONING, - IMAGE_CLASSIFICATION, - INTENT_DETECTION_AND_SLOT_FILLING, - SEGMENTATION, - SEQ2SEQ, - SEQUENCE_LABELING, - SPEECH2TEXT, -) +from projects.models import ProjectType # Define the example directories EXAMPLE_DIR = Path(__file__).parent.resolve() / "examples" @@ -287,7 +277,12 @@ class Options: # Text tasks -text_tasks = [DOCUMENT_CLASSIFICATION, SEQUENCE_LABELING, SEQ2SEQ, INTENT_DETECTION_AND_SLOT_FILLING] +text_tasks = [ + ProjectType.DOCUMENT_CLASSIFICATION, + ProjectType.SEQUENCE_LABELING, + ProjectType.SEQ2SEQ, + ProjectType.INTENT_DETECTION_AND_SLOT_FILLING, +] for task_id in text_tasks: Options.register( Option( @@ -312,7 +307,7 @@ for task_id in text_tasks: Options.register( Option( display_name=CSV.name, - task_id=DOCUMENT_CLASSIFICATION, + task_id=ProjectType.DOCUMENT_CLASSIFICATION, file_format=CSV, arg=ArgDelimiter, file=TEXT_CLASSIFICATION_DIR / "example.csv", @@ -321,7 +316,7 @@ Options.register( Options.register( Option( display_name=FastText.name, - task_id=DOCUMENT_CLASSIFICATION, + task_id=ProjectType.DOCUMENT_CLASSIFICATION, file_format=FastText, arg=ArgEncoding, file=TEXT_CLASSIFICATION_DIR / "example.txt", @@ -330,7 +325,7 @@ Options.register( Options.register( Option( display_name=JSON.name, - task_id=DOCUMENT_CLASSIFICATION, + task_id=ProjectType.DOCUMENT_CLASSIFICATION, file_format=JSON, arg=ArgColumn, file=TEXT_CLASSIFICATION_DIR / "example.json", @@ -339,7 +334,7 @@ Options.register( Options.register( Option( display_name=JSONL.name, - task_id=DOCUMENT_CLASSIFICATION, + task_id=ProjectType.DOCUMENT_CLASSIFICATION, file_format=JSONL, arg=ArgColumn, file=TEXT_CLASSIFICATION_DIR / "example.jsonl", @@ -348,7 +343,7 @@ Options.register( Options.register( Option( display_name=Excel.name, - task_id=DOCUMENT_CLASSIFICATION, + task_id=ProjectType.DOCUMENT_CLASSIFICATION, file_format=Excel, arg=ArgColumn, file=TEXT_CLASSIFICATION_DIR / "example.csv", @@ -359,7 +354,7 @@ Options.register( Options.register( Option( display_name=JSONL.name, - task_id=SEQUENCE_LABELING, + task_id=ProjectType.SEQUENCE_LABELING, file_format=JSONL, arg=ArgColumn, file=SEQUENCE_LABELING_DIR / "example.jsonl", @@ -368,7 +363,7 @@ Options.register( Options.register( Option( display_name=CoNLL.name, - task_id=SEQUENCE_LABELING, + task_id=ProjectType.SEQUENCE_LABELING, file_format=CoNLL, arg=ArgCoNLL, file=SEQUENCE_LABELING_DIR / "example.txt", @@ -390,7 +385,7 @@ Options.register( Options.register( Option( display_name=CSV.name, - task_id=SEQ2SEQ, + task_id=ProjectType.SEQ2SEQ, file_format=CSV, arg=ArgDelimiter, file=SEQ2SEQ_DIR / "example.csv", @@ -399,7 +394,7 @@ Options.register( Options.register( Option( display_name=JSON.name, - task_id=SEQ2SEQ, + task_id=ProjectType.SEQ2SEQ, file_format=JSON, arg=ArgColumn, file=SEQ2SEQ_DIR / "example.json", @@ -408,7 +403,7 @@ Options.register( Options.register( Option( display_name=JSONL.name, - task_id=SEQ2SEQ, + task_id=ProjectType.SEQ2SEQ, file_format=JSONL, arg=ArgColumn, file=SEQ2SEQ_DIR / "example.jsonl", @@ -417,7 +412,7 @@ Options.register( Options.register( Option( display_name=Excel.name, - task_id=SEQ2SEQ, + task_id=ProjectType.SEQ2SEQ, file_format=Excel, arg=ArgColumn, file=SEQ2SEQ_DIR / "example.csv", @@ -428,7 +423,7 @@ Options.register( Options.register( Option( display_name=JSONL.name, - task_id=INTENT_DETECTION_AND_SLOT_FILLING, + task_id=ProjectType.INTENT_DETECTION_AND_SLOT_FILLING, file_format=JSONL, arg=ArgNone, file=INTENT_DETECTION_DIR / "example.jsonl", @@ -436,7 +431,12 @@ Options.register( ) # Image tasks -image_tasks = [IMAGE_CLASSIFICATION, IMAGE_CAPTIONING, BOUNDING_BOX, SEGMENTATION] +image_tasks = [ + ProjectType.IMAGE_CLASSIFICATION, + ProjectType.IMAGE_CAPTIONING, + ProjectType.BOUNDING_BOX, + ProjectType.SEGMENTATION, +] for task_name in image_tasks: Options.register( Option( @@ -452,7 +452,7 @@ for task_name in image_tasks: Options.register( Option( display_name=AudioFile.name, - task_id=SPEECH2TEXT, + task_id=ProjectType.SPEECH2TEXT, file_format=AudioFile, arg=ArgNone, file=SPEECH_TO_TEXT_DIR / "audio_files.txt", diff --git a/backend/data_import/tests/test_catalog.py b/backend/data_import/tests/test_catalog.py index e702cd81..dc10e615 100644 --- a/backend/data_import/tests/test_catalog.py +++ b/backend/data_import/tests/test_catalog.py @@ -1,27 +1,12 @@ import unittest from data_import.pipeline.catalog import Options -from projects.models import ( - DOCUMENT_CLASSIFICATION, - IMAGE_CLASSIFICATION, - INTENT_DETECTION_AND_SLOT_FILLING, - SEQ2SEQ, - SEQUENCE_LABELING, - SPEECH2TEXT, -) +from projects.models import ProjectType class TestOptions(unittest.TestCase): def test_return_at_least_one_option(self): - tasks = [ - DOCUMENT_CLASSIFICATION, - IMAGE_CLASSIFICATION, - INTENT_DETECTION_AND_SLOT_FILLING, - SEQ2SEQ, - SEQUENCE_LABELING, - SPEECH2TEXT, - ] - for task in tasks: + for task in ProjectType: with self.subTest(task=task): options = Options.filter_by_task(task) self.assertGreaterEqual(len(options), 1) diff --git a/backend/data_import/tests/test_examples.py b/backend/data_import/tests/test_examples.py index 4ce574d3..d356c414 100644 --- a/backend/data_import/tests/test_examples.py +++ b/backend/data_import/tests/test_examples.py @@ -4,13 +4,13 @@ from django.test import TestCase from data_import.pipeline.examples import Examples from examples.models import Example -from projects.models import DOCUMENT_CLASSIFICATION +from projects.models import ProjectType from projects.tests.utils import prepare_project class TestExamples(TestCase): def setUp(self): - self.project = prepare_project(DOCUMENT_CLASSIFICATION) + self.project = prepare_project(ProjectType.DOCUMENT_CLASSIFICATION) self.example_uuid = uuid.uuid4() example = Example(uuid=self.example_uuid, text="A", project=self.project.item) self.examples = Examples([example]) diff --git a/backend/data_import/tests/test_label.py b/backend/data_import/tests/test_label.py index 362c2230..3717dc91 100644 --- a/backend/data_import/tests/test_label.py +++ b/backend/data_import/tests/test_label.py @@ -15,7 +15,7 @@ from labels.models import Category as CategoryModel from labels.models import Relation as RelationModel from labels.models import Span as SpanModel from labels.models import TextLabel as TextModel -from projects.models import DOCUMENT_CLASSIFICATION, SEQ2SEQ, SEQUENCE_LABELING +from projects.models import ProjectType from projects.tests.utils import prepare_project @@ -29,7 +29,7 @@ class TestLabel(TestCase): class TestCategoryLabel(TestLabel): - task = DOCUMENT_CLASSIFICATION + task = ProjectType.DOCUMENT_CLASSIFICATION def test_comparison(self): category1 = CategoryLabel(label="A", example_uuid=uuid.uuid4()) @@ -61,7 +61,7 @@ class TestCategoryLabel(TestLabel): class TestSpanLabel(TestLabel): - task = SEQUENCE_LABELING + task = ProjectType.SEQUENCE_LABELING def test_comparison(self): span1 = SpanLabel(label="A", start_offset=0, end_offset=1, example_uuid=uuid.uuid4()) @@ -110,7 +110,7 @@ class TestSpanLabel(TestLabel): class TestTextLabel(TestLabel): - task = SEQ2SEQ + task = ProjectType.SEQ2SEQ def test_comparison(self): text1 = TextLabel(text="A", example_uuid=uuid.uuid4()) @@ -140,7 +140,7 @@ class TestTextLabel(TestLabel): class TestRelationLabel(TestLabel): - task = SEQUENCE_LABELING + task = ProjectType.SEQUENCE_LABELING def test_comparison(self): relation1 = RelationLabel(type="A", from_id=0, to_id=1, example_uuid=uuid.uuid4()) diff --git a/backend/data_import/tests/test_label_types.py b/backend/data_import/tests/test_label_types.py index d1195fdf..af0a8535 100644 --- a/backend/data_import/tests/test_label_types.py +++ b/backend/data_import/tests/test_label_types.py @@ -3,13 +3,13 @@ from model_mommy import mommy from data_import.pipeline.label_types import LabelTypes from label_types.models import CategoryType -from projects.models import DOCUMENT_CLASSIFICATION +from projects.models import ProjectType from projects.tests.utils import prepare_project class TestCategoryLabel(TestCase): def setUp(self): - self.project = prepare_project(DOCUMENT_CLASSIFICATION) + self.project = prepare_project(ProjectType.DOCUMENT_CLASSIFICATION) self.user = self.project.admin self.example = mommy.make("Example", project=self.project.item) diff --git a/backend/data_import/tests/test_labels.py b/backend/data_import/tests/test_labels.py index 3ed395ea..1d398052 100644 --- a/backend/data_import/tests/test_labels.py +++ b/backend/data_import/tests/test_labels.py @@ -16,14 +16,14 @@ from data_import.pipeline.labels import Categories, Relations, Spans, Texts from label_types.models import CategoryType, RelationType, SpanType from labels.models import Category, Relation, Span from labels.models import TextLabel as TextLabelModel -from projects.models import DOCUMENT_CLASSIFICATION, SEQUENCE_LABELING +from projects.models import ProjectType from projects.tests.utils import prepare_project class TestCategories(TestCase): def setUp(self): self.types = LabelTypes(CategoryType) - self.project = prepare_project(DOCUMENT_CLASSIFICATION) + self.project = prepare_project(ProjectType.DOCUMENT_CLASSIFICATION) self.user = self.project.admin example_uuid = uuid.uuid4() labels = [ @@ -59,7 +59,7 @@ class TestCategories(TestCase): class TestSpans(TestCase): def setUp(self): self.types = LabelTypes(SpanType) - self.project = prepare_project(SEQUENCE_LABELING, allow_overlapping=True) + self.project = prepare_project(ProjectType.SEQUENCE_LABELING, allow_overlapping=True) self.user = self.project.admin example_uuid = uuid.uuid4() labels = [ @@ -113,7 +113,7 @@ class TestSpans(TestCase): class TestTexts(TestCase): def setUp(self): self.types = LabelTypes(DummyLabelType) - self.project = prepare_project(SEQUENCE_LABELING) + self.project = prepare_project(ProjectType.SEQUENCE_LABELING) self.user = self.project.admin example_uuid = uuid.uuid4() labels = [ @@ -143,7 +143,7 @@ class TestTexts(TestCase): class TestRelations(TestCase): def setUp(self): self.types = LabelTypes(RelationType) - self.project = prepare_project(SEQUENCE_LABELING, use_relation=True) + self.project = prepare_project(ProjectType.SEQUENCE_LABELING, use_relation=True) self.user = self.project.admin example_uuid = uuid.uuid4() example = mommy.make("Example", project=self.project.item, uuid=example_uuid, text="hello world") diff --git a/backend/data_import/tests/test_tasks.py b/backend/data_import/tests/test_tasks.py index d6696916..e3c98cb5 100644 --- a/backend/data_import/tests/test_tasks.py +++ b/backend/data_import/tests/test_tasks.py @@ -12,13 +12,7 @@ from data_import.pipeline.catalog import RELATION_EXTRACTION from examples.models import Example from label_types.models import SpanType from labels.models import Category, Span -from projects.models import ( - DOCUMENT_CLASSIFICATION, - IMAGE_CLASSIFICATION, - INTENT_DETECTION_AND_SLOT_FILLING, - SEQ2SEQ, - SEQUENCE_LABELING, -) +from projects.models import ProjectType from projects.tests.utils import prepare_project @@ -57,7 +51,7 @@ class TestImportData(TestCase): @override_settings(MAX_UPLOAD_SIZE=0) class TestMaxFileSize(TestImportData): - task = DOCUMENT_CLASSIFICATION + task = ProjectType.DOCUMENT_CLASSIFICATION def test_jsonl(self): filename = "text_classification/example.jsonl" @@ -69,7 +63,7 @@ class TestMaxFileSize(TestImportData): class TestInvalidFileFormat(TestImportData): - task = DOCUMENT_CLASSIFICATION + task = ProjectType.DOCUMENT_CLASSIFICATION def test_invalid_file_format(self): filename = "text_classification/example.csv" @@ -79,7 +73,7 @@ class TestInvalidFileFormat(TestImportData): class TestImportClassificationData(TestImportData): - task = DOCUMENT_CLASSIFICATION + task = ProjectType.DOCUMENT_CLASSIFICATION def assert_examples(self, dataset): with self.subTest(): @@ -180,7 +174,7 @@ class TestImportClassificationData(TestImportData): class TestImportSequenceLabelingData(TestImportData): - task = SEQUENCE_LABELING + task = ProjectType.SEQUENCE_LABELING def assert_examples(self, dataset): self.assertEqual(Example.objects.count(), len(dataset)) @@ -223,7 +217,7 @@ class TestImportSequenceLabelingData(TestImportData): class TestImportRelationExtractionData(TestImportData): - task = SEQUENCE_LABELING + task = ProjectType.SEQUENCE_LABELING def setUp(self): self.project = prepare_project(self.task, use_relation=True) @@ -259,7 +253,7 @@ class TestImportRelationExtractionData(TestImportData): class TestImportSeq2seqData(TestImportData): - task = SEQ2SEQ + task = ProjectType.SEQ2SEQ def assert_examples(self, dataset): self.assertEqual(Example.objects.count(), len(dataset)) @@ -291,7 +285,7 @@ class TestImportSeq2seqData(TestImportData): class TestImportIntentDetectionAndSlotFillingData(TestImportData): - task = INTENT_DETECTION_AND_SLOT_FILLING + task = ProjectType.INTENT_DETECTION_AND_SLOT_FILLING def assert_examples(self, dataset): self.assertEqual(Example.objects.count(), len(dataset)) @@ -316,7 +310,7 @@ class TestImportIntentDetectionAndSlotFillingData(TestImportData): class TestImportImageClassificationData(TestImportData): - task = IMAGE_CLASSIFICATION + task = ProjectType.IMAGE_CLASSIFICATION def test_example(self): filename = "images/1500x500.jpeg" @@ -327,7 +321,7 @@ class TestImportImageClassificationData(TestImportData): @override_settings(ENABLE_FILE_TYPE_CHECK=True) class TestFileTypeChecking(TestImportData): - task = IMAGE_CLASSIFICATION + task = ProjectType.IMAGE_CLASSIFICATION def test_example(self): filename = "images/example.ico" diff --git a/backend/data_import/tests/test_views.py b/backend/data_import/tests/test_views.py index d0f49912..0b00b9b5 100644 --- a/backend/data_import/tests/test_views.py +++ b/backend/data_import/tests/test_views.py @@ -2,13 +2,13 @@ from rest_framework import status from rest_framework.reverse import reverse from api.tests.utils import CRUDMixin -from projects.models import DOCUMENT_CLASSIFICATION +from projects.models import ProjectType from projects.tests.utils import prepare_project class TestImportCatalog(CRUDMixin): def setUp(self): - self.project = prepare_project(task=DOCUMENT_CLASSIFICATION) + self.project = prepare_project(task=ProjectType.DOCUMENT_CLASSIFICATION) self.url = reverse(viewname="catalog", args=[self.project.item.id]) def test_allows_project_admin_to_list_catalog(self): diff --git a/backend/examples/tests/test_document.py b/backend/examples/tests/test_document.py index ccb67f49..639cb686 100644 --- a/backend/examples/tests/test_document.py +++ b/backend/examples/tests/test_document.py @@ -5,14 +5,14 @@ from rest_framework.reverse import reverse from .utils import make_doc, make_example_state from api.tests.utils import CRUDMixin -from projects.models import DOCUMENT_CLASSIFICATION +from projects.models import ProjectType from projects.tests.utils import assign_user_to_role, prepare_project from users.tests.utils import make_user class TestExampleListAPI(CRUDMixin): def setUp(self): - self.project = prepare_project(task=DOCUMENT_CLASSIFICATION) + self.project = prepare_project(task=ProjectType.DOCUMENT_CLASSIFICATION) self.non_member = make_user() self.example = make_doc(self.project.item) self.data = {"text": "example"} @@ -60,7 +60,7 @@ class TestExampleListAPI(CRUDMixin): class TestExampleListCollaborative(CRUDMixin): def setUp(self): - self.project = prepare_project(task=DOCUMENT_CLASSIFICATION) + self.project = prepare_project(task=ProjectType.DOCUMENT_CLASSIFICATION) self.example = make_doc(self.project.item) self.url = reverse(viewname="example_list", args=[self.project.item.id]) @@ -89,7 +89,7 @@ class TestExampleListCollaborative(CRUDMixin): class TestExampleListFilter(CRUDMixin): def setUp(self): - self.project = prepare_project(task=DOCUMENT_CLASSIFICATION) + self.project = prepare_project(task=ProjectType.DOCUMENT_CLASSIFICATION) self.example = make_doc(self.project.item) make_example_state(self.example, self.project.admin) @@ -129,7 +129,7 @@ class TestExampleListFilter(CRUDMixin): class TestExampleDetail(CRUDMixin): def setUp(self): - self.project = prepare_project(task=DOCUMENT_CLASSIFICATION) + self.project = prepare_project(task=ProjectType.DOCUMENT_CLASSIFICATION) self.non_member = make_user() doc = make_doc(self.project.item) self.data = {"text": "example"} diff --git a/backend/examples/tests/test_models.py b/backend/examples/tests/test_models.py index c7b70808..0b406a6a 100644 --- a/backend/examples/tests/test_models.py +++ b/backend/examples/tests/test_models.py @@ -2,13 +2,13 @@ from django.test import TestCase from model_mommy import mommy from examples.models import ExampleState -from projects.models import IMAGE_CLASSIFICATION, SEQUENCE_LABELING +from projects.models import ProjectType from projects.tests.utils import prepare_project class TestExampleState(TestCase): def setUp(self): - self.project = prepare_project(SEQUENCE_LABELING) + self.project = prepare_project(ProjectType.SEQUENCE_LABELING) self.example = mommy.make("Example", project=self.project.item) self.other = mommy.make("Example", project=self.project.item) self.examples = self.project.item.examples.all() @@ -61,11 +61,11 @@ class TestExampleState(TestCase): class TestExample(TestCase): def test_text_project_returns_text_as_data_property(self): - project = prepare_project(SEQUENCE_LABELING) + project = prepare_project(ProjectType.SEQUENCE_LABELING) example = mommy.make("Example", project=project.item) self.assertEqual(example.text, example.data) def test_image_project_returns_filename_as_data_property(self): - project = prepare_project(IMAGE_CLASSIFICATION) + project = prepare_project(ProjectType.IMAGE_CLASSIFICATION) example = mommy.make("Example", project=project.item) self.assertEqual(str(example.filename), example.data) diff --git a/backend/label_types/tests/test_views.py b/backend/label_types/tests/test_views.py index 0b46d12a..00198e9c 100644 --- a/backend/label_types/tests/test_views.py +++ b/backend/label_types/tests/test_views.py @@ -7,7 +7,7 @@ from rest_framework.test import APITestCase from .utils import make_label from api.tests.utils import CRUDMixin -from projects.models import DOCUMENT_CLASSIFICATION +from projects.models import ProjectType from projects.tests.utils import make_project, prepare_project from users.tests.utils import make_user @@ -18,7 +18,7 @@ class TestLabelList(CRUDMixin): @classmethod def setUpTestData(cls): cls.non_member = make_user() - cls.project_a = prepare_project(DOCUMENT_CLASSIFICATION) + cls.project_a = prepare_project(ProjectType.DOCUMENT_CLASSIFICATION) cls.label = make_label(cls.project_a.item) cls.url = reverse(viewname="category_types", args=[cls.project_a.item.id]) @@ -41,7 +41,7 @@ class TestLabelList(CRUDMixin): class TestLabelSearch(CRUDMixin): def setUp(self): - self.project = prepare_project(DOCUMENT_CLASSIFICATION) + self.project = prepare_project(ProjectType.DOCUMENT_CLASSIFICATION) make_label(self.project.item) self.url = reverse(viewname="category_types", args=[self.project.item.id]) @@ -54,7 +54,7 @@ class TestLabelSearch(CRUDMixin): class TestLabelCreate(CRUDMixin): @classmethod def setUpTestData(cls): - cls.non_member = make_user(DOCUMENT_CLASSIFICATION) + cls.non_member = make_user(ProjectType.DOCUMENT_CLASSIFICATION) cls.project = prepare_project() cls.url = reverse(viewname="category_types", args=[cls.project.item.id]) cls.data = {"text": "example"} @@ -77,7 +77,7 @@ class TestLabelDetailAPI(CRUDMixin): @classmethod def setUpTestData(cls): cls.non_member = make_user() - cls.project = prepare_project(DOCUMENT_CLASSIFICATION) + cls.project = prepare_project(ProjectType.DOCUMENT_CLASSIFICATION) cls.label = make_label(cls.project.item) cls.url = reverse(viewname="category_type", args=[cls.project.item.id, cls.label.id]) cls.data = {"text": "example"} @@ -125,7 +125,7 @@ class TestLabelUploadAPI(APITestCase): @classmethod def setUpTestData(cls): cls.non_member = make_user() - cls.project = prepare_project(DOCUMENT_CLASSIFICATION) + cls.project = prepare_project(ProjectType.DOCUMENT_CLASSIFICATION) cls.url = reverse(viewname="category_type_upload", args=[cls.project.item.id]) def assert_upload_file(self, filename, user=None, expected_status=status.HTTP_403_FORBIDDEN): diff --git a/backend/label_types/tests/utils.py b/backend/label_types/tests/utils.py index a6a5ca5c..19985b21 100644 --- a/backend/label_types/tests/utils.py +++ b/backend/label_types/tests/utils.py @@ -1,10 +1,13 @@ from model_mommy import mommy -from projects.models import BOUNDING_BOX, SEGMENTATION +from projects.models import ProjectType def make_label(project, **kwargs): - if project.project_type.endswith("Classification") or project.project_type in {BOUNDING_BOX, SEGMENTATION}: + if project.project_type.endswith("Classification") or project.project_type in { + ProjectType.BOUNDING_BOX, + ProjectType.SEGMENTATION, + }: return mommy.make("CategoryType", project=project, **kwargs) else: return mommy.make("SpanType", project=project, **kwargs) diff --git a/backend/labels/tests/test_category.py b/backend/labels/tests/test_category.py index 4cdb6b27..3bcfbdfd 100644 --- a/backend/labels/tests/test_category.py +++ b/backend/labels/tests/test_category.py @@ -5,7 +5,7 @@ from django.test import TestCase from model_mommy import mommy from labels.models import Category -from projects.models import DOCUMENT_CLASSIFICATION +from projects.models import ProjectType from projects.tests.utils import prepare_project @@ -16,7 +16,7 @@ class TestCategoryLabeling(abc.ABC, TestCase): @classmethod def setUpTestData(cls): cls.project = prepare_project( - DOCUMENT_CLASSIFICATION, + ProjectType.DOCUMENT_CLASSIFICATION, single_class_classification=cls.exclusive, collaborative_annotation=cls.collaborative, ) diff --git a/backend/labels/tests/test_relation.py b/backend/labels/tests/test_relation.py index 738a2b7e..1612b913 100644 --- a/backend/labels/tests/test_relation.py +++ b/backend/labels/tests/test_relation.py @@ -2,14 +2,14 @@ from django.core.exceptions import ValidationError from django.test import TestCase from model_mommy import mommy -from projects.models import SEQUENCE_LABELING +from projects.models import ProjectType from projects.tests.utils import prepare_project class TestRelationLabeling(TestCase): @classmethod def setUpTestData(cls): - cls.project = prepare_project(SEQUENCE_LABELING) + cls.project = prepare_project(ProjectType.SEQUENCE_LABELING) cls.example = mommy.make("Example", project=cls.project.item) cls.label_type = mommy.make("RelationType", project=cls.project.item) cls.user = cls.project.admin diff --git a/backend/labels/tests/test_span.py b/backend/labels/tests/test_span.py index 715fd2e5..1213ac8e 100644 --- a/backend/labels/tests/test_span.py +++ b/backend/labels/tests/test_span.py @@ -6,7 +6,7 @@ from model_mommy import mommy from label_types.models import SpanType from labels.models import Span -from projects.models import SEQUENCE_LABELING +from projects.models import ProjectType from projects.tests.utils import prepare_project @@ -17,7 +17,7 @@ class TestSpanLabeling(abc.ABC, TestCase): @classmethod def setUpTestData(cls): cls.project = prepare_project( - SEQUENCE_LABELING, allow_overlapping=cls.overlapping, collaborative_annotation=cls.collaborative + ProjectType.SEQUENCE_LABELING, allow_overlapping=cls.overlapping, collaborative_annotation=cls.collaborative ) cls.example = mommy.make("Example", project=cls.project.item) cls.label_type = mommy.make("SpanType", project=cls.project.item) @@ -136,7 +136,7 @@ class TestCollaborativeOverlappingSpanLabeling(TestSpanLabeling): class TestSpan(TestCase): def setUp(self): - self.project = prepare_project(SEQUENCE_LABELING, allow_overlapping=False) + self.project = prepare_project(ProjectType.SEQUENCE_LABELING, allow_overlapping=False) self.example = mommy.make("Example", project=self.project.item) self.user = self.project.admin @@ -167,7 +167,7 @@ class TestSpan(TestCase): ) def test_unique_constraint_if_overlapping_is_allowed(self): - project = prepare_project(SEQUENCE_LABELING, allow_overlapping=True) + project = prepare_project(ProjectType.SEQUENCE_LABELING, allow_overlapping=True) example = mommy.make("Example", project=project.item) user = project.admin mommy.make("Span", example=example, start_offset=5, end_offset=10, user=user) @@ -183,7 +183,7 @@ class TestSpan(TestCase): class TestSpanWithoutCollaborativeMode(TestCase): def setUp(self): - self.project = prepare_project(SEQUENCE_LABELING, False, allow_overlapping=False) + self.project = prepare_project(ProjectType.SEQUENCE_LABELING, False, allow_overlapping=False) self.example = mommy.make("Example", project=self.project.item) def test_allow_users_to_create_same_spans(self): @@ -193,14 +193,14 @@ class TestSpanWithoutCollaborativeMode(TestCase): class TestSpanWithCollaborativeMode(TestCase): def test_deny_users_to_create_same_spans(self): - project = prepare_project(SEQUENCE_LABELING, True, allow_overlapping=False) + project = prepare_project(ProjectType.SEQUENCE_LABELING, True, allow_overlapping=False) example = mommy.make("Example", project=project.item) mommy.make("Span", example=example, start_offset=5, end_offset=10, user=project.admin) with self.assertRaises(ValidationError): mommy.make("Span", example=example, start_offset=5, end_offset=10, user=project.approver) def test_allow_users_to_create_same_spans_if_overlapping_is_allowed(self): - project = prepare_project(SEQUENCE_LABELING, True, allow_overlapping=True) + project = prepare_project(ProjectType.SEQUENCE_LABELING, True, allow_overlapping=True) example = mommy.make("Example", project=project.item) mommy.make("Span", example=example, start_offset=5, end_offset=10, user=project.admin) mommy.make("Span", example=example, start_offset=5, end_offset=10, user=project.approver) @@ -208,7 +208,7 @@ class TestSpanWithCollaborativeMode(TestCase): class TestLabelDistribution(TestCase): def setUp(self): - self.project = prepare_project(SEQUENCE_LABELING, allow_overlapping=False) + self.project = prepare_project(ProjectType.SEQUENCE_LABELING, allow_overlapping=False) self.example = mommy.make("Example", project=self.project.item) self.user = self.project.admin diff --git a/backend/labels/tests/test_text_label.py b/backend/labels/tests/test_text_label.py index 096e5440..71cbddd6 100644 --- a/backend/labels/tests/test_text_label.py +++ b/backend/labels/tests/test_text_label.py @@ -5,7 +5,7 @@ from django.test import TestCase from model_mommy import mommy from labels.models import TextLabel -from projects.models import SEQ2SEQ +from projects.models import ProjectType from projects.tests.utils import prepare_project @@ -14,7 +14,7 @@ class TestTextLabeling(abc.ABC, TestCase): @classmethod def setUpTestData(cls): - cls.project = prepare_project(SEQ2SEQ, collaborative_annotation=cls.collaborative) + cls.project = prepare_project(ProjectType.SEQ2SEQ, collaborative_annotation=cls.collaborative) cls.example = mommy.make("Example", project=cls.project.item) cls.user = cls.project.admin cls.another_user = cls.project.approver diff --git a/backend/labels/tests/test_views.py b/backend/labels/tests/test_views.py index 98fe93c8..77ec1439 100644 --- a/backend/labels/tests/test_views.py +++ b/backend/labels/tests/test_views.py @@ -9,20 +9,14 @@ from api.tests.utils import CRUDMixin from examples.tests.utils import make_doc from label_types.tests.utils import make_label from labels.models import BoundingBox, Category, Segmentation, Span, TextLabel -from projects.models import ( - BOUNDING_BOX, - DOCUMENT_CLASSIFICATION, - SEGMENTATION, - SEQ2SEQ, - SEQUENCE_LABELING, -) +from projects.models import ProjectType from projects.tests.utils import prepare_project from users.tests.utils import make_user class TestLabelList: model = Category - task = DOCUMENT_CLASSIFICATION + task = ProjectType.DOCUMENT_CLASSIFICATION view_name = "annotation_list" @classmethod @@ -57,13 +51,13 @@ class TestLabelList: class TestCategoryList(TestLabelList, CRUDMixin): model = Category - task = DOCUMENT_CLASSIFICATION + task = ProjectType.DOCUMENT_CLASSIFICATION view_name = "category_list" class TestSpanList(TestLabelList, CRUDMixin): model = Span - task = SEQUENCE_LABELING + task = ProjectType.SEQUENCE_LABELING view_name = "span_list" @classmethod @@ -73,7 +67,7 @@ class TestSpanList(TestLabelList, CRUDMixin): class TestBBoxList(TestLabelList, CRUDMixin): model = BoundingBox - task = BOUNDING_BOX + task = ProjectType.BOUNDING_BOX view_name = "bbox_list" @classmethod @@ -83,7 +77,7 @@ class TestBBoxList(TestLabelList, CRUDMixin): class TestSegmentationList(TestLabelList, CRUDMixin): model = Segmentation - task = SEGMENTATION + task = ProjectType.SEGMENTATION view_name = "segmentation_list" @classmethod @@ -93,13 +87,13 @@ class TestSegmentationList(TestLabelList, CRUDMixin): class TestTextList(TestLabelList, CRUDMixin): model = TextLabel - task = SEQ2SEQ + task = ProjectType.SEQ2SEQ view_name = "text_list" class TestSharedLabelList: model = Category - task = DOCUMENT_CLASSIFICATION + task = ProjectType.DOCUMENT_CLASSIFICATION view_name = "annotation_list" @classmethod @@ -127,13 +121,13 @@ class TestSharedLabelList: class TestSharedCategoryList(TestSharedLabelList, CRUDMixin): model = Category - task = DOCUMENT_CLASSIFICATION + task = ProjectType.DOCUMENT_CLASSIFICATION view_name = "category_list" class TestSharedSpanList(TestSharedLabelList, CRUDMixin): model = Span - task = SEQUENCE_LABELING + task = ProjectType.SEQUENCE_LABELING view_name = "span_list" start_offset = 0 @@ -145,12 +139,12 @@ class TestSharedSpanList(TestSharedLabelList, CRUDMixin): class TestSharedTextList(TestSharedLabelList, CRUDMixin): model = TextLabel - task = SEQ2SEQ + task = ProjectType.SEQ2SEQ view_name = "text_list" class TestDataLabeling: - task = DOCUMENT_CLASSIFICATION + task = ProjectType.DOCUMENT_CLASSIFICATION view_name = "annotation_list" def setUp(self): @@ -180,7 +174,7 @@ class TestCategoryCreation(TestDataLabeling, CRUDMixin): class TestSpanCreation(TestDataLabeling, CRUDMixin): - task = SEQUENCE_LABELING + task = ProjectType.SEQUENCE_LABELING view_name = "span_list" def create_data(self): @@ -189,7 +183,7 @@ class TestSpanCreation(TestDataLabeling, CRUDMixin): class TestRelationCreation(TestDataLabeling, CRUDMixin): - task = SEQUENCE_LABELING + task = ProjectType.SEQUENCE_LABELING view_name = "relation_list" def create_data(self): @@ -200,7 +194,7 @@ class TestRelationCreation(TestDataLabeling, CRUDMixin): class TestTextLabelCreation(TestDataLabeling, CRUDMixin): - task = SEQ2SEQ + task = ProjectType.SEQ2SEQ view_name = "text_list" def create_data(self): @@ -208,7 +202,7 @@ class TestTextLabelCreation(TestDataLabeling, CRUDMixin): class TestBoundingBoxCreation(TestDataLabeling, CRUDMixin): - task = BOUNDING_BOX + task = ProjectType.BOUNDING_BOX view_name = "bbox_list" def create_data(self): @@ -222,7 +216,7 @@ class TestBoundingBoxCreation(TestDataLabeling, CRUDMixin): class TestSegmentationCreation(TestDataLabeling, CRUDMixin): - task = SEGMENTATION + task = ProjectType.SEGMENTATION view_name = "segmentation_list" def create_data(self): @@ -236,7 +230,7 @@ class TestSegmentationCreation(TestDataLabeling, CRUDMixin): class TestLabelDetail: - task = SEQUENCE_LABELING + task = ProjectType.SEQUENCE_LABELING view_name = "annotation_detail" def setUp(self): @@ -286,7 +280,7 @@ class TestLabelDetail: class TestCategoryDetail(TestLabelDetail, CRUDMixin): - task = DOCUMENT_CLASSIFICATION + task = ProjectType.DOCUMENT_CLASSIFICATION view_name = "category_detail" def create_annotation_data(self, doc): @@ -294,12 +288,12 @@ class TestCategoryDetail(TestLabelDetail, CRUDMixin): class TestSpanDetail(TestLabelDetail, CRUDMixin): - task = SEQUENCE_LABELING + task = ProjectType.SEQUENCE_LABELING view_name = "span_detail" class TestTextDetail(TestLabelDetail, CRUDMixin): - task = SEQ2SEQ + task = ProjectType.SEQ2SEQ view_name = "text_detail" def setUp(self): @@ -311,7 +305,7 @@ class TestTextDetail(TestLabelDetail, CRUDMixin): class TestBBoxDetail(TestLabelDetail, CRUDMixin): - task = BOUNDING_BOX + task = ProjectType.BOUNDING_BOX view_name = "bbox_detail" def create_annotation_data(self, doc): @@ -319,7 +313,7 @@ class TestBBoxDetail(TestLabelDetail, CRUDMixin): class TestSegmentationDetail(TestLabelDetail, CRUDMixin): - task = SEGMENTATION + task = ProjectType.SEGMENTATION view_name = "segmentation_detail" def create_annotation_data(self, doc): @@ -327,7 +321,7 @@ class TestSegmentationDetail(TestLabelDetail, CRUDMixin): class TestSharedLabelDetail: - task = DOCUMENT_CLASSIFICATION + task = ProjectType.DOCUMENT_CLASSIFICATION view_name = "annotation_detail" def setUp(self): @@ -358,7 +352,7 @@ class TestSharedCategoryDetail(TestSharedLabelDetail, CRUDMixin): class TestSharedSpanDetail(TestSharedLabelDetail, CRUDMixin): - task = SEQUENCE_LABELING + task = ProjectType.SEQUENCE_LABELING view_name = "span_detail" def make_annotation(self, doc, member): @@ -366,7 +360,7 @@ class TestSharedSpanDetail(TestSharedLabelDetail, CRUDMixin): class TestSharedTextDetail(TestSharedLabelDetail, CRUDMixin): - task = SEQ2SEQ + task = ProjectType.SEQ2SEQ view_name = "text_detail" def setUp(self): diff --git a/backend/labels/tests/utils.py b/backend/labels/tests/utils.py index 00908a6d..2bfe5301 100644 --- a/backend/labels/tests/utils.py +++ b/backend/labels/tests/utils.py @@ -1,18 +1,13 @@ from model_mommy import mommy -from projects.models import ( - DOCUMENT_CLASSIFICATION, - SEQ2SEQ, - SEQUENCE_LABELING, - SPEECH2TEXT, -) +from projects.models import ProjectType def make_annotation(task, doc, user, **kwargs): annotation_model = { - DOCUMENT_CLASSIFICATION: "Category", - SEQUENCE_LABELING: "Span", - SEQ2SEQ: "TextLabel", - SPEECH2TEXT: "TextLabel", + ProjectType.DOCUMENT_CLASSIFICATION: "Category", + ProjectType.SEQUENCE_LABELING: "Span", + ProjectType.SEQ2SEQ: "TextLabel", + ProjectType.SPEECH2TEXT: "TextLabel", }.get(task) return mommy.make(annotation_model, example=doc, user=user, **kwargs) diff --git a/backend/metrics/tests.py b/backend/metrics/tests.py index eadad24b..d64777a0 100644 --- a/backend/metrics/tests.py +++ b/backend/metrics/tests.py @@ -5,13 +5,13 @@ from rest_framework.reverse import reverse from api.tests.utils import CRUDMixin from examples.tests.utils import make_doc from label_types.tests.utils import make_label -from projects.models import DOCUMENT_CLASSIFICATION +from projects.models import ProjectType from projects.tests.utils import prepare_project class TestMemberProgress(CRUDMixin): def setUp(self): - self.project = prepare_project(DOCUMENT_CLASSIFICATION) + self.project = prepare_project(ProjectType.DOCUMENT_CLASSIFICATION) self.example = make_doc(self.project.item) self.url = reverse(viewname="member_progress", args=[self.project.item.id]) @@ -32,7 +32,9 @@ class TestProgressHelper(CRUDMixin): collaborative_annotation = False def setUp(self): - self.project = prepare_project(DOCUMENT_CLASSIFICATION, collaborative_annotation=self.collaborative_annotation) + self.project = prepare_project( + ProjectType.DOCUMENT_CLASSIFICATION, collaborative_annotation=self.collaborative_annotation + ) self.example = make_doc(self.project.item) mommy.make("ExampleState", example=self.example, confirmed_by=self.project.admin) self.url = reverse(viewname="progress", args=[self.project.item.id]) @@ -65,7 +67,7 @@ class TestProgressOnCollaborativeAnnotation(TestProgressHelper): class TestCategoryDistribution(CRUDMixin): def setUp(self): - self.project = prepare_project(DOCUMENT_CLASSIFICATION) + self.project = prepare_project(ProjectType.DOCUMENT_CLASSIFICATION) self.example = make_doc(self.project.item) self.label = make_label(self.project.item, text="label") mommy.make("Category", example=self.example, label=self.label, user=self.project.admin) diff --git a/backend/projects/models.py b/backend/projects/models.py index 607d2354..ed5d4218 100644 --- a/backend/projects/models.py +++ b/backend/projects/models.py @@ -10,26 +10,17 @@ from polymorphic.models import PolymorphicModel from roles.models import Role -DOCUMENT_CLASSIFICATION = "DocumentClassification" -SEQUENCE_LABELING = "SequenceLabeling" -SEQ2SEQ = "Seq2seq" -SPEECH2TEXT = "Speech2text" -IMAGE_CLASSIFICATION = "ImageClassification" -BOUNDING_BOX = "BoundingBox" -SEGMENTATION = "Segmentation" -IMAGE_CAPTIONING = "ImageCaptioning" -INTENT_DETECTION_AND_SLOT_FILLING = "IntentDetectionAndSlotFilling" -PROJECT_CHOICES = ( - (DOCUMENT_CLASSIFICATION, "document classification"), - (SEQUENCE_LABELING, "sequence labeling"), - (SEQ2SEQ, "sequence to sequence"), - (INTENT_DETECTION_AND_SLOT_FILLING, "intent detection and slot filling"), - (SPEECH2TEXT, "speech to text"), - (IMAGE_CLASSIFICATION, "image classification"), - (BOUNDING_BOX, "bounding box"), - (SEGMENTATION, "segmentation"), - (IMAGE_CAPTIONING, "image captioning"), -) + +class ProjectType(models.TextChoices): + DOCUMENT_CLASSIFICATION = "DocumentClassification" + SEQUENCE_LABELING = "SequenceLabeling" + SEQ2SEQ = "Seq2seq" + INTENT_DETECTION_AND_SLOT_FILLING = "IntentDetectionAndSlotFilling" + SPEECH2TEXT = "Speech2text" + IMAGE_CLASSIFICATION = "ImageClassification" + BOUNDING_BOX = "BoundingBox" + SEGMENTATION = "Segmentation" + IMAGE_CAPTIONING = "ImageCaptioning" class Project(PolymorphicModel): @@ -43,7 +34,7 @@ class Project(PolymorphicModel): on_delete=models.SET_NULL, null=True, ) - project_type = models.CharField(max_length=30, choices=PROJECT_CHOICES) + project_type = models.CharField(max_length=30, choices=ProjectType.choices) random_order = models.BooleanField(default=False) collaborative_annotation = models.BooleanField(default=False) single_class_classification = models.BooleanField(default=False) diff --git a/backend/projects/tests/test_project.py b/backend/projects/tests/test_project.py index dc68e193..0c57dde9 100644 --- a/backend/projects/tests/test_project.py +++ b/backend/projects/tests/test_project.py @@ -6,7 +6,7 @@ from rest_framework.reverse import reverse from api.tests.utils import CRUDMixin from examples.tests.utils import make_doc from label_types.tests.utils import make_label -from projects.models import DOCUMENT_CLASSIFICATION, Member, Project +from projects.models import Member, Project, ProjectType from projects.tests.utils import prepare_project from roles.tests.utils import create_default_roles from users.tests.utils import make_user @@ -141,12 +141,9 @@ class TestProjectModel(TestCase): class TestCloneProject(CRUDMixin): - task = DOCUMENT_CLASSIFICATION - view_name = "annotation_list" - @classmethod def setUpTestData(cls): - project = prepare_project(task=DOCUMENT_CLASSIFICATION) + project = prepare_project(task=ProjectType.DOCUMENT_CLASSIFICATION) cls.project = project.item cls.user = project.admin make_doc(cls.project) diff --git a/backend/projects/tests/utils.py b/backend/projects/tests/utils.py index f3a30b20..fe55bad3 100644 --- a/backend/projects/tests/utils.py +++ b/backend/projects/tests/utils.py @@ -3,19 +3,7 @@ from typing import List from django.conf import settings from model_mommy import mommy -from projects.models import ( - BOUNDING_BOX, - DOCUMENT_CLASSIFICATION, - IMAGE_CAPTIONING, - IMAGE_CLASSIFICATION, - INTENT_DETECTION_AND_SLOT_FILLING, - SEGMENTATION, - SEQ2SEQ, - SEQUENCE_LABELING, - SPEECH2TEXT, - Member, - Role, -) +from projects.models import Member, ProjectType, Role from roles.tests.utils import create_default_roles from users.tests.utils import make_user @@ -65,15 +53,15 @@ def make_project(task: str, users: List[str], roles: List[str], collaborative_an # create a project. project_model = { - DOCUMENT_CLASSIFICATION: "TextClassificationProject", - SEQUENCE_LABELING: "SequenceLabelingProject", - SEQ2SEQ: "Seq2seqProject", - SPEECH2TEXT: "Speech2TextProject", - IMAGE_CLASSIFICATION: "ImageClassificationProject", - INTENT_DETECTION_AND_SLOT_FILLING: "IntentDetectionAndSlotFillingProject", - BOUNDING_BOX: "BoundingBoxProject", - SEGMENTATION: "SegmentationProject", - IMAGE_CAPTIONING: "ImageCaptioningProject", + ProjectType.DOCUMENT_CLASSIFICATION: "TextClassificationProject", + ProjectType.SEQUENCE_LABELING: "SequenceLabelingProject", + ProjectType.SEQ2SEQ: "Seq2seqProject", + ProjectType.SPEECH2TEXT: "Speech2TextProject", + ProjectType.IMAGE_CLASSIFICATION: "ImageClassificationProject", + ProjectType.INTENT_DETECTION_AND_SLOT_FILLING: "IntentDetectionAndSlotFillingProject", + ProjectType.BOUNDING_BOX: "BoundingBoxProject", + ProjectType.SEGMENTATION: "SegmentationProject", + ProjectType.IMAGE_CAPTIONING: "ImageCaptioningProject", }.get(task, "Project") project = mommy.make( _model=project_model,