Skip to content

Commit 4f747bb

Browse files
committed
Fix import of new llm init function. Cleanup code smells
1 parent 3f8fafc commit 4f747bb

2 files changed

Lines changed: 77 additions & 45 deletions

File tree

src/gee_mcp/server/coderun.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@
22

33
from loguru import logger
44

5-
from .genai import init_genai_client
5+
from .genai import init_llm_client
66
from .helpers import extract_xml_tag
77

88

99
class GEEPythonExecution:
1010
def __init__(self, genai_client=None):
11-
self.genai_client = genai_client or init_genai_client()
11+
self.genai_client = genai_client or init_llm_client()
1212

1313
def exec(self, code):
1414
namespace: dict = {}

src/gee_mcp/server/genai.py

Lines changed: 75 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@
33
import os
44
from abc import abstractmethod, ABC
55
import anthropic
6+
import openai
67
from openai import OpenAI
78
from google import genai as google_genai
89
from google.genai import types
910
from loguru import logger
10-
from typing import TypedDict
11+
from typing import Any, TypedDict
1112
from pathlib import Path
1213
from enum import Enum
1314

@@ -29,9 +30,9 @@ class LLMCallReturn(TypedDict):
2930
{"answer": answer, "thought": thought, "response": response}
3031
"""
3132
answer: str
32-
thought: str
33-
resposne: str
34-
33+
thought: str | None
34+
response: Any # raw provider SDK response object, or None for cache hits
35+
3536

3637
def __init__(
3738
self,
@@ -41,44 +42,57 @@ def __init__(
4142
):
4243

4344
if self._provider is None:
44-
raise RuntimeError("something ... ")
45+
raise RuntimeError(
46+
f"{type(self).__name__} must set a class-level `_provider`"
47+
)
4548

49+
self.api_key = api_key
4650
self.model = model
4751
self.cache_dir = cache_dir
4852
if cache_dir:
4953
os.makedirs(cache_dir, exist_ok=True)
5054
logger.debug(f"response caching enabled at {cache_dir}")
5155

52-
def _cache_key(self, text: str, include_thinking: bool) -> str:
56+
def _cache_path(self, text: str, include_thinking: bool) -> str:
5357
content = f"{self.model}::{include_thinking}::{text}"
54-
return hashlib.sha256(content.encode()).hexdigest()[:16]
55-
56-
57-
def _check_cache(self, text, include_thinking) -> LLMCallReturn:
58-
# check cache
59-
key = self._cache_key(text, include_thinking)
60-
self._cache_path = os.path.join(self.cache_dir, f"{key}.json")
61-
if os.path.exists(self._cache_path):
62-
logger.debug(f"cache hit: {key}")
63-
with open(self._cache_path) as f:
58+
key = hashlib.sha256(content.encode()).hexdigest()[:16]
59+
return os.path.join(self.cache_dir, f"{key}.json")
60+
61+
def _check_cache(self, path: str) -> LLMCallReturn | None:
62+
if os.path.exists(path):
63+
logger.debug(f"cache hit: {path}")
64+
with open(path) as f:
6465
cached = json.load(f)
6566
return {
6667
"answer": cached["answer"],
6768
"thought": cached.get("thought"),
68-
"response": cached.get("response"),
69+
"response": None, # raw provider response isn't persisted
6970
}
70-
71-
def _save_to_cache(self, call_return: LLMCallReturn):
72-
with open(self._cache_path, "w") as f:
73-
json.dump(call_return, f, indent=2)
71+
logger.debug(f"cache miss: {path}")
72+
return None
73+
74+
def _save_to_cache(self, path: str, call_return: LLMCallReturn):
75+
# `response` is a provider SDK object and isn't JSON-serializable; only
76+
# the extracted text is persisted (re-reads get `response: None`).
77+
with open(path, "w") as f:
78+
json.dump(
79+
{"answer": call_return["answer"], "thought": call_return.get("thought")},
80+
f,
81+
indent=2,
82+
)
7483
logger.debug("cached response")
7584

7685
def call(self, text: str, include_thinking: bool=True) -> LLMCallReturn:
77-
if self.cache_dir:
78-
return self._check_cache(text=text, include_thinking=include_thinking)
86+
cache_path = (
87+
self._cache_path(text, include_thinking) if self.cache_dir else None
88+
)
89+
if cache_path is not None:
90+
cached = self._check_cache(cache_path)
91+
if cached is not None:
92+
return cached
7993
call_return = self._call(text=text, include_thinking=include_thinking)
80-
if self.cache_dir:
81-
self._save_to_cache(call_return)
94+
if cache_path is not None:
95+
self._save_to_cache(cache_path, call_return)
8296
return call_return
8397

8498
@abstractmethod
@@ -102,11 +116,34 @@ def __init__(self,
102116
self.client = OpenAI(api_key=api_key)
103117

104118
def _call(self, text: str, include_thinking: bool=True) -> BaseLLM.LLMCallReturn:
105-
model_response = self.client.responses.create(model=self.model, input=text)
106-
answer = model_response.output_text
119+
kwargs: dict = {"model": self.model, "input": text}
120+
if include_thinking:
121+
# Only takes effect on reasoning models (o-series, gpt-5, ...); the
122+
# raw chain-of-thought is never returned, just this summary. Some
123+
# orgs require verification before summaries are permitted, hence
124+
# the fallback below.
125+
kwargs["reasoning"] = {"summary": "auto"}
126+
127+
try:
128+
model_response = self.client.responses.create(**kwargs)
129+
except openai.BadRequestError:
130+
if "reasoning" not in kwargs:
131+
raise
132+
kwargs.pop("reasoning")
133+
model_response = self.client.responses.create(**kwargs)
134+
135+
thought = None
136+
for item in model_response.output:
137+
if item.type == "reasoning":
138+
parts = [
139+
p.text for p in (item.summary or []) if getattr(p, "text", None)
140+
]
141+
thought = "\n".join(parts) or None
142+
break
143+
107144
return {
108-
"answer": answer,
109-
"thought": None,
145+
"answer": model_response.output_text,
146+
"thought": thought,
110147
"response": model_response,
111148
}
112149

@@ -283,21 +320,16 @@ def init_llm_client(
283320

284321
if provider is None:
285322
raise ValueError("LLM provider not found.")
286-
287-
if not (provider in LLMProvider):
288-
raise ValueError(f"LLM provider must be one of the following: {[prov.value for prov in LLMProvider]}")
289-
323+
324+
try:
325+
llm_provider = LLMProvider(provider)
326+
except ValueError:
327+
raise ValueError(
328+
f"LLM provider must be one of the following: "
329+
f"{[prov.value for prov in LLMProvider]}"
330+
)
331+
290332
if model is None:
291333
raise ValueError("LLM identity not configured.")
292-
293-
# if provider == LLMProvider.GOOGLE.value:
294-
# return init_google_genai_client(model=model)
295-
296-
# if provider == LLMProvider.ANTHROPIC.value:
297-
# pass
298-
299-
# if provider == LLMProvider.OPENAI.value:
300-
# pass
301334

302-
llm_provider = LLMProvider(provider)
303335
return LLM_INIT_DICT[llm_provider](model=model)

0 commit comments

Comments
 (0)