@@ -30,6 +30,15 @@ def fake_model_from_pretrained(*args, **kwargs):
3030 captured ["model_kwargs" ] = kwargs
3131 return FakeModel ()
3232
33+ def fake_auto_config_from_pretrained (* args , ** kwargs ):
34+ _ = args
35+ captured ["config_kwargs" ] = kwargs
36+ return SimpleNamespace (
37+ model_type = "falcon_ocr" ,
38+ _attn_implementation = kwargs .get ("attn_implementation" ),
39+ _attn_implementation_internal = kwargs .get ("attn_implementation" ),
40+ )
41+
3342 def fake_generation_config_from_pretrained (* args , ** kwargs ):
3443 return SimpleNamespace ()
3544
@@ -45,6 +54,10 @@ def fake_generation_config_from_pretrained(*args, **kwargs):
4554 "docling.models.inference_engines.vlm.transformers_engine.AutoModelForCausalLM.from_pretrained" ,
4655 fake_model_from_pretrained ,
4756 )
57+ monkeypatch .setattr (
58+ "docling.models.inference_engines.vlm.transformers_engine.AutoConfig.from_pretrained" ,
59+ fake_auto_config_from_pretrained ,
60+ )
4861 monkeypatch .setattr (
4962 "docling.models.inference_engines.vlm.transformers_engine.GenerationConfig.from_pretrained" ,
5063 fake_generation_config_from_pretrained ,
@@ -67,6 +80,8 @@ def fake_generation_config_from_pretrained(*args, **kwargs):
6780 )
6881
6982 assert captured ["model_kwargs" ]["attn_implementation" ] == "eager"
83+ assert captured ["config_kwargs" ]["attn_implementation" ] == "eager"
84+ assert captured ["model_kwargs" ]["config" ]._attn_implementation == "eager"
7085
7186
7287def test_transformers_engine_accepts_legacy_private_attn_implementation_key (
@@ -86,6 +101,15 @@ def fake_model_from_pretrained(*args, **kwargs):
86101 captured ["model_kwargs" ] = kwargs
87102 return FakeModel ()
88103
104+ def fake_auto_config_from_pretrained (* args , ** kwargs ):
105+ _ = args
106+ captured ["config_kwargs" ] = kwargs
107+ return SimpleNamespace (
108+ model_type = "falcon_ocr" ,
109+ _attn_implementation = kwargs .get ("attn_implementation" ),
110+ _attn_implementation_internal = kwargs .get ("attn_implementation" ),
111+ )
112+
89113 def fake_generation_config_from_pretrained (* args , ** kwargs ):
90114 return SimpleNamespace ()
91115
@@ -101,6 +125,10 @@ def fake_generation_config_from_pretrained(*args, **kwargs):
101125 "docling.models.inference_engines.vlm.transformers_engine.AutoModelForCausalLM.from_pretrained" ,
102126 fake_model_from_pretrained ,
103127 )
128+ monkeypatch .setattr (
129+ "docling.models.inference_engines.vlm.transformers_engine.AutoConfig.from_pretrained" ,
130+ fake_auto_config_from_pretrained ,
131+ )
104132 monkeypatch .setattr (
105133 "docling.models.inference_engines.vlm.transformers_engine.GenerationConfig.from_pretrained" ,
106134 fake_generation_config_from_pretrained ,
@@ -123,6 +151,8 @@ def fake_generation_config_from_pretrained(*args, **kwargs):
123151 )
124152
125153 assert captured ["model_kwargs" ]["attn_implementation" ] == "eager"
154+ assert captured ["config_kwargs" ]["attn_implementation" ] == "eager"
155+ assert captured ["model_kwargs" ]["config" ]._attn_implementation == "eager"
126156
127157
128158def test_transformers_engine_defaults_falcon_ocr_to_eager (
@@ -142,6 +172,15 @@ def fake_model_from_pretrained(*args, **kwargs):
142172 captured ["model_kwargs" ] = kwargs
143173 return FakeModel ()
144174
175+ def fake_auto_config_from_pretrained (* args , ** kwargs ):
176+ _ = args
177+ captured ["config_kwargs" ] = kwargs
178+ return SimpleNamespace (
179+ model_type = "falcon_ocr" ,
180+ _attn_implementation = kwargs .get ("attn_implementation" ),
181+ _attn_implementation_internal = kwargs .get ("attn_implementation" ),
182+ )
183+
145184 def fake_generation_config_from_pretrained (* args , ** kwargs ):
146185 return SimpleNamespace ()
147186
@@ -157,6 +196,10 @@ def fake_generation_config_from_pretrained(*args, **kwargs):
157196 "docling.models.inference_engines.vlm.transformers_engine.AutoModelForCausalLM.from_pretrained" ,
158197 fake_model_from_pretrained ,
159198 )
199+ monkeypatch .setattr (
200+ "docling.models.inference_engines.vlm.transformers_engine.AutoConfig.from_pretrained" ,
201+ fake_auto_config_from_pretrained ,
202+ )
160203 monkeypatch .setattr (
161204 "docling.models.inference_engines.vlm.transformers_engine.GenerationConfig.from_pretrained" ,
162205 fake_generation_config_from_pretrained ,
@@ -178,6 +221,8 @@ def fake_generation_config_from_pretrained(*args, **kwargs):
178221 )
179222
180223 assert captured ["model_kwargs" ]["attn_implementation" ] == "eager"
224+ assert captured ["config_kwargs" ]["attn_implementation" ] == "eager"
225+ assert captured ["model_kwargs" ]["config" ]._attn_implementation == "eager"
181226
182227
183228def test_transformers_engine_falls_back_without_generation_config_file (
@@ -199,8 +244,18 @@ def fake_processor_from_pretrained(*args, **kwargs):
199244 def fake_model_from_pretrained (* args , ** kwargs ):
200245 model = FakeModel ()
201246 captured ["model" ] = model
247+ captured ["model_kwargs" ] = kwargs
202248 return model
203249
250+ def fake_auto_config_from_pretrained (* args , ** kwargs ):
251+ _ = args
252+ captured ["config_kwargs" ] = kwargs
253+ return SimpleNamespace (
254+ model_type = "falcon_ocr" ,
255+ _attn_implementation = kwargs .get ("attn_implementation" ),
256+ _attn_implementation_internal = kwargs .get ("attn_implementation" ),
257+ )
258+
204259 def fake_generation_config_from_pretrained (* args , ** kwargs ):
205260 _ = (args , kwargs )
206261 raise OSError ("missing file named generation_config.json" )
@@ -221,6 +276,10 @@ def fake_generation_config_from_model_config(model_config):
221276 "docling.models.inference_engines.vlm.transformers_engine.AutoModelForCausalLM.from_pretrained" ,
222277 fake_model_from_pretrained ,
223278 )
279+ monkeypatch .setattr (
280+ "docling.models.inference_engines.vlm.transformers_engine.AutoConfig.from_pretrained" ,
281+ fake_auto_config_from_pretrained ,
282+ )
224283 monkeypatch .setattr (
225284 "docling.models.inference_engines.vlm.transformers_engine.GenerationConfig.from_pretrained" ,
226285 fake_generation_config_from_pretrained ,
@@ -247,6 +306,8 @@ def fake_generation_config_from_model_config(model_config):
247306
248307 assert engine .generation_config .source == "fallback"
249308 assert captured ["fallback_model_config" ] is captured ["model" ].config
309+ assert captured ["config_kwargs" ]["attn_implementation" ] == "eager"
310+ assert captured ["model_kwargs" ]["config" ]._attn_implementation == "eager"
250311
251312
252313def test_transformers_engine_uses_falcon_native_generate_batch () -> None :
0 commit comments