8000 update loss utilities to take stage by H-Huang · Pull Request #1077 · pytorch/PiPPy · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

update loss utilities to take stage #1077

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 22, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 18 additions & 13 deletions pippy/PipelineSchedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ def _maybe_compute_loss(self, stage, output, target_mbs, mb_index):
f"[{stage.stage_index}] Loss of microbatch {mb_index}: {loss}"
)

def _maybe_get_loss(self, mb_index):
def _maybe_get_loss(self, stage, mb_index):
valid_index = 0 <= mb_index < len(self._internal_losses)
if self._has_backward and valid_index:
if stage.is_last and self._has_backward and valid_index:
return self._internal_losses[mb_index]
elif len(self._internal_losses) != 0 and not valid_index:
raise RuntimeError(
Expand All @@ -56,12 +56,17 @@ def _maybe_get_loss(sel 8000 f, mb_index):
else:
return None

def _update_losses(self, losses):
def _update_losses(self, stages, losses):
"""
Update the losses to those in the internal state
"""
# if stages not a list turn into a list
if not isinstance(stages, list):
stages = [stages]
contains_last_stage = any([stage.is_last for stage in stages])

Comment on lines +63 to +67
Copy link
Contributor
@kwen2501 kwen2501 Apr 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because of this, maybe consider having two versions of this utility? One for single-stage schedules, the other for multi.

# Return losses if there is a container passed in
if losses is not None:
if contains_last_stage and losses is not None:
if len(self._internal_losses) != self._n_microbatches:
raise RuntimeError(
f"Expecting {self._n_microbatches} losses but got {len(self._internal_losses)}"
Expand Down Expand Up @@ -330,7 +335,7 @@ def step_microbatches(
for work in works.values():
work.wait()

loss = self._maybe_get_loss(i)
loss = self._maybe_get_loss(self._stage, i)
self._stage.backward_one_chunk(loss=loss)

ops = self._stage.get_bwd_send_ops()
Expand All @@ -342,7 +347,7 @@ def step_microbatches(
)

# Return losses if there is a container passed in
self._update_losses(losses)
self._update_losses(self._stage, losses)

# Wait for all backward sends to finish
for work in bwd_sends_to_wait:
Expand Down Expand Up @@ -423,7 +428,7 @@ def step_microbatches(
for work in works.values():
work.wait()

loss = self._maybe_get_loss(bwd_mb_index)
loss = self._maybe_get_loss(self._stage, bwd_mb_index)
self._stage.backward_one_chunk(loss=loss)

ops = self._stage.get_bwd_send_ops()
Expand All @@ -440,7 +445,7 @@ def step_microbatches(
work.wait()

# Return losses if there is a container passed in
self._update_losses(losses)
self._update_losses(self._stage, losses)


class PipelineScheduleMulti(PipelineSchedule):
Expand Down Expand Up @@ -553,14 +558,14 @@ def step_microbatches(
if ops:
dist.batch_isend_irecv(ops).pop().wait()

loss = self._maybe_get_loss(i)
loss = self._maybe_get_loss(stage, i)
stage.backward_one_chunk(loss=loss)

ops = stage.get_bwd_send_ops()
if ops:
dist.batch_isend_irecv(ops)

self._update_losses(losses)
self._update_losses(self._stages, losses)


class ScheduleInterleaved1F1B(PipelineScheduleMulti):
Expand Down Expand Up @@ -739,7 +744,7 @@ def backward_stage_local_index(step):
)

# bwd
loss = self._maybe_get_loss(bwd_mb_index)
loss = self._maybe_get_loss(bwd_stage, bwd_mb_index)
bwd_stage.backward_one_chunk(loss=loss)
ops.extend(bwd_stage.get_bwd_send_ops())

Expand All @@ -764,7 +769,7 @@ def backward_stage_local_index(step):
for work in works.values():
work.wait()

loss = self._maybe_get_loss(bwd_mb_index)
loss = self._maybe_get_loss(bwd_stage, bwd_mb_index)
bwd_stage.backward_one_chunk(loss=loss)

ops = bwd_stage.get_bwd_send_ops()
Expand All @@ -776,4 +781,4 @@ def backward_stage_local_index(step):
work.wait()

# Return losses if there is a container passed in
self._update_losses(losses)
self._update_losses(self._stages, losses)
Loading
319F
0