Skip to content

Commit d73b419

Browse files
authored
feat: SDXS-09 support and update doc (#1356)
1 parent 5c243db commit d73b419

7 files changed

Lines changed: 51 additions & 48 deletions

File tree

docs/distilled_sd.md

Lines changed: 16 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -87,51 +87,32 @@ pipe.save_pretrained("segmindtiny-sd", safe_serialization=True)
8787
```bash
8888
python convert_diffusers_to_original_stable_diffusion.py \
8989
--model_path ./segmindtiny-sd \
90-
--checkpoint_path ./segmind_tiny-sd.ckpt --half
90+
--checkpoint_path ./segmind_tiny-sd.safetensors --half --use_safetensors
9191
```
9292

93-
The file segmind_tiny-sd.ckpt will be generated and is now ready for use with sd.cpp. You can follow a similar process for the other models mentioned above.
93+
The file segmind_tiny-sd.safetensors will be generated and is now ready for use with sd.cpp. You can follow a similar process for the other models mentioned above.
9494

9595

96-
##### Another available .ckpt file:
97-
98-
* https://huggingface.co/ClashSAN/small-sd/resolve/main/tinySDdistilled.ckpt
99-
100-
To use this file, you must first adjust its non-contiguous tensors:
101-
102-
```python
103-
import torch
104-
ckpt = torch.load("tinySDdistilled.ckpt", map_location=torch.device('cpu'))
105-
for key, value in ckpt['state_dict'].items():
106-
if isinstance(value, torch.Tensor):
107-
ckpt['state_dict'][key] = value.contiguous()
108-
torch.save(ckpt, "tinySDdistilled_fixed.ckpt")
109-
```
110-
111-
112-
### SDXS-512
96+
### SDXS-512-DreamShaper
11397

11498
Another very tiny and **incredibly fast** model is SDXS by IDKiro et al. The authors refer to it as *"Real-Time One-Step Latent Diffusion Models with Image Conditions"*. For details read the paper: https://arxiv.org/pdf/2403.16627 . Once again the authors removed some more blocks of U-Net part and unlike other SD1 models they use an adjusted _AutoEncoderTiny_ instead of default _AutoEncoderKL_ for the VAE part.
99+
##### Some ready-to-run SDXS-512 model files are available online, such as:
115100

116-
##### 1. Download the diffusers model from Hugging Face using Python:
117-
118-
```python
119-
from diffusers import StableDiffusionPipeline
120-
pipe = StableDiffusionPipeline.from_pretrained("IDKiro/sdxs-512-dreamshaper")
121-
pipe.save_pretrained(save_directory="sdxs")
122-
```
123-
##### 2. Create a safetensors file
124-
125-
```bash
126-
python convert_diffusers_to_original_stable_diffusion.py \
127-
--model_path sdxs --checkpoint_path sdxs.safetensors --half --use_safetensors
128-
```
129-
130-
##### 3. Run the model as follows:
101+
* https://huggingface.co/akleine/sdxs-512
102+
* https://huggingface.co/concedo/sdxs-512-tinySDdistilled-GGUF
131103

104+
##### Run the model as follows:
132105
```bash
133106
~/stable-diffusion.cpp/build/bin/sd-cli -m sdxs.safetensors -p "portrait of a lovely cat" \
134107
--cfg-scale 1 --steps 1
135108
```
109+
Both options: ``` --cfg-scale 1 ``` and ``` --steps 1 ``` are mandatory here.
110+
111+
### SDXS-512-0.9
112+
113+
Even though the name "SDXS-512-0.9" is similar to "SDXS-512-DreamShaper", it is *completely different* but also **incredibly fast**. Sometimes it is preferred, so try it yourself.
114+
##### Download a ready-to-run file from here:
115+
116+
* https://huggingface.co/akleine/sdxs-09
136117

137-
Both options: ``` --cfg-scale 1 ``` and ``` --steps 1 ``` are mandatory here.
118+
For the use of this model, both options ``` --cfg-scale 1 ``` and ``` --steps 1 ``` are again absolutely necessary.

src/common_block.hpp

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,7 @@ class CrossAttention : public GGMLBlock {
277277
int64_t context_dim;
278278
int64_t n_head;
279279
int64_t d_head;
280+
bool xtra_dim = false;
280281

281282
public:
282283
CrossAttention(int64_t query_dim,
@@ -288,7 +289,11 @@ class CrossAttention : public GGMLBlock {
288289
query_dim(query_dim),
289290
context_dim(context_dim) {
290291
int64_t inner_dim = d_head * n_head;
291-
292+
if (context_dim == 320 && d_head == 320) {
293+
// LOG_DEBUG("CrossAttention: temp set dim to 1024 for sdxs_09");
294+
xtra_dim = true;
295+
context_dim = 1024;
296+
}
292297
blocks["to_q"] = std::shared_ptr<GGMLBlock>(new Linear(query_dim, inner_dim, false));
293298
blocks["to_k"] = std::shared_ptr<GGMLBlock>(new Linear(context_dim, inner_dim, false));
294299
blocks["to_v"] = std::shared_ptr<GGMLBlock>(new Linear(context_dim, inner_dim, false));
@@ -313,10 +318,16 @@ class CrossAttention : public GGMLBlock {
313318
int64_t n_context = context->ne[1];
314319
int64_t inner_dim = d_head * n_head;
315320

316-
auto q = to_q->forward(ctx, x); // [N, n_token, inner_dim]
321+
auto q = to_q->forward(ctx, x); // [N, n_token, inner_dim]
322+
if (xtra_dim) {
323+
// LOG_DEBUG("CrossAttention: temp set dim to 1024 for sdxs_09");
324+
context->ne[0] = 1024; // patch dim
325+
}
317326
auto k = to_k->forward(ctx, context); // [N, n_context, inner_dim]
318327
auto v = to_v->forward(ctx, context); // [N, n_context, inner_dim]
319-
328+
if (xtra_dim) {
329+
context->ne[0] = 320; // reset dim to orig
330+
}
320331
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, n_head, nullptr, false, ctx->flash_attn_enabled); // [N, n_token, inner_dim]
321332

322333
x = to_out_0->forward(ctx, x); // [N, n_token, query_dim]

src/model.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1019,6 +1019,7 @@ SDVersion ModelLoader::get_sd_version() {
10191019
bool has_middle_block_1 = false;
10201020
bool has_output_block_311 = false;
10211021
bool has_output_block_71 = false;
1022+
bool has_attn_1024 = false;
10221023

10231024
for (auto& [name, tensor_storage] : tensor_storage_map) {
10241025
if (!(is_xl)) {
@@ -1091,6 +1092,10 @@ SDVersion ModelLoader::get_sd_version() {
10911092
if (tensor_storage.name.find("model.diffusion_model.output_blocks.7.1") != std::string::npos ||
10921093
tensor_storage.name.find("unet.up_blocks.2.attentions.1") != std::string::npos) {
10931094
has_output_block_71 = true;
1095+
if (tensor_storage.name.find("model.diffusion_model.output_blocks.7.1.transformer_blocks.0.attn1.to_k.weight") != std::string::npos) {
1096+
if (tensor_storage.ne[0] == 1024)
1097+
has_attn_1024 = true;
1098+
}
10941099
}
10951100
if (tensor_storage.name == "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight" ||
10961101
tensor_storage.name == "cond_stage_model.model.token_embedding.weight" ||
@@ -1164,7 +1169,7 @@ SDVersion ModelLoader::get_sd_version() {
11641169
}
11651170
if (!has_middle_block_1) {
11661171
if (!has_output_block_71) {
1167-
return VERSION_SDXS;
1172+
return VERSION_SDXS_512_DS;
11681173
}
11691174
return VERSION_SD1_TINY_UNET;
11701175
}
@@ -1174,7 +1179,7 @@ SDVersion ModelLoader::get_sd_version() {
11741179
return VERSION_SD2_INPAINT;
11751180
}
11761181
if (!has_middle_block_1) {
1177-
return VERSION_SD2_TINY_UNET;
1182+
return has_attn_1024 ? VERSION_SDXS_09 : VERSION_SD2_TINY_UNET;
11781183
}
11791184
return VERSION_SD2;
11801185
}

src/model.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ enum SDVersion {
2828
VERSION_SD2,
2929
VERSION_SD2_INPAINT,
3030
VERSION_SD2_TINY_UNET,
31-
VERSION_SDXS,
31+
VERSION_SDXS_512_DS,
32+
VERSION_SDXS_09,
3233
VERSION_SDXL,
3334
VERSION_SDXL_INPAINT,
3435
VERSION_SDXL_PIX2PIX,
@@ -55,14 +56,14 @@ enum SDVersion {
5556
};
5657

5758
static inline bool sd_version_is_sd1(SDVersion version) {
58-
if (version == VERSION_SD1 || version == VERSION_SD1_INPAINT || version == VERSION_SD1_PIX2PIX || version == VERSION_SD1_TINY_UNET || version == VERSION_SDXS) {
59+
if (version == VERSION_SD1 || version == VERSION_SD1_INPAINT || version == VERSION_SD1_PIX2PIX || version == VERSION_SD1_TINY_UNET || version == VERSION_SDXS_512_DS) {
5960
return true;
6061
}
6162
return false;
6263
}
6364

6465
static inline bool sd_version_is_sd2(SDVersion version) {
65-
if (version == VERSION_SD2 || version == VERSION_SD2_INPAINT || version == VERSION_SD2_TINY_UNET) {
66+
if (version == VERSION_SD2 || version == VERSION_SD2_INPAINT || version == VERSION_SD2_TINY_UNET || version == VERSION_SDXS_09) {
6667
return true;
6768
}
6869
return false;

src/name_conversion.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1120,7 +1120,7 @@ std::string convert_tensor_name(std::string name, SDVersion version) {
11201120
for (const auto& prefix : first_stage_model_prefix_vec) {
11211121
if (starts_with(name, prefix)) {
11221122
name = convert_first_stage_model_name(name.substr(prefix.size()), prefix);
1123-
if (version == VERSION_SDXS) {
1123+
if (version == VERSION_SDXS_512_DS || version == VERSION_SDXS_09) {
11241124
name = "tae." + name;
11251125
} else {
11261126
name = prefix + name;

src/stable-diffusion.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ const char* model_version_to_str[] = {
3030
"SD 2.x",
3131
"SD 2.x Inpaint",
3232
"SD 2.x Tiny UNet",
33-
"SDXS",
33+
"SDXS (512-DS)",
34+
"SDXS (09)",
3435
"SDXL",
3536
"SDXL Inpaint",
3637
"SDXL Instruct-Pix2Pix",
@@ -414,7 +415,7 @@ class StableDiffusionGGML {
414415
}
415416

416417
bool tae_preview_only = sd_ctx_params->tae_preview_only;
417-
if (version == VERSION_SDXS) {
418+
if (version == VERSION_SDXS_512_DS || version == VERSION_SDXS_09) {
418419
tae_preview_only = false;
419420
use_tae = true;
420421
}

src/unet.hpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -217,11 +217,11 @@ class UnetModelBlock : public GGMLBlock {
217217
} else if (sd_version_is_unet_edit(version)) {
218218
in_channels = 8;
219219
}
220-
if (version == VERSION_SD1_TINY_UNET || version == VERSION_SD2_TINY_UNET || version == VERSION_SDXS) {
220+
if (version == VERSION_SD1_TINY_UNET || version == VERSION_SD2_TINY_UNET || version == VERSION_SDXS_512_DS || version == VERSION_SDXS_09) {
221221
num_res_blocks = 1;
222222
channel_mult = {1, 2, 4};
223223
tiny_unet = true;
224-
if (version == VERSION_SDXS) {
224+
if (version == VERSION_SDXS_512_DS) {
225225
attention_resolutions = {4, 2}; // here just like SDXL
226226
}
227227
}
@@ -264,6 +264,10 @@ class UnetModelBlock : public GGMLBlock {
264264
if (version == VERSION_SVD) {
265265
return new SpatialVideoTransformer(in_channels, n_head, d_head, depth, context_dim, use_linear_projection);
266266
} else {
267+
if (version == VERSION_SDXS_09 && n_head == 5) {
268+
n_head = 1; // to carry a special case of sdxs_09 into CrossAttentionLayer,
269+
d_head = 320; // works as long the product remains equal (5*64 == 1*320)
270+
}
267271
return new SpatialTransformer(in_channels, n_head, d_head, depth, context_dim, use_linear_projection);
268272
}
269273
};

0 commit comments

Comments
 (0)