8000 Unified Weight Adapter system for better maintainability and future feature of Lora system by KohakuBlueleaf · Pull Request #7540 · comfyanonymous/ComfyUI · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Unified Weight Adapter system for better maintainability and future feature of Lora system #7540

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Apr 22, 2025
321 changes: 17 additions & 304 deletions comfy/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import comfy.utils
import comfy.model_management
import comfy.model_base
import comfy.weight_adapter as weight_adapter
import logging
import torch

Expand Down Expand Up @@ -49,139 +50,12 @@ def load_lora(lora, to_load, log_missing=True):
dora_scale = lora[dora_scale_name]
loaded_keys.add(dora_scale_name)

reshape_name = "{}.reshape_weight".format(x)
reshape = None
if reshape_name in lora.keys():
try:
reshape = lora[reshape_name].tolist()
loaded_keys.add(reshape_name)
except:
pass

regular_lora = "{}.lora_up.weight".format(x)
diffusers_lora = "{}_lora.up.weight".format(x)
diffusers2_lora = "{}.lora_B.weight".format(x)
diffusers3_lora = "{}.lora.up.weight".format(x)
mochi_lora = "{}.lora_B".format(x)
transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
A_name = None

if regular_lora in lora.keys():
A_name = regular_lora
B_name = "{}.lora_down.weight".format(x)
mid_name = "{}.lora_mid.weight".format(x)
elif diffusers_lora in lora.keys():
A_name = diffusers_lora
B_name = "{}_lora.down.weight".format(x)
mid_name = None
elif diffusers2_lora in lora.keys():
A_name = diffusers2_lora
B_name = "{}.lora_A.weight".format(x)
mid_name = None
elif diffusers3_lora in lora.keys():
A_name = diffusers3_lora
B_name = "{}.lora.down.weight".format(x)
mid_name = None
elif mochi_lora in lora.keys():
A_name = mochi_lora
B_name = "{}.lora_A".format(x)
mid_name = None
elif transformers_lora in lora.keys():
A_name = transformers_lora
B_name ="{}.lora_linear_layer.down.weight".format(x)
mid_name = None

if A_name is not None:
mid = None
if mid_name is not None and mid_name in lora.keys():
mid = lora[mid_name]
loaded_keys.add(mid_name)
patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid, dora_scale, reshape))
loaded_keys.add(A_name)
loaded_keys.add(B_name)


######## loha
hada_w1_a_name = "{}.hada_w1_a".format(x)
hada_w1_b_name = "{}.hada_w1_b".format(x)
hada_w2_a_name = "{}.hada_w2_a".format(x)
hada_w2_b_name = "{}.hada_w2_b".format(x)
hada_t1_name = "{}.hada_t1".format(x)
hada_t2_name = "{}.hada_t2".format(x)
if hada_w1_a_name in lora.keys():
hada_t1 = None
hada_t2 = None
if hada_t1_name in lora.keys():
hada_t1 = lora[hada_t1_name]
hada_t2 = lora[hada_t2_name]
loaded_keys.add(hada_t1_name)
loaded_keys.add(hada_t2_name)

patch_dict[to_load[x]] = ("loha", (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2, dora_scale))
loaded_keys.add(hada_w1_a_name)
loaded_keys.add(hada_w1_b_name)
loaded_keys.add(hada_w2_a_name)
loaded_keys.add(hada_w2_b_name)


######## lokr
lokr_w1_name = "{}.lokr_w1".format(x)
lokr_w2_name = "{}.lokr_w2".format(x)
lokr_w1_a_name = "{}.lokr_w1_a".format(x)
lokr_w1_b_name = "{}.lokr_w1_b".format(x)
lokr_t2_name = "{}.lokr_t2".format(x)
lokr_w2_a_name = "{}.lokr_w2_a".format(x)
lokr_w2_b_name = "{}.lokr_w2_b".format(x)

lokr_w1 = None
if lokr_w1_name in lora.keys():
lokr_w1 = lora[lokr_w1_name]
loaded_keys.add(lokr_w1_name)

lokr_w2 = None
if lokr_w2_name in lora.keys():
lokr_w2 = lora[lokr_w2_name]
loaded_keys.add(lokr_w2_name)

lokr_w1_a = None
if lokr_w1_a_name in lora.keys():
lokr_w1_a = lora[lokr_w1_a_name]
loaded_keys.add(lokr_w1_a_name)

lokr_w1_b = None
if lokr_w1_b_name in lora.keys():
lokr_w1_b = lora[lokr_w1_b_name]
loaded_keys.add(lokr_w1_b_name)

lokr_w2_a = None
if lokr_w2_a_name in lora.keys():
lokr_w2_a = lora[lokr_w2_a_name]
loaded_keys.add(lokr_w2_a_name)

lokr_w2_b = None
if lokr_w2_b_name in lora.keys():
lokr_w2_b = lora[lokr_w2_b_name]
loaded_keys.add(lokr_w2_b_name)

lokr_t2 = None
if lokr_t2_name in lora.keys():
lokr_t2 = lora[lokr_t2_name]
loaded_keys.add(lokr_t2_name)

if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None):
patch_dict[to_load[x]] = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale))

#glora
a1_name = "{}.a1.weight".format(x)
a2_name = "{}.a2.weight".format(x)
b1_name = "{}.b1.weight".format(x)
b2_name = "{}.b2.weight".format(x)
if a1_name in lora:
patch_dict[to_load[x]] = ("glora", (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha, dora_scale))
loaded_keys.add(a1_name)
loaded_keys.add(a2_name)
loaded_keys.add(b1_name)
loaded_keys.add(b2_name)
for adapter_cls in weight_adapter.adapters:
adapter = adapter_cls.load(x, lora, alpha, dora_scale, loaded_keys)
if adapter is not None:
patch_dict[to_load[x]] = adapter
loaded_keys.update(adapter.loaded_keys)
continue

w_norm_name = "{}.w_norm".format(x)
b_norm_name = "{}.b_norm".format(x)
Expand Down Expand Up @@ -408,26 +282,6 @@ def model_lora_keys_unet(model, key_map={}):
return key_map


def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function):
dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, intermediate_dtype)
lora_diff *= alpha
weight_calc = weight + function(lora_diff).type(weight.dtype)
weight_norm = (
weight_calc.transpose(0, 1)
.reshape(weight_calc.shape[1], -1)
.norm(dim=1, keepdim=True)
.reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1))
.transpose(0, 1)
)

weight_calc *= (dora_scale / weight_norm).type(weight.dtype)
if strength != 1.0:
weight_calc -= weight
weight += strength * (weight_calc)
else:
weight[:] = weight_calc
return weight

def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Tensor:
"""
Pad a tensor to a new shape with zeros.
Expand Down Expand Up @@ -482,6 +336,16 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, ori
if isinstance(v, list):
v = (calculate_weight(v[1:], v[0][1](comfy.model_management.cast_to_device(v[0][0], weight.device, intermediate_dtype, copy=True), inplace=True), key, intermediate_dtype=intermediate_dtype), )

if isinstance(v, weight_adapter.WeightAdapterBase):
output = v.calculate_weight(weight, key, strength, strength_model, offset, function, intermediate_dtype, original_weights)
if output is None:
logging.warning("Calculate Weight Failed: {} {}".format(v.name, key))
else:
weight = output
if old_weight is not None:
weight = old_weight
continue

if len(v) == 1:
patch_type = "diff"
elif len(v) == 2:
Expand All @@ -508,157 +372,6 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, ori
diff_weight = comfy.model_management.cast_to_device(target_weight, weight.device, intermediate_dtype) - \
comfy.model_management.cast_to_device(original_weights[key][0][0], weight.device, intermediate_dtype)
weight += function(strength * comfy.model_management.cast_to_device(diff_weight, weight.device, weight.dtype))
elif patch_type == "lora": #lora/locon
mat1 = comfy.model_management.cast_to_device(v[0], weight.device, intermediate_dtype)
mat2 = comfy.model_management.cast_to_device(v[1], weight.device, intermediate_dtype)
dora_scale = v[4]
reshape = v[5]

if reshape is not None:
weight = pad_tensor_to_shape(weight, reshape)

if v[2] is not None:
alpha = v[2] / mat2.shape[0]
else:
alpha = 1.0

if v[3] is not None:
#locon mid weights, hopefully the math is fine because I didn't properly test it
mat3 = comfy.model_management.cast_to_device(v[3], weight.device, intermediate_dtype)
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
try:
lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape)
if dora_scale is not None:
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e))
elif patch_type == "lokr":
w1 = v[0]
w2 = v[1]
w1_a = v[3]
w1_b = v[4]
w2_a = v[5]
w2_b = v[6]
t2 = v[7]
dora_scale = v[8]
dim = None

if w1 is None:
dim = w1_b.shape[0]
w1 = torch.mm(comfy.model_management.cast_to_device(w1_a, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w1_b, weight.device, intermediate_dtype))
else:
w1 = comfy.model_management.cast_to_device(w1, weight.device, intermediate_dtype)

if w2 is None:
dim = w2_b.shape[0]
if t2 is None:
w2 = torch.mm(comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype))
else:
w2 = torch.einsum('i j k l, j r, i p -> p r k l',
comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype))
else:
w2 = comfy.model_management.cast_to_device(w2, weight.device, intermediate_dtype)

if len(w2.shape) == 4:
w1 = w1.unsqueeze(2).unsqueeze(2)
if v[2] is not None and dim is not None:
alpha = v[2] / dim
else:
alpha = 1.0

try:
lora_diff = torch.kron(w1, w2).reshape(weight.shape)
if dora_scale is not None:
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e))
elif patch_type == "loha":
w1a = v[0]
w1b = v[1]
if v[2] is not None:
alpha = v[2] / w1b.shape[0]
else:
alpha = 1.0

w2a = v[3]
w2b = v[4]
dora_scale = v[7]
if v[5] is not None: #cp decomposition
t1 = v[5]
t2 = v[6]
m1 = torch.einsum('i j k l, j r, i p -> p r k l',
comfy.model_management.cast_to_device(t1, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype))

m2 = torch.einsum('i j k l, j r, i p -> p r k l',
comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype))
else:
m1 = torch.mm(comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype))
m2 = torch.mm(comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype))

try:
lora_diff = (m1 * m2).reshape(weight.shape)
if dora_scale is not None:
1E0A weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e))
elif patch_type == "glora":
dora_scale = v[5]

old_glora = False
if v[3].shape[1] == v[2].shape[0] == v[0].shape[0] == v[1].shape[1]:
rank = v[0].shape[0]
old_glora = True

if v[3].shape[0] == v[2].shape[1] == v[0].shape[1] == v[1].shape[0]:
if old_glora and v[1].shape[0] == weight.shape[0] and weight.shape[0] == weight.shape[1]:
pass
else:
old_glora = False
rank = v[1].shape[0]

a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, intermediate_dtype)
a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, intermediate_dtype)
b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, intermediate_dtype)
b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, intermediate_dtype)

if v[4] is not None:
alpha = v[4] / rank
else:
alpha = 1.0

try:
if old_glora:
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2), a1)).reshape(weight.shape) #old lycoris glora
else:
if weight.dim() > 2:
lora_diff = torch.einsum("o i ..., i j -> o j ...", torch.einsum("o i ..., i j -> o j ...", weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
else:
lora_diff = torch.mm(torch.mm(weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
lora_diff += torch.mm(b1, b2).reshape(weight.shape)

if dora_scale is not None:
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e))
else:
logging.warning("patch type not recognized {} {}".format(patch_type, key))

Expand Down
13 changes: 13 additions & 0 deletions comfy/weight_adapter/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from .base import WeightAdapterBase
from .lora import LoRAAdapter
from .loha import LoHaAdapter
from .lokr import LoKrAdapter
from .glora import GLoRAAdapter


adapters: list[type[WeightAdapterBase]] = [
LoRAAdapter,
LoHaAdapter,
LoKrAdapter,
GLoRAAdapter,
]
Loading
0