You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

136 lines
3.3 KiB

  1. import abc
  2. from typing import Any, Dict, Optional, Union
  3. from pydantic import BaseModel, validator
  4. from api.models import Project
  5. from label_types.models import LabelType, CategoryType, SpanType
  6. from labels.models import Category, Span, TextLabel as TL
  7. class Label(BaseModel, abc.ABC):
  8. @abc.abstractmethod
  9. def has_name(self) -> bool:
  10. raise NotImplementedError()
  11. @property
  12. @abc.abstractmethod
  13. def name(self) -> str:
  14. raise NotImplementedError()
  15. @classmethod
  16. def parse(cls, obj: Any):
  17. raise NotImplementedError()
  18. @abc.abstractmethod
  19. def create(self, project: Project) -> Optional[LabelType]:
  20. raise NotImplementedError()
  21. @abc.abstractmethod
  22. def create_annotation(self, user, example, mapping):
  23. raise NotImplementedError
  24. def __hash__(self):
  25. return hash(tuple(self.dict()))
  26. class CategoryLabel(Label):
  27. label: str
  28. @validator('label')
  29. def label_is_not_empty(cls, value: str):
  30. if value:
  31. return value
  32. else:
  33. raise ValueError('is not empty.')
  34. def has_name(self) -> bool:
  35. return True
  36. @property
  37. def name(self) -> str:
  38. return self.label
  39. @classmethod
  40. def parse(cls, obj: Any):
  41. if isinstance(obj, str):
  42. return cls(label=obj)
  43. elif isinstance(obj, int):
  44. return cls(label=str(obj))
  45. else:
  46. raise TypeError(f'{obj} is not str.')
  47. def create(self, project: Project) -> Optional[LabelType]:
  48. return CategoryType(text=self.label, project=project)
  49. def create_annotation(self, user, example, mapping: Dict[str, LabelType]):
  50. return Category(
  51. user=user,
  52. example=example,
  53. label=mapping[self.label]
  54. )
  55. class SpanLabel(Label):
  56. label: Union[str, int]
  57. start_offset: int
  58. end_offset: int
  59. def has_name(self) -> bool:
  60. return True
  61. @property
  62. def name(self) -> str:
  63. return self.label
  64. @classmethod
  65. def parse(cls, obj: Any):
  66. if isinstance(obj, list) or isinstance(obj, tuple):
  67. columns = ['start_offset', 'end_offset', 'label']
  68. obj = zip(columns, obj)
  69. return cls.parse_obj(obj)
  70. elif isinstance(obj, dict):
  71. return cls.parse_obj(obj)
  72. else:
  73. raise TypeError(f'{obj} is invalid type.')
  74. def create(self, project: Project) -> Optional[LabelType]:
  75. return SpanType(text=self.label, project=project)
  76. def create_annotation(self, user, example, mapping: Dict[str, LabelType]):
  77. return Span(
  78. user=user,
  79. example=example,
  80. start_offset=self.start_offset,
  81. end_offset=self.end_offset,
  82. label=mapping[self.label]
  83. )
  84. class TextLabel(Label):
  85. text: str
  86. def has_name(self) -> bool:
  87. return False
  88. @property
  89. def name(self) -> str:
  90. return self.text
  91. @classmethod
  92. def parse(cls, obj: Any):
  93. if isinstance(obj, str) and obj:
  94. return cls(text=obj)
  95. else:
  96. raise TypeError(f'{obj} is not str or empty.')
  97. def create(self, project: Project) -> Optional[LabelType]:
  98. return None
  99. def create_annotation(self, user, example, mapping):
  100. return TL(
  101. user=user,
  102. example=example,
  103. text=self.text
  104. )