-
Notifications
You must be signed in to change notification settings - Fork 432
Auto-detect bf16 support for CUDA #993
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -20,6 +20,7 @@ | |||||
) | ||||||
from trl import DataCollatorForCompletionOnlyLM, SFTTrainer | ||||||
import click | ||||||
import psutil | ||||||
import torch | ||||||
|
||||||
# Local | ||||||
|
@@ -94,6 +95,7 @@ def report_cuda_device(args_device: torch.device, min_vram: int = 0) -> None: | |||||
"""Report CUDA/ROCm device properties""" | ||||||
print(f" NVidia CUDA version: {torch.version.cuda or 'n/a'}") | ||||||
print(f" AMD ROCm HIP version: {torch.version.hip or 'n/a'}") | ||||||
print(f" Supports bf16: {torch.cuda.is_bf16_supported()}") | ||||||
|
||||||
def _gib(size: int) -> str: | ||||||
return "{:.1f} GiB".format(size / 1024**3) | ||||||
|
@@ -173,6 +175,50 @@ def linux_train( | |||||
hpu.init() | ||||||
report_hpu_device(device) | ||||||
|
||||||
# device register a module, e.g. torch.cpu or torch.cuda | ||||||
device_module = getattr(torch, device.type, None) | ||||||
# bfloat16 is not supported on older CUDA versions and devices | ||||||
# with CUDA support level < 8.0. | ||||||
if hasattr(device_module, "is_bf16_supported"): | ||||||
use_bf16 = device_module.is_bf16_supported() | ||||||
use_fp16 = not use_bf16 | ||||||
elif device.type == "cpu": | ||||||
# TODO: check if Torch and CPU support AVX2, F16C, AVX512 | ||||||
use_bf16 = False | ||||||
use_fp16 = False | ||||||
else: | ||||||
# assume bf16 supported unless device says otherwise | ||||||
use_bf16 = True | ||||||
use_fp16 = False | ||||||
|
||||||
torch_dtype = "auto" if device.type == "cuda" else None | ||||||
if device.type == "cpu": | ||||||
total_memory = psutil.virtual_memory().total / (1024**3) | ||||||
if total_memory < 60: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
A system with 64GB of RAM, will report:
And we have. |
||||||
# Using our default model, a system with 32 GB of RAM | ||||||
# will get OOM killed using torch_dtype=None, though we | ||||||
# seem to get much better performance with this setting | ||||||
# where there's enough memory. Using `None` makes it | ||||||
# use float32 as opposed to float16 or bf16. | ||||||
# | ||||||
# Anecdotally, 64 GB seems to be enough, but this calculation | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A system with 64GB of RAM will report ~62.6 GiB so we base our calculation on There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since it's such a rough guess, 60 still seems fine? We need to actually do some math at some point ... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll share my math in a few :) stay tuned! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Some more numbers:
Essentially a system with 48GB of RAM should be able to run both training and inferencing. Although 48 GB of RAM is not very common. |
||||||
# may come out to be slightly less than 64 GB, so we just check | ||||||
# for 60 GB. It would be better to do a smarter calculation on | ||||||
# the actual memory requirement here. | ||||||
torch_dtype = "auto" | ||||||
|
||||||
# torch compile fails to build, see PyTorch #124707 | ||||||
# scaled_dot_product_attention(): argument 'is_causal' must be bool, not SymBool | ||||||
use_torch_compile = False | ||||||
# if device.type == "cuda" and torch.version.cuda is not None: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. leftover? |
||||||
# # check for NVIDIA V100, A100, or H100 | ||||||
# cap = torch.cuda.get_device_capability(device) | ||||||
# use_torch_compile = cap in {(7, 0), (8, 0), (9, 0)} | ||||||
|
||||||
print( | ||||||
f"LINUX_TRAIN.PY: {use_bf16=}, {use_fp16=}, {torch_dtype=}, {use_torch_compile=}" | ||||||
) | ||||||
|
||||||
print("LINUX_TRAIN.PY: LOADING DATASETS") | ||||||
# Get the file name | ||||||
train_dataset = load_dataset("json", data_files=train_file, split="train") | ||||||
|
@@ -194,6 +240,7 @@ def linux_train( | |||||
|
||||||
if four_bit_quant: | ||||||
print("LINUX_TRAIN.PY: USING 4-bit quantization with BitsAndBytes") | ||||||
use_bf16 = False | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah, I was thinking this should go here. I noticed we were doing this already by setting it to !fp16 below (I think?) |
||||||
use_fp16 = True | ||||||
bnb_config = BitsAndBytesConfig( | ||||||
load_in_4bit=True, | ||||||
|
@@ -203,7 +250,6 @@ def linux_train( | |||||
) | ||||||
else: | ||||||
print("LINUX_TRAIN.PY: NOT USING 4-bit quantization") | ||||||
use_fp16 = False | ||||||
bnb_config = None | ||||||
|
||||||
# Loading the model | ||||||
|
@@ -214,7 +260,7 @@ def linux_train( | |||||
|
||||||
model = AutoModelForCausalLM.from_pretrained( | ||||||
model_name, | ||||||
torch_dtype="auto", | ||||||
torch_dtype=torch_dtype, | ||||||
quantization_config=bnb_config, | ||||||
config=config, | ||||||
trust_remote_code=True, | ||||||
|
@@ -340,7 +386,7 @@ def model_generate(user, **kwargs): | |||||
num_train_epochs=num_epochs, | ||||||
per_device_train_batch_size=per_device_train_batch_size, | ||||||
fp16=use_fp16, | ||||||
bf16=not use_fp16, | ||||||
bf16=use_bf16, | ||||||
# use_ipex=True, # TODO CPU test this possible optimization | ||||||
use_cpu=model.device.type == "cpu", | ||||||
save_strategy="epoch", | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would be super useful to document why torch_dtype=None is faster on CPUs
So on CPUs,
dtype=float32
will be faster thandtype=bfloat16
. Why? A link to a nice explanation of that would be greatI guess
torch_dtype=None
gives adtype=float32
on CPUs? Why? I can't quickly find any docs that explains that - does None just mean that we use whatever torch.get_default_dtype() gives, which is float32 by default?What dtype does
torch_dtype=auto
give on CPUs? Sounds likedtype=bfloat16
? Why? It's detecting that this particular model was saved with that dtype?Can't seem to find the docs on hugginface.co/docs so see: https://github.com/huggingface/transformers/blob/f3f640dce14bee3b3930a774c3dfac92977eee7f/src/transformers/modeling_utils.py#L2878-L2898
We don't have a dtype in the model's config.json so that doesn't seem to be a factor, but it could be
If the torch_dtype=auto behavior is model-specific, but we know we want float32 except on low-memory systems ... maybe for CPUs, we should just explicitly set torch_dtype to either float32 or bfloat16?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It very much depends on the hardware and compiler. In general, x86_64 CPUs have support for standard precision and double precision floats (fp32, fp64). Half precision instructions (fp16) were added in ISA level x86_64-v3 and brain float (bf16) SIMD instruction were added in x86_64-v4. The document https://pytorch.org/docs/stable/amp.html#ops-that-can-autocast-to-float16 explains autocasting.
AFAIK we want to use
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To quote from another comment I just made ...
I don't understand what's going on here well enough to explain this. At least in my two test environments the current code seems to be a nice improvement.
It would definitely be better if we could have a clearer explanation. I imagine there's a bit of luck involved right now, and instead of checking something else, we need to set the ideal configuration.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Based on @tiran 's comment, it's possible that my laptop has this support while my server does not:
which could explain why
"auto"
gets the best performance on my laptop.If
"auto"
isn't doing an adequate job of checking if that is supported and is choosing it even if it's not actually supported on my server, maybe that's killing the performance. I'll have to keep digging here ...Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the output of
torch.backends.cpu.get_cpu_capability()
on your server and your laptop?I get
AVX512
on a server with Intel Xeon Platinum with avx512 instruction set andAVX2
on an Intel Core i7-8650U with avx2 but without AVX512.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On my laptop -- 11th Gen Intel(R) Core(TM) i7-1185G7 @ 3.00GHz
On the server -- AMD EPYC 7R32