8000 GitHub - ielab/Starbucks: Starbucks: Improved Training for 2D Matryoshka Embeddings
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

ielab/Starbucks

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

23 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Starbucks

Starbucks: Improved Training for 2D Matryoshka Embeddings

We propose Starbucks: a new 2D MRL fine-tuning and pre-training method.

Starbucks is composed of two key processes: the Starbucks Masked Autoencoding (SMAE) pretraining and the Starbucks Representation Learning (SRL) fine-tuning processes.

In Starbucks, the model loss is computed based on a limited target list of layer-dimension pairs, ranging from smaller to larger sizes, much like how the coffeehouses chain Starbucks offers coffee in different cup sizes, from Demi to Trenta.

General guidelines

Our codebase is built on top of torch and transformers.

We recommend using a conda environment to install the required dependencies. To install the required dependencies:

conda create -n starbucks python=3.10
conda activate starbucks

pip install torch
pip install transformers datasets peft
pip install deepspeed accelerate

For SMAE pre-training, see smae.

For SRL fine-tuning on retrieval task, see retrieval.

For SRL fine-tuning on STS task, see sts.

Model Checkpoints

We released our model checkpoints on Hugging Face Model Hub:

Pre-trained SMAE: bert-base-uncased-fineweb100bt-smae

Fine-tuned Starbucks_STS: Starbucks_STS

Fine-tuned Starbucks_Retrieval: Starbucks-msmarco

Full Result Table Appendix:

(Table A1) Complete Results Across All Starbucks Sizes for In-Domain Datasets

Complete results across all Starbucks sizes for in-domain datasets. Spearman's Correlation is used to evaluate the STS task. For retrieval task, MRR@10 is used to evaluate on MS MARCO dev set, while nDCG@10 is used to evaluate on DL19 and DL20. * denotes statistical significance differences (p<0.05) between the corresponding method and SMAE-SRL (Starbucks).

Size Parameters Method STS-B STS12 STS13 STS14 STS15 STS16 SICK-R Average MARCO dev DL19 DL20
Demi (n=2, d=32) BERT 0.5919* 0.4700* 0.6052* 0.5424* 0.6530* 0.6397* 0.5996* 0.5860 0.0001* 0.0000* 0.0000*
BERT-Separate 0.7134 0.6637 0.6770* 0.6374* 0.7451* 0.7059* 0.7210* 0.6948 0.2092* 0.4599* 0.4826
BERT-2DMSE 0.7197 0.6204* 0.6747* 0.6278* 0.7336* 0.7058* 0.7066* 0.6841 0.1589* 0.3459* 0.3828*
BERT-SRL 0.7413 0.6789 0.7190 0.6684* 0.7655* 0.7309 0.7455 0.7214 0.2147* 0.5003 0.5294
SMAE-SRL (Starbucks) 0.7455 0.6854 0.7368 0.6955 0.7972 0.7459 0.7473 0.7362 0.2282 0.5122 0.5042
Short (n=4, d=64) BERT 0.6021* 0.4478* 0.5881* 0.5230* 0.6599* 0.6455* 0.6083* 0.5821 0.0000* 0.0000* 0.0000*
BERT-Separate 0.7399* 0.6882* 0.7387* 0.6800* 0.7819* 0.7419* 0.7579* 0.7326 0.2692* 0.5527 0.5487*
BERT-2DMSE 0.7311* 0.6364* 0.6977* 0.6551* 0.7637* 0.7312* 0.7323* 0.7068 0.2266* 0.4988* 0.5179*
BERT-SRL 0.7621* 0.7126 0.7400* 0.6900* 0.7924* 0.7613 0.7678 0.7466 0.2722* 0.5653 0.5853
SMAE-SRL (Starbucks) 0.7933 0.7296 0.7824 0.7444 0.8308 0.7840 0.7700 0.7764 0.2950 0.5766 0.5920
Tall (n=6, d=128) BERT 0.6128* 0.4689* 0.6037* 0.5464* 0.6728* 0.6624* 0.6271* 0.5991 0.0000* 0.0000* 0.0000*
BERT-Separate 0.7853* 0.7150 0.7794 0.7261* 0.7959* 0.7778 0.7802 0.7657 0.2935* 0.5939 0.6028
BERT-2DMSE 0.7501* 0.6434* 0.7061* 0.6630* 0.7739* 0.7421* 0.7522* 0.7187 0.2532* 0.5464* 0.5503*
BERT-SRL 0.7838* 0.7158 0.7685* 0.7146* 0.8066* 0.7762 0.7792 0.7635 0.2983* 0.6102 0.6247
SMAE-SRL (Starbucks) 0.8202 0.7374 0.7998 0.7654 0.8426 0.8029 0.7831 0.7931 0.3274 0.6346 0.6319
Grande (n=8, d=256) BERT 0.6433* 0.5121* 0.6370* 0.5712* 0.7062* 0.6737* 0.6504* 0.6277 0.0009* 0.0000* 0.0000*
BERT-Separate 0.8174 0.7304 0.8097 0.7588 0.8289* 0.7998 0.7807 0.7894 0.3160* 0.6185 0.6259
BERT-2DMSE 0.7721* 0.6635* 0.7400* 0.6756* 0.7949* 0.7570* 0.7774 0.7401 0.2770* 0.5604* 0.5856*
BERT-SRL 0.8112 0.7310 0.7925 0.7397* 0.8246* 0.7959 0.7879 0.7833 0.3159* 0.6342 0.6344
SMAE-SRL (Starbucks) 0.8278 0.7382 0.8033 0.7736 0.8469 0.8092 0.7874 0.7981 0.3369 0.6525 0.6292
Venti (n=10, d=512) BERT 0.7331* 0.6748* 0.6983* 0.6488* 0.7876* 0.7047* 0.7133* 0.7087 0.0000* 0.0100* 0.0047*
BERT-Separate 0.8259 0.7379 0.8088 0.7691 0.8412 0.7931 0.7866 0.7946 0.3279* 0.6502 0.6133
BERT-2DMSE 0.7850* 0.7097* 0.7422* 0.6927* 0.8218* 0.7626* 0.7832 0.7568 0.3022* 0.5769* 0.6082
BERT-SRL 0.8196 0.7346 0.8070 0.7635 0.8391 0.8003 0.7903 0.7935* 0.3249* 0.6405 0.6499
SMAE-SRL (Starbucks) 0.8292 0.7405 0.7989 0.7772 0.8503 0.8138 0.7860 0.7994 0.3416 0.6493 0.6310
Trenta (n=12, d=768) BERT 0.8411 0.7413 0.8319* 0.7902 0.8521 0.8151 0.7949* 0.8095 0.3341 0.6404 0.6410
BERT-Separate 0.8411 0.7413 0.8319* 0.7902 0.8521 0.8151 0.7949* 0.8095 0.3341 0.6404 0.6410
BERT-2DMSE 0.8196 0.7414 0.8087 0.7616 0.8417 0.8017 0.7986* 0.7962 0.3185* 0.6021 0.6001*
BERT-SRL 0.8258 0.7456 0.8185 0.7752 0.8474 0.8059 0.7889 0.8010 0.3266* 0.6550 0.6518
SMAE-SRL (Starbucks) 0.8274 0.7404 0.7998 0.7789 0.8497 0.8131 0.7844 0.7991 0.3403 0.6558 0.6358

(Table A2) Complete Results Across All Starbucks Sizes for Out-of-Domain Datasets

Complete results across all Starbucks sizes for out-of-domain datasets, evaluated using nDCG@10. * denotes statistical significance differences (p<0.05) between the corresponding method and SMAE-SRL (Starbucks).

Size Parameters Method Db.-Ent. Quora Scidocs Scifact Cli.-Fever Arguana Touche20 Fiqa Trec-Cvd Nfcorpus Fever Hpqa Avg.
Domain General General Scientific Scientific Scientific Debate Debate Finance Biomedical Biomedical Wiki Wiki
Demi (n=2, d=32) BERT 0.0011* 0.0001* 0.0013* 0.0079* 0.0000* 0.0005* 0.0000* 0.0022* 0.0000* 0.0135* 0.0000* 0.0000* 0.0022
BERT-Separate 0.1657 0.6900 0.0480* 0.2692 0.0657* 0.1517* 0.1707* 0.0832* 0.3969* 0.1607 0.3474* 0.1620* 0.2259
BERT-2DMSE 0.1314* 0.5291* 0.0414* 0.2158* 0.0582* 0.1094* 0.1601* 0.0556* 0.3262* 0.1371* 0.2111* 0.1001* 0.1730
BERT-SRL 0.1717 0.6865 0.0539 0.2577 0.0703* 0.1471* 0.1895* 0.0806* 0.3791* 0.1622 0.3472* 0.1698 0.2263
SMAE-SRL (Starbucks) 0.1680 0.6912 0.0602 0.2717 0.0977 0.1801 0.2336 0.1140 0.5051 0.1768 0.3630 0.1711 0.2527
Short (n=4, d=64) BERT 0.0000* 0.0001* 0.0006* 0.0000* 0.0000* 0.0012* 0.0000* 0.0000* 0.0014* 0.0100* 0.0000* 0.0001* 0.0011
BERT-Separate 0.2205 0.7769* 0.0741* 0.3809 0.0912* 0.2032* 0.1890 0.1524* 0.5089* 0.2154 0.5084* 0.3028* 0.3020
BERT-2DMSE 0.1915* 0.7157* 0.0633* 0.3122* 0.0902* 0.1851* 0.1807* 0.1164* 0.4445* 0.1811* 0.4155* 0.2241* 0.2600
BERT-SRL 0.2227 0.7709* 0.0758* 0.3616 0.0990* 0.2028* 0.20 7925 74 0.1325* 0.4838* 0.2060* 0.5009* 0.2996* 0.2969
SMAE-SRL (Starbucks) 0.2353 0.7829 0.0858 0.3857 0.1093 0.2273 0.2220 0.1859 0.6180 0.2270 0.5243 0.3159 0.3266
Tall (n=6, d=128) BERT 0.0010* 0.0821* 0.0017* 0.0087* 0.0000* 0.0048* 0.0000* 0.0004* 0.0126* 0.0142* 0.0000* 0.0000* 0.0105
BERT-Separate 0.2498* 0.8054* 0.0929* 0.4074 0.1122* 0.2294* 0.2306 0.1649* 0.5527* 0.2191* 0.5786* 0.3718* 0.3346
BERT-2DMSE 0.2029* 0.7587* 0.0692* 0.3492* 0.1070* 0.2102* 0.2142 0.1432* 0.4644* 0.1918* 0.4638* 0.2669* 0.2868
BERT-SRL 0.2443* 0.7997* 0.0899* 0.4154 0.1214* 0.2295* 0.2210 0.1687* 0.5619* 0.2185* 0.5495* 0.3580* 0.3315
SMAE-SRL (Starbucks) 0.2740 0.8150 0.1054 0.4333 0.1304 0.2537 0.2402 0.2202 0.6451 0.2609 0.5967 0.4044 0.3649
Grande (n=8, d=256) BERT 0.0020* 0.3888* 0.0034* 0.0494* 0.0084* 0.0461* 0.0000* 0.0030* 0.0520* 0.0419* 0.0006* 0.0015* 0.0498
BERT-Separate 0.2820 0.8260* 0.1020* 0.4688 0.1289* 0.2416* 0.2137 0.2077* 0.5650* 0.2396* 0.6321* 0.4238* 0.3609
BERT-2DMSE 0.2223* 0.5341* 0.0724* 0.3605* 0.1091* 0.2237* 0.2342 0.1558* 0.5109* 0.1993* 0.5088* 0.2954* 0.2855
BERT-SRL 0.2729* 0.8220* 0.0992* 0.4600 0.1329* 0.2449* 0.2202 0.1971* 0.6103 0.2380* 0.6056* 0.4081* 0.3593
SMAE-SRL (Starbucks) 0.2920 0.7808 0.1168 0.4538 0.1430 0.2651 0.2451 0.2379 0.6672 0.2719 0.6243 0.4380 0.3780
Venti (n=10, d=512) BERT 0.0005* 0.1195* 0.0040* 0.0250* 0.0000* 0.0085* 0.0000* 0.0000* 0.0138* 0.0247* 0.0000* 0.0000* 0.0163
BERT-Separate 0.3109 0.7687* 0.1087* 0.4665 0.1376* 0.2646 0.2299 0.2277* 0.5898* 0.2543* 0.6666* 0.4652* 0.3742
BERT-2DMSE 0.2442* 0.1566* 0.0807* 0.3560* 0.1305* 0.2301* 0.2466 0.1791* 0.5437* 0.2030* 0.5687* 0.3267* 0.2722
BERT-SRL 0.2856 0.7751* 0.1070* 0.4727 0.1485 0.2553* 0.2486 0.2206* 0.6103* 0.2460* 0.6430* 0.4439 0.3714
SMAE-SRL (Starbucks) 0.2968 0.0178 0.1175 0.4629 0.1538 0.2673 0.2545 0.2486 0.6645 0.2733 0.6359 0.4466 0.3200
Trenta (n=12, d=768) BERT 0.3018 0.8224* 0.1109* 0.4597 0.1604 0.2707 0.2337* 0.2200* 0.6041 0.2585* 0.6601* 0.4548* 0.3798
BERT-Separate 0.3018 0.8224* 0.1109* 0.4597 0.1604 0.2707 0.2337* 0.2200* 0.6041 0.2585* 0.6601* 0.4548* 0.3798
BERT-2DMSE 0.2552* 0.8152* 0.0878* 0.3863* 0.1391* 0.2363* 0.1936 0.1851* 0.5673 0.2105* 0.5848* 0.3532* 0.3345
BERT-SRL 0.2925 0.3419* 0.1072* 0.4826 0.1569 0.2533* 0.2628* 0.2217* 0.6294 0.2508* 0.6470* 0.4466 0.3411
SMAE-SRL (Starbucks) 0.2997 0.8477 0.1206 0.4715 0.1618 0.2734 0.2022 0.2508 0.6140 0.2738 0.6275 0.4486 0.3826

About

Starbucks: Improved Training for 2D Matryoshka Embeddings

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 3

  •  
  •  
  •  
0