You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

101 lines
2.3 KiB

  1. import abc
  2. from typing import Any, Dict, Union
  3. from pydantic import BaseModel
  4. class Label(BaseModel, abc.ABC):
  5. @abc.abstractmethod
  6. def has_name(self) -> bool:
  7. raise NotImplementedError()
  8. @property
  9. @abc.abstractmethod
  10. def name(self) -> str:
  11. raise NotImplementedError()
  12. @classmethod
  13. def parse(cls, obj: Any):
  14. raise NotImplementedError()
  15. @abc.abstractmethod
  16. def replace(self, mapping: Dict[str, int]) -> 'Label':
  17. raise NotImplementedError
  18. def __hash__(self):
  19. return hash(tuple(self.dict()))
  20. class CategoryLabel(Label):
  21. label: Union[str, int]
  22. def has_name(self) -> bool:
  23. return True
  24. @property
  25. def name(self) -> str:
  26. return self.label
  27. @classmethod
  28. def parse(cls, obj: Any):
  29. if isinstance(obj, str):
  30. return cls(label=obj)
  31. raise TypeError(f'{obj} is not str.')
  32. def replace(self, mapping: Dict[str, int]) -> 'Label':
  33. label = mapping[self.label]
  34. return CategoryLabel(label=label)
  35. class OffsetLabel(Label):
  36. label: Union[str, int]
  37. start_offset: int
  38. end_offset: int
  39. def has_name(self) -> bool:
  40. return True
  41. @property
  42. def name(self) -> str:
  43. return self.label
  44. @classmethod
  45. def parse(cls, obj: Any):
  46. if isinstance(obj, list) or isinstance(obj, tuple):
  47. columns = ['start_offset', 'end_offset', 'label']
  48. obj = zip(columns, obj)
  49. return cls.parse_obj(obj)
  50. elif isinstance(obj, dict):
  51. return cls.parse_obj(obj)
  52. else:
  53. raise TypeError(f'{obj} is invalid type.')
  54. def replace(self, mapping: Dict[str, int]) -> 'Label':
  55. label = mapping[self.label]
  56. return OffsetLabel(
  57. label=label,
  58. start_offset=self.start_offset,
  59. end_offset=self.end_offset
  60. )
  61. class TextLabel(Label):
  62. text: str
  63. def has_name(self) -> bool:
  64. return False
  65. @property
  66. def name(self) -> str:
  67. return self.text
  68. @classmethod
  69. def parse(cls, obj: Any):
  70. if isinstance(obj, str):
  71. return cls(text=obj)
  72. else:
  73. raise TypeError(f'{obj} is not str.')
  74. def replace(self, mapping: Dict[str, str]) -> 'Label':
  75. return self