8000 [llm]add bf16 moment adamw by lugimzzz · Pull Request #9732 · PaddlePaddle/PaddleNLP · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

[llm]add bf16 moment adamw #9732

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Mar 3, 2025
Merged

[llm]add bf16 moment adamw #9732

merged 10 commits into from
Mar 3, 2025

Conversation

lugimzzz
Copy link
Contributor
@lugimzzz lugimzzz commented Jan 2, 2025

PR types

New features

PR changes

Others

Description

使用只需要增加 --optim adamw_16bit_moment

8000
Copy link
codecov bot commented Jan 2, 2025

Codecov Report

Attention: Patch coverage is 11.64384% with 129 lines in your changes missing coverage. Please review.

Project coverage is 51.03%. Comparing base (8e4ff07) to head (b2761bc).
Report is 367 commits behind head on develop.

Files with missing lines Patch % Lines
paddlenlp/utils/optimizer.py 9.35% 126 Missing ⚠️
paddlenlp/trainer/trainer.py 40.00% 3 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##           develop    #9732      +/-   ##
===========================================
- Coverage    51.41%   51.03%   -0.39%     
===========================================
  Files          745      745              
  Lines       118351   119410    +1059     
===========================================
+ Hits         60856    60939      +83     
- Misses       57495    58471     +976     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
paddle-bot bot commented Jan 6, 2025

Thanks for your contribution!

@lugimzzz lugimzzz changed the title [llm]add adam [llm]add bf16 adamw Feb 20, 2025
@lugimzzz lugimzzz changed the title [llm]add bf16 adamw [llm]add bf16 moment adamw Feb 20, 2025
@@ -0,0 +1,15 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2024 -> 2025

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

# Update param
if master_weight_ptr is not None:
tl.store(master_weight_ptr + offsets, param, mask=mask)
tl.store(param_ptr + offsets, param.to(tl.bfloat16), mask=mask)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里设计的需要考虑下optimizer原始参数的dtype,是否考虑float16的场景,开源模型部分模型是float16

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

现在考虑了

BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里offsets为啥是arrange的方式

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我理解是要读取[0, BLOCK_SIZE 8000 ]所有tensor进行操作

@@ -149,3 +154,227 @@ def adamw_python(
beta1_pow[:], beta2_pow[:] = beta1 * beta1_pow[:], beta2 * beta2_pow[:]
# 看看怎么更新
return


class AdamWPython(AdamW):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个名字是不是有点奇怪,是不是朴素实现,或者 AdamWSlow之类的更合适点

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

改成AdamWCustom

type = core.VarDesc.VarType.DENSE_TENSOR
except:
type = core.VarDesc.VarType.LOD_TENSOR
self._add_accumulator(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

论文中的beta1和beta2是float32是吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

Copy link
Collaborator
@wawltor wawltor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@wawltor wawltor merged commit 4b65eec into PaddlePaddle:develop Mar 3, 2025
9 of 12 checks passed
@lugimzzz lugimzzz deleted the bf16 branch March 3, 2025 09:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants
0