8000 High CrossEntropy and Z Loss variance after loading from checkpoint · Issue #776 · allenai/OLMo · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
High CrossEntropy and Z Loss variance after loading from checkpoint #776
Open
@abhijangda

Description

@abhijangda

🐛 Describe the bug

I have been playing with configs/official-1124/OLMo-7B-stage1.yaml and training using the dataset in the YAML file. Unfortunately, I have found a strange issue. After loading from a checkpoint the variance in Cross Entropy and Z Loss has increased dramatically. For example, I ran first iteration till steps 5600 and then re ran training from a checkpoint of 4400. Here are Loss graphs from wandb:

image

You can see clearly that after step 4400 the variance is high.

I have tried this on following two systems and both shows the same problem.

  • 64 AMD MI300X with 8 nodes using ROCM 6.1, PyTorch 2.5.1, and Python 3.11
  • 64 NVIDIA A100 with 8 nodes using CUDA 12.4, PytTorch2.5.1, and Python 3.11

I have tried changing with heads and layers of OLMo-7B-stage1.yaml: 16 and 32 but both have same issues.
I have been using OLMo Core checkpointer using the following method:

  1. First collect tensors of all nodes in model, train, and optim folder of checkpoints in a single folder accessible to all nodes.
  2. Then set --load_path= to the above folder containing all tensors.

Below is the config I used (I removed the dataset URLs):

run_name: OLMo2-7B-stage1
seed: 6198
dry_run: false

model:
  d_model: 4096
  n_heads: 32
  n_layers: 32
  mlp_hidden_size: 22016
  weight_tying: false
  alibi: false
  rope: true
  rope_theta: 500000
  flash_attention: true
  attention_dropout: 0.0
  include_bias: false
  block_type: sequential
  layer_norm_type: rms
  layer_norm_with_affine: true
  layer_norm_eps: 1e-6
  bias_for_layer_norm: false
  attention_layer_norm: true
  attention_layer_norm_with_affine: true
  norm_after: true
  activation_type: swiglu
  residual_dropout: 0.0
  embedding_dropout: 0.0
  max_sequence_length: 4096
  vocab_size: 100278
  embedding_size: 100352
  eos_token_id: 100257
  pad_token_id: 100277
  init_device: meta
  init_fn: normal
  init_std: 0.02
  init_cutoff_factor: 3

softmax_auxiliary_loss: true
auxiliary_loss_multiplier: 1e-5
fused_loss: true

compile: null

wandb:
  project: "llm-kron"
  entity: "abhijangda-microsoft"
  log_interval: 1
  group: "7B"

optimizer:
  name: adamw
  learning_rate: 3.0e-4
  weight_decay: 0.1
  eps: 1e-8
  decay_norm_and_bias: true
  decay_embeddings: false
  betas:
  - 0.9
  - 0.95
  metrics_log_interval: 1

scheduler:
  name: cosine_with_warmup
  units: tokens
  t_warmup: 8388608000
  t_max: 5e12
  alpha_f: 0.1
  warmup_min_lr: 0.0

tokenizer:
  identifier: tokenizers/allenai_dolma2.json
  truncate_direction: right

save_overwrite: false

save_interval: 1000
save_interval_ephemeral: 250
save_num_checkpoints_to_keep: -1
sharded_checkpointer: olmo_core

save_interval_unsharded: null
save_num_unsharded_checkpoints_to_keep: -1

load_path: null

max_duration: 1ep
global_train_batch_size: 1024
device_train_microbatch_size: 8

precision: amp_bf16

fsdp:
  wrapping_strategy: by_block_and_size
  precision: mixed

max_grad_norm: 1.0
max_grad_norm_ratio: null

speed_monitor:
  window_size: 1

gen1_gc_interval: 1

eval_interval: 1000
eval_subset_num_batches: -1
device_eval_batch_size: ${device_train_microbatch_size}
data:
  pad_direction: right
  # generate_doc_lengths: true
  num_workers: 32
  drop_last: true
  pin_memory: true
  prefetch_factor: 8
  persistent_workers: true
  memmap_dtype: uint32
  timeout: 0
  instance_filter:
    repetition_max_period: 13
    repetition_min_period: 1
    repetition_max_count: 32

Any idea what could be the issue here?

Versions

absl-py==2.1.0
accelerate==0.18.0
-e git+ssh://git@github.com/abhijangda/OLMo.git@77e47c6d84c018fc33a5eda086056c1402f74381#egg=ai2_olmo
ai2-olmo-core==0.1.0
aiofiles==23.2.1
aiohappyeyeballs==2.4.3
aiohttp==3.11.3
aioshutil==1.5
aiosignal==1.3.1
annotated-types==0.7.0
antlr4-python3-runtime==4.9.3
anyio==4.6.2.post1
anykeystore==0.2
apex==0.1
appdirs==1.4.4
asttokens==2.4.1
astunparse==1.6.3
attrs==24.2.0
autocommand==2.2.2
backoff==2.2.1
backports.tarfile==1.2.0
beaker-gantry==1.10.0
beaker-py==1.32.3
beautifulsoup4==4.12.3
bitsandbytes==0.44.1
black==23.12.1
boltons==21.0.0
boto3==1.35.84
botocore==1.35.84
bracex==2.5.post1
Brotli==1.1.0
build==1.2.2.post1
cached_path==1.6.5
cachetools==5.5.0
certifi==2024.8.30
cffi==1.17.1
chardet==5.2.0
charset-normalizer==3.4.0
click==8.1.7
click-help-colors==0.9.4
click-option-group==0.5.6
cmake==3.31.0.1
codeshield==1.0.1
colorama==0.4.6
coloredlogs==15.0.1
contourpy==1.3.1
cryptacular==1.6.2
cryptography==43.0.3
cupy==13.3.0
cxxfilt==0.3.0
cycler==0.12.1
dataclasses-json==0.6.7
datasets==3.2.0
decorator==5.1.1
defusedxml==0.7.1
Deprecated==1.2.15
dill==0.3.6
distro==1.9.0
docker==7.1.0
docker-pycreds==0.4.0
docutils==0.21.2
effdet==0.4.1
einops==0.8.0
emoji==2.14.0
eval_type_backport==0.2.0
evaluate==0.4.3
exceptiongroup==1.2.2
executing==2.1.0
expecttest==0.2.1
face==24.0.0
fastapi==0.115.5
fastrlock==0.8.2
ffmpy==0.4.0
filelock==3.16.1
filetype==1.2.0
fire==0.7.0
flash_attn @ file:///home/aiscuser/ajangda/flash-attention
flatbuffers==24.3.25
fonttools==4.55.0
frozenlist==1.5.0
fsspec==2023.9.2
ftfy==6.3.1
gitdb==4.0.11
GitPython==3.1.43
glom==22.1.0
google-api-core==2.23.0
google-auth==2.36.0
google-cloud-core==2.4.1
google-cloud-storage==2.19.0
google-cloud-vision==3.8.1
google-crc32c==1.6.0
google-resumable-media==2.7.2
googleapis-common-protos==1.66.0
gradio==5.6.0
gradio_client==1.4.3
greenlet==3.1.1
grpcio==1.68.0
grpcio-status==1.62.3
h11==0.14.0
httpcore==1.0.7
httpx==0.27.2
huggingface-hub==0.26.5
humanfriendly==10.0
hupper==1.12.1
hypothesis==6.119.2
idna==3.10
importlib_metadata==7.1.0
importlib_resources==6.4.0
inflate64==1.0.0
inflect==7.3.1
iniconfig==2.0.0
iopath==0.1.10
ipython==8.29.0
isort==5.12.0
jaraco.classes==3.4.0
jaraco.context==5.3.0
jaraco.functools==4.0.1
jaraco.text==3.12.1
jedi==0.19.2
jeepney==0.8.0
Jinja2==3.1.4
jmespath==1.0.1
joblib==1.4.2
jsonpatch==1.33
jsonpath-python==1.0.6
jsonpointer==3.0.0
jsonschema==4.23.0
jsonschema-specifications==2024.10.1
keyring==25.5.0
kiwisolver==1.4.7
langchain==0.2.17
langchain-community==0.2.19
langchain-core==0.2.43
langchain-openai==0.1.20
langchain-text-splitters==0.2.4
langdetect==1.0.9
langsmith==0.1.143
layoutparser==0.3.4
lightning-utilities==0.11.9
lintrunner==0.12.5
loralib==0.1.2
lxml==5.3.0
markdown-it-py==3.0.0
MarkupSafe==2.1.5
marshmallow==3.23.1
matplotlib==3.9.2
matplotlib-inline==0.1.7
mdurl==0.1.2
more-itertools==10.3.0
mpi4py @ file:///work/ci_py311/mpi4py_1676858691457/work
mpmath==1.3.0
mscclpp @ file:///home/ajangda/mscclpp
msgspec==0.18.6
multidict==6.1.0
multiprocess==0.70.14
multivolumefile==0.2.3
mypy==1.3.0
mypy-extensions==1.0.0
necessary==0.4.3
nest-asyncio==1.6.0
netifaces==0.11.0
networkx==3.4.2
nh3==0.2.20
ninja==1.11.1.1
nltk==3.9.1
numpy==1.26.4
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==9.1.0.70
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.21.5
nvidia-nvjitlink-cu12==12.4.127
nvidia-nvtx-cu12==12.1.105
oauthlib==3.2.2
omegaconf==2.3.0
> > openai==1.39.0
opencv-python==4.10.0.84
opentelemetry-api==1.25.0
opentelemetry-exporter-otlp-proto-common==1.25.0
opentelemetry-exporter-otlp-proto-http==1.25.0
opentelemetry-instrumentation==0.46b0
opentelemetry-instrumentation-requests==0.46b0
opentelemetry-proto==1.25.0
opentelemetry-sdk==1.25.0
opentelemetry-semantic-conventions==0.46b0
opentelemetry-util-http==0.46b0
optimum==1.23.3
optree==0.13.1
ordered-set==4.1.0
orjson==3.10.11
packaging==24.2
pandas==2.2.3
parso==0.8.4
PasteDeploy==3.1.0
pathspec==0.12.1
pbkdf2==1.3
pdf2image==1.17.0
pdfminer.six==20231228
pdfplumber==0.11.4
peewee==3.17.8
peft==0.13.2
petname==2.6
pexpect==4.9.0
pi_heif==0.20.0
pikepdf==9.4.2
pillow==11.0.0
pkginfo==1.12.0
plaster==1.1.2
plaster-pastedeploy==1.0.1
platformdirs==4.2.2
pluggy==1.5.0
portalocker==3.0.0
prettytable==3.12.0
prompt_toolkit==3.0.48
propcache==0.2.0
proto-plus==1.25.0
protobuf==4.25.5
psutil==6.1.0
ptyprocess==0.7.0
pure_eval==0.2.3
py7zr==0.22.0
pyarrow==18.0.0
pyasn1==0.6.1
pyasn1_modules==0.4.1
pybcj==1.0.2
pybind11==2.13.6
pybind11_global==2.13.6
pycocotools==2.0.8
pycparser==2.22
pycryptodomex==3.21.0
pydantic==2.9.2
pydantic_core==2.23.4
pydub==0.25.1
pyfastkron @ file:///home/aiscuser/ajangda/OLMo/pyfastkron-1.0.1-py3-none-any.whl#sha256=600f33c84967e12106e7e2b25f583422bf4a1a1f8dc887b5e8df54fa9bba2082
Pygments==2.18.0
pyparsing==3.2.0
pypdf==5.1.0
pypdfium2==4.30.0
pyppmd==1.1.0
pyproject_hooks==1.2.0
pyramid==2.0.2
pyramid-mailer==0.15.1
pytest==8.3.4
pytest-sphinx==0.6.3
python-dateutil==2.8.2
python-iso639==2024.10.22
python-magic==0.4.27
python-multipart==0.0.12
python3-openid==3.2.0
pytorch-triton-rocm==3.1.0
pytz==2024.2
PyYAML==6.0.1
pyzstd==0.16.2
RapidFuzz==3.10.1
readme_renderer==44.0
referencing==0.35.1
regex==2024.11.6
repoze.sendmail==4.4.1
requests==2.32.3
requests-oauthlib==2.0.0
requests-toolbelt==1.0.0
requirements-parser==0.11.0
responses==0.18.0
rfc3986==2.0.0
rich==13.5.3
rouge_score==0.1.2
rpds-py==0.21.0
rsa==4.9
ruamel.yaml==0.17.40
ruamel.yaml.clib==0.2.12
ruff==0.7.4
s3transfer==0.10.4
safehttpx==0.1.1
safetensors==0.4.5
scikit-learn==1.5.2
scipy==1.14.1
SecretStorage==3.3.3
semantic-version==2.10.0
semgrep==1.96.0
sentence-transformers==3.3.1
sentencepiece==0.2.0
sentry-sdk==2.19.2
setproctitle==1.3.4
shellingham==1.5.4
six==1.16.0
smart-open==7.1.0
smashed==0.21.5
smmap==5.0.1
sniffio==1.3.1
sortedcontainers==2.4.0
soupsieve==2.6
SQLAlchemy==2.0.36
stack-data==0.6.3
starlette==0.41.3
sympy==1.13.1
tabulate==0.9.0
tenacity==8.5.0
termcolor==2.5.0
texttable==1.7.0
threadpoolctl==3.5.0
tiktoken==0.8.0
timm==1.0.11
tokenize_rt==6.1.0
tokenizers==0.13.3
tomli==2.0.1
tomlkit==0.12.0
torch==2.5.1+rocm6.1
torchaudio==2.5.1+rocm6.1
torchmetrics==1.6.0
torchvision==0.20.1+rocm6.1
tqdm==4.67.1
traitlets==5.14.3
transaction==5.0
transformers==4.28.1
translationstring==1.4
triton==3.1.0
trouting==0.3.3
twine==6.0.1
typeguard==4.3.0
typer==0.13.1
types-dataclasses==0.6.6
types-setuptools==75.6.0.20241126
typing-inspect==0.9.0
typing_extensions==4.12.2
tzdata==2024.2
unstructured==0.15.8
unstructured-client==0.27.0
unstructured-inference==0.7.36
unstructured.pytesseract==0.3.13
urllib3==2.2.3
uvicorn==0.32.0
velruse==1.1.1
venusian==3.1.1
wandb==0.19.1
wcmatch==8.5.2
wcwidth==0.2.13
WebOb==1.8.9
websockets==12.0
wrapt==1.16.0
WTForms==3.2.1
wtforms-recaptcha==0.3.2
xxhash==3.5.0
yarl==1.17.2
zipp==3.19.2
zope.deprecation==5.0
zope.interface==7.2
zope.sqlalchemy==3.1

Metadata

Metadata

Assignees

Labels

type/bugAn issue about a bug

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions

    0