Skip to content

Commit 5851928

Browse files
Support for control-lora (#10686)
* run control-lora on diffusers * cannot load lora adapter * test * 1 * add control-lora * 1 * 1 * 1 * fix PeftAdapterMixin * fix module_to_save bug * delete json print * resolve conflits * merged but bug * change peft.py * 1 * delete state_dict print * fix alpha * Create control_lora.py * Add files via upload * rename * no need modify as peft updated * add doc * fix code style * styling isn't that hard 😉 * empty --------- Co-authored-by: Sayak Paul <[email protected]>
1 parent 0c1ccc0 commit 5851928

File tree

7 files changed

+312
-1
lines changed

7 files changed

+312
-1
lines changed

‎docs/source/en/api/models/controlnet.md‎

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,21 @@ url = "https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5/blob/m
3333
pipe = StableDiffusionControlNetPipeline.from_single_file(url, controlnet=controlnet)
3434
```
3535

36+
## Loading from Control LoRA
37+
38+
Control-LoRA is introduced by Stability AI in [stabilityai/control-lora](https://huggingface.co/stabilityai/control-lora) by adding low-rank parameter efficient fine tuning to ControlNet. This approach offers a more efficient and compact method to bring model control to a wider variety of consumer GPUs.
39+
40+
```py
41+
from diffusers import ControlNetModel, UNet2DConditionModel
42+
43+
lora_id ="stabilityai/control-lora"
44+
lora_filename ="control-LoRAs-rank128/control-lora-canny-rank128.safetensors"
45+
46+
unet = UNet2DConditionModel.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", torch_dtype=torch.bfloat16).to("cuda")
47+
controlnet = ControlNetModel.from_unet(unet).to(device="cuda", dtype=torch.bfloat16)
48+
controlnet.load_lora_adapter(lora_id, weight_name=lora_filename, prefix=None, controlnet_config=controlnet.config)
49+
```
50+
3651
## ControlNetModel
3752

3853
[[autodoc]] ControlNetModel
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Control-LoRA inference example
2+
3+
Control-LoRA is introduced by Stability AI in [stabilityai/control-lora](https://huggingface.co/stabilityai/control-lora) by adding low-rank parameter efficient fine tuning to ControlNet. This approach offers a more efficient and compact method to bring model control to a wider variety of consumer GPUs.
4+
5+
## Installing the dependencies
6+
7+
Before running the scripts, make sure to install the library's training dependencies:
8+
9+
**Important**
10+
11+
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
12+
```bash
13+
git clone https://github.com/huggingface/diffusers
14+
cd diffusers
15+
pip install .
16+
```
17+
18+
Then cd in the example folder and run
19+
```bash
20+
pip install -r requirements.txt
21+
```
22+
23+
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
24+
25+
```bash
26+
accelerate config
27+
```
28+
29+
## Inference on SDXL
30+
31+
[stabilityai/control-lora](https://huggingface.co/stabilityai/control-lora) provides a set of Control-LoRA weights for SDXL. Here we use the `canny` condition to generate an image from a text prompt and a reference image.
32+
33+
```bash
34+
python control_lora.py
35+
```
36+
37+
## Acknowledgements
38+
39+
-[stabilityai/control-lora](https://huggingface.co/stabilityai/control-lora)
40+
-[comfyanonymous/ControlNet-v1-1_fp16_safetensors](https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors)
41+
-[HighCWu/control-lora-v2](https://github.com/HighCWu/control-lora-v2)
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
importcv2
2+
importnumpyasnp
3+
importtorch
4+
fromPILimportImage
5+
6+
fromdiffusersimport (
7+
AutoencoderKL,
8+
ControlNetModel,
9+
StableDiffusionXLControlNetPipeline,
10+
UNet2DConditionModel,
11+
)
12+
fromdiffusers.utilsimportload_image, make_image_grid
13+
14+
15+
pipe_id="stabilityai/stable-diffusion-xl-base-1.0"
16+
lora_id="stabilityai/control-lora"
17+
lora_filename="control-LoRAs-rank128/control-lora-canny-rank128.safetensors"
18+
19+
unet=UNet2DConditionModel.from_pretrained(pipe_id, subfolder="unet", torch_dtype=torch.bfloat16).to("cuda")
20+
controlnet=ControlNetModel.from_unet(unet).to(device="cuda", dtype=torch.bfloat16)
21+
controlnet.load_lora_adapter(lora_id, weight_name=lora_filename, prefix=None, controlnet_config=controlnet.config)
22+
23+
prompt="aerial view, a futuristic research complex in a bright foggy jungle, hard lighting"
24+
negative_prompt="low quality, bad quality, sketches"
25+
26+
image=load_image(
27+
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png"
28+
)
29+
30+
controlnet_conditioning_scale=1.0# recommended for good generalization
31+
32+
vae=AutoencoderKL.from_pretrained("stabilityai/sdxl-vae", torch_dtype=torch.bfloat16)
33+
pipe=StableDiffusionXLControlNetPipeline.from_pretrained(
34+
pipe_id,
35+
unet=unet,
36+
controlnet=controlnet,
37+
vae=vae,
38+
torch_dtype=torch.bfloat16,
39+
safety_checker=None,
40+
).to("cuda")
41+
42+
image=np.array(image)
43+
image=cv2.Canny(image, 100, 200)
44+
image=image[:, :, None]
45+
image=np.concatenate([image, image, image], axis=2)
46+
image=Image.fromarray(image)
47+
48+
images=pipe(
49+
prompt,
50+
negative_prompt=negative_prompt,
51+
image=image,
52+
controlnet_conditioning_scale=controlnet_conditioning_scale,
53+
num_images_per_prompt=4,
54+
).images
55+
56+
final_image= [image] +images
57+
grid=make_image_grid(final_image, 1, 5)
58+
grid.save("hf-logo_canny.png")

‎src/diffusers/loaders/peft.py‎

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
MIN_PEFT_VERSION,
2828
USE_PEFT_BACKEND,
2929
check_peft_version,
30+
convert_sai_sd_control_lora_state_dict_to_peft,
3031
convert_unet_state_dict_to_peft,
3132
delete_adapter_layers,
3233
get_adapter_name,
@@ -232,6 +233,13 @@ def load_lora_adapter(
232233
if"lora_A"notinfirst_key:
233234
state_dict=convert_unet_state_dict_to_peft(state_dict)
234235

236+
# Control LoRA from SAI is different from BFL Control LoRA
237+
# https://huggingface.co/stabilityai/control-lora
238+
# https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors
239+
is_sai_sd_control_lora="lora_controlnet"instate_dict
240+
ifis_sai_sd_control_lora:
241+
state_dict=convert_sai_sd_control_lora_state_dict_to_peft(state_dict)
242+
235243
rank={}
236244
forkey, valinstate_dict.items():
237245
# Cannot figure out rank from lora layers that don't have at least 2 dimensions.
@@ -263,6 +271,14 @@ def load_lora_adapter(
263271
adapter_name=adapter_name,
264272
)
265273

274+
# Adjust LoRA config for Control LoRA
275+
ifis_sai_sd_control_lora:
276+
lora_config.lora_alpha=lora_config.r
277+
lora_config.alpha_pattern=lora_config.rank_pattern
278+
lora_config.bias="all"
279+
lora_config.modules_to_save=lora_config.exclude_modules
280+
lora_config.exclude_modules=None
281+
266282
# <Unsafe code
267283
# We can be sure that the following works as it just sets attention processors, lora layers and puts all in the same dtype
268284
# Now we remove any existing hooks to `_pipeline`.

‎src/diffusers/models/controlnets/controlnet.py‎

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
fromtorch.nnimportfunctionalasF
2020

2121
from ...configuration_utilsimportConfigMixin, register_to_config
22+
from ...loadersimportPeftAdapterMixin
2223
from ...loaders.single_file_modelimportFromOriginalModelMixin
2324
from ...utilsimportBaseOutput, logging
2425
from ..attentionimportAttentionMixin
@@ -106,7 +107,7 @@ def forward(self, conditioning):
106107
returnembedding
107108

108109

109-
classControlNetModel(ModelMixin, AttentionMixin, ConfigMixin, FromOriginalModelMixin):
110+
classControlNetModel(ModelMixin, AttentionMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin):
110111
"""
111112
A ControlNet model.
112113

‎src/diffusers/utils/__init__.py‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@
143143
from .remote_utilsimportremote_decode
144144
from .state_dict_utilsimport (
145145
convert_all_state_dict_to_peft,
146+
convert_sai_sd_control_lora_state_dict_to_peft,
146147
convert_state_dict_to_diffusers,
147148
convert_state_dict_to_kohya,
148149
convert_state_dict_to_peft,

‎src/diffusers/utils/state_dict_utils.py‎

Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,36 @@ class StateDictType(enum.Enum):
5656
".to_out.lora_magnitude_vector": ".to_out.0.lora_magnitude_vector",
5757
}
5858

59+
CONTROL_LORA_TO_DIFFUSERS={
60+
".to_q.down": ".to_q.lora_A.weight",
61+
".to_q.up": ".to_q.lora_B.weight",
62+
".to_k.down": ".to_k.lora_A.weight",
63+
".to_k.up": ".to_k.lora_B.weight",
64+
".to_v.down": ".to_v.lora_A.weight",
65+
".to_v.up": ".to_v.lora_B.weight",
66+
".to_out.0.down": ".to_out.0.lora_A.weight",
67+
".to_out.0.up": ".to_out.0.lora_B.weight",
68+
".ff.net.0.proj.down": ".ff.net.0.proj.lora_A.weight",
69+
".ff.net.0.proj.up": ".ff.net.0.proj.lora_B.weight",
70+
".ff.net.2.down": ".ff.net.2.lora_A.weight",
71+
".ff.net.2.up": ".ff.net.2.lora_B.weight",
72+
".proj_in.down": ".proj_in.lora_A.weight",
73+
".proj_in.up": ".proj_in.lora_B.weight",
74+
".proj_out.down": ".proj_out.lora_A.weight",
75+
".proj_out.up": ".proj_out.lora_B.weight",
76+
".conv.down": ".conv.lora_A.weight",
77+
".conv.up": ".conv.lora_B.weight",
78+
**{f".conv{i}.down": f".conv{i}.lora_A.weight"foriinrange(1, 3)},
79+
**{f".conv{i}.up": f".conv{i}.lora_B.weight"foriinrange(1, 3)},
80+
"conv_in.down": "conv_in.lora_A.weight",
81+
"conv_in.up": "conv_in.lora_B.weight",
82+
".conv_shortcut.down": ".conv_shortcut.lora_A.weight",
83+
".conv_shortcut.up": ".conv_shortcut.lora_B.weight",
84+
**{f".linear_{i}.down": f".linear_{i}.lora_A.weight"foriinrange(1, 3)},
85+
**{f".linear_{i}.up": f".linear_{i}.lora_B.weight"foriinrange(1, 3)},
86+
"time_emb_proj.down": "time_emb_proj.lora_A.weight",
87+
"time_emb_proj.up": "time_emb_proj.lora_B.weight",
88+
}
5989

6090
DIFFUSERS_TO_PEFT={
6191
".q_proj.lora_linear_layer.up": ".q_proj.lora_B",
@@ -259,6 +289,155 @@ def convert_unet_state_dict_to_peft(state_dict):
259289
returnconvert_state_dict(state_dict, mapping)
260290

261291

292+
defconvert_sai_sd_control_lora_state_dict_to_peft(state_dict):
293+
def_convert_controlnet_to_diffusers(state_dict):
294+
is_sdxl="input_blocks.11.0.in_layers.0.weight"notinstate_dict
295+
logger.info(f"Using ControlNet lora ({'SDXL'ifis_sdxlelse'SD15'})")
296+
297+
# Retrieves the keys for the input blocks only
298+
num_input_blocks=len({".".join(layer.split(".")[:2]) forlayerinstate_dictif"input_blocks"inlayer})
299+
input_blocks={
300+
layer_id: [keyforkeyinstate_dictiff"input_blocks.{layer_id}"inkey]
301+
forlayer_idinrange(num_input_blocks)
302+
}
303+
layers_per_block=2
304+
305+
# op blocks
306+
op_blocks= [keyforkeyinstate_dictif"0.op"inkey]
307+
308+
converted_state_dict={}
309+
# Conv in layers
310+
forkeyininput_blocks[0]:
311+
diffusers_key=key.replace("input_blocks.0.0", "conv_in")
312+
converted_state_dict[diffusers_key] =state_dict.get(key)
313+
314+
# controlnet time embedding blocks
315+
time_embedding_blocks= [keyforkeyinstate_dictif"time_embed"inkey]
316+
forkeyintime_embedding_blocks:
317+
diffusers_key=key.replace("time_embed.0", "time_embedding.linear_1").replace(
318+
"time_embed.2", "time_embedding.linear_2"
319+
)
320+
converted_state_dict[diffusers_key] =state_dict.get(key)
321+
322+
# controlnet label embedding blocks
323+
label_embedding_blocks= [keyforkeyinstate_dictif"label_emb"inkey]
324+
forkeyinlabel_embedding_blocks:
325+
diffusers_key=key.replace("label_emb.0.0", "add_embedding.linear_1").replace(
326+
"label_emb.0.2", "add_embedding.linear_2"
327+
)
328+
converted_state_dict[diffusers_key] =state_dict.get(key)
329+
330+
# Down blocks
331+
foriinrange(1, num_input_blocks):
332+
block_id= (i-1) // (layers_per_block+1)
333+
layer_in_block_id= (i-1) % (layers_per_block+1)
334+
335+
resnets= [
336+
keyforkeyininput_blocks[i] iff"input_blocks.{i}.0"inkeyandf"input_blocks.{i}.0.op"notinkey
337+
]
338+
forkeyinresnets:
339+
diffusers_key= (
340+
key.replace("in_layers.0", "norm1")
341+
.replace("in_layers.2", "conv1")
342+
.replace("out_layers.0", "norm2")
343+
.replace("out_layers.3", "conv2")
344+
.replace("emb_layers.1", "time_emb_proj")
345+
.replace("skip_connection", "conv_shortcut")
346+
)
347+
diffusers_key=diffusers_key.replace(
348+
f"input_blocks.{i}.0", f"down_blocks.{block_id}.resnets.{layer_in_block_id}"
349+
)
350+
converted_state_dict[diffusers_key] =state_dict.get(key)
351+
352+
iff"input_blocks.{i}.0.op.bias"instate_dict:
353+
forkeyin [keyforkeyinop_blocksiff"input_blocks.{i}.0.op"inkey]:
354+
diffusers_key=key.replace(
355+
f"input_blocks.{i}.0.op", f"down_blocks.{block_id}.downsamplers.0.conv"
356+
)
357+
converted_state_dict[diffusers_key] =state_dict.get(key)
358+
359+
attentions= [keyforkeyininput_blocks[i] iff"input_blocks.{i}.1"inkey]
360+
ifattentions:
361+
forkeyinattentions:
362+
diffusers_key=key.replace(
363+
f"input_blocks.{i}.1", f"down_blocks.{block_id}.attentions.{layer_in_block_id}"
364+
)
365+
converted_state_dict[diffusers_key] =state_dict.get(key)
366+
367+
# controlnet down blocks
368+
foriinrange(num_input_blocks):
369+
converted_state_dict[f"controlnet_down_blocks.{i}.weight"] =state_dict.get(f"zero_convs.{i}.0.weight")
370+
converted_state_dict[f"controlnet_down_blocks.{i}.bias"] =state_dict.get(f"zero_convs.{i}.0.bias")
371+
372+
# Retrieves the keys for the middle blocks only
373+
num_middle_blocks=len({".".join(layer.split(".")[:2]) forlayerinstate_dictif"middle_block"inlayer})
374+
middle_blocks={
375+
layer_id: [keyforkeyinstate_dictiff"middle_block.{layer_id}"inkey]
376+
forlayer_idinrange(num_middle_blocks)
377+
}
378+
379+
# Mid blocks
380+
forkeyinmiddle_blocks.keys():
381+
diffusers_key=max(key-1, 0)
382+
ifkey%2==0:
383+
forkinmiddle_blocks[key]:
384+
diffusers_key_hf= (
385+
k.replace("in_layers.0", "norm1")
386+
.replace("in_layers.2", "conv1")
387+
.replace("out_layers.0", "norm2")
388+
.replace("out_layers.3", "conv2")
389+
.replace("emb_layers.1", "time_emb_proj")
390+
.replace("skip_connection", "conv_shortcut")
391+
)
392+
diffusers_key_hf=diffusers_key_hf.replace(
393+
f"middle_block.{key}", f"mid_block.resnets.{diffusers_key}"
394+
)
395+
converted_state_dict[diffusers_key_hf] =state_dict.get(k)
396+
else:
397+
forkinmiddle_blocks[key]:
398+
diffusers_key_hf=k.replace(f"middle_block.{key}", f"mid_block.attentions.{diffusers_key}")
399+
converted_state_dict[diffusers_key_hf] =state_dict.get(k)
400+
401+
# mid block
402+
converted_state_dict["controlnet_mid_block.weight"] =state_dict.get("middle_block_out.0.weight")
403+
converted_state_dict["controlnet_mid_block.bias"] =state_dict.get("middle_block_out.0.bias")
404+
405+
# controlnet cond embedding blocks
406+
cond_embedding_blocks={
407+
".".join(layer.split(".")[:2])
408+
forlayerinstate_dict
409+
if"input_hint_block"inlayer
410+
and ("input_hint_block.0"notinlayer)
411+
and ("input_hint_block.14"notinlayer)
412+
}
413+
num_cond_embedding_blocks=len(cond_embedding_blocks)
414+
415+
foridxinrange(1, num_cond_embedding_blocks+1):
416+
diffusers_idx=idx-1
417+
cond_block_id=2*idx
418+
419+
converted_state_dict[f"controlnet_cond_embedding.blocks.{diffusers_idx}.weight"] =state_dict.get(
420+
f"input_hint_block.{cond_block_id}.weight"
421+
)
422+
converted_state_dict[f"controlnet_cond_embedding.blocks.{diffusers_idx}.bias"] =state_dict.get(
423+
f"input_hint_block.{cond_block_id}.bias"
424+
)
425+
426+
forkeyin [keyforkeyinstate_dictif"input_hint_block.0"inkey]:
427+
diffusers_key=key.replace("input_hint_block.0", "controlnet_cond_embedding.conv_in")
428+
converted_state_dict[diffusers_key] =state_dict.get(key)
429+
430+
forkeyin [keyforkeyinstate_dictif"input_hint_block.14"inkey]:
431+
diffusers_key=key.replace("input_hint_block.14", "controlnet_cond_embedding.conv_out")
432+
converted_state_dict[diffusers_key] =state_dict.get(key)
433+
434+
returnconverted_state_dict
435+
436+
state_dict=_convert_controlnet_to_diffusers(state_dict)
437+
mapping=CONTROL_LORA_TO_DIFFUSERS
438+
returnconvert_state_dict(state_dict, mapping)
439+
440+
262441
defconvert_all_state_dict_to_peft(state_dict):
263442
r"""
264443
Attempts to first `convert_state_dict_to_peft`, and if it doesn't detect `lora_linear_layer` for a valid

0 commit comments

Comments
(0)