Skip to content

Commit 586b6f1

Browse files
authored
feat: adapt res samplers for flow models for eta > 0 (#1436)
1 parent 9097ce5 commit 586b6f1

1 file changed

Lines changed: 27 additions & 6 deletions

File tree

src/denoiser.hpp

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -808,6 +808,18 @@ static std::tuple<float, float, float> get_ancestral_step_flow(float sigma_from,
808808
return {sigma_down, sigma_up, alpha_scale};
809809
}
810810

811+
static std::tuple<float, float, float> get_ancestral_step(float sigma_from,
812+
float sigma_to,
813+
float eta,
814+
bool is_flow_denoiser) {
815+
if (is_flow_denoiser) {
816+
return get_ancestral_step_flow(sigma_from, sigma_to, eta);
817+
} else {
818+
auto [sigma_down, sigma_up] = get_ancestral_step(sigma_from, sigma_to, eta);
819+
return {sigma_down, sigma_up, 1.0f};
820+
}
821+
}
822+
811823
static sd::Tensor<float> sample_euler_ancestral(denoise_cb_t model,
812824
sd::Tensor<float> x,
813825
const std::vector<float>& sigmas,
@@ -1247,6 +1259,7 @@ static sd::Tensor<float> sample_res_multistep(denoise_cb_t model,
12471259
sd::Tensor<float> x,
12481260
const std::vector<float>& sigmas,
12491261
std::shared_ptr<RNG> rng,
1262+
bool is_flow_denoiser,
12501263
float eta) {
12511264
sd::Tensor<float> old_denoised = x;
12521265
bool have_old_sigma = false;
@@ -1278,7 +1291,8 @@ static sd::Tensor<float> sample_res_multistep(denoise_cb_t model,
12781291

12791292
float sigma_from = sigmas[i];
12801293
float sigma_to = sigmas[i + 1];
1281-
auto [sigma_down, sigma_up] = get_ancestral_step(sigma_from, sigma_to, eta);
1294+
1295+
auto [sigma_down, sigma_up, alpha_scale] = get_ancestral_step(sigma_from, sigma_to, eta, is_flow_denoiser);
12821296

12831297
if (sigma_down == 0.0f || !have_old_sigma) {
12841298
x += ((x - denoised) / sigma_from) * (sigma_down - sigma_from);
@@ -1305,7 +1319,10 @@ static sd::Tensor<float> sample_res_multistep(denoise_cb_t model,
13051319
x = sigma_fn(h) * x + h * (b1 * denoised + b2 * old_denoised);
13061320
}
13071321

1308-
if (sigmas[i + 1] > 0 && sigma_up > 0.0f) {
1322+
if (sigma_to > 0.0f && sigma_up > 0.0f) {
1323+
if (is_flow_denoiser) {
1324+
x *= alpha_scale;
1325+
}
13091326
x += sd::Tensor<float>::randn_like(x, rng) * sigma_up;
13101327
}
13111328

@@ -1320,6 +1337,7 @@ static sd::Tensor<float> sample_res_2s(denoise_cb_t model,
13201337
sd::Tensor<float> x,
13211338
const std::vector<float>& sigmas,
13221339
std::shared_ptr<RNG> rng,
1340+
bool is_flow_denoiser,
13231341
float eta) {
13241342
const float c2 = 0.5f;
13251343
auto t_fn = [](float sigma) -> float { return -logf(sigma); };
@@ -1348,7 +1366,7 @@ static sd::Tensor<float> sample_res_2s(denoise_cb_t model,
13481366
}
13491367
sd::Tensor<float> denoised = std::move(denoised_opt);
13501368

1351-
auto [sigma_down, sigma_up] = get_ancestral_step(sigma_from, sigma_to, eta);
1369+
auto [sigma_down, sigma_up, alpha_scale] = get_ancestral_step(sigma_from, sigma_to, eta, is_flow_denoiser);
13521370

13531371
sd::Tensor<float> x0 = x;
13541372
if (sigma_down == 0.0f || sigma_from == 0.0f) {
@@ -1377,7 +1395,10 @@ static sd::Tensor<float> sample_res_2s(denoise_cb_t model,
13771395
x = x0 + h * (b1 * eps1 + b2 * eps2);
13781396
}
13791397

1380-
if (sigmas[i + 1] > 0 && sigma_up > 0.0f) {
1398+
if (sigma_to > 0.0f && sigma_up > 0.0f) {
1399+
if (is_flow_denoiser) {
1400+
x *= alpha_scale;
1401+
}
13811402
x += sd::Tensor<float>::randn_like(x, rng) * sigma_up;
13821403
}
13831404
}
@@ -1664,9 +1685,9 @@ static sd::Tensor<float> sample_k_diffusion(sample_method_t method,
16641685
case IPNDM_V_SAMPLE_METHOD:
16651686
return sample_ipndm_v(model, std::move(x), sigmas);
16661687
case RES_MULTISTEP_SAMPLE_METHOD:
1667-
return sample_res_multistep(model, std::move(x), sigmas, rng, eta);
1688+
return sample_res_multistep(model, std::move(x), sigmas, rng, is_flow_denoiser, eta);
16681689
case RES_2S_SAMPLE_METHOD:
1669-
return sample_res_2s(model, std::move(x), sigmas, rng, eta);
1690+
return sample_res_2s(model, std::move(x), sigmas, rng, is_flow_denoiser, eta);
16701691
case ER_SDE_SAMPLE_METHOD:
16711692
return sample_er_sde(model, std::move(x), sigmas, rng, is_flow_denoiser, eta);
16721693
case DDIM_TRAILING_SAMPLE_METHOD:

0 commit comments

Comments
 (0)