diff --git a/backend/data_import/pipeline/label.py b/backend/data_import/pipeline/label.py index a70cfd9a..a68d78b4 100644 --- a/backend/data_import/pipeline/label.py +++ b/backend/data_import/pipeline/label.py @@ -2,7 +2,7 @@ import abc import uuid from typing import Any, Optional -from pydantic import UUID4, BaseModel, validator +from pydantic import UUID4, BaseModel, NonNegativeInt, validator from .label_types import LabelTypes from examples.models import Example @@ -70,8 +70,8 @@ class CategoryLabel(Label): class SpanLabel(Label): label: str - start_offset: int - end_offset: int + start_offset: NonNegativeInt + end_offset: NonNegativeInt def __lt__(self, other): return self.start_offset < other.start_offset diff --git a/backend/data_import/tests/test_label.py b/backend/data_import/tests/test_label.py index 010a5d48..9c3ce964 100644 --- a/backend/data_import/tests/test_label.py +++ b/backend/data_import/tests/test_label.py @@ -82,6 +82,10 @@ class TestSpanLabel(TestLabel): self.assertEqual(span.start_offset, 0) self.assertEqual(span.end_offset, 1) + def test_invalid_negative_offset(self): + with self.assertRaises(ValueError): + SpanLabel(label="A", start_offset=-1, end_offset=1, example_uuid=uuid.uuid4()) + def test_parse_invalid_dict(self): example_uuid = uuid.uuid4() with self.assertRaises(ValueError):