You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
ByteDance recently released bagel, one of the first open model with image generation.
This model is relatively small in scale (14B mixture of two 7B models), and uses relatively standard architectures (Qwen2 + siglip/ViT + VAE for image generation).
It's probably a good opportunity to keep moving forward on increasing eole multi-modality, and keep factorizing/structuring the related codepaths.
Notes
there is no official HF code for now, only the original repo
the config is a bit messy and will require some patching in our conversion
at a first glance, it looks relatively similar to gemma3 for the vision "understanding" part
I did not check the image generation part yet, not sure how much work will be required
Testing
For reference, here is a quick script put together from the official inference.ipynb notebook, with the "gdp" example we already use for pixtral/gemma3 testing.
importosfromcopyimportdeepcopyfromtypingimport (
Any,
AsyncIterable,
Callable,
Dict,
Generator,
List,
NamedTuple,
Optional,
Tuple,
Union,
)
importrequestsfromioimportBytesIOfromPILimportImageimporttorchfromaccelerateimportinfer_auto_device_map, load_checkpoint_and_dispatch, init_empty_weightsfromdata.transformsimportImageTransformfromdata.data_utilsimportpil_img2rgb, add_special_tokensfrommodeling.bagelimport (
BagelConfig, Bagel, Qwen2Config, Qwen2ForCausalLM, SiglipVisionConfig, SiglipVisionModel
)
frommodeling.qwen2importQwen2Tokenizerfrommodeling.bagel.qwen2_navitimportNaiveCachefrommodeling.autoencoderimportload_aefromsafetensors.torchimportload_filemodel_path="./BAGEL-7B-MoT"# Download from https://huggingface.co/ByteDance-Seed/BAGEL-7B-MoT# LLM config preparingllm_config=Qwen2Config.from_json_file(os.path.join(model_path, "llm_config.json"))
llm_config.qk_norm=Truellm_config.tie_word_embeddings=Falsellm_config.layer_module="Qwen2MoTDecoderLayer"# ViT config preparingvit_config=SiglipVisionConfig.from_json_file(os.path.join(model_path, "vit_config.json"))
vit_config.rope=Falsevit_config.num_hidden_layers=vit_config.num_hidden_layers-1# VAE loadingvae_model, vae_config=load_ae(local_path=os.path.join(model_path, "ae.safetensors"))
# Bagel config preparingconfig=BagelConfig(
visual_gen=True,
visual_und=True,
llm_config=llm_config,
vit_config=vit_config,
vae_config=vae_config,
vit_max_num_patch_per_side=70,
connector_act='gelu_pytorch_tanh',
latent_patch_size=2,
max_latent_size=64,
)
withinit_empty_weights():
language_model=Qwen2ForCausalLM(llm_config)
vit_model=SiglipVisionModel(vit_config)
model=Bagel(language_model, vit_model, config)
model.vit_model.vision_model.embeddings.convert_conv2d_to_linear(vit_config, meta=True)
# Tokenizer Preparingtokenizer=Qwen2Tokenizer.from_pretrained(model_path)
tokenizer, new_token_ids, _=add_special_tokens(tokenizer)
# Image Transform Preparingvae_transform=ImageTransform(1024, 512, 16)
vit_transform=ImageTransform(980, 224, 14)
# max_mem_per_gpu = "24GiB" # Modify it according to your GPU settingmax_mem_per_gpu="40GiB"# Modify it according to your GPU settingdevice_map=infer_auto_device_map(
model,
max_memory={i: max_mem_per_gpuforiinrange(torch.cuda.device_count())},
no_split_module_classes=["Bagel", "Qwen2MoTDecoderLayer"],
)
print(device_map)
same_device_modules= [
'language_model.model.embed_tokens',
'time_embedder',
'latent_pos_embed',
'vae2llm',
'llm2vae',
'connector',
'vit_pos_embed'
]
iftorch.cuda.device_count() ==1:
first_device=device_map.get(same_device_modules[0], "cuda:0")
forkinsame_device_modules:
ifkindevice_map:
device_map[k] =first_deviceelse:
device_map[k] ="cuda:0"else:
first_device=device_map.get(same_device_modules[0])
forkinsame_device_modules:
ifkindevice_map:
device_map[k] =first_devicemodel=load_checkpoint_and_dispatch(
model,
checkpoint=os.path.join(model_path, "ema.safetensors"),
device_map=device_map,
offload_buffers=True,
dtype=torch.bfloat16,
offload_folder="./offload",
)
model=model.eval()
print('Model loaded')
frominferencerimportInterleaveInferencerinferencer=InterleaveInferencer(
model=model,
vae_model=vae_model,
tokenizer=tokenizer,
vae_transform=vae_transform,
vit_transform=vit_transform,
new_token_ids=new_token_ids
)
importrandomimportnumpyasnpseed=42random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
iftorch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic=Truetorch.backends.cudnn.benchmark=Falseinference_hyper=dict(
max_think_token_n=1000,
do_sample=False,
# text_temperature=0.3,
)
# image = Image.open('test_images/meme.jpg')image=Image.open("test_images/gdp.png")
# prompt = "Can someone explain what’s funny about this meme??"prompt="List the top 5 countries in Europe with the highest GDP from this image"# display(image)print(prompt)
print('-'*10)
output_dict=inferencer(image=image, text=prompt, understanding_output=True, **inference_hyper)
print(output_dict['text'])
Outputs
List the top 5 countries in Europe with the highest GDP from this image
----------
1. **Germany** - $3.99T (4.65%)
2. **France** - $2.78T (3.24%)
3. **United Kingdom** - $2.82T (3.29%)
4. **Italy** - $2.07T (2.42%)
5. **Spain** - $1.43T (1.66%)
The text was updated successfully, but these errors were encountered:
Context
ByteDance recently released bagel, one of the first open model with image generation.
This model is relatively small in scale (14B mixture of two 7B models), and uses relatively standard architectures (Qwen2 + siglip/ViT + VAE for image generation).
It's probably a good opportunity to keep moving forward on increasing eole multi-modality, and keep factorizing/structuring the related codepaths.
Notes
Testing
For reference, here is a quick script put together from the official inference.ipynb notebook, with the "gdp" example we already use for pixtral/gemma3 testing.
Outputs
The text was updated successfully, but these errors were encountered: