diff --git a/backend/api/models.py b/backend/api/models.py index 58699dd4..61c12cd6 100644 --- a/backend/api/models.py +++ b/backend/api/models.py @@ -1,3 +1,4 @@ +import abc import random import string import uuid @@ -39,7 +40,9 @@ class Project(PolymorphicModel): collaborative_annotation = models.BooleanField(default=False) single_class_classification = models.BooleanField(default=False) - def is_task_of(self, task: Literal['text', 'image', 'speech']): + @property + @abc.abstractmethod + def is_text_project(self) -> bool: raise NotImplementedError() def __str__(self): @@ -48,40 +51,46 @@ class Project(PolymorphicModel): class TextClassificationProject(Project): - def is_task_of(self, task: Literal['text', 'image', 'speech']): - return task == 'text' + @property + def is_text_project(self) -> bool: + return True class SequenceLabelingProject(Project): allow_overlapping = models.BooleanField(default=False) grapheme_mode = models.BooleanField(default=False) - def is_task_of(self, task: Literal['text', 'image', 'speech']): - return task == 'text' + @property + def is_text_project(self) -> bool: + return True class Seq2seqProject(Project): - def is_task_of(self, task: Literal['text', 'image', 'speech']): - return task == 'text' + @property + def is_text_project(self) -> bool: + return True class IntentDetectionAndSlotFillingProject(Project): - def is_task_of(self, task: Literal['text', 'image', 'speech']): - return task == 'text' + @property + def is_text_project(self) -> bool: + return True class Speech2textProject(Project): - def is_task_of(self, task: Literal['text', 'image', 'speech']): - return task == 'speech' + @property + def is_text_project(self) -> bool: + return False class ImageClassificationProject(Project): - def is_task_of(self, task: Literal['text', 'image', 'speech']): - return task == 'image' + @property + def is_text_project(self) -> bool: + return False def generate_random_hex_color(): diff --git a/backend/api/views/auto_labeling.py b/backend/api/views/auto_labeling.py index 6f9bbc31..137c0e17 100644 --- a/backend/api/views/auto_labeling.py +++ b/backend/api/views/auto_labeling.py @@ -142,7 +142,7 @@ class AutoLabelingConfigParameterTest(APIView): def prepare_example(self): text = self.request.data['text'] - if self.project.is_task_of('text'): + if self.project.is_text_project: return text else: tu = TemporaryUpload.objects.get(upload_id=text) @@ -221,7 +221,7 @@ class AutoLabelingAnnotation(generics.CreateAPIView): def get_example(self, project): example = get_object_or_404(Example, pk=self.kwargs['example_id']) - if project.is_task_of('text'): + if project.is_text_project: return example.text else: return str(example.filename)