From d0fea9c7e5f07abb5b5c0197a83c4efc4986afef Mon Sep 17 00:00:00 2001 From: Hironsan Date: Wed, 7 Apr 2021 08:10:59 +0900 Subject: [PATCH] Add replace method to label --- app/api/views/upload/label.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/app/api/views/upload/label.py b/app/api/views/upload/label.py index 19f4c259..9beec021 100644 --- a/app/api/views/upload/label.py +++ b/app/api/views/upload/label.py @@ -1,5 +1,5 @@ import abc -from typing import Any +from typing import Any, Dict from pydantic import BaseModel @@ -19,6 +19,10 @@ class Label(BaseModel, abc.ABC): def parse(cls, obj: Any): raise NotImplementedError() + @abc.abstractmethod + def replace(self, mapping: Dict[str, int]) -> 'Label': + raise NotImplementedError + def __hash__(self): return hash(tuple(self.dict())) @@ -39,6 +43,10 @@ class CategoryLabel(Label): return cls(label=obj) raise TypeError(f'{obj} is not str.') + def replace(self, mapping: Dict[str, int]) -> 'Label': + label = mapping.get(self.label, self.label) + return CategoryLabel(label=label) + class OffsetLabel(Label): label: str @@ -63,6 +71,14 @@ class OffsetLabel(Label): else: raise TypeError(f'{obj} is invalid type.') + def replace(self, mapping: Dict[str, int]) -> 'Label': + label = mapping.get(self.label, self.label) + return OffsetLabel( + label=label, + start_offset=self.start_offset, + end_offset=self.end_offset + ) + class TextLabel(Label): text: str @@ -80,3 +96,6 @@ class TextLabel(Label): return cls(text=obj) else: raise TypeError(f'{obj} is not str.') + + def replace(self, mapping: Dict[str, str]) -> 'Label': + return self