taskpool.py (9013B)
1 # This Source Code Form is subject to the terms of the Mozilla Public 2 # License, v. 2.0. If a copy of the MPL was not distributed with this file, 3 # You can obtain one at http://mozilla.org/MPL/2.0/. 4 5 # flake8: noqa: F821 6 7 import fcntl 8 import os 9 import select 10 import time 11 from subprocess import PIPE, Popen 12 13 14 class TaskPool: 15 # Run a series of subprocesses. Try to keep up to a certain number going in 16 # parallel at any given time. Enforce time limits. 17 # 18 # This is implemented using non-blocking I/O, and so is Unix-specific. 19 # 20 # We assume that, if a task closes its standard error, then it's safe to 21 # wait for it to terminate. So an ill-behaved task that closes its standard 22 # output and then hangs will hang us, as well. However, as it takes special 23 # effort to close one's standard output, this seems unlikely to be a 24 # problem in practice. 25 26 # A task we should run in a subprocess. Users should subclass this and 27 # fill in the methods as given. 28 class Task: 29 def __init__(self): 30 self.pipe = None 31 self.start_time = None 32 33 # Record that this task is running, with |pipe| as its Popen object, 34 # and should time out at |deadline|. 35 def start(self, pipe, deadline): 36 self.pipe = pipe 37 self.deadline = deadline 38 39 # Return a shell command (a string or sequence of arguments) to be 40 # passed to Popen to run the task. The command will be given 41 # /dev/null as its standard input, and pipes as its standard output 42 # and error. 43 def cmd(self): 44 raise NotImplementedError 45 46 # TaskPool calls this method to report that the process wrote 47 # |string| to its standard output. 48 def onStdout(self, string): 49 raise NotImplementedError 50 51 # TaskPool calls this method to report that the process wrote 52 # |string| to its standard error. 53 def onStderr(self, string): 54 raise NotImplementedError 55 56 # TaskPool calls this method to report that the process terminated, 57 # yielding |returncode|. 58 def onFinished(self, returncode): 59 raise NotImplementedError 60 61 # TaskPool calls this method to report that the process timed out and 62 # was killed. 63 def onTimeout(self): 64 raise NotImplementedError 65 66 # If a task output handler (onStdout, onStderr) throws this, we terminate 67 # the task. 68 class TerminateTask(Exception): 69 pass 70 71 def __init__(self, tasks, cwd=".", job_limit=4, timeout=150): 72 self.pending = iter(tasks) 73 self.cwd = cwd 74 self.job_limit = job_limit 75 self.timeout = timeout 76 self.next_pending = next(self.pending, None) 77 78 def run_all(self): 79 # The currently running tasks: a set of Task instances. 80 running = set() 81 with open(os.devnull) as devnull: 82 while True: 83 while len(running) < self.job_limit and self.next_pending: 84 task = self.next_pending 85 p = Popen( 86 task.cmd(), 87 bufsize=16384, 88 stdin=devnull, 89 stdout=PIPE, 90 stderr=PIPE, 91 cwd=self.cwd, 92 ) 93 94 # Put the stdout and stderr pipes in non-blocking mode. See 95 # the post-'select' code below for details. 96 flags = fcntl.fcntl(p.stdout, fcntl.F_GETFL) 97 fcntl.fcntl(p.stdout, fcntl.F_SETFL, flags | os.O_NONBLOCK) 98 flags = fcntl.fcntl(p.stderr, fcntl.F_GETFL) 99 fcntl.fcntl(p.stderr, fcntl.F_SETFL, flags | os.O_NONBLOCK) 100 101 task.start(p, time.time() + self.timeout) 102 running.add(task) 103 self.next_pending = next(self.pending, None) 104 105 # If we have no tasks running, and the above wasn't able to 106 # start any new ones, then we must be done! 107 if not running: 108 break 109 110 # How many seconds do we have until the earliest deadline? 111 now = time.time() 112 secs_to_next_deadline = max(min([t.deadline for t in running]) - now, 0) 113 114 # Wait for output or a timeout. 115 stdouts_and_stderrs = [t.pipe.stdout for t in running] + [ 116 t.pipe.stderr for t in running 117 ] 118 (readable, w, x) = select.select( 119 stdouts_and_stderrs, [], [], secs_to_next_deadline 120 ) 121 finished = set() 122 terminate = set() 123 for t in running: 124 # Since we've placed the pipes in non-blocking mode, these 125 # 'read's will simply return as many bytes as are available, 126 # rather than blocking until they have accumulated the full 127 # amount requested (or reached EOF). The 'read's should 128 # never throw, since 'select' has told us there was 129 # something available. 130 if t.pipe.stdout in readable: 131 output = t.pipe.stdout.read(16384) 132 if len(output): 133 try: 134 t.onStdout(output.decode("utf-8")) 135 except TerminateTask: 136 terminate.add(t) 137 if t.pipe.stderr in readable: 138 output = t.pipe.stderr.read(16384) 139 if len(output): 140 try: 141 t.onStderr(output.decode("utf-8")) 142 except TerminateTask: 143 terminate.add(t) 144 else: 145 # We assume that, once a task has closed its stderr, 146 # it will soon terminate. If a task closes its 147 # stderr and then hangs, we'll hang too, here. 148 t.pipe.wait() 149 t.onFinished(t.pipe.returncode) 150 finished.add(t) 151 # Remove the finished tasks from the running set. (Do this here 152 # to avoid mutating the set while iterating over it.) 153 running -= finished 154 155 # Terminate any tasks whose handlers have asked us to do so. 156 for t in terminate: 157 t.pipe.terminate() 158 t.pipe.wait() 159 running.remove(t) 160 161 # Terminate any tasks which have missed their deadline. 162 finished = set() 163 for t in running: 164 if now >= t.deadline: 165 t.pipe.terminate() 166 t.pipe.wait() 167 t.onTimeout() 168 finished.add(t) 169 # Remove the finished tasks from the running set. (Do this here 170 # to avoid mutating the set while iterating over it.) 171 running -= finished 172 173 174 def get_cpu_count(): 175 """ 176 Guess at a reasonable parallelism count to set as the default for the 177 current machine and run. 178 """ 179 # Python 2.6+ 180 try: 181 import multiprocessing 182 183 return multiprocessing.cpu_count() 184 except (ImportError, NotImplementedError): 185 pass 186 187 # POSIX 188 try: 189 res = int(os.sysconf("SC_NPROCESSORS_ONLN")) 190 if res > 0: 191 return res 192 except (AttributeError, ValueError): 193 pass 194 195 # Windows 196 try: 197 res = int(os.environ["NUMBER_OF_PROCESSORS"]) 198 if res > 0: 199 return res 200 except (KeyError, ValueError): 201 pass 202 203 return 1 204 205 206 if __name__ == "__main__": 207 # Test TaskPool by using it to implement the unique 'sleep sort' algorithm. 208 def sleep_sort(ns, timeout): 209 sorted = [] 210 211 class SortableTask(TaskPool.Task): 212 def __init__(self, n): 213 super().__init__() 214 self.n = n 215 216 def start(self, pipe, deadline): 217 super().start(pipe, deadline) 218 219 def cmd(self): 220 return ["sh", "-c", "echo out; sleep %d; echo err>&2" % (self.n,)] 221 222 def onStdout(self, text): 223 print("%d stdout: %r" % (self.n, text)) 224 225 def onStderr(self, text): 226 print("%d stderr: %r" % (self.n, text)) 227 228 def onFinished(self, returncode): 229 print("%d (rc=%d)" % (self.n, returncode)) 230 sorted.append(self.n) 231 232 def onTimeout(self): 233 print("%d timed out" % (self.n,)) 234 235 p = TaskPool([SortableTask(_) for _ in ns], job_limit=len(ns), timeout=timeout) 236 p.run_all() 237 return sorted 238 239 print(repr(sleep_sort([1, 1, 2, 3, 5, 8, 13, 21, 34], 15)))