diff --git a/hobo/multitenant/threads.py b/hobo/multitenant/threads.py index 46f3f9f..bc7387e 100644 --- a/hobo/multitenant/threads.py +++ b/hobo/multitenant/threads.py @@ -14,6 +14,7 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . +import functools import threading _Thread_start = threading.Thread.start @@ -28,21 +29,34 @@ def _new_start(self): return _Thread_start(self) +def wrap_run(self, func): + if getattr(func, '_wrapped', False): + return func + + @functools.wraps(func) + def wrapper(): + tenant = getattr(self, 'tenant', None) + + if tenant is not None: + from django.db import connection + + old_tenant = connection.tenant + connection.set_tenant(self.tenant) + try: + func() + finally: + connection.set_tenant(old_tenant) + connection.close() + else: + func() + + wrapper._wrapped = True + return wrapper + + def _new__bootstrap_inner(self): - tenant = getattr(self, 'tenant', None) - - if tenant is not None: - from django.db import connection - - old_tenant = connection.tenant - connection.set_tenant(self.tenant) - try: - _Thread__bootstrap_inner(self) - finally: - connection.set_tenant(old_tenant) - connection.close() - else: - _Thread__bootstrap_inner(self) + self.run = wrap_run(self, self.run) + _Thread__bootstrap_inner(self) def install_tenant_aware_threads():