Coverage for website/thaliawebsite/api/openapi.py: 24.59%
49 statements
« prev ^ index » next coverage.py v7.6.7, created at 2025-08-14 10:31 +0000
« prev ^ index » next coverage.py v7.6.7, created at 2025-08-14 10:31 +0000
1import warnings
3from django.http import HttpRequest
5from oauth2_provider.scopes import get_scopes_backend
6from rest_framework import exceptions
7from rest_framework.request import Request
8from rest_framework.reverse import reverse
9from rest_framework.schemas.openapi import AutoSchema, SchemaGenerator
10from rest_framework.schemas.utils import is_list_view
13class OAuthSchemaGenerator(SchemaGenerator):
14 def get_schema(self, request=None, public=False):
15 schema = super().get_schema(request, public)
16 if "components" in schema:
17 schema["components"]["securitySchemes"] = {
18 "oauth2": {
19 "type": "oauth2",
20 "description": "OAuth2",
21 "flows": {
22 "implicit": {
23 "authorizationUrl": reverse("oauth2_provider:authorize"),
24 "scopes": get_scopes_backend().get_all_scopes(),
25 }
26 },
27 }
28 }
29 return schema
32class OAuthAutoSchema(AutoSchema):
33 def get_operation(self, path, method):
34 operation = super().get_operation(path, method)
35 if self.view and hasattr(self.view, "required_scopes"):
36 operation["security"] = [{"oauth2": self.view.required_scopes}]
37 else:
38 operation["security"] = [{"oauth2": ["read", "write"]}]
39 return operation
41 def get_operation_id_base(self, path, method, action):
42 name = super().get_operation_id_base(path, method, action)
43 if "admin" in path:
44 return "Admin" + name.capitalize()
45 return name
47 def get_operation_id(self, path, method):
48 method_name = getattr(self.view, "action", method.lower())
49 if is_list_view(path, method, self.view):
50 action = "list"
51 elif method_name not in self.method_mapping:
52 action = self._to_camel_case(
53 f"{self.method_mapping[method.lower()]}_{method_name}"
54 )
55 else:
56 action = self.method_mapping[method.lower()]
58 name = self.get_operation_id_base(path, method, action)
60 return action + name
62 def get_serializer(self, path, method):
63 view = self.view
64 http_request = HttpRequest()
65 http_request.method = method
66 http_request.path = path
67 view.request = Request(http_request)
68 view.request.member = view.request.user
70 if not hasattr(view, "get_serializer"):
71 return None
73 try:
74 return view.get_serializer()
75 except exceptions.APIException:
76 warnings.warn(
77 f"{view.__class__.__name__}.get_serializer() raised an "
78 "exception during schema generation. Serializer fields "
79 f"will not be generated for {method} {path}."
80 )
81 return None