Coverage for website/payments/payables.py: 100.00%
124 statements
« prev ^ index » next coverage.py v7.6.7, created at 2025-08-14 10:31 +0000
« prev ^ index » next coverage.py v7.6.7, created at 2025-08-14 10:31 +0000
1from abc import ABC, abstractmethod
2from decimal import Decimal
3from functools import lru_cache
4from typing import Generic, TypeVar
6from django.apps import apps
7from django.core.exceptions import ObjectDoesNotExist
8from django.db.models import Model
9from django.db.models.signals import pre_save
10from django.utils.functional import classproperty
12from members.models.member import Member
13from payments.exceptions import PaymentError
15PayableModel = TypeVar("PayableModel", bound=Model)
18class NotRegistered(Exception):
19 pass
22class Payable(ABC, Generic[PayableModel]):
23 """Base class for a wrapper around a model that can be paid for.
25 This class provides a common interface for different models for which
26 a payment can be made. For each payable model, a subclass of `Payable`
27 should be created that implements the necessary properties and methods.
29 These `Payable` wrapper classes are then registered in the global `payables`
30 registry, which handles logic for preventing disallowed changes to paid model
31 instances, and provides a factory for the `Payable` objects from model instances.
33 For type hinting, an implementation can specify the generic type `PayableModel`:
35 ```
36 class MyModelPayable(Payable[MyModel]):
37 ...
38 ```
39 """
41 def __init__(self, model: PayableModel):
42 self.model = model
44 @property
45 def pk(self):
46 return self.model.pk
48 @property
49 def payment(self):
50 return self.model.payment
52 @payment.setter
53 def payment(self, payment):
54 self.model.payment = payment
56 def get_payment(self):
57 try:
58 self.model.refresh_from_db(fields=["payment"])
59 except ObjectDoesNotExist:
60 return None
61 return self.payment
63 @property
64 @abstractmethod
65 def payment_amount(self) -> Decimal:
66 """The amount that should be paid for this model."""
68 @property
69 @abstractmethod
70 def payment_topic(self) -> str:
71 """A short description of what the payment is for.
73 This will be saved to Payment.topic when a payment is created.
74 """
76 @property
77 @abstractmethod
78 def payment_notes(self) -> str:
79 """Detailed notes about the payment.
81 This will be saved to Payment.notes when a payment is created.
82 """
84 @property
85 @abstractmethod
86 def payment_payer(self) -> Member | None:
87 """The member who paid or should pay for this model."""
89 @property
90 def tpay_allowed(self) -> bool:
91 return True
93 @property
94 def paying_allowed(self) -> bool:
95 return True
97 @abstractmethod
98 def can_manage_payment(self, member: Member) -> bool:
99 """Return whether the given member can manage the payment for this payable."""
101 @classproperty
102 def immutable_after_payment(cls) -> bool: # noqa: N805
103 return False
105 @classproperty
106 def immutable_foreign_key_models(cls) -> dict[type[Model], str]: # noqa: N805
107 return {}
109 @classproperty
110 def immutable_model_fields_after_payment(cls) -> list[str]: # noqa: N805
111 return []
113 def __hash__(self):
114 return hash((self.payment_amount, self.payment_topic, self.payment_notes))
117class Payables:
118 def __init__(self):
119 self._registry: dict[str, type[Payable]] = {}
121 @lru_cache(maxsize=1024)
122 def _get_key(self, model: Model | type[Model]):
123 return f"{model._meta.app_label}.{model._meta.model_name}"
125 def get_payable(self, model: Model) -> Payable:
126 if self._get_key(model) not in self._registry:
127 raise NotRegistered(f"No Payable registered for {self._get_key(model)}")
128 return self._registry[self._get_key(model)](model)
130 def get_payable_models(self) -> list[type[Model]]:
131 """Return all registered models."""
132 return [apps.get_model(key) for key in self._registry]
134 def register(self, model: type[Model], payable_class: type[Payable]):
135 """Register a payable class for a model.
137 This sets up signals that ensure specified fields are not changed after payment.
138 It also makes it possible to get a Payable instance given a model instance.
139 """
140 self._registry[self._get_key(model)] = payable_class
141 if payable_class.immutable_after_payment:
142 pre_save.connect(
143 prevent_saving, sender=model, dispatch_uid=f"prevent_saving_{model}"
144 )
146 for foreign_model in payable_class.immutable_foreign_key_models:
147 foreign_key_field = payable_class.immutable_foreign_key_models[
148 foreign_model
149 ]
150 pre_save.connect(
151 prevent_saving_related(foreign_key_field),
152 sender=foreign_model,
153 dispatch_uid=f"prevent_saving_related_{model}_{foreign_model}",
154 )
156 def _unregister(self, model: type[Model]):
157 """Unregister a payable class for a model.
159 This is for testing purposes only, to clean up the registry between tests.
160 """
161 payable_class = self._registry.get(self._get_key(model))
162 if payable_class.immutable_after_payment:
163 pre_save.disconnect(dispatch_uid=f"prevent_saving_{model}")
164 for foreign_model in payable_class.immutable_foreign_key_models:
165 pre_save.disconnect(
166 dispatch_uid=f"prevent_saving_related_{model}_{foreign_model}"
167 )
168 del self._registry[self._get_key(model)]
171payables = Payables()
174def prevent_saving(sender, instance, **kwargs):
175 if not instance.pk:
176 # Do nothing if the model is not created yet
177 return
179 payable = payables.get_payable(instance)
180 if not payable.immutable_after_payment:
181 # Do nothing if the model is not marked as immutable
182 return
183 if not payable.payment:
184 # Do nothing if the model is not actually paid
185 if payable.get_payment() is not None:
186 # If this happens, there was a payment, but it is being deleted
187 raise PaymentError("You are trying to unlink a payment from its payable.")
188 return
189 try:
190 old_instance = sender.objects.get(pk=instance.pk)
191 except sender.DoesNotExist:
192 return
194 immutable_fields = (
195 payable.immutable_model_fields_after_payment[sender]
196 if isinstance(payable.immutable_model_fields_after_payment, dict)
197 else payable.immutable_model_fields_after_payment
198 )
199 for field in immutable_fields:
200 if getattr(old_instance, field) != getattr(instance, field):
201 raise PaymentError("Cannot change this model")
204def prevent_saving_related(foreign_key_field):
205 def prevent_related_saving_paid_after_immutable(sender, instance, **kwargs):
206 payable = payables.get_payable(getattr(instance, foreign_key_field))
207 if not payable.immutable_after_payment:
208 # Do nothing if the parent is not marked as immutable
209 return
210 if not payable.payment:
211 # Do nothing if the parent is not actually paid
212 return
213 try:
214 old_instance = sender.objects.get(pk=instance.pk)
215 except sender.DoesNotExist:
216 raise PaymentError(
217 "Cannot save this model with foreign key to immutable payment"
218 )
220 immutable_fields = (
221 payable.immutable_model_fields_after_payment[sender]
222 if isinstance(payable.immutable_model_fields_after_payment, dict)
223 else []
224 )
225 for field in immutable_fields:
226 if getattr(old_instance, field) != getattr(instance, field):
227 raise PaymentError("Cannot change this model")
229 return prevent_related_saving_paid_after_immutable