Description
🐛 Describe the bug
I have been playing with configs/official-1124/OLMo-7B-stage1.yaml and training using the dataset in the YAML file. Unfortunately, I have found a strange issue. After loading from a checkpoint the variance in Cross Entropy and Z Loss has increased dramatically. For example, I ran first iteration till steps 5600 and then re ran training from a checkpoint of 4400. Here are Loss graphs from wandb:
You can see clearly that after step 4400 the variance is high.
I have tried this on following two systems and both shows the same problem.
- 64 AMD MI300X with 8 nodes using ROCM 6.1, PyTorch 2.5.1, and Python 3.11
- 64 NVIDIA A100 with 8 nodes using CUDA 12.4, PytTorch2.5.1, and Python 3.11
I have tried changing with heads and layers of OLMo-7B-stage1.yaml: 16 and 32 but both have same issues.
I have been using OLMo Core checkpointer using the following method:
- First collect tensors of all nodes in
model
,train
, andoptim
folder of checkpoints in a single folder accessible to all nodes. - Then set
--load_path=
to the above folder containing all tensors.
Below is the config I used (I removed the dataset URLs):
run_name: OLMo2-7B-stage1
seed: 6198
dry_run: false
model:
d_model: 4096
n_heads: 32
n_layers: 32
mlp_hidden_size: 22016
weight_tying: false
alibi: false
rope: true
rope_theta: 500000
flash_attention: true
attention_dropout: 0.0
include_bias: false
block_type: sequential
layer_norm_type: rms
layer_norm_with_affine: true
layer_norm_eps: 1e-6
bias_for_layer_norm: false
attention_layer_norm: true
attention_layer_norm_with_affine: true
norm_after: true
activation_type: swiglu
residual_dropout: 0.0
embedding_dropout: 0.0
max_sequence_length: 4096
vocab_size: 100278
embedding_size: 100352
eos_token_id: 100257
pad_token_id: 100277
init_device: meta
init_fn: normal
init_std: 0.02
init_cutoff_factor: 3
softmax_auxiliary_loss: true
auxiliary_loss_multiplier: 1e-5
fused_loss: true
compile: null
wandb:
project: "llm-kron"
entity: "abhijangda-microsoft"
log_interval: 1
group: "7B"
optimizer:
name: adamw
learning_rate: 3.0e-4
weight_decay: 0.1
eps: 1e-8
decay_norm_and_bias: true
decay_embeddings: false
betas:
- 0.9
- 0.95
metrics_log_interval: 1
scheduler:
name: cosine_with_warmup
units: tokens
t_warmup: 8388608000
t_max: 5e12
alpha_f: 0.1
warmup_min_lr: 0.0
tokenizer:
identifier: tokenizers/allenai_dolma2.json
truncate_direction: right
save_overwrite: false
save_interval: 1000
save_interval_ephemeral: 250
save_num_checkpoints_to_keep: -1
sharded_checkpointer: olmo_core
save_interval_unsharded: null
save_num_unsharded_checkpoints_to_keep: -1
load_path: null
max_duration: 1ep
global_train_batch_size: 1024
device_train_microbatch_size: 8
precision: amp_bf16
fsdp:
wrapping_strategy: by_block_and_size
precision: mixed
max_grad_norm: 1.0
max_grad_norm_ratio: null
speed_monitor:
window_size: 1
gen1_gc_interval: 1
eval_interval: 1000
eval_subset_num_batches: -1
device_eval_batch_size: ${device_train_microbatch_size}
data:
pad_direction: right
# generate_doc_lengths: true
num_workers: 32
drop_last: true
pin_memory: true
prefetch_factor: 8
persistent_workers: true
memmap_dtype: uint32
timeout: 0
instance_filter:
repetition_max_period: 13
repetition_min_period: 1
repetition_max_count: 32
Any idea what could be the issue here?
Versions
absl-py==2.1.0
accelerate==0.18.0
-e git+ssh://git@github.com/abhijangda/OLMo.git@77e47c6d84c018fc33a5eda086056c1402f74381#egg=ai2_olmo
ai2-olmo-core==0.1.0
aiofiles==23.2.1
aiohappyeyeballs==2.4.3
aiohttp==3.11.3
aioshutil==1.5
aiosignal==1.3.1
annotated-types==0.7.0
antlr4-python3-runtime==4.9.3
anyio==4.6.2.post1
anykeystore==0.2
apex==0.1
appdirs==1.4.4
asttokens==2.4.1
astunparse==1.6.3
attrs==24.2.0
autocommand==2.2.2
backoff==2.2.1
backports.tarfile==1.2.0
beaker-gantry==1.10.0
beaker-py==1.32.3
beautifulsoup4==4.12.3
bitsandbytes==0.44.1
black==23.12.1
boltons==21.0.0
boto3==1.35.84
botocore==1.35.84
bracex==2.5.post1
Brotli==1.1.0
build==1.2.2.post1
cached_path==1.6.5
cachetools==5.5.0
certifi==2024.8.30
cffi==1.17.1
chardet==5.2.0
charset-normalizer==3.4.0
click==8.1.7
click-help-colors==0.9.4
click-option-group==0.5.6
cmake==3.31.0.1
codeshield==1.0.1
colorama==0.4.6
coloredlogs==15.0.1
contourpy==1.3.1
cryptacular==1.6.2
cryptography==43.0.3
cupy==13.3.0
cxxfilt==0.3.0
cycler==0.12.1
dataclasses-json==0.6.7
datasets==3.2.0
decorator==5.1.1
defusedxml==0.7.1
Deprecated==1.2.15
dill==0.3.6
distro==1.9.0
docker==7.1.0
docker-pycreds==0.4.0
docutils==0.21.2
effdet==0.4.1
einops==0.8.0
emoji==2.14.0
eval_type_backport==0.2.0
evaluate==0.4.3
exceptiongroup==1.2.2
executing==2.1.0
expecttest==0.2.1
face==24.0.0
fastapi==0.115.5
fastrlock==0.8.2
ffmpy==0.4.0
filelock==3.16.1
filetype==1.2.0
fire==0.7.0
flash_attn @ file:///home/aiscuser/ajangda/flash-attention
flatbuffers==24.3.25
fonttools==4.55.0
frozenlist==1.5.0
fsspec==2023.9.2
ftfy==6.3.1
gitdb==4.0.11
GitPython==3.1.43
glom==22.1.0
google-api-core==2.23.0
google-auth==2.36.0
google-cloud-core==2.4.1
google-cloud-storage==2.19.0
google-cloud-vision==3.8.1
google-crc32c==1.6.0
google-resumable-media==2.7.2
googleapis-common-protos==1.66.0
gradio==5.6.0
gradio_client==1.4.3
greenlet==3.1.1
grpcio==1.68.0
grpcio-status==1.62.3
h11==0.14.0
httpcore==1.0.7
httpx==0.27.2
huggingface-hub==0.26.5
humanfriendly==10.0
hupper==1.12.1
hypothesis==6.119.2
idna==3.10
importlib_metadata==7.1.0
importlib_resources==6.4.0
inflate64==1.0.0
inflect==7.3.1
iniconfig==2.0.0
iopath==0.1.10
ipython==8.29.0
isort==5.12.0
jaraco.classes==3.4.0
jaraco.context==5.3.0
jaraco.functools==4.0.1
jaraco.text==3.12.1
jedi==0.19.2
jeepney==0.8.0
Jinja2==3.1.4
jmespath==1.0.1
joblib==1.4.2
jsonpatch==1.33
jsonpath-python==1.0.6
jsonpointer==3.0.0
jsonschema==4.23.0
jsonschema-specifications==2024.10.1
keyring==25.5.0
kiwisolver==1.4.7
langchain==0.2.17
langchain-community==0.2.19
langchain-core==0.2.43
langchain-openai==0.1.20
langchain-text-splitters==0.2.4
langdetect==1.0.9
langsmith==0.1.143
layoutparser==0.3.4
lightning-utilities==0.11.9
lintrunner==0.12.5
loralib==0.1.2
lxml==5.3.0
markdown-it-py==3.0.0
MarkupSafe==2.1.5
marshmallow==3.23.1
matplotlib==3.9.2
matplotlib-inline==0.1.7
mdurl==0.1.2
more-itertools==10.3.0
mpi4py @ file:///work/ci_py311/mpi4py_1676858691457/work
mpmath==1.3.0
mscclpp @ file:///home/ajangda/mscclpp
msgspec==0.18.6
multidict==6.1.0
multiprocess==0.70.14
multivolumefile==0.2.3
mypy==1.3.0
mypy-extensions==1.0.0
necessary==0.4.3
nest-asyncio==1.6.0
netifaces==0.11.0
networkx==3.4.2
nh3==0.2.20
ninja==1.11.1.1
nltk==3.9.1
numpy==1.26.4
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==9.1.0.70
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.21.5
nvidia-nvjitlink-cu12==12.4.127
nvidia-nvtx-cu12==12.1.105
oauthlib==3.2.2
omegaconf==2.3.0
>
>
openai==1.39.0
opencv-python==4.10.0.84
opentelemetry-api==1.25.0
opentelemetry-exporter-otlp-proto-common==1.25.0
opentelemetry-exporter-otlp-proto-http==1.25.0
opentelemetry-instrumentation==0.46b0
opentelemetry-instrumentation-requests==0.46b0
opentelemetry-proto==1.25.0
opentelemetry-sdk==1.25.0
opentelemetry-semantic-conventions==0.46b0
opentelemetry-util-http==0.46b0
optimum==1.23.3
optree==0.13.1
ordered-set==4.1.0
orjson==3.10.11
packaging==24.2
pandas==2.2.3
parso==0.8.4
PasteDeploy==3.1.0
pathspec==0.12.1
pbkdf2==1.3
pdf2image==1.17.0
pdfminer.six==20231228
pdfplumber==0.11.4
peewee==3.17.8
peft==0.13.2
petname==2.6
pexpect==4.9.0
pi_heif==0.20.0
pikepdf==9.4.2
pillow==11.0.0
pkginfo==1.12.0
plaster==1.1.2
plaster-pastedeploy==1.0.1
platformdirs==4.2.2
pluggy==1.5.0
portalocker==3.0.0
prettytable==3.12.0
prompt_toolkit==3.0.48
propcache==0.2.0
proto-plus==1.25.0
protobuf==4.25.5
psutil==6.1.0
ptyprocess==0.7.0
pure_eval==0.2.3
py7zr==0.22.0
pyarrow==18.0.0
pyasn1==0.6.1
pyasn1_modules==0.4.1
pybcj==1.0.2
pybind11==2.13.6
pybind11_global==2.13.6
pycocotools==2.0.8
pycparser==2.22
pycryptodomex==3.21.0
pydantic==2.9.2
pydantic_core==2.23.4
pydub==0.25.1
pyfastkron @ file:///home/aiscuser/ajangda/OLMo/pyfastkron-1.0.1-py3-none-any.whl#sha256=600f33c84967e12106e7e2b25f583422bf4a1a1f8dc887b5e8df54fa9bba2082
Pygments==2.18.0
pyparsing==3.2.0
pypdf==5.1.0
pypdfium2==4.30.0
pyppmd==1.1.0
pyproject_hooks==1.2.0
pyramid==2.0.2
pyramid-mailer==0.15.1
pytest==8.3.4
pytest-sphinx==0.6.3
python-dateutil==2.8.2
python-iso639==2024.10.22
python-magic==0.4.27
python-multipart==0.0.12
python3-openid==3.2.0
pytorch-triton-rocm==3.1.0
pytz==2024.2
PyYAML==6.0.1
pyzstd==0.16.2
RapidFuzz==3.10.1
readme_renderer==44.0
referencing==0.35.1
regex==2024.11.6
repoze.sendmail==4.4.1
requests==2.32.3
requests-oauthlib==2.0.0
requests-toolbelt==1.0.0
requirements-parser==0.11.0
responses==0.18.0
rfc3986==2.0.0
rich==13.5.3
rouge_score==0.1.2
rpds-py==0.21.0
rsa==4.9
ruamel.yaml==0.17.40
ruamel.yaml.clib==0.2.12
ruff==0.7.4
s3transfer==0.10.4
safehttpx==0.1.1
safetensors==0.4.5
scikit-learn==1.5.2
scipy==1.14.1
SecretStorage==3.3.3
semantic-version==2.10.0
semgrep==1.96.0
sentence-transformers==3.3.1
sentencepiece==0.2.0
sentry-sdk==2.19.2
setproctitle==1.3.4
shellingham==1.5.4
six==1.16.0
smart-open==7.1.0
smashed==0.21.5
smmap==5.0.1
sniffio==1.3.1
sortedcontainers==2.4.0
soupsieve==2.6
SQLAlchemy==2.0.36
stack-data==0.6.3
starlette==0.41.3
sympy==1.13.1
tabulate==0.9.0
tenacity==8.5.0
termcolor==2.5.0
texttable==1.7.0
threadpoolctl==3.5.0
tiktoken==0.8.0
timm==1.0.11
tokenize_rt==6.1.0
tokenizers==0.13.3
tomli==2.0.1
tomlkit==0.12.0
torch==2.5.1+rocm6.1
torchaudio==2.5.1+rocm6.1
torchmetrics==1.6.0
torchvision==0.20.1+rocm6.1
tqdm==4.67.1
traitlets==5.14.3
transaction==5.0
transformers==4.28.1
translationstring==1.4
triton==3.1.0
trouting==0.3.3
twine==6.0.1
typeguard==4.3.0
typer==0.13.1
types-dataclasses==0.6.6
types-setuptools==75.6.0.20241126
typing-inspect==0.9.0
typing_extensions==4.12.2
tzdata==2024.2
unstructured==0.15.8
unstructured-client==0.27.0
unstructured-inference==0.7.36
unstructured.pytesseract==0.3.13
urllib3==2.2.3
uvicorn==0.32.0
velruse==1.1.1
venusian==3.1.1
wandb==0.19.1
wcmatch==8.5.2
wcwidth==0.2.13
WebOb==1.8.9
websockets==12.0
wrapt==1.16.0
WTForms==3.2.1
wtforms-recaptcha==0.3.2
xxhash==3.5.0
yarl==1.17.2
zipp==3.19.2
zope.deprecation==5.0
zope.interface==7.2
zope.sqlalchemy==3.1