1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
|
# Author: Giampaolo Rodola' <g.rodola [AT] gmail [DOT] com>
# License: MIT
from __future__ import with_statement
import contextlib
import signal
import sys
def _sigterm_handler(signum, frame):
sys.exit(0)
_sigterm_handler.__enter_ctx__ = False
@contextlib.contextmanager
def handle_exit(callback=None, append=False):
"""A context manager which properly handles SIGTERM and SIGINT
(KeyboardInterrupt) signals, registering a function which is
guaranteed to be called after signals are received.
Also, it makes sure to execute previously registered signal
handlers as well (if any).
>>> app = App()
>>> with handle_exit(app.stop):
... app.start()
...
>>>
If append == False raise RuntimeError if there's already a handler
registered for SIGTERM, otherwise both new and old handlers are
executed in this order.
"""
old_handler = signal.signal(signal.SIGTERM, _sigterm_handler)
if (old_handler != signal.SIG_DFL) and (old_handler != _sigterm_handler):
if not append:
raise RuntimeError("there is already a handler registered for "
"SIGTERM: %r" % old_handler)
def handler(signum, frame):
try:
_sigterm_handler(signum, frame)
finally:
old_handler(signum, frame)
signal.signal(signal.SIGTERM, handler)
if _sigterm_handler.__enter_ctx__:
raise RuntimeError("can't use nested contexts")
_sigterm_handler.__enter_ctx__ = True
try:
yield
except KeyboardInterrupt:
pass
except SystemExit, err:
# code != 0 refers to an application error (e.g. explicit
# sys.exit('some error') call).
# We don't want that to pass silently.
# Nevertheless, the 'finally' clause below will always
# be executed.
if err.code != 0:
raise
finally:
_sigterm_handler.__enter_ctx__ = False
if callback is not None:
callback()
if __name__ == '__main__':
# ===============================================================
# --- test suite
# ===============================================================
import unittest
import os
class TestOnExit(unittest.TestCase):
def setUp(self):
# reset signal handlers
signal.signal(signal.SIGTERM, signal.SIG_DFL)
self.flag = None
def tearDown(self):
# make sure we exited the ctx manager
self.assertTrue(self.flag is not None)
def test_base(self):
with handle_exit():
pass
self.flag = True
def test_callback(self):
callback = []
with handle_exit(lambda: callback.append(None)):
pass
self.flag = True
self.assertEqual(callback, [None])
def test_kinterrupt(self):
with handle_exit():
raise KeyboardInterrupt
self.flag = True
def test_sigterm(self):
with handle_exit():
os.kill(os.getpid(), signal.SIGTERM)
self.flag = True
def test_sigint(self):
with handle_exit():
os.kill(os.getpid(), signal.SIGINT)
self.flag = True
def test_sigterm_old(self):
# make sure the old handler gets executed
queue = []
signal.signal(signal.SIGTERM, lambda s, f: queue.append('old'))
with handle_exit(lambda: queue.append('new'), append=True):
os.kill(os.getpid(), signal.SIGTERM)
self.flag = True
self.assertEqual(queue, ['old', 'new'])
def test_sigint_old(self):
# make sure the old handler gets executed
queue = []
signal.signal(signal.SIGINT, lambda s, f: queue.append('old'))
with handle_exit(lambda: queue.append('new'), append=True):
os.kill(os.getpid(), signal.SIGINT)
self.flag = True
self.assertEqual(queue, ['old', 'new'])
def test_no_append(self):
# make sure we can't use the context manager if there's
# already a handler registered for SIGTERM
signal.signal(signal.SIGTERM, lambda s, f: sys.exit(0))
try:
with handle_exit(lambda: self.flag.append(None)):
pass
except RuntimeError:
pass
else:
self.fail("exception not raised")
finally:
self.flag = True
def test_nested_context(self):
self.flag = True
try:
with handle_exit():
with handle_exit():
pass
except RuntimeError:
pass
else:
self.fail("exception not raised")
unittest.main()
|