Source code for artiq.master.scheduler

import asyncio
import logging
import csv
import os.path
from enum import Enum
from time import time

from sipyco.sync_struct import Notifier
from sipyco.asyncio_tools import TaskObject, Condition

from artiq.master.worker import Worker, log_worker_exception
from artiq.tools import asyncio_wait_or_cancel


logger = logging.getLogger(__name__)


class RunStatus(Enum):
    pending = 0
    flushing = 1
    preparing = 2
    prepare_done = 3
    running = 4
    run_done = 5
    analyzing = 6
    deleting = 7
    paused = 8


def _mk_worker_method(name):
    async def worker_method(self, *args, **kwargs):
        if self.worker.closed.is_set():
            return True
        m = getattr(self.worker, name)
        try:
            return await m(*args, **kwargs)
        except Exception as e:
            if isinstance(e, asyncio.CancelledError):
                raise
            if self.worker.closed.is_set():
                logger.debug("suppressing worker exception of terminated run",
                             exc_info=True)
                # Return completion on termination
                return True
            else:
                raise
    return worker_method


class Run:
    def __init__(self, rid, pipeline_name,
                 wd, expid, priority, due_date, flush,
                 pool, **kwargs):
        # called through pool
        self.rid = rid
        self.pipeline_name = pipeline_name
        self.wd = wd
        self.expid = expid
        self.priority = priority
        self.due_date = due_date
        self.flush = flush

        self.worker = Worker(pool.worker_handlers)
        self.termination_requested = False

        self._status = RunStatus.pending

        notification = {
            "pipeline": self.pipeline_name,
            "expid": self.expid,
            "priority": self.priority,
            "due_date": self.due_date,
            "flush": self.flush,
            "status": self._status.name
        }
        notification.update(kwargs)
        self._notifier = pool.notifier
        self._notifier[self.rid] = notification
        self._state_changed = pool.state_changed

    @property
    def status(self):
        return self._status

    @status.setter
    def status(self, value):
        self._status = value
        if not self.worker.closed.is_set():
            self._notifier[self.rid]["status"] = self._status.name
        self._state_changed.notify()

    def priority_key(self):
        """Return a comparable value that defines a run priority order.

        Applies only to runs the due date of which has already elapsed.
        """
        return (self.priority, -(self.due_date or 0), -self.rid)

    async def close(self):
        # called through pool
        await self.worker.close()
        del self._notifier[self.rid]

    _build = _mk_worker_method("build")

    async def build(self):
        await self._build(self.rid, self.pipeline_name,
                          self.wd, self.expid,
                          self.priority)

    prepare = _mk_worker_method("prepare")
    run = _mk_worker_method("run")
    resume = _mk_worker_method("resume")
    analyze = _mk_worker_method("analyze")


class RunPool:
    def __init__(self, ridc, worker_handlers, notifier, experiment_db, log_submissions):
        self.runs = dict()
        self.state_changed = Condition()

        self.ridc = ridc
        self.worker_handlers = worker_handlers
        self.notifier = notifier
        self.experiment_db = experiment_db
        self.log_submissions = log_submissions

    def log_submission(self, rid, expid):
        start_time = time()
        with open(self.log_submissions, 'a', newline='') as f:
            writer = csv.writer(f)
            writer.writerow([rid, start_time, expid["file"]])

    def submit(self, expid, priority, due_date, flush, pipeline_name):
        """
        Submits an experiment to be run by this pool

        If expid has the attribute `repo_rev`, treat it as a git revision or
        reference and resolve into a unique git hash before submission
        """
        # mutates expid to insert head repository revision if None and
        # replaces relative path with the absolute one.
        # called through scheduler.
        rid = self.ridc.get()
        if "repo_rev" in expid:
            repo_rev_or_ref = expid["repo_rev"] or self.experiment_db.cur_rev
            wd, repo_msg, repo_rev = self.experiment_db.repo_backend.request_rev(repo_rev_or_ref)

            # Mutate expid's repo_rev to that returned from request_rev, in case
            # a branch was passed instead of a hash
            expid["repo_rev"] = repo_rev
        else:
            if "file" in expid:
                expid["file"] = os.path.abspath(expid["file"])
            wd, repo_msg = None, None

        run = Run(rid, pipeline_name, wd, expid, priority, due_date, flush,
                  self, repo_msg=repo_msg)
        if self.log_submissions is not None:
            self.log_submission(rid, expid)
        self.runs[rid] = run
        self.state_changed.notify()
        return rid

    async def delete(self, rid):
        # called through deleter
        if rid not in self.runs:
            return
        run = self.runs[rid]
        await run.close()
        if "repo_rev" in run.expid:
            self.experiment_db.repo_backend.release_rev(run.expid["repo_rev"])
        del self.runs[rid]


class PrepareStage(TaskObject):
    def __init__(self, pool, delete_cb):
        self.pool = pool
        self.delete_cb = delete_cb

    def _get_run(self):
        """If a run should get prepared now, return it. Otherwise, return a
        float giving the time until the next check, or None if no time-based
        check is required.

        The latter can be the case if there are no due-date runs, or none
        of them are going to become next-in-line before further pool state
        changes (which will also cause a re-evaluation).
        """
        pending_runs = list(
            filter(lambda r: r.status == RunStatus.pending,
                   self.pool.runs.values()))

        now = time()
        def is_runnable(r):
            return (r.due_date or 0) < now

        prepared_max = max((r.priority_key() for r in self.pool.runs.values()
                            if r.status == RunStatus.prepare_done),
                           default=None)
        def takes_precedence(r):
            return prepared_max is None or r.priority_key() > prepared_max

        candidate = max(filter(is_runnable, pending_runs),
                        key=lambda r: r.priority_key(),
                        default=None)
        if candidate is not None and takes_precedence(candidate):
            return candidate

        return min((r.due_date - now for r in pending_runs
                    if (not is_runnable(r) and takes_precedence(r))),
                   default=None)

    async def _do(self):
        while True:
            run = self._get_run()
            if run is None:
                await self.pool.state_changed.wait()
            elif isinstance(run, float):
                await asyncio_wait_or_cancel([self.pool.state_changed.wait()],
                                             timeout=run)
            else:
                if run.flush:
                    run.status = RunStatus.flushing
                    while not all(r.status in (RunStatus.pending,
                                               RunStatus.deleting)
                                  or r.priority < run.priority
                                  or r is run
                                  for r in self.pool.runs.values()):
                        ev = [self.pool.state_changed.wait(),
                              run.worker.closed.wait()]
                        await asyncio_wait_or_cancel(
                            ev, return_when=asyncio.FIRST_COMPLETED)
                        if run.worker.closed.is_set():
                            break
                    if run.worker.closed.is_set():
                        continue
                run.status = RunStatus.preparing
                try:
                    await run.build()
                    await run.prepare()
                except Exception:
                    logger.error("got worker exception in prepare stage, "
                                 "deleting RID %d", run.rid)
                    log_worker_exception()
                    self.delete_cb(run.rid)
                else:
                    run.status = RunStatus.prepare_done


class RunStage(TaskObject):
    def __init__(self, pool, delete_cb):
        self.pool = pool
        self.delete_cb = delete_cb

    def _get_run(self):
        prepared_runs = filter(lambda r: r.status == RunStatus.prepare_done,
                               self.pool.runs.values())
        try:
            r = max(prepared_runs, key=lambda r: r.priority_key())
        except ValueError:
            # prepared_runs is an empty sequence
            r = None
        return r

    async def _do(self):
        stack = []

        while True:
            next_irun = self._get_run()
            if not stack or (
                    next_irun is not None and
                    next_irun.priority_key() > stack[-1].priority_key()):
                while next_irun is None:
                    await self.pool.state_changed.wait()
                    next_irun = self._get_run()
                stack.append(next_irun)

            run = stack.pop()
            try:
                if run.status == RunStatus.paused:
                    run.status = RunStatus.running
                    # clear "termination requested" flag now
                    # so that if it is set again during the resume, this
                    # results in another exception.
                    request_termination = run.termination_requested
                    run.termination_requested = False
                    completed = await run.resume(request_termination)
                else:
                    run.status = RunStatus.running
                    completed = await run.run()
            except Exception:
                logger.error("got worker exception in run stage, "
                             "deleting RID %d", run.rid)
                log_worker_exception()
                self.delete_cb(run.rid)
            else:
                if completed:
                    run.status = RunStatus.run_done
                else:
                    run.status = RunStatus.paused
                    stack.append(run)


class AnalyzeStage(TaskObject):
    def __init__(self, pool, delete_cb):
        self.pool = pool
        self.delete_cb = delete_cb

    def _get_run(self):
        run_runs = filter(lambda r: r.status == RunStatus.run_done,
                          self.pool.runs.values())
        try:
            r = max(run_runs, key=lambda r: r.priority_key())
        except ValueError:
            # run_runs is an empty sequence
            r = None
        return r

    async def _do(self):
        while True:
            run = self._get_run()
            while run is None:
                await self.pool.state_changed.wait()
                run = self._get_run()
            run.status = RunStatus.analyzing
            try:
                await run.analyze()
            except Exception:
                logger.error("got worker exception in analyze stage of RID %d.",
                             run.rid)
                log_worker_exception()
            self.delete_cb(run.rid)


class Pipeline:
    def __init__(self, ridc, deleter, worker_handlers, notifier, experiment_db, log_submissions):
        self.pool = RunPool(ridc, worker_handlers, notifier, experiment_db, log_submissions)
        self._prepare = PrepareStage(self.pool, deleter.delete)
        self._run = RunStage(self.pool, deleter.delete)
        self._analyze = AnalyzeStage(self.pool, deleter.delete)

    def start(self, *, loop=None):
        self._prepare.start(loop=loop)
        self._run.start(loop=loop)
        self._analyze.start(loop=loop)

    async def stop(self):
        # NB: restart of a stopped pipeline is not supported
        await self._analyze.stop()
        await self._run.stop()
        await self._prepare.stop()


class Deleter(TaskObject):
    """Provides a synchronous interface for instigating deletion of runs.

    :meth:`RunPool.delete` is an async function (it needs to close the worker
    connection, etc.), so we maintain a queue of RIDs to delete on a background task.
    """
    def __init__(self, pipelines):
        self._pipelines = pipelines
        self._queue = asyncio.Queue()

    def delete(self, rid):
        """Delete the run with the given RID.

        Multiple calls for the same RID are silently ignored.
        """
        logger.debug("delete request for RID %d", rid)
        for pipeline in self._pipelines.values():
            if rid in pipeline.pool.runs:
                pipeline.pool.runs[rid].status = RunStatus.deleting
                break
        self._queue.put_nowait(rid)

    async def join(self):
        await self._queue.join()

    async def _delete(self, rid):
        # By looking up the run by RID, we implicitly make sure to delete each run only
        # once.
        for pipeline in self._pipelines.values():
            if rid in pipeline.pool.runs:
                logger.debug("deleting RID %d...", rid)
                await pipeline.pool.delete(rid)
                logger.debug("deletion of RID %d completed", rid)
                break

    async def _gc_pipelines(self):
        pipeline_names = list(self._pipelines.keys())
        for name in pipeline_names:
            if not self._pipelines[name].pool.runs:
                logger.debug("garbage-collecting pipeline '%s'...", name)
                await self._pipelines[name].stop()
                del self._pipelines[name]
                logger.debug("garbage-collection of pipeline '%s' completed",
                             name)

    async def _do(self):
        while True:
            rid = await self._queue.get()
            await self._delete(rid)
            await self._gc_pipelines()
            self._queue.task_done()


[docs] class Scheduler: def __init__(self, ridc, worker_handlers, experiment_db, log_submissions): self.notifier = Notifier(dict()) self._pipelines = dict() self._worker_handlers = worker_handlers self._experiment_db = experiment_db self._terminated = False self._ridc = ridc self._deleter = Deleter(self._pipelines) self._log_submissions = log_submissions def start(self, *, loop=None): self._loop = loop self._deleter.start(loop=self._loop) async def stop(self): # NB: restart of a stopped scheduler is not supported self._terminated = True # prevent further runs from being created for pipeline in self._pipelines.values(): for rid in pipeline.pool.runs.keys(): self._deleter.delete(rid) await self._deleter.join() await self._deleter.stop() if self._pipelines: logger.warning("some pipelines were not garbage-collected")
[docs] def submit(self, pipeline_name, expid, priority=0, due_date=None, flush=False): """Submits a new run. When called through an experiment, the default values of ``pipeline_name``, ``expid`` and ``priority`` correspond to those of the current run.""" # mutates expid to insert head repository revision if None, and # replaces relative file path with absolute one if self._terminated: return try: pipeline = self._pipelines[pipeline_name] except KeyError: logger.debug("creating pipeline '%s'", pipeline_name) pipeline = Pipeline(self._ridc, self._deleter, self._worker_handlers, self.notifier, self._experiment_db, self._log_submissions) self._pipelines[pipeline_name] = pipeline pipeline.start(loop=self._loop) return pipeline.pool.submit(expid, priority, due_date, flush, pipeline_name)
[docs] def delete(self, rid): """Kills the run with the specified RID.""" self._deleter.delete(rid)
[docs] def request_termination(self, rid): """Requests graceful termination of the run with the specified RID.""" for pipeline in self._pipelines.values(): if rid in pipeline.pool.runs: run = pipeline.pool.runs[rid] if run.status == RunStatus.running or run.status == RunStatus.paused: run.termination_requested = True else: self.delete(rid) break
[docs] def get_status(self): """Returns a dictionary containing information about the runs currently tracked by the scheduler. Must not be modified.""" return self.notifier.raw_view
[docs] def check_pause(self, rid): """Returns ``True`` if there is a condition that could make :meth:`pause` not return immediately (termination requested or higher priority run). The typical purpose of this function is to check from a kernel whether returning control to the host and pausing would have an effect, in order to avoid the cost of switching kernels in the common case where :meth:`pause` does nothing. This function does not have side effects, and does not have to be followed by a call to :meth:`pause`. """ for pipeline in self._pipelines.values(): if rid in pipeline.pool.runs: run = pipeline.pool.runs[rid] if run.status != RunStatus.running: return False if run.termination_requested: return True prepared_runs = filter(lambda r: r.status == RunStatus.prepare_done, pipeline.pool.runs.values()) try: r = max(prepared_runs, key=lambda r: r.priority_key()) except ValueError: # prepared_runs is an empty sequence return False return r.priority_key() > run.priority_key() raise KeyError("RID not found")
[docs] def check_termination(self, rid): """Returns ``True`` if termination is requested.""" for pipeline in self._pipelines.values(): if rid in pipeline.pool.runs: run = pipeline.pool.runs[rid] if run.termination_requested: return True return False