Browse Source

Handle ValidationError in LabelMaker, fix #1898

fix/1898
Hironsan 2 years ago
parent
commit
dc3bf5c59e
3 changed files with 28 additions and 5 deletions
  1. 11
      backend/data_import/pipeline/label.py
  2. 11
      backend/data_import/pipeline/makers.py
  3. 11
      backend/data_import/tests/test_makers.py

11
backend/data_import/pipeline/label.py

@ -2,7 +2,14 @@ import abc
import uuid import uuid
from typing import Any, Optional from typing import Any, Optional
from pydantic import UUID4, BaseModel, ConstrainedStr, NonNegativeInt, root_validator
from pydantic import (
UUID4,
BaseModel,
ConstrainedStr,
NonNegativeInt,
ValidationError,
root_validator,
)
from .label_types import LabelTypes from .label_types import LabelTypes
from examples.models import Example from examples.models import Example
@ -77,7 +84,7 @@ class SpanLabel(Label):
def check_start_offset_is_less_than_end_offset(cls, values): def check_start_offset_is_less_than_end_offset(cls, values):
start_offset, end_offset = values.get("start_offset"), values.get("end_offset") start_offset, end_offset = values.get("start_offset"), values.get("end_offset")
if start_offset >= end_offset: if start_offset >= end_offset:
raise ValueError("start_offset must be less than end_offset.")
raise ValidationError("start_offset must be less than end_offset.")
return values return values
@classmethod @classmethod

11
backend/data_import/pipeline/makers.py

@ -1,6 +1,7 @@
from typing import List, Optional, Type from typing import List, Optional, Type
import pandas as pd import pandas as pd
from pydantic import ValidationError
from .data import BaseData from .data import BaseData
from .exceptions import FileParseException from .exceptions import FileParseException
@ -93,15 +94,19 @@ class LabelMaker:
return [] return []
df_label = df.explode(self.column) df_label = df.explode(self.column)
df_label = df_label[[UUID_COLUMN, self.column]]
df_label.dropna(subset=[self.column], inplace=True) df_label.dropna(subset=[self.column], inplace=True)
labels = [] labels = []
for row in df_label.to_dict(orient="records"): for row in df_label.to_dict(orient="records"):
try: try:
label = self.label_class.parse(row[UUID_COLUMN], row[self.column]) label = self.label_class.parse(row[UUID_COLUMN], row[self.column])
labels.append(label) labels.append(label)
except ValueError:
pass
except ValidationError as e:
errors = e.errors()
filename = row.get(UPLOAD_NAME_COLUMN, "")
line = row.get(LINE_NUMBER_COLUMN, 0)
for error in errors:
message = str(error["loc"]) + ":" + error["msg"]
self._errors.append(FileParseException(filename, line, message))
return labels return labels
def check_column_existence(self, df: pd.DataFrame) -> bool: def check_column_existence(self, df: pd.DataFrame) -> bool:

11
backend/data_import/tests/test_makers.py

@ -86,3 +86,14 @@ class TestLabelFormatter(TestCase):
) )
labels = label_maker.make(df) labels = label_maker.make(df)
self.assertEqual(len(labels), 1) self.assertEqual(len(labels), 1)
def test_validation_error(self):
label_maker = LabelMaker(column=self.label_column, label_class=self.label_class)
df = pd.DataFrame(
[
{LINE_NUMBER_COLUMN: 1, UUID_COLUMN: uuid.uuid4(), self.label_column: [""]}, # empty label
]
)
labels = label_maker.make(df)
self.assertEqual(len(labels), 0)
self.assertEqual(len(label_maker.errors), 1)
Loading…
Cancel
Save