You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

460 lines
9.7 KiB

3 years ago
3 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
  1. from collections import defaultdict
  2. from dataclasses import dataclass
  3. from pathlib import Path
  4. from typing import Dict, List, Type
  5. from pydantic import BaseModel
  6. from typing_extensions import Literal
  7. from .exceptions import FileFormatException
  8. from projects.models import ProjectType
  9. # Define the example directories
  10. EXAMPLE_DIR = Path(__file__).parent.resolve() / "examples"
  11. TASK_AGNOSTIC_DIR = EXAMPLE_DIR / "task_agnostic"
  12. TEXT_CLASSIFICATION_DIR = EXAMPLE_DIR / "text_classification"
  13. SEQUENCE_LABELING_DIR = EXAMPLE_DIR / "sequence_labeling"
  14. RELATION_EXTRACTION_DIR = EXAMPLE_DIR / "relation_extraction"
  15. SEQ2SEQ_DIR = EXAMPLE_DIR / "sequence_to_sequence"
  16. INTENT_DETECTION_DIR = EXAMPLE_DIR / "intent_detection"
  17. IMAGE_CLASSIFICATION_DIR = EXAMPLE_DIR / "image_classification"
  18. SPEECH_TO_TEXT_DIR = EXAMPLE_DIR / "speech_to_text"
  19. # Define the task identifiers
  20. RELATION_EXTRACTION = "RelationExtraction"
  21. encodings = Literal[
  22. "Auto",
  23. "ascii",
  24. "big5",
  25. "big5hkscs",
  26. "cp037",
  27. "cp273",
  28. "cp424",
  29. "cp437",
  30. "cp500",
  31. "cp720",
  32. "cp737",
  33. "cp775",
  34. "cp850",
  35. "cp852",
  36. "cp855",
  37. "cp856",
  38. "cp857",
  39. "cp858",
  40. "cp860",
  41. "cp861",
  42. "cp862",
  43. "cp863",
  44. "cp864",
  45. "cp865",
  46. "cp866",
  47. "cp869",
  48. "cp874",
  49. "cp875",
  50. "cp932",
  51. "cp949",
  52. "cp950",
  53. "cp1006",
  54. "cp1026",
  55. "cp1125",
  56. "cp1140",
  57. "cp1250",
  58. "cp1251",
  59. "cp1252",
  60. "cp1253",
  61. "cp1254",
  62. "cp1255",
  63. "cp1256",
  64. "cp1257",
  65. "cp1258",
  66. "cp65001",
  67. "euc_jp",
  68. "euc_jis_2004",
  69. "euc_jisx0213",
  70. "euc_kr",
  71. "gb2312",
  72. "gbk",
  73. "gb18030",
  74. "hz",
  75. "iso2022_jp",
  76. "iso2022_jp_1",
  77. "iso2022_jp_2",
  78. "iso2022_jp_2004",
  79. "iso2022_jp_3",
  80. "iso2022_jp_ext",
  81. "iso2022_kr",
  82. "latin_1",
  83. "iso8859_2",
  84. "iso8859_3",
  85. "iso8859_4",
  86. "iso8859_5",
  87. "iso8859_6",
  88. "iso8859_7",
  89. "iso8859_8",
  90. "iso8859_9",
  91. "iso8859_10",
  92. "iso8859_11",
  93. "iso8859_13",
  94. "iso8859_14",
  95. "iso8859_15",
  96. "iso8859_16",
  97. "johab",
  98. "koi8_r",
  99. "koi8_t",
  100. "koi8_u",
  101. "kz1048",
  102. "mac_cyrillic",
  103. "mac_greek",
  104. "mac_iceland",
  105. "mac_latin2",
  106. "mac_roman",
  107. "mac_turkish",
  108. "ptcp154",
  109. "shift_jis",
  110. "shift_jis_2004",
  111. "shift_jisx0213",
  112. "utf_32",
  113. "utf_32_be",
  114. "utf_32_le",
  115. "utf_16",
  116. "utf_16_be",
  117. "utf_16_le",
  118. "utf_7",
  119. "utf_8",
  120. "utf_8_sig",
  121. ]
  122. class Format:
  123. name = ""
  124. accept_types = ""
  125. @classmethod
  126. def dict(cls):
  127. return {"name": cls.name, "accept_types": cls.accept_types}
  128. def validate_mime(self, mime: str):
  129. return True
  130. @staticmethod
  131. def is_plain_text():
  132. return False
  133. class CSV(Format):
  134. name = "CSV"
  135. accept_types = "text/csv"
  136. class FastText(Format):
  137. name = "fastText"
  138. accept_types = "text/plain"
  139. class JSON(Format):
  140. name = "JSON"
  141. accept_types = "application/json"
  142. class JSONL(Format):
  143. name = "JSONL"
  144. accept_types = "*"
  145. class Excel(Format):
  146. name = "Excel"
  147. accept_types = "application/vnd.ms-excel, application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
  148. class TextFile(Format):
  149. name = "TextFile"
  150. accept_types = "text/*"
  151. @staticmethod
  152. def is_plain_text():
  153. return True
  154. class TextLine(Format):
  155. name = "TextLine"
  156. accept_types = "text/*"
  157. @staticmethod
  158. def is_plain_text():
  159. return True
  160. class CoNLL(Format):
  161. name = "CoNLL"
  162. accept_types = "text/*"
  163. class ImageFile(Format):
  164. name = "ImageFile"
  165. accept_types = "image/png, image/jpeg, image/bmp, image/gif"
  166. def validate_mime(self, mime: str):
  167. return mime in self.accept_types
  168. class AudioFile(Format):
  169. name = "AudioFile"
  170. accept_types = "audio/ogg, audio/aac, audio/mpeg, audio/wav"
  171. def validate_mime(self, mime: str):
  172. return mime in self.accept_types
  173. class ArgColumn(BaseModel):
  174. encoding: encodings = "utf_8"
  175. column_data: str = "text"
  176. column_label: str = "label"
  177. class ArgDelimiter(ArgColumn):
  178. encoding: encodings = "utf_8"
  179. delimiter: Literal[",", "\t", ";", "|", " "] = ","
  180. class ArgEncoding(BaseModel):
  181. encoding: encodings = "utf_8"
  182. class ArgCoNLL(BaseModel):
  183. encoding: encodings = "utf_8"
  184. scheme: Literal["IOB2", "IOE2", "IOBES", "BILOU"] = "IOB2"
  185. delimiter: Literal[" ", ""] = " "
  186. class ArgNone(BaseModel):
  187. pass
  188. @dataclass
  189. class Option:
  190. display_name: str
  191. task_id: str
  192. file_format: Type[Format]
  193. arg: Type[BaseModel]
  194. file: Path
  195. @property
  196. def example(self) -> str:
  197. with open(self.file, "r", encoding="utf-8") as f:
  198. return f.read()
  199. def dict(self) -> Dict:
  200. return {
  201. **self.file_format.dict(),
  202. **self.arg.schema(),
  203. "example": self.example,
  204. "task_id": self.task_id,
  205. "display_name": self.display_name,
  206. }
  207. def create_file_format(file_format: str) -> Format:
  208. for format_class in Format.__subclasses__():
  209. if format_class.name == file_format:
  210. return format_class()
  211. raise FileFormatException(file_format)
  212. class Options:
  213. options: Dict[str, List] = defaultdict(list)
  214. @classmethod
  215. def filter_by_task(cls, task_name: str, use_relation: bool = False):
  216. options = cls.options[task_name]
  217. if use_relation:
  218. options = cls.options[task_name] + cls.options[RELATION_EXTRACTION]
  219. return [option.dict() for option in options]
  220. @classmethod
  221. def register(cls, option: Option):
  222. cls.options[option.task_id].append(option)
  223. # Text tasks
  224. text_tasks = [
  225. ProjectType.DOCUMENT_CLASSIFICATION,
  226. ProjectType.SEQUENCE_LABELING,
  227. ProjectType.SEQ2SEQ,
  228. ProjectType.INTENT_DETECTION_AND_SLOT_FILLING,
  229. ]
  230. for task_id in text_tasks:
  231. Options.register(
  232. Option(
  233. display_name=TextFile.name,
  234. task_id=task_id,
  235. file_format=TextFile,
  236. arg=ArgEncoding,
  237. file=TASK_AGNOSTIC_DIR / "text_files.txt",
  238. )
  239. )
  240. Options.register(
  241. Option(
  242. display_name=TextLine.name,
  243. task_id=task_id,
  244. file_format=TextLine,
  245. arg=ArgEncoding,
  246. file=TASK_AGNOSTIC_DIR / "text_lines.txt",
  247. )
  248. )
  249. # Text Classification
  250. Options.register(
  251. Option(
  252. display_name=CSV.name,
  253. task_id=ProjectType.DOCUMENT_CLASSIFICATION,
  254. file_format=CSV,
  255. arg=ArgDelimiter,
  256. file=TEXT_CLASSIFICATION_DIR / "example.csv",
  257. )
  258. )
  259. Options.register(
  260. Option(
  261. display_name=FastText.name,
  262. task_id=ProjectType.DOCUMENT_CLASSIFICATION,
  263. file_format=FastText,
  264. arg=ArgEncoding,
  265. file=TEXT_CLASSIFICATION_DIR / "example.txt",
  266. )
  267. )
  268. Options.register(
  269. Option(
  270. display_name=JSON.name,
  271. task_id=ProjectType.DOCUMENT_CLASSIFICATION,
  272. file_format=JSON,
  273. arg=ArgColumn,
  274. file=TEXT_CLASSIFICATION_DIR / "example.json",
  275. )
  276. )
  277. Options.register(
  278. Option(
  279. display_name=JSONL.name,
  280. task_id=ProjectType.DOCUMENT_CLASSIFICATION,
  281. file_format=JSONL,
  282. arg=ArgColumn,
  283. file=TEXT_CLASSIFICATION_DIR / "example.jsonl",
  284. )
  285. )
  286. Options.register(
  287. Option(
  288. display_name=Excel.name,
  289. task_id=ProjectType.DOCUMENT_CLASSIFICATION,
  290. file_format=Excel,
  291. arg=ArgColumn,
  292. file=TEXT_CLASSIFICATION_DIR / "example.csv",
  293. )
  294. )
  295. # Sequence Labelling
  296. Options.register(
  297. Option(
  298. display_name=JSONL.name,
  299. task_id=ProjectType.SEQUENCE_LABELING,
  300. file_format=JSONL,
  301. arg=ArgColumn,
  302. file=SEQUENCE_LABELING_DIR / "example.jsonl",
  303. )
  304. )
  305. Options.register(
  306. Option(
  307. display_name=CoNLL.name,
  308. task_id=ProjectType.SEQUENCE_LABELING,
  309. file_format=CoNLL,
  310. arg=ArgCoNLL,
  311. file=SEQUENCE_LABELING_DIR / "example.txt",
  312. )
  313. )
  314. # Relation Extraction
  315. Options.register(
  316. Option(
  317. display_name="JSONL(Relation)",
  318. task_id=RELATION_EXTRACTION,
  319. file_format=JSONL,
  320. arg=ArgNone,
  321. file=RELATION_EXTRACTION_DIR / "example.jsonl",
  322. )
  323. )
  324. # Seq2seq
  325. Options.register(
  326. Option(
  327. display_name=CSV.name,
  328. task_id=ProjectType.SEQ2SEQ,
  329. file_format=CSV,
  330. arg=ArgDelimiter,
  331. file=SEQ2SEQ_DIR / "example.csv",
  332. )
  333. )
  334. Options.register(
  335. Option(
  336. display_name=JSON.name,
  337. task_id=ProjectType.SEQ2SEQ,
  338. file_format=JSON,
  339. arg=ArgColumn,
  340. file=SEQ2SEQ_DIR / "example.json",
  341. )
  342. )
  343. Options.register(
  344. Option(
  345. display_name=JSONL.name,
  346. task_id=ProjectType.SEQ2SEQ,
  347. file_format=JSONL,
  348. arg=ArgColumn,
  349. file=SEQ2SEQ_DIR / "example.jsonl",
  350. )
  351. )
  352. Options.register(
  353. Option(
  354. display_name=Excel.name,
  355. task_id=ProjectType.SEQ2SEQ,
  356. file_format=Excel,
  357. arg=ArgColumn,
  358. file=SEQ2SEQ_DIR / "example.csv",
  359. )
  360. )
  361. # Intent detection
  362. Options.register(
  363. Option(
  364. display_name=JSONL.name,
  365. task_id=ProjectType.INTENT_DETECTION_AND_SLOT_FILLING,
  366. file_format=JSONL,
  367. arg=ArgNone,
  368. file=INTENT_DETECTION_DIR / "example.jsonl",
  369. )
  370. )
  371. # Image tasks
  372. image_tasks = [
  373. ProjectType.IMAGE_CLASSIFICATION,
  374. ProjectType.IMAGE_CAPTIONING,
  375. ProjectType.BOUNDING_BOX,
  376. ProjectType.SEGMENTATION,
  377. ]
  378. for task_name in image_tasks:
  379. Options.register(
  380. Option(
  381. display_name=ImageFile.name,
  382. task_id=task_name,
  383. file_format=ImageFile,
  384. arg=ArgNone,
  385. file=IMAGE_CLASSIFICATION_DIR / "image_files.txt",
  386. )
  387. )
  388. # Speech to Text
  389. Options.register(
  390. Option(
  391. display_name=AudioFile.name,
  392. task_id=ProjectType.SPEECH2TEXT,
  393. file_format=AudioFile,
  394. arg=ArgNone,
  395. file=SPEECH_TO_TEXT_DIR / "audio_files.txt",
  396. )
  397. )