-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathkaggle_llm_engine.py
More file actions
351 lines (286 loc) · 11.9 KB
/
kaggle_llm_engine.py
File metadata and controls
351 lines (286 loc) · 11.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
"""
Kaggle-Optimized LLM Engine - 4-bit Qwen3.5-35B with multi-GPU sharding
Production-ready for T4x2 with aggressive VRAM optimizations
"""
import torch
import json
import logging
import os
from typing import Dict, List, Optional, Union
from dataclasses import dataclass
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@dataclass
class KaggleConfig:
"""Kaggle-specific configuration"""
mock_mode: bool = False # Toggle between Mock and Real LLM
model_name: str = "Qwen/Qwen2.5-35B-Instruct"
max_memory_per_gpu: float = 15.5 # GB per T4 (leave some margin)
use_flash_attention: bool = True
use_torch_compile: bool = False # Disable for compatibility
inference_mode: bool = True
gradient_checkpointing: bool = False
class KaggleLLMEngine:
"""
Production-ready LLM engine for Kaggle T4x2
"""
def __init__(self, config: KaggleConfig):
self.config = config
self.model = None
self.tokenizer = None
self.device_map = "auto"
# Mock LLM for testing
if config.mock_mode:
from agent.mock_llm import get_mock_llm
self.mock_llm = get_mock_llm()
logger.info("Using Mock LLM for testing")
else:
self.mock_llm = None
self._setup_quantization_config()
self._setup_generation_config()
def _setup_quantization_config(self):
"""Setup 4-bit quantization with aggressive memory optimization"""
try:
from transformers import BitsAndBytesConfig
self.quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
)
logger.info("4-bit quantization config loaded")
except ImportError:
logger.error("bitsandbytes not available")
self.quantization_config = None
def _setup_generation_config(self):
"""Setup generation config for speed and memory efficiency"""
self.generation_config = {
"max_new_tokens": 256, # Aggressive token limit
"temperature": 0.1,
"top_p": 0.95,
"do_sample": False,
"pad_token_id": 151643,
"eos_token_id": 151645,
"use_cache": True,
"return_dict_in_generate": True,
}
def load_model(self) -> bool:
"""
Load model with multi-GPU sharding and memory optimizations
"""
if self.config.mock_mode:
logger.info("Mock mode enabled - skipping model loading")
return True
try:
logger.info("Loading tokenizer...")
from transformers import AutoTokenizer
self.tokenizer = AutoTokenizer.from_pretrained(
self.config.model_name,
trust_remote_code=True,
padding_side="left"
)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
logger.info("Loading model with multi-GPU sharding...")
self._load_model_with_sharding()
logger.info("Applying inference optimizations...")
self._apply_inference_optimizations()
logger.info("Model loaded successfully")
self._log_memory_usage()
return True
except Exception as e:
logger.error(f"Failed to load model: {e}")
return False
def _load_model_with_sharding(self):
"""Load model with aggressive memory management"""
from transformers import AutoModelForCausalLM
import gc
# Clear cache before loading
gc.collect()
torch.cuda.empty_cache()
# Load with device_map="auto" for automatic sharding
self.model = AutoModelForCausalLM.from_pretrained(
self.config.model_name,
quantization_config=self.quantization_config,
device_map=self.device_map,
trust_remote_code=True,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
# Memory optimizations
use_cache=True,
attn_implementation="flash_attention_2" if self.config.use_flash_attention else "sdpa",
)
# Configure model for inference
self.model.config.use_cache = True
if self.config.gradient_checkpointing:
self.model.gradient_checkpointing_disable()
def _apply_inference_optimizations(self):
"""Apply inference-specific optimizations"""
if self.config.inference_mode:
self.model = torch.inference_mode(self.model)
# Disable gradient computation
for param in self.model.parameters():
param.requires_grad = False
# Enable memory-efficient attention if available
if hasattr(self.model, 'enable_attention_slicing'):
self.model.enable_attention_slicing()
# Enable CPU offloading for very large models if needed
if hasattr(self.model, 'enable_cpu_offload'):
# Only enable if we're running out of GPU memory
pass
def _log_memory_usage(self):
"""Log detailed memory usage"""
if torch.cuda.is_available():
for i in range(torch.cuda.device_count()):
allocated = torch.cuda.memory_allocated(i) / 1024**3
cached = torch.cuda.memory_reserved(i) / 1024**3
max_allocated = torch.cuda.max_memory_allocated(i) / 1024**3
logger.info(f"GPU {i}: {allocated:.2f}GB allocated, {cached:.2f}GB cached, {max_allocated:.2f}GB max")
def generate_response(self, prompt: str, max_new_tokens: Optional[int] = None) -> str:
"""
Generate response with optimized inference
"""
if self.config.mock_mode:
# Use Mock LLM
response = self.mock_llm.generate_response(prompt, max_new_tokens)
return response
if not self.model or not self.tokenizer:
raise RuntimeError("Model not loaded")
try:
# Override generation config if provided
gen_config = self.generation_config.copy()
if max_new_tokens:
gen_config["max_new_tokens"] = max_new_tokens
# Tokenize with memory efficiency
inputs = self.tokenizer(
prompt,
return_tensors="pt",
padding=True,
truncation=True,
max_length=2048, # Conservative context window
).to(self.model.device)
# Generate with memory management
with torch.inference_mode():
outputs = self.model.generate(
**inputs,
**gen_config
)
# Decode response
response = self.tokenizer.decode(
outputs[0][inputs['input_ids'].shape[1]:],
skip_special_tokens=True
)
# Cleanup
del inputs, outputs
torch.cuda.empty_cache()
return response.strip()
except Exception as e:
logger.error(f"Generation failed: {e}")
return ""
def generate_json_response(self, prompt: str) -> Dict:
"""
Generate JSON response for ReAct agent
"""
if self.config.mock_mode:
return self.mock_llm.generate_json_response(prompt)
# Add JSON formatting instruction
json_prompt = f"""{prompt}
Respond in strict JSON format with these keys:
- "thought": Your reasoning about the current situation
- "action": The tool to execute (e.g., "get_node_status", "execute_command")
- "action_input": The parameters for the action (JSON object)
Example:
{{"thought": "I need to check the node status", "action": "get_node_status", "action_input": {{"node_id": "Node-01"}}}}
Your response:"""
response = self.generate_response(json_prompt, max_new_tokens=256)
try:
# Extract JSON from response
import re
json_match = re.search(r'\{.*\}', response, re.DOTALL)
if json_match:
return json.loads(json_match.group())
else:
return {
"thought": response[:100],
"action": "error",
"action_input": {"error": "Failed to parse JSON response"}
}
except Exception as e:
logger.error(f"JSON parsing failed: {e}")
return {
"thought": "JSON parsing failed",
"action": "error",
"action_input": {"error": str(e)}
}
def get_memory_stats(self) -> Dict:
"""Get detailed memory statistics"""
stats = {}
if torch.cuda.is_available():
for i in range(torch.cuda.device_count()):
stats[f"gpu_{i}"] = {
"allocated_gb": torch.cuda.memory_allocated(i) / 1024**3,
"cached_gb": torch.cuda.memory_reserved(i) / 1024**3,
"max_allocated_gb": torch.cuda.max_memory_allocated(i) / 1024**3,
"utilization_percent": (torch.cuda.memory_allocated(i) / (self.config.max_memory_per_gpu * 1024**3)) * 100
}
return stats
def cleanup(self):
"""Clean up model and tokenizer from memory"""
if self.model:
del self.model
if self.tokenizer:
del self.tokenizer
torch.cuda.empty_cache()
import gc
gc.collect()
logger.info("Model and tokenizer cleaned up from memory")
def health_check(self) -> Dict:
"""Perform health check on the LLM engine"""
if self.config.mock_mode:
return {
"status": "healthy",
"mode": "mock",
"model": self.mock_llm.model_name
}
try:
# Test generation
test_prompt = "Test prompt"
response = self.generate_response(test_prompt, max_new_tokens=10)
# Get memory stats
memory_stats = self.get_memory_stats()
return {
"status": "healthy",
"mode": "production",
"model": self.config.model_name,
"test_response": response[:50],
"memory_stats": memory_stats
}
except Exception as e:
return {
"status": "unhealthy",
"error": str(e)
}
# Factory function
def create_kaggle_llm_engine(mock_mode: bool = False) -> KaggleLLMEngine:
"""Create LLM engine with Kaggle optimizations"""
config = KaggleConfig(mock_mode=mock_mode)
return KaggleLLMEngine(config)
if __name__ == "__main__":
# Test the Kaggle engine
print("Testing Kaggle LLM Engine...")
# Test with Mock mode first
engine = create_kaggle_llm_engine(mock_mode=True)
if engine.load_model():
print("Mock engine loaded successfully")
# Test generation
response = engine.generate_response("Test prompt", max_new_tokens=50)
print(f"Mock response: {response}")
# Test JSON generation
json_response = engine.generate_json_response("I need to check node status")
print(f"JSON response: {json_response}")
# Health check
health = engine.health_check()
print(f"Health check: {health}")
engine.cleanup()
print("Kaggle engine test completed")