Browse Source

Merge pull request #1675 from doccano/enhancement/mypy

[Enhancement] Add mypy to the workflow
pull/1678/head
Hiroki Nakayama 2 years ago
committed by GitHub
parent
commit
5ff1dab66f
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 90 additions and 43 deletions
  1. 3
      .github/workflows/ci.yml
  2. 4
      Pipfile
  3. 33
      Pipfile.lock
  4. 4
      backend/api/tests/utils.py
  5. 6
      backend/auto_labeling/pipeline/labels.py
  6. 4
      backend/data_export/pipeline/factories.py
  7. 21
      backend/data_export/pipeline/repositories.py
  8. 8
      backend/data_export/pipeline/writers.py
  9. 6
      backend/data_import/pipeline/builders.py
  10. 6
      backend/data_import/pipeline/cleaners.py
  11. 4
      backend/data_import/pipeline/labels.py
  12. 8
      backend/data_import/pipeline/parsers.py
  13. 2
      backend/data_import/pipeline/readers.py
  14. 6
      backend/data_import/pipeline/writers.py
  15. 4
      backend/projects/permissions.py
  16. 6
      backend/projects/serializers.py
  17. 2
      backend/projects/tests/utils.py
  18. 6
      pyproject.toml

3
.github/workflows/ci.yml

@ -35,6 +35,9 @@ jobs:
run: |
pipenv run black
working-directory: ./backend
- name: mypy
run: |
pipenv run mypy
- name: Run tests
run: |
pipenv run test

4
Pipfile

@ -19,6 +19,9 @@ watchdog = "*"
argh = "*"
black = "*"
pyproject-flake8 = "*"
types-chardet = "*"
types-requests = "*"
types-waitress = "*"
[packages]
django = "~=3.2"
@ -63,6 +66,7 @@ python_version = "3.8"
isort = "isort api -c --skip migrations"
flake8 = "pflake8 --filename \"*.py\" --extend-exclude \"*/migrations\""
black = "black --check ."
mypy = "mypy ."
wait_for_db = "python manage.py wait_for_db"
test = "python manage.py test --pattern=\"test*.py\""
migrate = "python manage.py migrate"

33
Pipfile.lock

@ -1,7 +1,7 @@
{
"_meta": {
"hash": {
"sha256": "1ec1af8bf3969c0798488c46364a48661b7789b66b2682c23f858ca760a4c89c"
"sha256": "b73f92aa76c8d2a6fdc0927b1853c46c27ae9508d5c2173b6e79823fcd1bf634"
},
"pipfile-spec": 6,
"requires": {
@ -1680,6 +1680,37 @@
],
"version": "==2.0.0"
},
"types-chardet": {
"hashes": [
"sha256:519850a12ab0009f3ec5bdca35ce1c0de4eb4a67a2110aa206386e6219b3ecd8",
"sha256:8990a86d4c7cfa6c6c5889fc49e456e477851e75b5adb396d42ae106d0ae02ea"
],
"index": "pypi",
"version": "==4.0.3"
},
"types-requests": {
"hashes": [
"sha256:8ec9f5f84adc6f579f53943312c28a84e87dc70201b54f7c4fbc7d22ecfa8a3e",
"sha256:c2f4e4754d07ca0a88fd8a89bbc6c8a9f90fb441f9c9b572fd5c484f04817486"
],
"index": "pypi",
"version": "==2.27.8"
},
"types-urllib3": {
"hashes": [
"sha256:4a54f6274ab1c80968115634a55fb9341a699492b95e32104a7c513db9fe02e9",
"sha256:abd2d4857837482b1834b4817f0587678dcc531dbc9abe4cde4da28cef3f522c"
],
"version": "==1.26.9"
},
"types-waitress": {
"hashes": [
"sha256:d7843d13487effb0e0774ec294f42ca63ed9f74a9296b47e4e290ddb21a05292",
"sha256:fdd57199a5a7b5b3e65973feb137964bd750cdb1af4f7cc7c9d6053342f86ff2"
],
"index": "pypi",
"version": "==2.0.6"
},
"typing-extensions": {
"hashes": [
"sha256:4ca091dea149f945ec56afb48dae714f21e8692ef22a395223bcd328961b6a0e",

4
backend/api/tests/utils.py

@ -1,10 +1,12 @@
from typing import Any, Dict
from rest_framework import status
from rest_framework.test import APITestCase
class CRUDMixin(APITestCase):
url = ""
data = {}
data: Dict[str, Any] = {}
def assert_fetch(self, user=None, expected=status.HTTP_403_FORBIDDEN):
if user:

6
backend/auto_labeling/pipeline/labels.py

@ -6,13 +6,13 @@ from django.contrib.auth.models import User
from projects.models import Project
from examples.models import Example
from label_types.models import CategoryType, SpanType
from label_types.models import CategoryType, LabelType, SpanType
from labels.models import Label, Category, Span, TextLabel
class LabelCollection(abc.ABC):
label_type = None
model = None
label_type: LabelType = None
model: Label = None
def __init__(self, labels):
self.labels = labels

4
backend/data_export/pipeline/factories.py

@ -11,7 +11,7 @@ from projects.models import (
from . import catalog, repositories, writers
def create_repository(project) -> repositories.BaseRepository:
def create_repository(project):
mapping = {
DOCUMENT_CLASSIFICATION: repositories.TextClassificationRepository,
SEQUENCE_LABELING: repositories.SequenceLabelingRepository,
@ -22,7 +22,7 @@ def create_repository(project) -> repositories.BaseRepository:
}
if project.project_type not in mapping:
ValueError(f"Invalid project type: {project.project_type}")
repository = mapping.get(project.project_type)(project)
repository = mapping[project.project_type](project)
return repository

21
backend/data_export/pipeline/repositories.py

@ -1,12 +1,14 @@
import abc
import itertools
from collections import defaultdict
from typing import Dict, Iterator, List
from typing import Dict, Iterator, List, Tuple, Union
from projects.models import Project
from examples.models import Example
from .data import Record
SpanType = Tuple[int, int, str]
class BaseRepository(abc.ABC):
def __init__(self, project: Project):
@ -144,16 +146,15 @@ class IntentDetectionSlotFillingRepository(TextRepository):
)
def label_per_user(self, doc) -> Dict:
category_per_user = defaultdict(list)
span_per_user = defaultdict(list)
label_per_user = defaultdict(dict)
category_per_user: Dict[str, List[str]] = defaultdict(list)
span_per_user: Dict[str, List[SpanType]] = defaultdict(list)
label_per_user: Dict[str, Dict[str, Union[List[str], List[SpanType]]]] = defaultdict(dict)
for a in doc.categories.all():
category_per_user[a.user.username].append(a.label.text)
for a in doc.spans.all():
label = (a.start_offset, a.end_offset, a.label.text)
span_per_user[a.user.username].append(label)
for user, label in category_per_user.items():
label_per_user[user]["cats"] = label
for user, label in span_per_user.items():
label_per_user[user]["entities"] = label
span_per_user[a.user.username].append((a.start_offset, a.end_offset, a.label.text))
for user, cats in category_per_user.items():
label_per_user[user]["cats"] = cats
for user, span in span_per_user.items():
label_per_user[user]["entities"] = span
return label_per_user

8
backend/data_export/pipeline/writers.py

@ -58,9 +58,9 @@ class CsvWriter(BaseWriter):
def write(self, records: Iterator[Record]) -> str:
writers = {}
file_handlers = set()
records = list(records)
header = self.create_header(records)
for record in records:
record_list = list(records)
header = self.create_header(record_list)
for record in record_list:
filename = os.path.join(self.tmpdir, f"{record.user}.{self.extension}")
if filename not in writers:
f = open(filename, mode="a", encoding="utf-8")
@ -82,7 +82,7 @@ class CsvWriter(BaseWriter):
def create_line(self, record) -> Dict:
return {"id": record.id, "data": record.data, "label": "#".join(sorted(record.label)), **record.metadata}
def create_header(self, records: List[Record]) -> Iterable[str]:
def create_header(self, records: List[Record]) -> List[str]:
header = ["id", "data", "label"]
header += sorted(set(itertools.chain(*[r.metadata.keys() for r in records])))
return header

6
backend/data_import/pipeline/builders.py

@ -1,6 +1,6 @@
import abc
from logging import getLogger
from typing import Any, Dict, List, Optional, Type, TypeVar
from typing import Any, Dict, List, Optional, Type
from pydantic import ValidationError
@ -10,7 +10,6 @@ from .labels import Label
from .readers import Builder, Record
logger = getLogger(__name__)
T = TypeVar("T")
class PlainBuilder(Builder):
@ -34,7 +33,8 @@ def build_data(row: Dict[Any, Any], name: str, data_class: Type[BaseData], filen
class Column(abc.ABC):
def __init__(self, name: str, value_class: Type[T]):
# Todo: need to redesign.
def __init__(self, name: str, value_class: Any):
self.name = name
self.value_class = value_class

6
backend/data_import/pipeline/cleaners.py

@ -1,7 +1,7 @@
from typing import List
from projects.models import Project
from .labels import CategoryLabel, Label, SpanLabel
from .labels import Label, SpanLabel
class Cleaner:
@ -21,7 +21,7 @@ class SpanCleaner(Cleaner):
super().__init__(project)
self.allow_overlapping = getattr(project, "allow_overlapping", False)
def clean(self, labels: List[SpanLabel]) -> List[SpanLabel]:
def clean(self, labels: List[SpanLabel]) -> List[SpanLabel]: # type: ignore
if self.allow_overlapping:
return labels
@ -44,7 +44,7 @@ class CategoryCleaner(Cleaner):
super().__init__(project)
self.exclusive = getattr(project, "single_class_classification", False)
def clean(self, labels: List[CategoryLabel]) -> List[CategoryLabel]:
def clean(self, labels: List[Label]) -> List[Label]:
if self.exclusive:
return labels[:1]
else:

4
backend/data_import/pipeline/labels.py

@ -1,5 +1,5 @@
import abc
from typing import Any, Dict, Optional, Union
from typing import Any, Dict, Optional
from pydantic import BaseModel, validator
@ -68,7 +68,7 @@ class CategoryLabel(Label):
class SpanLabel(Label):
label: Union[str, int]
label: str
start_offset: int
end_offset: int

8
backend/data_import/pipeline/parsers.py

@ -47,7 +47,7 @@ def detect_encoding(filename: str, buffer_size: int = io.DEFAULT_BUFFER_SIZE) ->
if detector.done:
break
if detector.done:
return detector.result["encoding"]
return detector.result["encoding"] or "utf-8"
else:
return "utf-8"
@ -164,7 +164,7 @@ class JSONParser(Parser):
def __init__(self, encoding: str = DEFAULT_ENCODING, **kwargs):
self.encoding = encoding
self._errors = []
self._errors: List[FileParseException] = []
def parse(self, filename: str) -> Iterator[Dict[Any, Any]]:
encoding = decide_encoding(filename, self.encoding)
@ -191,7 +191,7 @@ class JSONLParser(Parser):
def __init__(self, encoding: str = DEFAULT_ENCODING, **kwargs):
self.encoding = encoding
self._errors = []
self._errors: List[FileParseException] = []
def parse(self, filename: str) -> Iterator[Dict[Any, Any]]:
reader = LineReader(filename, self.encoding)
@ -288,7 +288,7 @@ class CoNLLParser(Parser):
self.encoding = encoding
self.delimiter = delimiter
mapping = {"IOB2": IOB2, "IOE2": IOE2, "IOBES": IOBES, "BILOU": BILOU}
self._errors = []
self._errors: List[FileParseException] = []
if scheme in mapping:
self.scheme = mapping[scheme]
else:

2
backend/data_import/pipeline/readers.py

@ -100,7 +100,7 @@ class Reader(BaseReader):
self.filenames = filenames
self.parser = parser
self.builder = builder
self._errors = []
self._errors: List[FileParseException] = []
def __iter__(self) -> Iterator[Record]:
for filename in self.filenames:

6
backend/data_import/pipeline/writers.py

@ -9,7 +9,7 @@ from projects.models import Project
from examples.models import Example
from label_types.models import CategoryType, SpanType
from .exceptions import FileParseException
from .readers import BaseReader
from .readers import BaseReader, Record
class Writer(abc.ABC):
@ -33,7 +33,7 @@ def group_by_class(instances):
class Examples:
def __init__(self, buffer_size: int = settings.IMPORT_BATCH_SIZE):
self.buffer_size = buffer_size
self.buffer = []
self.buffer: List[Record] = []
def __len__(self):
return len(self.buffer)
@ -85,7 +85,7 @@ class Examples:
class BulkWriter(Writer):
def __init__(self, batch_size: int):
self.examples = Examples(batch_size)
self._errors = []
self._errors: List[FileParseException] = []
def save(self, reader: BaseReader, project: Project, user, cleaner):
it = iter(reader)

4
backend/projects/permissions.py

@ -50,5 +50,5 @@ class IsAnnotationApprover(RolePermission):
role_name = settings.ROLE_ANNOTATION_APPROVER
IsProjectMember = IsAnnotator | IsAnnotationApprover | IsProjectAdmin
IsProjectStaffAndReadOnly = IsAnnotatorAndReadOnly | IsAnnotationApproverAndReadOnly
IsProjectMember = IsAnnotator | IsAnnotationApprover | IsProjectAdmin # type: ignore
IsProjectStaffAndReadOnly = IsAnnotatorAndReadOnly | IsAnnotationApproverAndReadOnly # type: ignore

6
backend/projects/serializers.py

@ -49,7 +49,7 @@ class ProjectSerializer(serializers.ModelSerializer):
class Meta:
model = Project
fields = (
fields = [
"id",
"name",
"description",
@ -66,7 +66,7 @@ class ProjectSerializer(serializers.ModelSerializer):
"can_define_category",
"can_define_span",
"tags",
)
]
read_only_fields = (
"updated_at",
"is_text_project",
@ -86,7 +86,7 @@ class TextClassificationProjectSerializer(ProjectSerializer):
class SequenceLabelingProjectSerializer(ProjectSerializer):
class Meta(ProjectSerializer.Meta):
model = SequenceLabelingProject
fields = ProjectSerializer.Meta.fields + ("allow_overlapping", "grapheme_mode")
fields = ProjectSerializer.Meta.fields + ["allow_overlapping", "grapheme_mode"]
class Seq2seqProjectSerializer(ProjectSerializer):

2
backend/projects/tests/utils.py

@ -49,7 +49,7 @@ def assign_user_to_role(project_member, project, role_name):
return mapping
def make_project(task: str, users: List[str], roles: List[str] = None, collaborative_annotation=False, **kwargs):
def make_project(task: str, users: List[str], roles: List[str], collaborative_annotation=False, **kwargs):
create_default_roles()
# create users.

6
pyproject.toml

@ -13,3 +13,9 @@ max-line-length = 120
max-complexity = 18
ignore = "E203,E266,W503,"
filename = "backend/*"
[tool.mypy]
python_version = "3.8"
ignore_missing_imports = true
show_error_codes = true
exclude = "(migrations)|(app/settings.py)"
Loading…
Cancel
Save