1717from toolkit .accelerator import unwrap_model
1818from optimum .quanto import freeze , QTensor
1919from toolkit .util .quantize import quantize , get_qtype , quantize_model
20+ from toolkit .unloader import FakeTextEncoder
2021
2122from transformers import AutoProcessor , Mistral3ForConditionalGeneration
2223from .src .model import Flux2 , Flux2Params
@@ -124,32 +125,49 @@ def load_te(self):
124125 tokenizer = AutoProcessor .from_pretrained (MISTRAL_PATH )
125126 return text_encoder , tokenizer
126127
127- def load_model (self ):
128- dtype = self .torch_dtype
129- self .print_and_status_update ("Loading Flux2 model" )
130- # will be updated if we detect a existing checkpoint in training folder
128+ def _should_defer_text_encoder_load (self ) -> bool :
129+ return bool (self .model_config .model_kwargs .get ("defer_text_encoder_load" , False ))
130+
131+ def _should_skip_text_encoder_startup_load (self ) -> bool :
132+ return bool (self .model_config .model_kwargs .get ("skip_text_encoder_load" , False ))
133+
134+ def _should_defer_transformer_load (self ) -> bool :
135+ return bool (self .model_config .model_kwargs .get ("defer_transformer_load" , False ))
136+
137+ def _make_placeholder_text_encoder (self ):
138+ return FakeTextEncoder (device = self .device_torch , dtype = self .torch_dtype ), None
139+
140+ def _make_placeholder_transformer (self ):
141+ return FakeTextEncoder (device = self .device_torch , dtype = self .torch_dtype )
142+
143+ def should_delay_train_setup_until_after_dataset (self ) -> bool :
144+ return self ._should_defer_transformer_load ()
145+
146+ def _resolve_transformer_path (self ):
131147 model_path = self .model_config .name_or_path
132148 transformer_path = model_path
133149
134- self .print_and_status_update ("Loading transformer" )
135- with torch .device ("meta" ):
136- transformer = Flux2 (self .get_flux2_params ())
137-
138- # use local path if provided
139150 if os .path .exists (os .path .join (transformer_path , self .flux2_te_filename )):
140151 transformer_path = os .path .join (transformer_path , self .flux2_te_filename )
141152
142153 if not os .path .exists (transformer_path ):
143- # assume it is from the hub
144154 transformer_path = huggingface_hub .hf_hub_download (
145155 repo_id = model_path ,
146156 filename = self .flux2_te_filename ,
147157 token = HF_TOKEN ,
148158 )
149159
160+ return transformer_path
161+
162+ def _load_transformer_module (self ):
163+ dtype = self .torch_dtype
164+ self .print_and_status_update ("Loading transformer" )
165+ with torch .device ("meta" ):
166+ transformer = Flux2 (self .get_flux2_params ())
167+
168+ transformer_path = self ._resolve_transformer_path ()
150169 transformer_state_dict = load_file (transformer_path , device = "cpu" )
151170
152- # cast to dtype
153171 for key in transformer_state_dict :
154172 transformer_state_dict [key ] = transformer_state_dict [key ].to (dtype )
155173
@@ -158,7 +176,6 @@ def load_model(self):
158176 transformer .to (self .quantize_device , dtype = dtype )
159177
160178 if self .model_config .quantize :
161- # patch the state dict method
162179 patch_dequantization_on_save (transformer )
163180 self .print_and_status_update ("Quantizing Transformer" )
164181 quantize_model (self , transformer )
@@ -181,7 +198,76 @@ def load_model(self):
181198 self .print_and_status_update ("Moving transformer to CPU" )
182199 transformer .to ("cpu" )
183200
201+ return transformer
202+
203+ def load_text_encoder_components (self ):
204+ has_real_text_encoder = False
205+ if isinstance (self .text_encoder , list ):
206+ has_real_text_encoder = any (
207+ encoder is not None and not isinstance (encoder , FakeTextEncoder )
208+ for encoder in self .text_encoder
209+ )
210+ else :
211+ has_real_text_encoder = (
212+ self .text_encoder is not None
213+ and not isinstance (self .text_encoder , FakeTextEncoder )
214+ )
215+
216+ if has_real_text_encoder :
217+ return
218+
219+ self .print_and_status_update ("Lazy-loading text encoder" )
184220 text_encoder , tokenizer = self .load_te ()
221+ self .pipeline .text_encoder = text_encoder
222+ self .pipeline .tokenizer = tokenizer
223+ self .text_encoder = [text_encoder ]
224+ self .tokenizer = [tokenizer ]
225+ text_encoder .to (self .device_torch )
226+ text_encoder .requires_grad_ (False )
227+ text_encoder .eval ()
228+ flush ()
229+
230+ def load_transformer_components (self ):
231+ if self .transformer is not None and not isinstance (self .transformer , FakeTextEncoder ):
232+ return
233+
234+ self .print_and_status_update ("Lazy-loading transformer" )
235+ transformer = self ._load_transformer_module ()
236+ self .pipeline .transformer = transformer
237+ self .model = transformer
238+ flush ()
239+
240+ def load_deferred_components (self ):
241+ self .load_transformer_components ()
242+
243+ def load_model (self ):
244+ dtype = self .torch_dtype
245+ self .print_and_status_update ("Loading Flux2 model" )
246+ # will be updated if we detect a existing checkpoint in training folder
247+ model_path = self .model_config .name_or_path
248+ if self ._should_defer_transformer_load ():
249+ self .print_and_status_update ("Deferring transformer load" )
250+ self .print_and_status_update (
251+ "Transformer weights will be loaded after dataset cache preparation"
252+ )
253+ transformer = self ._make_placeholder_transformer ()
254+ else :
255+ transformer = self ._load_transformer_module ()
256+
257+ if self ._should_skip_text_encoder_startup_load ():
258+ self .print_and_status_update ("Skipping text encoder startup load" )
259+ self .print_and_status_update (
260+ "Reusing cached blank prompt embeddings; startup text stack will stay unloaded"
261+ )
262+ text_encoder , tokenizer = self ._make_placeholder_text_encoder ()
263+ elif self ._should_defer_text_encoder_load ():
264+ self .print_and_status_update ("Deferring text encoder load" )
265+ self .print_and_status_update (
266+ "Text stack will be loaded only if prompt encoding is needed later"
267+ )
268+ text_encoder , tokenizer = self ._make_placeholder_text_encoder ()
269+ else :
270+ text_encoder , tokenizer = self .load_te ()
185271
186272 self .print_and_status_update ("Loading VAE" )
187273 vae_path = self .model_config .vae_path
0 commit comments