You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

128 lines
3.3 KiB

2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
  1. import abc
  2. from typing import Any, Dict, Optional
  3. from pydantic import BaseModel, validator
  4. from label_types.models import CategoryType, LabelType, SpanType
  5. from labels.models import Category, Span
  6. from labels.models import TextLabel as TL
  7. from projects.models import Project
  8. class Label(BaseModel, abc.ABC):
  9. @abc.abstractmethod
  10. def has_name(self) -> bool:
  11. raise NotImplementedError()
  12. @property
  13. @abc.abstractmethod
  14. def name(self) -> str:
  15. raise NotImplementedError()
  16. @classmethod
  17. def parse(cls, obj: Any):
  18. raise NotImplementedError()
  19. @abc.abstractmethod
  20. def create(self, project: Project) -> Optional[LabelType]:
  21. raise NotImplementedError()
  22. @abc.abstractmethod
  23. def create_annotation(self, user, example, mapping):
  24. raise NotImplementedError
  25. def __hash__(self):
  26. return hash(tuple(self.dict()))
  27. class CategoryLabel(Label):
  28. label: str
  29. @validator("label")
  30. def label_is_not_empty(cls, value: str):
  31. if value:
  32. return value
  33. else:
  34. raise ValueError("is not empty.")
  35. def has_name(self) -> bool:
  36. return True
  37. @property
  38. def name(self) -> str:
  39. return self.label
  40. @classmethod
  41. def parse(cls, obj: Any):
  42. if isinstance(obj, str):
  43. return cls(label=obj)
  44. elif isinstance(obj, int):
  45. return cls(label=str(obj))
  46. else:
  47. raise TypeError(f"{obj} is not str.")
  48. def create(self, project: Project) -> Optional[LabelType]:
  49. return CategoryType(text=self.label, project=project)
  50. def create_annotation(self, user, example, mapping: Dict[str, LabelType]):
  51. return Category(user=user, example=example, label=mapping[self.label])
  52. class SpanLabel(Label):
  53. label: str
  54. start_offset: int
  55. end_offset: int
  56. def has_name(self) -> bool:
  57. return True
  58. @property
  59. def name(self) -> str:
  60. return self.label
  61. @classmethod
  62. def parse(cls, obj: Any):
  63. if isinstance(obj, list) or isinstance(obj, tuple):
  64. columns = ["start_offset", "end_offset", "label"]
  65. obj = zip(columns, obj)
  66. return cls.parse_obj(obj)
  67. elif isinstance(obj, dict):
  68. return cls.parse_obj(obj)
  69. else:
  70. raise TypeError(f"{obj} is invalid type.")
  71. def create(self, project: Project) -> Optional[LabelType]:
  72. return SpanType(text=self.label, project=project)
  73. def create_annotation(self, user, example, mapping: Dict[str, LabelType]):
  74. return Span(
  75. user=user,
  76. example=example,
  77. start_offset=self.start_offset,
  78. end_offset=self.end_offset,
  79. label=mapping[self.label],
  80. )
  81. class TextLabel(Label):
  82. text: str
  83. def has_name(self) -> bool:
  84. return False
  85. @property
  86. def name(self) -> str:
  87. return self.text
  88. @classmethod
  89. def parse(cls, obj: Any):
  90. if isinstance(obj, str) and obj:
  91. return cls(text=obj)
  92. else:
  93. raise TypeError(f"{obj} is not str or empty.")
  94. def create(self, project: Project) -> Optional[LabelType]:
  95. return None
  96. def create_annotation(self, user, example, mapping):
  97. return TL(user=user, example=example, text=self.text)