Browse Source

Handle ValidationError in LabelMaker, fix #1898

fix/1898
Hironsan 2 years ago
parent
commit
dc3bf5c59e
3 changed files with 28 additions and 5 deletions
  1. 11
      backend/data_import/pipeline/label.py
  2. 11
      backend/data_import/pipeline/makers.py
  3. 11
      backend/data_import/tests/test_makers.py

11
backend/data_import/pipeline/label.py

@ -2,7 +2,14 @@ import abc
import uuid
from typing import Any, Optional
from pydantic import UUID4, BaseModel, ConstrainedStr, NonNegativeInt, root_validator
from pydantic import (
UUID4,
BaseModel,
ConstrainedStr,
NonNegativeInt,
ValidationError,
root_validator,
)
from .label_types import LabelTypes
from examples.models import Example
@ -77,7 +84,7 @@ class SpanLabel(Label):
def check_start_offset_is_less_than_end_offset(cls, values):
start_offset, end_offset = values.get("start_offset"), values.get("end_offset")
if start_offset >= end_offset:
raise ValueError("start_offset must be less than end_offset.")
raise ValidationError("start_offset must be less than end_offset.")
return values
@classmethod

11
backend/data_import/pipeline/makers.py

@ -1,6 +1,7 @@
from typing import List, Optional, Type
import pandas as pd
from pydantic import ValidationError
from .data import BaseData
from .exceptions import FileParseException
@ -93,15 +94,19 @@ class LabelMaker:
return []
df_label = df.explode(self.column)
df_label = df_label[[UUID_COLUMN, self.column]]
df_label.dropna(subset=[self.column], inplace=True)
labels = []
for row in df_label.to_dict(orient="records"):
try:
label = self.label_class.parse(row[UUID_COLUMN], row[self.column])
labels.append(label)
except ValueError:
pass
except ValidationError as e:
errors = e.errors()
filename = row.get(UPLOAD_NAME_COLUMN, "")
line = row.get(LINE_NUMBER_COLUMN, 0)
for error in errors:
message = str(error["loc"]) + ":" + error["msg"]
self._errors.append(FileParseException(filename, line, message))
return labels
def check_column_existence(self, df: pd.DataFrame) -> bool:

11
backend/data_import/tests/test_makers.py

@ -86,3 +86,14 @@ class TestLabelFormatter(TestCase):
)
labels = label_maker.make(df)
self.assertEqual(len(labels), 1)
def test_validation_error(self):
label_maker = LabelMaker(column=self.label_column, label_class=self.label_class)
df = pd.DataFrame(
[
{LINE_NUMBER_COLUMN: 1, UUID_COLUMN: uuid.uuid4(), self.label_column: [""]}, # empty label
]
)
labels = label_maker.make(df)
self.assertEqual(len(labels), 0)
self.assertEqual(len(label_maker.errors), 1)
Loading…
Cancel
Save