Browse Source

Replace get_by_text with __getitem__

pull/1832/head
Hironsan 2 years ago
parent
commit
703d6cb751
4 changed files with 10 additions and 12 deletions
  1. 6
      backend/data_import/pipeline/label.py
  2. 6
      backend/data_import/pipeline/label_types.py
  3. 6
      backend/data_import/tests/test_label.py
  4. 4
      backend/data_import/tests/test_label_types.py

6
backend/data_import/pipeline/label.py

@ -62,7 +62,7 @@ class CategoryLabel(Label):
return CategoryType(text=self.label, project=project)
def create(self, user, example: Example, types: LabelTypes, **kwargs):
return CategoryModel(uuid=self.uuid, user=user, example=example, label=types.get_by_text(self.label))
return CategoryModel(uuid=self.uuid, user=user, example=example, label=types[self.label])
class SpanLabel(Label):
@ -100,7 +100,7 @@ class SpanLabel(Label):
example=example,
start_offset=self.start_offset,
end_offset=self.end_offset,
label=types.get_by_text(self.label),
label=types[self.label],
)
@ -141,7 +141,7 @@ class RelationLabel(Label):
uuid=self.uuid,
user=user,
example=example,
type=types.get_by_text(self.type),
type=types[self.type],
from_id=kwargs["id_to_span"][self.from_id],
to_id=kwargs["id_to_span"][self.to_id],
)

6
backend/data_import/pipeline/label_types.py

@ -12,12 +12,12 @@ class LabelTypes:
def __contains__(self, text: str) -> bool:
return text in self.types
def __getitem__(self, text: str) -> LabelType:
return self.types[text]
def save(self, label_types: List[LabelType]):
self.label_type_class.objects.bulk_create(label_types, ignore_conflicts=True)
def update(self, project: Project):
types = self.label_type_class.objects.filter(project=project)
self.types = {label_type.text: label_type for label_type in types}
def get_by_text(self, text: str) -> LabelType:
return self.types[text]

6
backend/data_import/tests/test_label.py

@ -55,7 +55,7 @@ class TestCategoryLabel(TestLabel):
def test_create(self):
category = CategoryLabel(label="A", example_uuid=uuid.uuid4())
types = MagicMock()
types.get_by_text.return_value = mommy.make(CategoryType, project=self.project.item)
types.__getitem__.return_value = mommy.make(CategoryType, project=self.project.item)
category_model = category.create(self.user, self.example, types)
self.assertIsInstance(category_model, CategoryModel)
@ -104,7 +104,7 @@ class TestSpanLabel(TestLabel):
def test_create(self):
span = SpanLabel(label="A", start_offset=0, end_offset=1, example_uuid=uuid.uuid4())
types = MagicMock()
types.get_by_text.return_value = mommy.make(SpanType, project=self.project.item)
types.__getitem__.return_value = mommy.make(SpanType, project=self.project.item)
span_model = span.create(self.user, self.example, types)
self.assertIsInstance(span_model, SpanModel)
@ -168,7 +168,7 @@ class TestRelationLabel(TestLabel):
def test_create(self):
relation = RelationLabel(type="A", from_id=0, to_id=1, example_uuid=uuid.uuid4())
types = MagicMock()
types.get_by_text.return_value = mommy.make(RelationType, project=self.project.item)
types.__getitem__.return_value = mommy.make(RelationType, project=self.project.item)
id_to_span = {
0: mommy.make(SpanModel, start_offset=0, end_offset=1),
1: mommy.make(SpanModel, start_offset=2, end_offset=3),

4
backend/data_import/tests/test_label_types.py

@ -22,10 +22,8 @@ class TestCategoryLabel(TestCase):
def test_update(self):
label_types = LabelTypes(CategoryType)
with self.assertRaises(KeyError):
label_types.get_by_text("A")
category_types = [CategoryType(text="A", project=self.project.item)]
label_types.save(category_types)
label_types.update(self.project.item)
category_type = label_types.get_by_text("A")
category_type = label_types["A"]
self.assertEqual(category_type.text, "A")
Loading…
Cancel
Save