diff --git a/backend/data_export/pipeline/catalog.py b/backend/data_export/pipeline/catalog.py index 3eb74b43..8ec81bb3 100644 --- a/backend/data_export/pipeline/catalog.py +++ b/backend/data_export/pipeline/catalog.py @@ -69,11 +69,13 @@ class Options: @classmethod def filter_by_task(cls, task_name: str): options = cls.options[task_name] - return [{**format.dict(), **option.schema(), "example": example} for format, option, example in options] + return [ + {**file_format.dict(), **option.schema(), "example": example} for file_format, option, example in options + ] @classmethod - def register(cls, task: str, format: Type[Format], option: Type[BaseModel], example: str): - cls.options[task].append((format, option, example)) + def register(cls, task: str, file_format: Type[Format], option: Type[BaseModel], example: str): + cls.options[task].append((file_format, option, example)) # Text Classification diff --git a/backend/data_export/pipeline/data.py b/backend/data_export/pipeline/data.py index 937e1aac..1698c132 100644 --- a/backend/data_export/pipeline/data.py +++ b/backend/data_export/pipeline/data.py @@ -4,9 +4,9 @@ from typing import Any, Dict, List, Union class Record: def __init__( - self, id: int, data: str, label: Union[List[Any], Dict[Any, Any]], user: str, metadata: Dict[Any, Any] + self, data_id: int, data: str, label: Union[List[Any], Dict[Any, Any]], user: str, metadata: Dict[Any, Any] ): - self.id = id + self.id = data_id self.data = data self.label = label self.user = user diff --git a/backend/data_export/pipeline/repositories.py b/backend/data_export/pipeline/repositories.py index 227e82f4..a45fc5b5 100644 --- a/backend/data_export/pipeline/repositories.py +++ b/backend/data_export/pipeline/repositories.py @@ -31,7 +31,7 @@ class FileRepository(BaseRepository): label_per_user = self.reduce_user(label_per_user) for user, label in label_per_user.items(): yield Record( - id=example.id, + data_id=example.id, data=str(example.filename).split("/")[-1], label=label, user=user, @@ -45,7 +45,7 @@ class FileRepository(BaseRepository): # This means I will allow each user to be able to approve the doc. if len(label_per_user) == 0: yield Record( - id=example.id, data=str(example.filename).split("/")[-1], label=[], user="unknown", metadata={} + data_id=example.id, data=str(example.filename).split("/")[-1], label=[], user="unknown", metadata={} ) def label_per_user(self, example) -> Dict: @@ -82,7 +82,7 @@ class TextRepository(BaseRepository): if self.project.collaborative_annotation: label_per_user = self.reduce_user(label_per_user) for user, label in label_per_user.items(): - yield Record(id=doc.id, data=doc.text, label=label, user=user, metadata=doc.meta) + yield Record(data_id=doc.id, data=doc.text, label=label, user=user, metadata=doc.meta) # todo: # If there is no label, export the doc with `unknown` user. # This is a quick solution. @@ -90,7 +90,7 @@ class TextRepository(BaseRepository): # with the user who approved the doc. # This means I will allow each user to be able to approve the doc. if len(label_per_user) == 0: - yield Record(id=doc.id, data=doc.text, label=[], user="unknown", metadata={}) + yield Record(data_id=doc.id, data=doc.text, label=[], user="unknown", metadata={}) @abc.abstractmethod def label_per_user(self, doc) -> Dict: diff --git a/backend/data_export/tests/test_writer.py b/backend/data_export/tests/test_writer.py index 7b40c129..8997fbe4 100644 --- a/backend/data_export/tests/test_writer.py +++ b/backend/data_export/tests/test_writer.py @@ -9,9 +9,9 @@ from ..pipeline.writers import CsvWriter, IntentAndSlotWriter class TestCSVWriter(unittest.TestCase): def setUp(self): self.records = [ - Record(id=0, data="exampleA", label=["labelA"], user="admin", metadata={"hidden": "secretA"}), - Record(id=1, data="exampleB", label=["labelB"], user="admin", metadata={"hidden": "secretB"}), - Record(id=2, data="exampleC", label=["labelC"], user="admin", metadata={"meta": "secretC"}), + Record(data_id=0, data="exampleA", label=["labelA"], user="admin", metadata={"hidden": "secretA"}), + Record(data_id=1, data="exampleB", label=["labelB"], user="admin", metadata={"hidden": "secretB"}), + Record(data_id=2, data="exampleC", label=["labelC"], user="admin", metadata={"meta": "secretC"}), ] def test_create_header(self): @@ -29,8 +29,8 @@ class TestCSVWriter(unittest.TestCase): def test_label_order(self): writer = CsvWriter(".") - record1 = Record(id=0, data="", label=["labelA", "labelB"], user="", metadata={}) - record2 = Record(id=0, data="", label=["labelB", "labelA"], user="", metadata={}) + record1 = Record(data_id=0, data="", label=["labelA", "labelB"], user="", metadata={}) + record2 = Record(data_id=0, data="", label=["labelB", "labelA"], user="", metadata={}) line1 = writer.create_line(record1) line2 = writer.create_line(record2) expected = "labelA#labelB" @@ -61,7 +61,11 @@ class TestCSVWriter(unittest.TestCase): class TestIntentWriter(unittest.TestCase): def setUp(self): self.record = Record( - id=0, data="exampleA", label={"cats": ["positive"], "entities": [(0, 1, "LOC")]}, user="admin", metadata={} + data_id=0, + data="exampleA", + label={"cats": ["positive"], "entities": [(0, 1, "LOC")]}, + user="admin", + metadata={}, ) def test_create_line(self):