diff --git a/backend/data_import/pipeline/labels.py b/backend/data_import/pipeline/labels.py index 1009e26f..5ffab4df 100644 --- a/backend/data_import/pipeline/labels.py +++ b/backend/data_import/pipeline/labels.py @@ -3,9 +3,11 @@ from typing import Any, Dict, Optional from pydantic import BaseModel, validator -from label_types.models import CategoryType, LabelType, SpanType -from labels.models import Category, Span -from labels.models import TextLabel as TL, Label as LabelModel +from label_types.models import CategoryType, LabelType, RelationType, SpanType +from labels.models import Category +from labels.models import Label as LabelModel +from labels.models import Relation, Span +from labels.models import TextLabel as TL from projects.models import Project @@ -28,7 +30,7 @@ class Label(BaseModel, abc.ABC): raise NotImplementedError() @abc.abstractmethod - def create(self, user, example, mapping) -> LabelModel: + def create(self, user, example, mapping, **kwargs) -> LabelModel: raise NotImplementedError def __hash__(self): @@ -64,11 +66,12 @@ class CategoryLabel(Label): def create_type(self, project: Project) -> Optional[LabelType]: return CategoryType(text=self.label, project=project) - def create(self, user, example, mapping: Dict[str, LabelType]): + def create(self, user, example, mapping: Dict[str, LabelType], **kwargs): return Category(user=user, example=example, label=mapping[self.label]) class SpanLabel(Label): + id: int = -1 label: str start_offset: int end_offset: int @@ -94,7 +97,7 @@ class SpanLabel(Label): def create_type(self, project: Project) -> Optional[LabelType]: return SpanType(text=self.label, project=project) - def create(self, user, example, mapping: Dict[str, LabelType]): + def create(self, user, example, mapping: Dict[str, LabelType], **kwargs): return Span( user=user, example=example, @@ -124,5 +127,37 @@ class TextLabel(Label): def create_type(self, project: Project) -> Optional[LabelType]: return None - def create(self, user, example, mapping): + def create(self, user, example, mapping, **kwargs): return TL(user=user, example=example, text=self.text) + + +class RelationLabel(Label): + from_id: int + to_id: int + type: str + + def has_name(self) -> bool: + return True + + @property + def name(self) -> str: + return self.type + + @classmethod + def parse(cls, obj: Any): + if isinstance(obj, dict): + return cls.parse_obj(obj) + else: + raise TypeError(f"{obj} is not dict.") + + def create_type(self, project: Project) -> Optional[LabelType]: + return RelationType(text=self.type, project=project) + + def create(self, user, example, mapping: Dict[str, LabelType], **kwargs): + return Relation( + user=user, + example=example, + type=mapping[self.type], + from_id=kwargs["span_mapping"][self.from_id], + to_id=kwargs["span_mapping"][self.to_id], + ) diff --git a/backend/data_import/tests/test_builder.py b/backend/data_import/tests/test_builder.py index 772c1daf..acba65ab 100644 --- a/backend/data_import/tests/test_builder.py +++ b/backend/data_import/tests/test_builder.py @@ -69,5 +69,8 @@ class TestColumnBuilder(unittest.TestCase): data_column = builders.DataColumn("text", TextData) label_columns = [builders.LabelColumn("cats", CategoryLabel), builders.LabelColumn("entities", SpanLabel)] actual = self.create_record(row, data_column, label_columns) - expected = {"data": "Text", "label": [{"label": "Label"}, {"label": "LOC", "start_offset": 0, "end_offset": 1}]} + expected = { + "data": "Text", + "label": [{"label": "Label"}, {"id": -1, "label": "LOC", "start_offset": 0, "end_offset": 1}], + } self.assert_record(actual, expected)