From 703d6cb751e98a9dc652f2cfe991c2acce5838bc Mon Sep 17 00:00:00 2001 From: Hironsan Date: Sun, 22 May 2022 16:36:34 +0900 Subject: [PATCH] Replace get_by_text with __getitem__ --- backend/data_import/pipeline/label.py | 6 +++--- backend/data_import/pipeline/label_types.py | 6 +++--- backend/data_import/tests/test_label.py | 6 +++--- backend/data_import/tests/test_label_types.py | 4 +--- 4 files changed, 10 insertions(+), 12 deletions(-) diff --git a/backend/data_import/pipeline/label.py b/backend/data_import/pipeline/label.py index e38d3a04..bdff2d43 100644 --- a/backend/data_import/pipeline/label.py +++ b/backend/data_import/pipeline/label.py @@ -62,7 +62,7 @@ class CategoryLabel(Label): return CategoryType(text=self.label, project=project) def create(self, user, example: Example, types: LabelTypes, **kwargs): - return CategoryModel(uuid=self.uuid, user=user, example=example, label=types.get_by_text(self.label)) + return CategoryModel(uuid=self.uuid, user=user, example=example, label=types[self.label]) class SpanLabel(Label): @@ -100,7 +100,7 @@ class SpanLabel(Label): example=example, start_offset=self.start_offset, end_offset=self.end_offset, - label=types.get_by_text(self.label), + label=types[self.label], ) @@ -141,7 +141,7 @@ class RelationLabel(Label): uuid=self.uuid, user=user, example=example, - type=types.get_by_text(self.type), + type=types[self.type], from_id=kwargs["id_to_span"][self.from_id], to_id=kwargs["id_to_span"][self.to_id], ) diff --git a/backend/data_import/pipeline/label_types.py b/backend/data_import/pipeline/label_types.py index a83f1ef5..037df5ba 100644 --- a/backend/data_import/pipeline/label_types.py +++ b/backend/data_import/pipeline/label_types.py @@ -12,12 +12,12 @@ class LabelTypes: def __contains__(self, text: str) -> bool: return text in self.types + def __getitem__(self, text: str) -> LabelType: + return self.types[text] + def save(self, label_types: List[LabelType]): self.label_type_class.objects.bulk_create(label_types, ignore_conflicts=True) def update(self, project: Project): types = self.label_type_class.objects.filter(project=project) self.types = {label_type.text: label_type for label_type in types} - - def get_by_text(self, text: str) -> LabelType: - return self.types[text] diff --git a/backend/data_import/tests/test_label.py b/backend/data_import/tests/test_label.py index d0fe6923..da79ca49 100644 --- a/backend/data_import/tests/test_label.py +++ b/backend/data_import/tests/test_label.py @@ -55,7 +55,7 @@ class TestCategoryLabel(TestLabel): def test_create(self): category = CategoryLabel(label="A", example_uuid=uuid.uuid4()) types = MagicMock() - types.get_by_text.return_value = mommy.make(CategoryType, project=self.project.item) + types.__getitem__.return_value = mommy.make(CategoryType, project=self.project.item) category_model = category.create(self.user, self.example, types) self.assertIsInstance(category_model, CategoryModel) @@ -104,7 +104,7 @@ class TestSpanLabel(TestLabel): def test_create(self): span = SpanLabel(label="A", start_offset=0, end_offset=1, example_uuid=uuid.uuid4()) types = MagicMock() - types.get_by_text.return_value = mommy.make(SpanType, project=self.project.item) + types.__getitem__.return_value = mommy.make(SpanType, project=self.project.item) span_model = span.create(self.user, self.example, types) self.assertIsInstance(span_model, SpanModel) @@ -168,7 +168,7 @@ class TestRelationLabel(TestLabel): def test_create(self): relation = RelationLabel(type="A", from_id=0, to_id=1, example_uuid=uuid.uuid4()) types = MagicMock() - types.get_by_text.return_value = mommy.make(RelationType, project=self.project.item) + types.__getitem__.return_value = mommy.make(RelationType, project=self.project.item) id_to_span = { 0: mommy.make(SpanModel, start_offset=0, end_offset=1), 1: mommy.make(SpanModel, start_offset=2, end_offset=3), diff --git a/backend/data_import/tests/test_label_types.py b/backend/data_import/tests/test_label_types.py index 099f08c6..d1195fdf 100644 --- a/backend/data_import/tests/test_label_types.py +++ b/backend/data_import/tests/test_label_types.py @@ -22,10 +22,8 @@ class TestCategoryLabel(TestCase): def test_update(self): label_types = LabelTypes(CategoryType) - with self.assertRaises(KeyError): - label_types.get_by_text("A") category_types = [CategoryType(text="A", project=self.project.item)] label_types.save(category_types) label_types.update(self.project.item) - category_type = label_types.get_by_text("A") + category_type = label_types["A"] self.assertEqual(category_type.text, "A")