diff --git a/luigi/worker.py b/luigi/worker.py index 8a00fdb246..b75595368e 100644 --- a/luigi/worker.py +++ b/luigi/worker.py @@ -124,7 +124,8 @@ class TaskProcess(multiprocessing.Process): } def __init__(self, task, worker_id, result_queue, status_reporter, - use_multiprocessing=False, worker_timeout=0, check_unfulfilled_deps=True): + use_multiprocessing=False, worker_timeout=0, check_unfulfilled_deps=True, + check_complete_on_run=False): super(TaskProcess, self).__init__() self.task = task self.worker_id = worker_id @@ -134,6 +135,7 @@ def __init__(self, task, worker_id, result_queue, status_reporter, self.timeout_time = time.time() + self.worker_timeout if self.worker_timeout else None self.use_multiprocessing = use_multiprocessing or self.timeout_time is not None self.check_unfulfilled_deps = check_unfulfilled_deps + self.check_complete_on_run = check_complete_on_run def _run_get_new_deps(self): task_gen = self.task.run() @@ -186,8 +188,6 @@ def run(self): if _is_external(self.task): # External task - # TODO(erikbern): We should check for task completeness after non-external tasks too! - # This will resolve #814 and make things a lot more consistent if self.task.complete(): status = DONE else: @@ -197,7 +197,13 @@ def run(self): else: with self._forward_attributes(): new_deps = self._run_get_new_deps() - status = DONE if not new_deps else PENDING + if not new_deps: + if not self.check_complete_on_run or self.task.complete(): + status = DONE + else: + raise TaskException("Task finished running, but complete() is still returning false.") + else: + status = PENDING if new_deps: logger.info( @@ -215,15 +221,17 @@ def run(self): raise except BaseException as ex: status = FAILED - logger.exception("[pid %s] Worker %s failed %s", os.getpid(), self.worker_id, self.task) - self.task.trigger_event(Event.FAILURE, self.task, ex) - raw_error_message = self.task.on_failure(ex) - expl = raw_error_message + expl = self._handle_run_exception(ex) finally: self.result_queue.put( (self.task.task_id, status, expl, missing, new_deps)) + def _handle_run_exception(self, ex): + logger.exception("[pid %s] Worker %s failed %s", os.getpid(), self.worker_id, self.task) + self.task.trigger_event(Event.FAILURE, self.task, ex) + return self.task.on_failure(ex) + def _recursive_terminate(self): import psutil @@ -447,6 +455,10 @@ class worker(Config): check_unfulfilled_deps = BoolParameter(default=True, description='If true, check for completeness of ' 'dependencies before running a task') + check_complete_on_run = BoolParameter(default=False, + description='If true, only mark tasks as done after running if they are complete. ' + 'Regardless of this setting, the worker will always check if external ' + 'tasks are complete before marking them as done.') force_multiprocessing = BoolParameter(default=False, description='If true, use multiprocessing also when ' 'running with 1 worker') @@ -1016,6 +1028,7 @@ def _create_task_process(self, task): use_multiprocessing=use_multiprocessing, worker_timeout=self._config.timeout, check_unfulfilled_deps=self._config.check_unfulfilled_deps, + check_complete_on_run=self._config.check_complete_on_run, ) def _purge_children(self): diff --git a/test/helpers.py b/test/helpers.py index c407835e30..64911a06e4 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -138,6 +138,13 @@ def run(self): self.comp = True +# string subclass that matches arguments containing the specified substring +# for use in mock 'called_with' assertions +class StringContaining(str): + def __eq__(self, other_str): + return self in other_str + + class LuigiTestCase(unittest.TestCase): """ Tasks registred within a test case will get unregistered in a finalizer diff --git a/test/worker_task_test.py b/test/worker_task_test.py index c0355fb9e9..5196210296 100644 --- a/test/worker_task_test.py +++ b/test/worker_task_test.py @@ -18,7 +18,7 @@ from subprocess import check_call import sys -from helpers import LuigiTestCase +from helpers import LuigiTestCase, StringContaining import mock from psutil import Process from time import sleep @@ -87,6 +87,25 @@ def on_failure(self, exception): task_process.run() mock_put.assert_called_once_with((task.task_id, FAILED, "test failure expl", [], [])) + def test_fail_on_false_complete(self): + class NeverCompleteTask(luigi.Task): + def complete(self): + return False + + task = NeverCompleteTask() + result_queue = multiprocessing.Queue() + task_process = TaskProcess(task, 1, result_queue, mock.Mock(), check_complete_on_run=True) + + with mock.patch.object(result_queue, 'put') as mock_put: + task_process.run() + mock_put.assert_called_once_with(( + task.task_id, + FAILED, + StringContaining("finished running, but complete() is still returning false"), + [], + None + )) + def test_cleanup_children_on_terminate(self): """ Subprocesses spawned by tasks should be terminated on terminate