Coverage for website/payments/payables.py: 100.00%

124 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2025-08-14 10:31 +0000

1from abc import ABC, abstractmethod 

2from decimal import Decimal 

3from functools import lru_cache 

4from typing import Generic, TypeVar 

5 

6from django.apps import apps 

7from django.core.exceptions import ObjectDoesNotExist 

8from django.db.models import Model 

9from django.db.models.signals import pre_save 

10from django.utils.functional import classproperty 

11 

12from members.models.member import Member 

13from payments.exceptions import PaymentError 

14 

15PayableModel = TypeVar("PayableModel", bound=Model) 

16 

17 

18class NotRegistered(Exception): 

19 pass 

20 

21 

22class Payable(ABC, Generic[PayableModel]): 

23 """Base class for a wrapper around a model that can be paid for. 

24 

25 This class provides a common interface for different models for which 

26 a payment can be made. For each payable model, a subclass of `Payable` 

27 should be created that implements the necessary properties and methods. 

28 

29 These `Payable` wrapper classes are then registered in the global `payables` 

30 registry, which handles logic for preventing disallowed changes to paid model 

31 instances, and provides a factory for the `Payable` objects from model instances. 

32 

33 For type hinting, an implementation can specify the generic type `PayableModel`: 

34 

35 ``` 

36 class MyModelPayable(Payable[MyModel]): 

37 ... 

38 ``` 

39 """ 

40 

41 def __init__(self, model: PayableModel): 

42 self.model = model 

43 

44 @property 

45 def pk(self): 

46 return self.model.pk 

47 

48 @property 

49 def payment(self): 

50 return self.model.payment 

51 

52 @payment.setter 

53 def payment(self, payment): 

54 self.model.payment = payment 

55 

56 def get_payment(self): 

57 try: 

58 self.model.refresh_from_db(fields=["payment"]) 

59 except ObjectDoesNotExist: 

60 return None 

61 return self.payment 

62 

63 @property 

64 @abstractmethod 

65 def payment_amount(self) -> Decimal: 

66 """The amount that should be paid for this model.""" 

67 

68 @property 

69 @abstractmethod 

70 def payment_topic(self) -> str: 

71 """A short description of what the payment is for. 

72 

73 This will be saved to Payment.topic when a payment is created. 

74 """ 

75 

76 @property 

77 @abstractmethod 

78 def payment_notes(self) -> str: 

79 """Detailed notes about the payment. 

80 

81 This will be saved to Payment.notes when a payment is created. 

82 """ 

83 

84 @property 

85 @abstractmethod 

86 def payment_payer(self) -> Member | None: 

87 """The member who paid or should pay for this model.""" 

88 

89 @property 

90 def tpay_allowed(self) -> bool: 

91 return True 

92 

93 @property 

94 def paying_allowed(self) -> bool: 

95 return True 

96 

97 @abstractmethod 

98 def can_manage_payment(self, member: Member) -> bool: 

99 """Return whether the given member can manage the payment for this payable.""" 

100 

101 @classproperty 

102 def immutable_after_payment(cls) -> bool: # noqa: N805 

103 return False 

104 

105 @classproperty 

106 def immutable_foreign_key_models(cls) -> dict[type[Model], str]: # noqa: N805 

107 return {} 

108 

109 @classproperty 

110 def immutable_model_fields_after_payment(cls) -> list[str]: # noqa: N805 

111 return [] 

112 

113 def __hash__(self): 

114 return hash((self.payment_amount, self.payment_topic, self.payment_notes)) 

115 

116 

117class Payables: 

118 def __init__(self): 

119 self._registry: dict[str, type[Payable]] = {} 

120 

121 @lru_cache(maxsize=1024) 

122 def _get_key(self, model: Model | type[Model]): 

123 return f"{model._meta.app_label}.{model._meta.model_name}" 

124 

125 def get_payable(self, model: Model) -> Payable: 

126 if self._get_key(model) not in self._registry: 

127 raise NotRegistered(f"No Payable registered for {self._get_key(model)}") 

128 return self._registry[self._get_key(model)](model) 

129 

130 def get_payable_models(self) -> list[type[Model]]: 

131 """Return all registered models.""" 

132 return [apps.get_model(key) for key in self._registry] 

133 

134 def register(self, model: type[Model], payable_class: type[Payable]): 

135 """Register a payable class for a model. 

136 

137 This sets up signals that ensure specified fields are not changed after payment. 

138 It also makes it possible to get a Payable instance given a model instance. 

139 """ 

140 self._registry[self._get_key(model)] = payable_class 

141 if payable_class.immutable_after_payment: 

142 pre_save.connect( 

143 prevent_saving, sender=model, dispatch_uid=f"prevent_saving_{model}" 

144 ) 

145 

146 for foreign_model in payable_class.immutable_foreign_key_models: 

147 foreign_key_field = payable_class.immutable_foreign_key_models[ 

148 foreign_model 

149 ] 

150 pre_save.connect( 

151 prevent_saving_related(foreign_key_field), 

152 sender=foreign_model, 

153 dispatch_uid=f"prevent_saving_related_{model}_{foreign_model}", 

154 ) 

155 

156 def _unregister(self, model: type[Model]): 

157 """Unregister a payable class for a model. 

158 

159 This is for testing purposes only, to clean up the registry between tests. 

160 """ 

161 payable_class = self._registry.get(self._get_key(model)) 

162 if payable_class.immutable_after_payment: 

163 pre_save.disconnect(dispatch_uid=f"prevent_saving_{model}") 

164 for foreign_model in payable_class.immutable_foreign_key_models: 

165 pre_save.disconnect( 

166 dispatch_uid=f"prevent_saving_related_{model}_{foreign_model}" 

167 ) 

168 del self._registry[self._get_key(model)] 

169 

170 

171payables = Payables() 

172 

173 

174def prevent_saving(sender, instance, **kwargs): 

175 if not instance.pk: 

176 # Do nothing if the model is not created yet 

177 return 

178 

179 payable = payables.get_payable(instance) 

180 if not payable.immutable_after_payment: 

181 # Do nothing if the model is not marked as immutable 

182 return 

183 if not payable.payment: 

184 # Do nothing if the model is not actually paid 

185 if payable.get_payment() is not None: 

186 # If this happens, there was a payment, but it is being deleted 

187 raise PaymentError("You are trying to unlink a payment from its payable.") 

188 return 

189 try: 

190 old_instance = sender.objects.get(pk=instance.pk) 

191 except sender.DoesNotExist: 

192 return 

193 

194 immutable_fields = ( 

195 payable.immutable_model_fields_after_payment[sender] 

196 if isinstance(payable.immutable_model_fields_after_payment, dict) 

197 else payable.immutable_model_fields_after_payment 

198 ) 

199 for field in immutable_fields: 

200 if getattr(old_instance, field) != getattr(instance, field): 

201 raise PaymentError("Cannot change this model") 

202 

203 

204def prevent_saving_related(foreign_key_field): 

205 def prevent_related_saving_paid_after_immutable(sender, instance, **kwargs): 

206 payable = payables.get_payable(getattr(instance, foreign_key_field)) 

207 if not payable.immutable_after_payment: 

208 # Do nothing if the parent is not marked as immutable 

209 return 

210 if not payable.payment: 

211 # Do nothing if the parent is not actually paid 

212 return 

213 try: 

214 old_instance = sender.objects.get(pk=instance.pk) 

215 except sender.DoesNotExist: 

216 raise PaymentError( 

217 "Cannot save this model with foreign key to immutable payment" 

218 ) 

219 

220 immutable_fields = ( 

221 payable.immutable_model_fields_after_payment[sender] 

222 if isinstance(payable.immutable_model_fields_after_payment, dict) 

223 else [] 

224 ) 

225 for field in immutable_fields: 

226 if getattr(old_instance, field) != getattr(instance, field): 

227 raise PaymentError("Cannot change this model") 

228 

229 return prevent_related_saving_paid_after_immutable