Skip to content

Commit 616c1bb

Browse files
committed
Defer Flux2 loads; DrawThings & sampling updates
1 parent 83e9129 commit 616c1bb

9 files changed

Lines changed: 795 additions & 225 deletions

File tree

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,4 +182,5 @@ cython_debug/
182182
aitk_db.db
183183
/notes.md
184184
/data
185-
.claude
185+
.claude
186+
/.aitk_bin/

extensions_built_in/diffusion_models/flux2/flux2_model.py

Lines changed: 98 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from toolkit.accelerator import unwrap_model
1818
from optimum.quanto import freeze, QTensor
1919
from toolkit.util.quantize import quantize, get_qtype, quantize_model
20+
from toolkit.unloader import FakeTextEncoder
2021

2122
from transformers import AutoProcessor, Mistral3ForConditionalGeneration
2223
from .src.model import Flux2, Flux2Params
@@ -124,32 +125,49 @@ def load_te(self):
124125
tokenizer = AutoProcessor.from_pretrained(MISTRAL_PATH)
125126
return text_encoder, tokenizer
126127

127-
def load_model(self):
128-
dtype = self.torch_dtype
129-
self.print_and_status_update("Loading Flux2 model")
130-
# will be updated if we detect a existing checkpoint in training folder
128+
def _should_defer_text_encoder_load(self) -> bool:
129+
return bool(self.model_config.model_kwargs.get("defer_text_encoder_load", False))
130+
131+
def _should_skip_text_encoder_startup_load(self) -> bool:
132+
return bool(self.model_config.model_kwargs.get("skip_text_encoder_load", False))
133+
134+
def _should_defer_transformer_load(self) -> bool:
135+
return bool(self.model_config.model_kwargs.get("defer_transformer_load", False))
136+
137+
def _make_placeholder_text_encoder(self):
138+
return FakeTextEncoder(device=self.device_torch, dtype=self.torch_dtype), None
139+
140+
def _make_placeholder_transformer(self):
141+
return FakeTextEncoder(device=self.device_torch, dtype=self.torch_dtype)
142+
143+
def should_delay_train_setup_until_after_dataset(self) -> bool:
144+
return self._should_defer_transformer_load()
145+
146+
def _resolve_transformer_path(self):
131147
model_path = self.model_config.name_or_path
132148
transformer_path = model_path
133149

134-
self.print_and_status_update("Loading transformer")
135-
with torch.device("meta"):
136-
transformer = Flux2(self.get_flux2_params())
137-
138-
# use local path if provided
139150
if os.path.exists(os.path.join(transformer_path, self.flux2_te_filename)):
140151
transformer_path = os.path.join(transformer_path, self.flux2_te_filename)
141152

142153
if not os.path.exists(transformer_path):
143-
# assume it is from the hub
144154
transformer_path = huggingface_hub.hf_hub_download(
145155
repo_id=model_path,
146156
filename=self.flux2_te_filename,
147157
token=HF_TOKEN,
148158
)
149159

160+
return transformer_path
161+
162+
def _load_transformer_module(self):
163+
dtype = self.torch_dtype
164+
self.print_and_status_update("Loading transformer")
165+
with torch.device("meta"):
166+
transformer = Flux2(self.get_flux2_params())
167+
168+
transformer_path = self._resolve_transformer_path()
150169
transformer_state_dict = load_file(transformer_path, device="cpu")
151170

152-
# cast to dtype
153171
for key in transformer_state_dict:
154172
transformer_state_dict[key] = transformer_state_dict[key].to(dtype)
155173

@@ -158,7 +176,6 @@ def load_model(self):
158176
transformer.to(self.quantize_device, dtype=dtype)
159177

160178
if self.model_config.quantize:
161-
# patch the state dict method
162179
patch_dequantization_on_save(transformer)
163180
self.print_and_status_update("Quantizing Transformer")
164181
quantize_model(self, transformer)
@@ -181,7 +198,76 @@ def load_model(self):
181198
self.print_and_status_update("Moving transformer to CPU")
182199
transformer.to("cpu")
183200

201+
return transformer
202+
203+
def load_text_encoder_components(self):
204+
has_real_text_encoder = False
205+
if isinstance(self.text_encoder, list):
206+
has_real_text_encoder = any(
207+
encoder is not None and not isinstance(encoder, FakeTextEncoder)
208+
for encoder in self.text_encoder
209+
)
210+
else:
211+
has_real_text_encoder = (
212+
self.text_encoder is not None
213+
and not isinstance(self.text_encoder, FakeTextEncoder)
214+
)
215+
216+
if has_real_text_encoder:
217+
return
218+
219+
self.print_and_status_update("Lazy-loading text encoder")
184220
text_encoder, tokenizer = self.load_te()
221+
self.pipeline.text_encoder = text_encoder
222+
self.pipeline.tokenizer = tokenizer
223+
self.text_encoder = [text_encoder]
224+
self.tokenizer = [tokenizer]
225+
text_encoder.to(self.device_torch)
226+
text_encoder.requires_grad_(False)
227+
text_encoder.eval()
228+
flush()
229+
230+
def load_transformer_components(self):
231+
if self.transformer is not None and not isinstance(self.transformer, FakeTextEncoder):
232+
return
233+
234+
self.print_and_status_update("Lazy-loading transformer")
235+
transformer = self._load_transformer_module()
236+
self.pipeline.transformer = transformer
237+
self.model = transformer
238+
flush()
239+
240+
def load_deferred_components(self):
241+
self.load_transformer_components()
242+
243+
def load_model(self):
244+
dtype = self.torch_dtype
245+
self.print_and_status_update("Loading Flux2 model")
246+
# will be updated if we detect a existing checkpoint in training folder
247+
model_path = self.model_config.name_or_path
248+
if self._should_defer_transformer_load():
249+
self.print_and_status_update("Deferring transformer load")
250+
self.print_and_status_update(
251+
"Transformer weights will be loaded after dataset cache preparation"
252+
)
253+
transformer = self._make_placeholder_transformer()
254+
else:
255+
transformer = self._load_transformer_module()
256+
257+
if self._should_skip_text_encoder_startup_load():
258+
self.print_and_status_update("Skipping text encoder startup load")
259+
self.print_and_status_update(
260+
"Reusing cached blank prompt embeddings; startup text stack will stay unloaded"
261+
)
262+
text_encoder, tokenizer = self._make_placeholder_text_encoder()
263+
elif self._should_defer_text_encoder_load():
264+
self.print_and_status_update("Deferring text encoder load")
265+
self.print_and_status_update(
266+
"Text stack will be loaded only if prompt encoding is needed later"
267+
)
268+
text_encoder, tokenizer = self._make_placeholder_text_encoder()
269+
else:
270+
text_encoder, tokenizer = self.load_te()
185271

186272
self.print_and_status_update("Loading VAE")
187273
vae_path = self.model_config.vae_path

extensions_built_in/sd_trainer/DiffusionTrainer.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -360,9 +360,24 @@ def sample_step_hook(self, img_num, total_imgs):
360360
self.update_status(
361361
"running", f"Generating images - {img_num + 1}/{total_imgs}")
362362

363-
def sample(self, step=None, is_first=False, is_startup_sample=False, remote_base_only=False):
363+
def sample(
364+
self,
365+
step=None,
366+
is_first=False,
367+
is_startup_sample=False,
368+
remote_base_only=False,
369+
lora_source_path=None,
370+
cleanup_lora_source_path=True,
371+
):
364372
self.maybe_stop()
365-
super().sample(step, is_first, is_startup_sample=is_startup_sample, remote_base_only=remote_base_only)
373+
super().sample(
374+
step,
375+
is_first,
376+
is_startup_sample=is_startup_sample,
377+
remote_base_only=remote_base_only,
378+
lora_source_path=lora_source_path,
379+
cleanup_lora_source_path=cleanup_lora_source_path,
380+
)
366381
self.maybe_stop()
367382

368383
def before_waiting_for_background_samples(self, pending_batches: int):

0 commit comments

Comments
 (0)