Browse Source

Simplify parse method in label

pull/1823/head
Hironsan 2 years ago
parent
commit
3bb844fa1e
2 changed files with 16 additions and 28 deletions
  1. 32
      backend/data_import/pipeline/label.py
  2. 12
      backend/data_import/tests/test_label.py

32
backend/data_import/pipeline/label.py

@ -2,7 +2,6 @@ import abc
import uuid
from typing import Any, Optional
import pydantic.error_wrappers
from pydantic import UUID4, BaseModel, validator
from .label_types import LabelTypes
@ -60,10 +59,7 @@ class CategoryLabel(Label):
@classmethod
def parse(cls, example_uuid: UUID4, obj: Any):
try:
return cls(example_uuid=example_uuid, label=obj)
except pydantic.error_wrappers.ValidationError:
return None
return cls(example_uuid=example_uuid, label=obj)
def create_type(self, project: Project) -> Optional[LabelType]:
return CategoryType(text=self.label, project=project)
@ -82,15 +78,13 @@ class SpanLabel(Label):
@classmethod
def parse(cls, example_uuid: UUID4, obj: Any):
try:
if isinstance(obj, list) or isinstance(obj, tuple):
columns = ["start_offset", "end_offset", "label"]
obj = zip(columns, obj)
return cls(example_uuid=example_uuid, **dict(obj))
elif isinstance(obj, dict):
return cls(example_uuid=example_uuid, **obj)
except pydantic.error_wrappers.ValidationError:
return None
if isinstance(obj, list) or isinstance(obj, tuple):
columns = ["start_offset", "end_offset", "label"]
obj = zip(columns, obj)
return cls(example_uuid=example_uuid, **dict(obj))
elif isinstance(obj, dict):
return cls(example_uuid=example_uuid, **obj)
raise ValueError("SpanLabel.parse()")
def create_type(self, project: Project) -> Optional[LabelType]:
return SpanType(text=self.label, project=project)
@ -114,10 +108,7 @@ class TextLabel(Label):
@classmethod
def parse(cls, example_uuid: UUID4, obj: Any):
try:
return cls(example_uuid=example_uuid, text=obj)
except pydantic.error_wrappers.ValidationError:
return None
return cls(example_uuid=example_uuid, text=obj)
def create_type(self, project: Project) -> Optional[LabelType]:
return None
@ -136,10 +127,7 @@ class RelationLabel(Label):
@classmethod
def parse(cls, example_uuid: UUID4, obj: Any):
try:
return cls(example_uuid=example_uuid, **obj)
except pydantic.error_wrappers.ValidationError:
return None
return cls(example_uuid=example_uuid, **obj)
def create_type(self, project: Project) -> Optional[LabelType]:
return RelationType(text=self.type, project=project)

12
backend/data_import/tests/test_label.py

@ -84,8 +84,8 @@ class TestSpanLabel(TestLabel):
def test_parse_invalid_dict(self):
example_uuid = uuid.uuid4()
span = SpanLabel.parse(example_uuid, obj={"label": "A", "start_offset": 0})
self.assertEqual(span, None)
with self.assertRaises(ValueError):
SpanLabel.parse(example_uuid, obj={"label": "A", "start_offset": 0})
def test_create_type(self):
span = SpanLabel(label="A", start_offset=0, end_offset=1, example_uuid=uuid.uuid4())
@ -116,8 +116,8 @@ class TestTextLabel(TestLabel):
def test_parse_invalid_data(self):
example_uuid = uuid.uuid4()
text = TextLabel.parse(example_uuid, obj=[])
self.assertEqual(text, None)
with self.assertRaises(ValueError):
TextLabel.parse(example_uuid, obj=[])
def test_create_type(self):
text = TextLabel(text="A", example_uuid=uuid.uuid4())
@ -148,8 +148,8 @@ class TestRelationLabel(TestLabel):
def test_parse_invalid_data(self):
example_uuid = uuid.uuid4()
relation = RelationLabel.parse(example_uuid, obj={"type": "A", "from_id": 0})
self.assertEqual(relation, None)
with self.assertRaises(ValueError):
RelationLabel.parse(example_uuid, obj={"type": "A", "from_id": 0})
def test_create_type(self):
relation = RelationLabel(type="A", from_id=0, to_id=1, example_uuid=uuid.uuid4())

Loading…
Cancel
Save