Coverage for website/thaliawebsite/api/openapi.py: 24.59%

49 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2025-08-14 10:31 +0000

1import warnings 

2 

3from django.http import HttpRequest 

4 

5from oauth2_provider.scopes import get_scopes_backend 

6from rest_framework import exceptions 

7from rest_framework.request import Request 

8from rest_framework.reverse import reverse 

9from rest_framework.schemas.openapi import AutoSchema, SchemaGenerator 

10from rest_framework.schemas.utils import is_list_view 

11 

12 

13class OAuthSchemaGenerator(SchemaGenerator): 

14 def get_schema(self, request=None, public=False): 

15 schema = super().get_schema(request, public) 

16 if "components" in schema: 

17 schema["components"]["securitySchemes"] = { 

18 "oauth2": { 

19 "type": "oauth2", 

20 "description": "OAuth2", 

21 "flows": { 

22 "implicit": { 

23 "authorizationUrl": reverse("oauth2_provider:authorize"), 

24 "scopes": get_scopes_backend().get_all_scopes(), 

25 } 

26 }, 

27 } 

28 } 

29 return schema 

30 

31 

32class OAuthAutoSchema(AutoSchema): 

33 def get_operation(self, path, method): 

34 operation = super().get_operation(path, method) 

35 if self.view and hasattr(self.view, "required_scopes"): 

36 operation["security"] = [{"oauth2": self.view.required_scopes}] 

37 else: 

38 operation["security"] = [{"oauth2": ["read", "write"]}] 

39 return operation 

40 

41 def get_operation_id_base(self, path, method, action): 

42 name = super().get_operation_id_base(path, method, action) 

43 if "admin" in path: 

44 return "Admin" + name.capitalize() 

45 return name 

46 

47 def get_operation_id(self, path, method): 

48 method_name = getattr(self.view, "action", method.lower()) 

49 if is_list_view(path, method, self.view): 

50 action = "list" 

51 elif method_name not in self.method_mapping: 

52 action = self._to_camel_case( 

53 f"{self.method_mapping[method.lower()]}_{method_name}" 

54 ) 

55 else: 

56 action = self.method_mapping[method.lower()] 

57 

58 name = self.get_operation_id_base(path, method, action) 

59 

60 return action + name 

61 

62 def get_serializer(self, path, method): 

63 view = self.view 

64 http_request = HttpRequest() 

65 http_request.method = method 

66 http_request.path = path 

67 view.request = Request(http_request) 

68 view.request.member = view.request.user 

69 

70 if not hasattr(view, "get_serializer"): 

71 return None 

72 

73 try: 

74 return view.get_serializer() 

75 except exceptions.APIException: 

76 warnings.warn( 

77 f"{view.__class__.__name__}.get_serializer() raised an " 

78 "exception during schema generation. Serializer fields " 

79 f"will not be generated for {method} {path}." 

80 ) 

81 return None