8000 Merging main into callbacks and fixing merge conflict. by Lucaweihs · Pull Request #347 · allenai/allenact · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Merging main into callbacks and fixing merge conflict. #347

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Jul 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ jobs:
python -m pip install --editable="./allenact_plugins[all]"
python -m pip install pip install -r allenact_plugins/babyai_plugin/extra_requirements.txt # Required as babyai is not on PyPI
python -m pip install compress_pickle # Needed for some mapping tests
python -m pip install -U protobuf==3.20.1 # Required until tensorboardX is fixed: https://github.com/lanpa/tensorboardX/issues/663
pip list

- name: Test with pytest
Expand Down
68 changes: 61 additions & 7 deletions allenact/algorithms/onpolicy_sync/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,20 @@
from allenact.utils.tensor_utils import batch_observations, detach_recursively
from allenact.utils.viz_utils import VizSuite

try:
# When debugging we don't want to timeout in the VectorSampledTasks

# noinspection PyPackageRequirements
import pydevd

DEBUGGING = True
except ImportError:
DEBUGGING = False

DEBUG_VST_TIMEOUT: Optional[int] = (lambda x: int(x) if x is not None else x)(
os.getenv("ALLENACT_DEBUG_VST_TIMEOUT", None)
)

TRAIN_MODE_STR = "train"
VALID_MODE_STR = "valid"
TEST_MODE_STR = "test"
Expand Down Expand Up @@ -102,6 +116,7 @@ def __init__(
deterministic_agents: bool = False,
max_sampler_processes_per_worker: Optional[int] = None,
initial_model_state_dict: Optional[Union[Dict[str, Any], int]] = None,
try_restart_after_task_timeout: bool = False,
**kwargs,
):
"""Initializer.
Expand Down Expand Up @@ -129,6 +144,7 @@ def __init__(
self.device = torch.device("cpu") if device == -1 else torch.device(device) # type: ignore
self.distributed_ip = distributed_ip
self.distributed_port = distributed_port
self.try_restart_after_task_timeout = try_restart_after_task_timeout

self.mode = mode.lower().strip()
assert self.mode in [
Expand Down Expand Up @@ -235,7 +251,7 @@ def __init__(
# During testing, we sometimes found that default timeout was too short
# resulting in the run terminating surprisingly, we increase it here.
timeout=datetime.timedelta(minutes=3000)
if self.mode == TEST_MODE_STR
if (self.mode == TEST_MODE_STR or DEBUGGING)
else dist.default_pg_timeout,
)
self.is_distributed = True
Expand Down Expand Up @@ -290,6 +306,7 @@ def vector_tasks(
else None,
mp_ctx=self.mp_ctx,
max_processes=self.max_sampler_processes_per_worker,
read_timeout=DEBUG_VST_TIMEOUT if DEBUGGING else 5 * 60,
)
return self._vector_tasks

Expand Down Expand Up @@ -1421,11 +1438,48 @@ def run_pipeline(self, valid_on_initial_weights: bool = False):
for k, v in self.training_pipeline.current_stage_storage.items()
}

for step in range(cur_stage_training_settings.num_steps):
num_paused = self.collect_step_across_all_task_samplers(
rollout_storage_uuid=self.training_pipeline.rollout_storage_uuid,
uuid_to_storage=uuid_to_storage,
)
vector_tasks_already_restarted = False
step = -1
while step < cur_stage_training_settings.num_steps - 1:
step += 1

try:
num_paused = self.collect_step_across_all_task_samplers(
rollout_storage_uuid=self.training_pipeline.rollout_storage_uuid,
uuid_to_storage=uuid_to_storage,
)
except TimeoutError:
if (
not self.try_restart_after_task_timeout
) or self.mode != TRAIN_MODE_STR:
# Apparently you can just call `raise` here and doing so will just raise the exception as though
# it was not caught (so the stacktrace isn't messed up)
raise
elif vector_tasks_already_restarted:
raise RuntimeError(
f"[{self.mode} worker {self.worker_id}] `vector_tasks` has timed out twice in the same"
f" rollout. This suggests that this error was not recoverable. Timeout exception:\n{traceback.format_exc()}"
)
else:
get_logger().warning(
f"[{self.mode} worker {self.worker_id}] `vector_tasks` appears to have crashed during"
f" training as it has timed out. You have set `try_restart_after_task_timeout` to `True` so"
f" we will attempt to restart these tasks from the beginning. USE THIS FEATURE AT YOUR OWN"
f" RISK. Timeout exception:\n{traceback.format_exc()}."
)
self.vector_tasks.close()
self._vector_tasks = None

vector_tasks_already_restarted = True
for (
storage
) in self.training_pipeline.current_stage_storage.values():
storage.after_updates()
self.initialize_storage_and_viz(
storage_to_initialize=list(uuid_to_storage.values())
)
step = -1
continue

# A more informative error message should already have been thrown in be given in
# `collect_step_across_all_task_samplers` if `num_paused != 0` here but this serves
Expand Down Expand Up @@ -1605,7 +1659,7 @@ def train(
get_logger().error(
f"[{self.mode} worker {self.worker_id}] Encountered {type(e).__name__}, exiting."
)
get_logger().exception(traceback.format_exc())
get_logger().error(traceback.format_exc())
finally:
if training_completed_successfully:
if self.worker_id == 0:
Expand Down
13 changes: 11 additions & 2 deletions allenact/algorithms/onpolicy_sync/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,9 +320,9 @@ def handler(_signo, _frame):
except Exception:
get_logger().error(
f"Error occurred when closing the RL engine used by work {mode}-{id}."
f" We cannot recover from this and will simply exit. The exception:"
f" We cannot recover from this and will simply exit. The exception:\n"
f"{traceback.format_exc()}"
)
get_logger().exception(traceback.format_exc())
sys.exit(1)
sys.exit(0)
else:
Expand Down Expand Up @@ -472,6 +472,15 @@ def start_train(

distributed_port = 0 if num_workers == 1 else self.get_port()

if (
num_workers > 1
and "NCCL_ASYNC_ERROR_HANDLING" not in os.environ
and "NCCL_BLOCKING_WAIT" not in os.environ
):
# This ensures the NCCL distributed backend will throw errors
# if we timeout at a call to `barrier()`
os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "1"

worker_ids = self.local_worker_ids(TRAIN_MODE_STR)

model_hash = None
Expand Down
9 changes: 6 additions & 3 deletions allenact/algorithms/onpolicy_sync/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def initialize(
self.full_size + 1, num_samplers, action_flat_dim, device=self.device
)

assert self.step == 0, "Must call `after_update` before calling `initialize`"
assert self.step == 0, "Must call `after_updates` before calling `initialize`"
self.insert_observations(observations=observations, time_step=0)
self.prev_actions[0].zero_() # Have to zero previous actions
self.masks[0].zero_() # Have to zero masks
Expand Down Expand Up @@ -529,8 +529,11 @@ def after_updates(self, **kwargs):
for key in storage:
storage[key][0][0].copy_(storage[key][0][-1])

self.masks[0].copy_(self.masks[-1])
self.prev_actions[0].copy_(self.prev_actions[-1])
if self._masks_full is not None:
self.masks[0].copy_(self.masks[-1])

if self._prev_actions_full is not None:
self.prev_actions[0].copy_(self.prev_actions[-1])

self._before_update_called = False
self._advantages = None
Expand Down
85 changes: 59 additions & 26 deletions allenact/algorithms/onpolicy_sync/vector_sampled_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ class VectorSampledTasks:
_mp_ctx: BaseContext
_connection_read_fns: List[Callable[[], Any]]
_connection_write_fns: List[Callable[[Any], None]]
_read_timeout: Optional[float]

def __init__(
self,
Expand All @@ -154,12 +155,16 @@ def __init__(
mp_ctx: Optional[BaseContext] = None,
should_log: bool = True,
max_processes: Optional[int] = None,
read_timeout: Optional[
float
] = 60, # Seconds to wait for a task to return a response before timing out
) -> None:

self._is_waiting = False
self._is_closed = True
self.should_log = should_log
self.max_processes = max_processes
self.read_timeout = read_timeout

assert (
sampler_fn_args is not None and len(sampler_fn_args) > 0
Expand Down Expand Up @@ -195,7 +200,8 @@ def __init__(
for args in sampler_fn_args:
args["mp_ctx"] = self._mp_ctx
(
self._connection_read_fns,
connection_poll_fns,
connection_read_fns,
self._connection_write_fns,
) = self._spawn_workers( # noqa
make_sampler_fn=make_sampler_fn,
Expand All @@ -204,6 +210,13 @@ def __init__(
],
)

self._connection_read_fns = [
self._create_read_function_with_timeout(
read_fn=read_fn, poll_fn=poll_fn, timeout=self.read_timeout
)
for read_fn, poll_fn in zip(connection_read_fns, connection_poll_fns)
]

self._is_closed = False

for write_fn in self._connection_write_fns:
Expand Down Expand Up @@ -234,6 +247,25 @@ def __init__(
space for read_fn in self._connection_read_fns for space in read_fn()
]

@staticmethod
def _create_read_function_with_timeout(
*,
read_fn: Callable[[], Any],
poll_fn: Callable[[float], bool],
timeout: Optional[float],
) -> Callable[[], Any]:
def read_with_timeout(timeout_to_use: Optional[float] = timeout):
if timeout_to_use is not None:
# noinspection PyArgumentList
if not poll_fn(timeout=timeout_to_use):
raise TimeoutError(
f"Did not recieve output from `VectorSampledTask` worker for {timeout_to_use} seconds."
)

return read_fn()

return read_with_timeout

def _reset_sampler_index_to_process_ind_and_subprocess_ind(self):
self.sampler_index_to_process_ind_and_subprocess_ind = [
[i, j]
Expand Down Expand Up @@ -297,7 +329,7 @@ def _task_sampling_loop_worker(
"""process worker for creating and interacting with the
Tasks/TaskSampler."""

ptitle("VectorSampledTask: {}".format(worker_id))
ptitle(f"VectorSampledTask: {worker_id}")

sp_vector_sampled_tasks = SingleProcessVectorSampledTasks(
make_sampler_fn=make_sampler_fn,
Expand All @@ -307,7 +339,7 @@ def _task_sampling_loop_worker(
)

if parent_pipe is not None:
parent_pipe.close()
parent_pipe.close() # Means this pipe will close when the calling process closes it
try:
while True:
read_input = connection_read_fn()
Expand Down Expand Up @@ -368,7 +400,9 @@ def _task_sampling_loop_worker(
if should_log:
get_logger().info(f"Worker {worker_id} KeyboardInterrupt")
except Exception as e:
get_logger().error(traceback.format_exc())
get_logger().error(
f"Worker {worker_id} encountered an exception:\n{traceback.format_exc()}"
)
raise e
finally:
if child_pipe is not None:
Expand All @@ -380,52 +414,50 @@ def _spawn_workers(
self,
make_sampler_fn: Callable[..., TaskSampler],
sampler_fn_args_list: Sequence[Sequence[Dict[str, Any]]],
) -> Tuple[List[Callable[[], Any]], List[Callable[[Any], None]]]:
) -> Tuple[
List[Callable[[], bool]], List[Callable[[], Any]], List[Callable[[Any], None]]
]:
parent_connections, worker_connections = zip(
*[self._mp_ctx.Pipe(duplex=True) for _ in range(self._num_processes)]
)
self._workers = []
k = 0
id: Union[int, str]
for id, stuff in enumerate(
for id, (worker_conn, parent_conn, current_sampler_fn_args_list) in enumerate(
zip(w F471 orker_connections, parent_connections, sampler_fn_args_list)
):
worker_conn, parent_conn, current_sampler_fn_args_list = stuff # type: ignore

if len(current_sampler_fn_args_list) != 1:
id = "{}({}-{})".format(
id, k, k + len(current_sampler_fn_args_list) - 1
)
id = f"{id}({k}-{k + len(current_sampler_fn_args_list) - 1})"
k += len(current_sampler_fn_args_list)

if self.should_log:
get_logger().info(
"Starting {}-th VectorSampledTask worker with args {}".format(
id, current_sampler_fn_args_list
)
f"Starting {id}-th VectorSampledTask worker with args {current_sampler_fn_args_list}"
)

ps = self._mp_ctx.Process( # type: ignore
target=self._task_sampling_loop_worker,
args=(
id,
worker_conn.recv,
worker_conn.send,
make_sampler_fn,
current_sampler_fn_args_list,
self._auto_resample_when_done,
self.should_log,
worker_conn,
parent_conn,
kwargs=dict(
worker_id=id,
connection_read_fn=worker_conn.recv,
connection_write_fn=worker_conn.send,
make_sampler_fn=make_sampler_fn,
sampler_fn_args_list=current_sampler_fn_args_list,
auto_resample_when_done=self._auto_resample_when_done,
should_log=self.should_log,
child_pipe=worker_conn,
parent_pipe=parent_conn,
),
)
self._workers.append(ps)
ps.daemon = True
ps.start()
worker_conn.close()
worker_conn.close() # Means this pipe will close when the child process closes it
time.sleep(
0.1
) # Useful to ensure things don't lock up when spawning many envs
return (
[p.poll for p in parent_connections],
[p.recv for p in parent_connections],
[p.send for p in parent_connections],
)
Expand Down Expand Up @@ -593,7 +625,8 @@ def close(self) -> None:
if self._is_waiting:
for read_fn in self._connection_read_fns:
try:
read_fn()
# noinspection PyArgumentList
read_fn(0) # Time out immediately
except Exception:
pass

Expand Down
Loading
0