You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

101 lines
2.2 KiB

  1. import abc
  2. from typing import Any, Dict
  3. from pydantic import BaseModel
  4. class Label(BaseModel, abc.ABC):
  5. @abc.abstractmethod
  6. def has_name(self) -> bool:
  7. raise NotImplementedError()
  8. @property
  9. @abc.abstractmethod
  10. def name(self) -> str:
  11. raise NotImplementedError()
  12. @classmethod
  13. def parse(cls, obj: Any):
  14. raise NotImplementedError()
  15. @abc.abstractmethod
  16. def replace(self, mapping: Dict[str, int]) -> 'Label':
  17. raise NotImplementedError
  18. def __hash__(self):
  19. return hash(tuple(self.dict()))
  20. class CategoryLabel(Label):
  21. label: str
  22. def has_name(self) -> bool:
  23. return True
  24. @property
  25. def name(self) -> str:
  26. return self.label
  27. @classmethod
  28. def parse(cls, obj: Any):
  29. if isinstance(obj, str):
  30. return cls(label=obj)
  31. raise TypeError(f'{obj} is not str.')
  32. def replace(self, mapping: Dict[str, int]) -> 'Label':
  33. label = mapping.get(self.label, self.label)
  34. return CategoryLabel(label=label)
  35. class OffsetLabel(Label):
  36. label: str
  37. start_offset: int
  38. end_offset: int
  39. def has_name(self) -> bool:
  40. return True
  41. @property
  42. def name(self) -> str:
  43. return self.label
  44. @classmethod
  45. def parse(cls, obj: Any):
  46. if isinstance(obj, list):
  47. columns = ['label', 'start_offset', 'end_offset']
  48. obj = zip(columns, obj)
  49. return cls.parse_obj(obj)
  50. elif isinstance(obj, dict):
  51. return cls.parse_obj(obj)
  52. else:
  53. raise TypeError(f'{obj} is invalid type.')
  54. def replace(self, mapping: Dict[str, int]) -> 'Label':
  55. label = mapping.get(self.label, self.label)
  56. return OffsetLabel(
  57. label=label,
  58. start_offset=self.start_offset,
  59. end_offset=self.end_offset
  60. )
  61. class TextLabel(Label):
  62. text: str
  63. def has_name(self) -> bool:
  64. return False
  65. @property
  66. def name(self) -> str:
  67. return self.text
  68. @classmethod
  69. def parse(cls, obj: Any):
  70. if isinstance(obj, str):
  71. return cls(text=obj)
  72. else:
  73. raise TypeError(f'{obj} is not str.')
  74. def replace(self, mapping: Dict[str, str]) -> 'Label':
  75. return self