Skip to content

Commit c97702e

Browse files
authored
feat: add sd-webui style Hires. fix support (#1451)
1 parent 44cca3d commit c97702e

14 files changed

Lines changed: 868 additions & 134 deletions

File tree

examples/cli/main.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,9 @@ void parse_args(int argc, const char** argv, SDCliParams& cli_params, SDContextP
278278
bool valid = cli_params.resolve_and_validate();
279279
if (valid && cli_params.mode != METADATA) {
280280
valid = ctx_params.resolve_and_validate(cli_params.mode) &&
281-
gen_params.resolve_and_validate(cli_params.mode, ctx_params.lora_model_dir);
281+
gen_params.resolve_and_validate(cli_params.mode,
282+
ctx_params.lora_model_dir,
283+
ctx_params.hires_upscalers_dir);
282284
}
283285

284286
if (!valid) {
@@ -688,6 +690,10 @@ int main(int argc, const char* argv[]) {
688690
vae_decode_only = false;
689691
}
690692

693+
if (gen_params.hires_enabled && !gen_params.hires_upscaler_model_path.empty()) {
694+
vae_decode_only = false;
695+
}
696+
691697
sd_ctx_params_t sd_ctx_params = ctx_params.to_sd_ctx_params_t(vae_decode_only, true, cli_params.taesd_preview);
692698

693699
SDImageVec results;

examples/common/common.cpp

Lines changed: 173 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,10 @@ ArgOptions SDContextParams::get_options() {
351351
"--lora-model-dir",
352352
"lora model directory",
353353
&lora_model_dir},
354-
354+
{"",
355+
"--hires-upscalers-dir",
356+
"highres fix upscaler model directory",
357+
&hires_upscalers_dir},
355358
{"",
356359
"--tensor-type-rules",
357360
"weight type per tensor pattern (example: \"^vae\\.=f16,model\\.=q8_0\")",
@@ -649,6 +652,7 @@ std::string SDContextParams::to_string() const {
649652
<< " wtype: " << sd_type_name(wtype) << ",\n"
650653
<< " tensor_type_rules: \"" << tensor_type_rules << "\",\n"
651654
<< " lora_model_dir: \"" << lora_model_dir << "\",\n"
655+
<< " hires_upscalers_dir: \"" << hires_upscalers_dir << "\",\n"
652656
<< " photo_maker_path: \"" << photo_maker_path << "\",\n"
653657
<< " rng_type: " << sd_rng_type_name(rng_type) << ",\n"
654658
<< " sampler_rng_type: " << sd_rng_type_name(sampler_rng_type) << ",\n"
@@ -777,6 +781,10 @@ ArgOptions SDGenerationParams::get_options() {
777781
"--pm-id-embed-path",
778782
"path to PHOTOMAKER v2 id embed",
779783
&pm_id_embed_path},
784+
{"",
785+
"--hires-upscaler",
786+
"highres fix upscaler, Latent (nearest) or a model name/path under --hires-upscalers-dir (default: Latent (nearest))",
787+
&hires_upscaler},
780788
};
781789

782790
options.int_options = {
@@ -826,6 +834,22 @@ ArgOptions SDGenerationParams::get_options() {
826834
"--upscale-tile-size",
827835
"tile size for ESRGAN upscaling (default: 128)",
828836
&upscale_tile_size},
837+
{"",
838+
"--hires-width",
839+
"highres fix target width, 0 to use --hires-scale (default: 0)",
840+
&hires_width},
841+
{"",
842+
"--hires-height",
843+
"highres fix target height, 0 to use --hires-scale (default: 0)",
844+
&hires_height},
845+
{"",
846+
"--hires-steps",
847+
"highres fix second pass sample steps, 0 to reuse --steps (default: 0)",
848+
&hires_steps},
849+
{"",
850+
"--hires-upscale-tile-size",
851+
"highres fix upscaler tile size, reserved for model-backed upscalers (default: 128)",
852+
&hires_upscale_tile_size},
829853
};
830854

831855
options.float_options = {
@@ -913,6 +937,14 @@ ArgOptions SDGenerationParams::get_options() {
913937
"--vae-tile-overlap",
914938
"tile overlap for vae tiling, in fraction of tile size (default: 0.5)",
915939
&vae_tiling_params.target_overlap},
940+
{"",
941+
"--hires-scale",
942+
"highres fix scale when target size is not set (default: 2.0)",
943+
&hires_scale},
944+
{"",
945+
"--hires-denoising-strength",
946+
"highres fix second pass denoising strength (default: 0.7)",
947+
&hires_denoising_strength},
916948
};
917949

918950
options.bool_options = {
@@ -936,6 +968,11 @@ ArgOptions SDGenerationParams::get_options() {
936968
"process vae in tiles to reduce memory usage",
937969
true,
938970
&vae_tiling_params.enabled},
971+
{"",
972+
"--hires",
973+
"enable highres fix",
974+
true,
975+
&hires_enabled},
939976
};
940977

941978
auto on_seed_arg = [&](int argc, const char** argv, int index) {
@@ -1424,6 +1461,37 @@ static bool parse_lora_json_field(const json& parent,
14241461
return true;
14251462
}
14261463

1464+
static bool resolve_model_file_from_dir(const std::string& model_name,
1465+
const std::string& model_dir,
1466+
const std::vector<std::string>& valid_ext,
1467+
const char* label,
1468+
std::string& resolved_path) {
1469+
if (model_dir.empty()) {
1470+
LOG_ERROR("%s directory is empty", label);
1471+
return false;
1472+
}
1473+
if (model_name.empty() ||
1474+
model_name.find('/') != std::string::npos ||
1475+
model_name.find('\\') != std::string::npos ||
1476+
fs::path(model_name).has_root_path() ||
1477+
fs::path(model_name).has_extension()) {
1478+
LOG_ERROR("%s must be a model name without path or extension: %s", label, model_name.c_str());
1479+
return false;
1480+
}
1481+
1482+
fs::path model_dir_path = model_dir;
1483+
for (const auto& ext : valid_ext) {
1484+
fs::path try_path = model_dir_path / (model_name + ext);
1485+
if (fs::exists(try_path) && fs::is_regular_file(try_path)) {
1486+
resolved_path = try_path.lexically_normal().string();
1487+
return true;
1488+
}
1489+
}
1490+
1491+
LOG_ERROR("can not find %s %s in %s", label, model_name.c_str(), model_dir_path.lexically_normal().string().c_str());
1492+
return false;
1493+
}
1494+
14271495
bool SDGenerationParams::from_json_str(
14281496
const std::string& json_str,
14291497
const std::function<std::string(const std::string&)>& lora_path_resolver) {
@@ -1487,6 +1555,34 @@ bool SDGenerationParams::from_json_str(
14871555
load_if_exists("increase_ref_index", increase_ref_index);
14881556
load_if_exists("embed_image_metadata", embed_image_metadata);
14891557

1558+
if (j.contains("hires") && j["hires"].is_object()) {
1559+
const json& hires_json = j["hires"];
1560+
if (hires_json.contains("enabled") && hires_json["enabled"].is_boolean()) {
1561+
hires_enabled = hires_json["enabled"];
1562+
}
1563+
if (hires_json.contains("upscaler") && hires_json["upscaler"].is_string()) {
1564+
hires_upscaler = hires_json["upscaler"];
1565+
}
1566+
if (hires_json.contains("scale") && hires_json["scale"].is_number()) {
1567+
hires_scale = hires_json["scale"];
1568+
}
1569+
if (hires_json.contains("target_width") && hires_json["target_width"].is_number_integer()) {
1570+
hires_width = hires_json["target_width"];
1571+
}
1572+
if (hires_json.contains("target_height") && hires_json["target_height"].is_number_integer()) {
1573+
hires_height = hires_json["target_height"];
1574+
}
1575+
if (hires_json.contains("steps") && hires_json["steps"].is_number_integer()) {
1576+
hires_steps = hires_json["steps"];
1577+
}
1578+
if (hires_json.contains("denoising_strength") && hires_json["denoising_strength"].is_number()) {
1579+
hires_denoising_strength = hires_json["denoising_strength"];
1580+
}
1581+
if (hires_json.contains("upscale_tile_size") && hires_json["upscale_tile_size"].is_number_integer()) {
1582+
hires_upscale_tile_size = hires_json["upscale_tile_size"];
1583+
}
1584+
}
1585+
14901586
auto parse_sample_params_json = [&](const json& sample_json,
14911587
sd_sample_params_t& target_params,
14921588
std::vector<int>& target_skip_layers,
@@ -1800,7 +1896,7 @@ bool SDGenerationParams::initialize_cache_params() {
18001896
return true;
18011897
}
18021898

1803-
bool SDGenerationParams::resolve(const std::string& lora_model_dir, bool strict) {
1899+
bool SDGenerationParams::resolve(const std::string& lora_model_dir, const std::string& hires_upscalers_dir, bool strict) {
18041900
if (high_noise_sample_params.sample_steps <= 0) {
18051901
high_noise_sample_params.sample_steps = -1;
18061902
}
@@ -1819,6 +1915,27 @@ bool SDGenerationParams::resolve(const std::string& lora_model_dir, bool strict)
18191915
sample_params.sample_steps = std::clamp(sample_params.sample_steps, 1, 100);
18201916
}
18211917

1918+
hires_upscaler_model_path.clear();
1919+
if (hires_enabled) {
1920+
if (hires_upscaler.empty()) {
1921+
hires_upscaler = "Latent (nearest)";
1922+
}
1923+
resolved_hires_upscaler = str_to_sd_hires_upscaler(hires_upscaler.c_str());
1924+
if (resolved_hires_upscaler == SD_HIRES_UPSCALER_NONE) {
1925+
hires_enabled = false;
1926+
} else if (resolved_hires_upscaler == SD_HIRES_UPSCALER_COUNT) {
1927+
static const std::vector<std::string> valid_ext = {".gguf", ".safetensors", ".pt", ".pth"};
1928+
if (!resolve_model_file_from_dir(hires_upscaler,
1929+
hires_upscalers_dir,
1930+
valid_ext,
1931+
"hires upscaler",
1932+
hires_upscaler_model_path)) {
1933+
return false;
1934+
}
1935+
resolved_hires_upscaler = SD_HIRES_UPSCALER_MODEL;
1936+
}
1937+
}
1938+
18221939
prompt_with_lora = prompt;
18231940
if (!lora_model_dir.empty()) {
18241941
extract_and_remove_lora(lora_model_dir);
@@ -1883,6 +2000,29 @@ bool SDGenerationParams::validate(SDMode mode) {
18832000
return false;
18842001
}
18852002

2003+
if (hires_enabled) {
2004+
if (hires_width < 0 || hires_height < 0) {
2005+
LOG_ERROR("error: hires target width and height must be >= 0");
2006+
return false;
2007+
}
2008+
if (hires_scale <= 0.f && hires_width <= 0 && hires_height <= 0) {
2009+
LOG_ERROR("error: hires scale must be positive when target size is not set");
2010+
return false;
2011+
}
2012+
if (hires_steps < 0) {
2013+
LOG_ERROR("error: hires steps must be >= 0");
2014+
return false;
2015+
}
2016+
if (hires_denoising_strength <= 0.f || hires_denoising_strength > 1.f) {
2017+
LOG_ERROR("error: hires denoising strength must be in (0.0, 1.0]");
2018+
return false;
2019+
}
2020+
if (hires_upscale_tile_size < 1) {
2021+
LOG_ERROR("error: hires upscale tile size must be positive");
2022+
return false;
2023+
}
2024+
}
2025+
18862026
if (mode == UPSCALE) {
18872027
if (init_image_path.length() == 0) {
18882028
LOG_ERROR("error: upscale mode needs an init image (--init-img)\n");
@@ -1893,8 +2033,11 @@ bool SDGenerationParams::validate(SDMode mode) {
18932033
return true;
18942034
}
18952035

1896-
bool SDGenerationParams::resolve_and_validate(SDMode mode, const std::string& lora_model_dir, bool strict) {
1897-
if (!resolve(lora_model_dir, strict)) {
2036+
bool SDGenerationParams::resolve_and_validate(SDMode mode,
2037+
const std::string& lora_model_dir,
2038+
const std::string& hires_upscalers_dir,
2039+
bool strict) {
2040+
if (!resolve(lora_model_dir, hires_upscalers_dir, strict)) {
18982041
return false;
18992042
}
19002043
if (!validate(mode)) {
@@ -1965,6 +2108,16 @@ sd_img_gen_params_t SDGenerationParams::to_sd_img_gen_params_t() {
19652108
params.pm_params = pm_params;
19662109
params.vae_tiling_params = vae_tiling_params;
19672110
params.cache = cache_params;
2111+
2112+
params.hires.enabled = hires_enabled;
2113+
params.hires.upscaler = resolved_hires_upscaler;
2114+
params.hires.model_path = hires_upscaler_model_path.empty() ? nullptr : hires_upscaler_model_path.c_str();
2115+
params.hires.scale = hires_scale;
2116+
params.hires.target_width = hires_width;
2117+
params.hires.target_height = hires_height;
2118+
params.hires.steps = hires_steps;
2119+
params.hires.denoising_strength = hires_denoising_strength;
2120+
params.hires.upscale_tile_size = hires_upscale_tile_size;
19682121
return params;
19692122
}
19702123

@@ -2089,6 +2242,15 @@ std::string SDGenerationParams::to_string() const {
20892242
<< " seed: " << seed << ",\n"
20902243
<< " upscale_repeats: " << upscale_repeats << ",\n"
20912244
<< " upscale_tile_size: " << upscale_tile_size << ",\n"
2245+
<< " hires: { enabled: " << (hires_enabled ? "true" : "false")
2246+
<< ", upscaler: \"" << hires_upscaler << "\""
2247+
<< ", model_path: \"" << hires_upscaler_model_path << "\""
2248+
<< ", scale: " << hires_scale
2249+
<< ", target_width: " << hires_width
2250+
<< ", target_height: " << hires_height
2251+
<< ", steps: " << hires_steps
2252+
<< ", denoising_strength: " << hires_denoising_strength
2253+
<< ", upscale_tile_size: " << hires_upscale_tile_size << " },\n"
20922254
<< " vae_tiling_params: { "
20932255
<< vae_tiling_params.enabled << ", "
20942256
<< vae_tiling_params.tile_size_x << ", "
@@ -2162,6 +2324,13 @@ std::string get_image_params(const SDContextParams& ctx_params, const SDGenerati
21622324
if (gen_params.clip_skip != -1) {
21632325
parameter_string += "Clip skip: " + std::to_string(gen_params.clip_skip) + ", ";
21642326
}
2327+
if (gen_params.hires_enabled) {
2328+
parameter_string += "Hires upscale: " + gen_params.hires_upscaler + ", ";
2329+
parameter_string += "Hires scale: " + std::to_string(gen_params.hires_scale) + ", ";
2330+
parameter_string += "Hires resize: " + std::to_string(gen_params.hires_width) + "x" + std::to_string(gen_params.hires_height) + ", ";
2331+
parameter_string += "Hires steps: " + std::to_string(gen_params.hires_steps) + ", ";
2332+
parameter_string += "Denoising strength: " + std::to_string(gen_params.hires_denoising_strength) + ", ";
2333+
}
21652334
parameter_string += "Version: stable-diffusion.cpp";
21662335
return parameter_string;
21672336
}

examples/common/common.h

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ struct SDContextParams {
101101
sd_type_t wtype = SD_TYPE_COUNT;
102102
std::string tensor_type_rules;
103103
std::string lora_model_dir = ".";
104+
std::string hires_upscalers_dir;
104105

105106
std::map<std::string, std::string> embedding_map;
106107
std::vector<sd_embedding_t> embedding_vec;
@@ -190,12 +191,23 @@ struct SDGenerationParams {
190191
int upscale_repeats = 1;
191192
int upscale_tile_size = 128;
192193

194+
bool hires_enabled = false;
195+
std::string hires_upscaler = "Latent (nearest)";
196+
std::string hires_upscaler_model_path;
197+
float hires_scale = 2.f;
198+
int hires_width = 0;
199+
int hires_height = 0;
200+
int hires_steps = 0;
201+
float hires_denoising_strength = 0.7f;
202+
int hires_upscale_tile_size = 128;
203+
193204
std::map<std::string, float> lora_map;
194205
std::map<std::string, float> high_noise_lora_map;
195206

196207
// Derived and normalized fields.
197208
std::string prompt_with_lora; // for metadata record only
198209
std::vector<sd_lora_t> lora_vec;
210+
sd_hires_upscaler_t resolved_hires_upscaler;
199211

200212
// Owned execution payload.
201213
SDImageOwner init_image;
@@ -225,9 +237,12 @@ struct SDGenerationParams {
225237
void set_width_and_height_if_unset(int w, int h);
226238
int get_resolved_width() const;
227239
int get_resolved_height() const;
228-
bool resolve(const std::string& lora_model_dir, bool strict = false);
240+
bool resolve(const std::string& lora_model_dir, const std::string& hires_upscalers_dir, bool strict = false);
229241
bool validate(SDMode mode);
230-
bool resolve_and_validate(SDMode mode, const std::string& lora_model_dir, bool strict = false);
242+
bool resolve_and_validate(SDMode mode,
243+
const std::string& lora_model_dir,
244+
const std::string& hires_upscalers_dir,
245+
bool strict = false);
231246
sd_img_gen_params_t to_sd_img_gen_params_t();
232247
sd_vid_gen_params_t to_sd_vid_gen_params_t();
233248
std::string to_string() const;

0 commit comments

Comments
 (0)