Skip to content

Commit ed03684

Browse files
committed
fix(vlm): preload falcon config for eager attention
Signed-off-by: Georg Heiler <georg.kf.heiler@gmail.com>
1 parent 3a3fc28 commit ed03684

2 files changed

Lines changed: 92 additions & 1 deletion

File tree

docling/models/inference_engines/vlm/transformers_engine.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from packaging import version
1212
from PIL.Image import Image
1313
from transformers import (
14+
AutoConfig,
1415
AutoModel,
1516
AutoModelForCausalLM,
1617
AutoModelForImageTextToText,
@@ -120,6 +121,23 @@ def _falcon_ocr_category_from_prompt(prompt: str) -> str:
120121
return "plain"
121122

122123

124+
def _force_falcon_ocr_eager_attention_config(config_obj: Any) -> None:
125+
if config_obj is None:
126+
return
127+
if getattr(config_obj, "_attn_implementation", None) in {
128+
None,
129+
"sdpa",
130+
"paged|sdpa",
131+
}:
132+
config_obj._attn_implementation = "eager"
133+
if getattr(config_obj, "_attn_implementation_internal", None) in {
134+
None,
135+
"sdpa",
136+
"paged|sdpa",
137+
}:
138+
config_obj._attn_implementation_internal = "eager"
139+
140+
123141
class TransformersVlmEngine(BaseVlmEngine, HuggingFaceModelDownloadMixin):
124142
"""HuggingFace Transformers engine for VLM inference.
125143
@@ -251,6 +269,17 @@ def download_wrapper(repo_id: str, revision: str) -> Path:
251269
elif model_type == TransformersModelType.AUTOMODEL_IMAGETEXTTOTEXT:
252270
model_cls = AutoModelForImageTextToText # type: ignore[assignment]
253271

272+
attn_implementation = self._get_attn_implementation()
273+
model_config = None
274+
if _value_mentions_falcon_ocr(repo_id):
275+
model_config = AutoConfig.from_pretrained(
276+
artifacts_path,
277+
trust_remote_code=self.options.trust_remote_code,
278+
revision=revision,
279+
attn_implementation=attn_implementation,
280+
)
281+
_force_falcon_ocr_eager_attention_config(model_config)
282+
254283
# Load processor
255284
self.processor = AutoProcessor.from_pretrained(
256285
artifacts_path,
@@ -269,10 +298,11 @@ def download_wrapper(repo_id: str, revision: str) -> Path:
269298
artifacts_path,
270299
device_map=self.device,
271300
dtype=torch_dtype,
272-
attn_implementation=self._get_attn_implementation(),
301+
attn_implementation=attn_implementation,
273302
trust_remote_code=self.options.trust_remote_code,
274303
revision=revision,
275304
quantization_config=quantization_config,
305+
config=model_config,
276306
)
277307

278308
self.vlm_model.eval()

tests/test_transformers_vlm_engine.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,15 @@ def fake_model_from_pretrained(*args, **kwargs):
3030
captured["model_kwargs"] = kwargs
3131
return FakeModel()
3232

33+
def fake_auto_config_from_pretrained(*args, **kwargs):
34+
_ = args
35+
captured["config_kwargs"] = kwargs
36+
return SimpleNamespace(
37+
model_type="falcon_ocr",
38+
_attn_implementation=kwargs.get("attn_implementation"),
39+
_attn_implementation_internal=kwargs.get("attn_implementation"),
40+
)
41+
3342
def fake_generation_config_from_pretrained(*args, **kwargs):
3443
return SimpleNamespace()
3544

@@ -45,6 +54,10 @@ def fake_generation_config_from_pretrained(*args, **kwargs):
4554
"docling.models.inference_engines.vlm.transformers_engine.AutoModelForCausalLM.from_pretrained",
4655
fake_model_from_pretrained,
4756
)
57+
monkeypatch.setattr(
58+
"docling.models.inference_engines.vlm.transformers_engine.AutoConfig.from_pretrained",
59+
fake_auto_config_from_pretrained,
60+
)
4861
monkeypatch.setattr(
4962
"docling.models.inference_engines.vlm.transformers_engine.GenerationConfig.from_pretrained",
5063
fake_generation_config_from_pretrained,
@@ -67,6 +80,8 @@ def fake_generation_config_from_pretrained(*args, **kwargs):
6780
)
6881

6982
assert captured["model_kwargs"]["attn_implementation"] == "eager"
83+
assert captured["config_kwargs"]["attn_implementation"] == "eager"
84+
assert captured["model_kwargs"]["config"]._attn_implementation == "eager"
7085

7186

7287
def test_transformers_engine_accepts_legacy_private_attn_implementation_key(
@@ -86,6 +101,15 @@ def fake_model_from_pretrained(*args, **kwargs):
86101
captured["model_kwargs"] = kwargs
87102
return FakeModel()
88103

104+
def fake_auto_config_from_pretrained(*args, **kwargs):
105+
_ = args
106+
captured["config_kwargs"] = kwargs
107+
return SimpleNamespace(
108+
model_type="falcon_ocr",
109+
_attn_implementation=kwargs.get("attn_implementation"),
110+
_attn_implementation_internal=kwargs.get("attn_implementation"),
111+
)
112+
89113
def fake_generation_config_from_pretrained(*args, **kwargs):
90114
return SimpleNamespace()
91115

@@ -101,6 +125,10 @@ def fake_generation_config_from_pretrained(*args, **kwargs):
101125
"docling.models.inference_engines.vlm.transformers_engine.AutoModelForCausalLM.from_pretrained",
102126
fake_model_from_pretrained,
103127
)
128+
monkeypatch.setattr(
129+
"docling.models.inference_engines.vlm.transformers_engine.AutoConfig.from_pretrained",
130+
fake_auto_config_from_pretrained,
131+
)
104132
monkeypatch.setattr(
105133
"docling.models.inference_engines.vlm.transformers_engine.GenerationConfig.from_pretrained",
106134
fake_generation_config_from_pretrained,
@@ -123,6 +151,8 @@ def fake_generation_config_from_pretrained(*args, **kwargs):
123151
)
124152

125153
assert captured["model_kwargs"]["attn_implementation"] == "eager"
154+
assert captured["config_kwargs"]["attn_implementation"] == "eager"
155+
assert captured["model_kwargs"]["config"]._attn_implementation == "eager"
126156

127157

128158
def test_transformers_engine_defaults_falcon_ocr_to_eager(
@@ -142,6 +172,15 @@ def fake_model_from_pretrained(*args, **kwargs):
142172
captured["model_kwargs"] = kwargs
143173
return FakeModel()
144174

175+
def fake_auto_config_from_pretrained(*args, **kwargs):
176+
_ = args
177+
captured["config_kwargs"] = kwargs
178+
return SimpleNamespace(
179+
model_type="falcon_ocr",
180+
_attn_implementation=kwargs.get("attn_implementation"),
181+
_attn_implementation_internal=kwargs.get("attn_implementation"),
182+
)
183+
145184
def fake_generation_config_from_pretrained(*args, **kwargs):
146185
return SimpleNamespace()
147186

@@ -157,6 +196,10 @@ def fake_generation_config_from_pretrained(*args, **kwargs):
157196
"docling.models.inference_engines.vlm.transformers_engine.AutoModelForCausalLM.from_pretrained",
158197
fake_model_from_pretrained,
159198
)
199+
monkeypatch.setattr(
200+
"docling.models.inference_engines.vlm.transformers_engine.AutoConfig.from_pretrained",
201+
fake_auto_config_from_pretrained,
202+
)
160203
monkeypatch.setattr(
161204
"docling.models.inference_engines.vlm.transformers_engine.GenerationConfig.from_pretrained",
162205
fake_generation_config_from_pretrained,
@@ -178,6 +221,8 @@ def fake_generation_config_from_pretrained(*args, **kwargs):
178221
)
179222

180223
assert captured["model_kwargs"]["attn_implementation"] == "eager"
224+
assert captured["config_kwargs"]["attn_implementation"] == "eager"
225+
assert captured["model_kwargs"]["config"]._attn_implementation == "eager"
181226

182227

183228
def test_transformers_engine_falls_back_without_generation_config_file(
@@ -199,8 +244,18 @@ def fake_processor_from_pretrained(*args, **kwargs):
199244
def fake_model_from_pretrained(*args, **kwargs):
200245
model = FakeModel()
201246
captured["model"] = model
247+
captured["model_kwargs"] = kwargs
202248
return model
203249

250+
def fake_auto_config_from_pretrained(*args, **kwargs):
251+
_ = args
252+
captured["config_kwargs"] = kwargs
253+
return SimpleNamespace(
254+
model_type="falcon_ocr",
255+
_attn_implementation=kwargs.get("attn_implementation"),
256+
_attn_implementation_internal=kwargs.get("attn_implementation"),
257+
)
258+
204259
def fake_generation_config_from_pretrained(*args, **kwargs):
205260
_ = (args, kwargs)
206261
raise OSError("missing file named generation_config.json")
@@ -221,6 +276,10 @@ def fake_generation_config_from_model_config(model_config):
221276
"docling.models.inference_engines.vlm.transformers_engine.AutoModelForCausalLM.from_pretrained",
222277
fake_model_from_pretrained,
223278
)
279+
monkeypatch.setattr(
280+
"docling.models.inference_engines.vlm.transformers_engine.AutoConfig.from_pretrained",
281+
fake_auto_config_from_pretrained,
282+
)
224283
monkeypatch.setattr(
225284
"docling.models.inference_engines.vlm.transformers_engine.GenerationConfig.from_pretrained",
226285
fake_generation_config_from_pretrained,
@@ -247,6 +306,8 @@ def fake_generation_config_from_model_config(model_config):
247306

248307
assert engine.generation_config.source == "fallback"
249308
assert captured["fallback_model_config"] is captured["model"].config
309+
assert captured["config_kwargs"]["attn_implementation"] == "eager"
310+
assert captured["model_kwargs"]["config"]._attn_implementation == "eager"
250311

251312

252313
def test_transformers_engine_uses_falcon_native_generate_batch() -> None:

0 commit comments

Comments
 (0)