Skip to content

Commit fab4c3b

Browse files
authored
finish trials when exciting experiment (#74)
* finish trials when exciting experiment Signed-off-by: kerthcet <kerthcet@gmail.com> * fix test error Signed-off-by: kerthcet <kerthcet@gmail.com> * rename cancelled to completed Signed-off-by: kerthcet <kerthcet@gmail.com> * rename cancelled to done Signed-off-by: kerthcet <kerthcet@gmail.com> * release v0.0.4 Signed-off-by: kerthcet <kerthcet@gmail.com> * rename FINISHED to COMPLETED Signed-off-by: kerthcet <kerthcet@gmail.com> --------- Signed-off-by: kerthcet <kerthcet@gmail.com>
1 parent 49c2fc1 commit fab4c3b

11 files changed

Lines changed: 125 additions & 87 deletions

File tree

README.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ pip install alphatrion
3434
### Install from Source
3535

3636
* Git clone the repository
37-
* Run `uv sync` to install dependencies from `pyproject.toml`.
3837
* Run `source start.sh` to activate the virtual environment.
3938

4039

alphatrion/experiment/base.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,14 @@ async def __aenter__(self):
5454
return self
5555

5656
async def __aexit__(self, exc_type, exc_val, exc_tb):
57+
self.complete()
58+
59+
def complete(self):
60+
for t in list(self._trials.values()):
61+
t.complete()
5762
self._trials = dict()
63+
# Set to None at the end of the experiment because
64+
# it will be used in trial.complete().
5865
self._runtime.current_exp = None
5966

6067
@classmethod

alphatrion/log/log.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,4 +102,4 @@ async def log_metrics(metrics: dict[str, float]):
102102
)
103103

104104
if should_early_stop:
105-
trial.cancel()
105+
trial.complete()

alphatrion/metadata/sql_models.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,11 @@
1212
class TrialStatus(enum.Enum):
1313
PENDING = "pending"
1414
RUNNING = "running"
15-
FINISHED = "finished"
15+
COMPLETED = "completed"
1616
FAILED = "failed"
1717

1818

19-
COMPLETED_STATUS = [TrialStatus.FINISHED, TrialStatus.FAILED]
19+
FINISHED_STATUS = [TrialStatus.COMPLETED, TrialStatus.FAILED]
2020

2121

2222
class Project(Base):
@@ -26,9 +26,11 @@ class Project(Base):
2626
name = Column(String, nullable=False)
2727
description = Column(String, nullable=True)
2828

29-
created_at = Column(DateTime(timezone=True), default=datetime.now(UTC))
29+
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
3030
updated_at = Column(
31-
DateTime(timezone=True), default=datetime.now(UTC), onupdate=datetime.now(UTC)
31+
DateTime(timezone=True),
32+
default=lambda: datetime.now(UTC),
33+
onupdate=lambda: datetime.now(UTC),
3234
)
3335
is_del = Column(Integer, default=0, comment="0 for not deleted, 1 for deleted")
3436

@@ -43,9 +45,11 @@ class Experiment(Base):
4345
project_id = Column(UUID(as_uuid=True), nullable=False)
4446
meta = Column(JSON, nullable=True, comment="Additional metadata for the experiment")
4547

46-
created_at = Column(DateTime(timezone=True), default=datetime.now(UTC))
48+
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
4749
updated_at = Column(
48-
DateTime(timezone=True), default=datetime.now(UTC), onupdate=datetime.now(UTC)
50+
DateTime(timezone=True),
51+
default=lambda: datetime.now(UTC),
52+
onupdate=lambda: datetime.now(UTC),
4953
)
5054
is_del = Column(Integer, default=0, comment="0 for not deleted, 1 for deleted")
5155

@@ -60,16 +64,19 @@ class Trial(Base):
6064
description = Column(String, nullable=True)
6165
meta = Column(JSON, nullable=True, comment="Additional metadata for the trial")
6266
params = Column(JSON, nullable=True, comment="Parameters for the experiment")
67+
duration = Column(Float, nullable=True, comment="Duration of the trial in seconds")
6368
status = Column(
6469
Enum(TrialStatus),
6570
default=TrialStatus.PENDING,
6671
nullable=False,
6772
comment="Status of the trial",
6873
)
6974

70-
created_at = Column(DateTime(timezone=True), default=datetime.now(UTC))
75+
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
7176
updated_at = Column(
72-
DateTime(timezone=True), default=datetime.now(UTC), onupdate=datetime.now(UTC)
77+
DateTime(timezone=True),
78+
default=lambda: datetime.now(UTC),
79+
onupdate=lambda: datetime.now(UTC),
7380
)
7481

7582

@@ -80,9 +87,11 @@ class Run(Base):
8087
project_id = Column(UUID(as_uuid=True), nullable=False)
8188
trial_id = Column(UUID(as_uuid=True), nullable=False)
8289

83-
created_at = Column(DateTime(timezone=True), default=datetime.now(UTC))
90+
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
8491
updated_at = Column(
85-
DateTime(timezone=True), default=datetime.now(UTC), onupdate=datetime.now(UTC)
92+
DateTime(timezone=True),
93+
default=lambda: datetime.now(UTC),
94+
onupdate=lambda: datetime.now(UTC),
8695
)
8796

8897

@@ -96,9 +105,11 @@ class Model(Base):
96105
version = Column(String, nullable=False)
97106
meta = Column(JSON, nullable=True, comment="Additional metadata for the model")
98107

99-
created_at = Column(DateTime(timezone=True), default=datetime.now(UTC))
108+
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
100109
updated_at = Column(
101-
DateTime(timezone=True), default=datetime.now(UTC), onupdate=datetime.now(UTC)
110+
DateTime(timezone=True),
111+
default=lambda: datetime.now(UTC),
112+
onupdate=lambda: datetime.now(UTC),
102113
)
103114
is_del = Column(Integer, default=0, comment="0 for not deleted, 1 for deleted")
104115

alphatrion/trial/trial.py

Lines changed: 25 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
import contextvars
2-
import os
32
import uuid
43
from datetime import UTC, datetime
54

65
from pydantic import BaseModel, Field, model_validator
76

8-
from alphatrion.metadata.sql_models import COMPLETED_STATUS, TrialStatus
7+
from alphatrion.metadata.sql_models import FINISHED_STATUS, TrialStatus
98
from alphatrion.run.run import Run
109
from alphatrion.runtime.runtime import global_runtime
1110
from alphatrion.utils.context import Context
@@ -127,10 +126,6 @@ def __init__(self, exp_id: int, config: TrialConfig | None = None):
127126
self._config = config or TrialConfig()
128127
self._runtime = global_runtime()
129128
self._step = 0
130-
self._context = Context(
131-
cancel_func=self._stop,
132-
timeout=self._timeout(),
133-
)
134129
self._construct_meta()
135130
self._runs = dict()
136131
self._running_tasks = dict()
@@ -141,7 +136,7 @@ async def __aenter__(self):
141136
return self
142137

143138
async def __aexit__(self, exc_type, exc_val, exc_tb):
144-
self.cancel()
139+
self.complete()
145140
if self._token:
146141
current_trial_id.reset(self._token)
147142

@@ -218,19 +213,18 @@ def should_early_stop(self, metric_key: str, metric_value: float) -> bool:
218213

219214
def _timeout(self) -> int | None:
220215
timeout = self._config.max_runtime_seconds
221-
if timeout < 0:
216+
if timeout is None or timeout < 0:
222217
return None
223218

224-
# Adjust timeout based on the trial start time from environment variable,
225-
# this is useful when running in cloud env when the trial process may be
226-
# restarted.
227-
start_time = os.environ.get("ALPHATRION_TRIAL_START_TIME", None)
228-
if start_time is not None:
229-
elapsed = (
230-
datetime.now(UTC)
231-
- datetime.fromisoformat(start_time).replace(tzinfo=UTC)
232-
).total_seconds()
233-
timeout -= int(elapsed)
219+
obj = self._get_obj()
220+
if obj is None:
221+
return timeout
222+
223+
elapsed = (
224+
datetime.now(UTC) - obj.created_at.replace(tzinfo=UTC)
225+
).total_seconds()
226+
timeout -= int(elapsed)
227+
234228
return timeout
235229

236230
# Make sure you have termination condition, either by timeout or by calling cancel()
@@ -240,7 +234,7 @@ def _timeout(self) -> int | None:
240234
async def wait(self):
241235
await self._context.wait()
242236

243-
def cancelled(self) -> bool:
237+
def done(self) -> bool:
244238
return self._context.cancelled()
245239

246240
# If the name is same in the same experiment, it will refer to the existing trial.
@@ -254,7 +248,7 @@ def _start(
254248
trial_obj = self._runtime._metadb.get_trial_by_name(
255249
trial_name=name, exp_id=self._exp_id
256250
)
257-
# FIXME: what if the existing trial is finished, will lead to confusion?
251+
# FIXME: what if the existing trial is completed, will lead to confusion?
258252
if trial_obj:
259253
self._id = trial_obj.uuid
260254
else:
@@ -268,25 +262,29 @@ def _start(
268262
status=TrialStatus.RUNNING,
269263
)
270264

265+
self._context = Context(
266+
cancel_func=self._stop,
267+
timeout=self._timeout(),
268+
)
269+
271270
# We don't reset the trial id context var here, because
272271
# each trial runs in its own context.
273272
self._token = current_trial_id.set(self._id)
274-
self._context.start()
275273

276-
# cancel function should be called manually as a pair of start
274+
# complete function should be called manually as a pair of start
277275
# FIXME: watch for system signals to cancel the trial gracefully,
278-
# or it could lead to trial not being marked as finished.
279-
def cancel(self):
276+
# or it could lead to trial not being marked as completed.
277+
def complete(self):
280278
self._context.cancel()
281279

282280
def _stop(self):
283281
trial = self._runtime._metadb.get_trial(trial_id=self._id)
284-
if trial is not None and trial.status not in COMPLETED_STATUS:
282+
if trial is not None and trial.status not in FINISHED_STATUS:
285283
duration = (
286284
datetime.now(UTC) - trial.created_at.replace(tzinfo=UTC)
287285
).total_seconds()
288286
self._runtime._metadb.update_trial(
289-
trial_id=self._id, status=TrialStatus.FINISHED, duration=duration
287+
trial_id=self._id, status=TrialStatus.COMPLETED, duration=duration
290288
)
291289

292290
self._runtime.current_exp.unregister_trial(self._id)
@@ -321,7 +319,7 @@ def start_run(self, call_func: callable) -> Run:
321319
task.add_done_callback(
322320
lambda t: (
323321
setattr(self, "_total_runs_counter", self._total_runs_counter + 1),
324-
self.cancel()
322+
self.complete()
325323
if self._total_runs_counter >= self._config.max_runs_per_trial
326324
else None,
327325
)

alphatrion/utils/context.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
# Inspired by golang context package
66
class Context:
7-
def __init__(self, cancel_func: Callable | None = None, timeout=None):
7+
def __init__(self, cancel_func: Callable | None = None, timeout: int | None = None):
88
"""A context for managing cancellation and timeouts.
99
:param cancel_func: A function to call when the context is cancelled.
1010
:param timeout: Timeout in seconds. If None, no timeout is set.
@@ -13,9 +13,6 @@ def __init__(self, cancel_func: Callable | None = None, timeout=None):
1313
self._cancel_func = cancel_func
1414
self._timeout = timeout
1515

16-
def start(self):
17-
# If timeout is None, it means no timeout is set.
18-
# If timeout is negative, it means already timed out.
1916
if self._timeout is not None:
2017
asyncio.create_task(self._auto_cancel(self._timeout))
2118

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphatrion"
3-
version = "0.0.3"
3+
version = "0.0.4"
44
description = "⚒️ AlphaTrion is an open-source framework to help build GenAI applications, including experiment tracking, adaptive model routing, prompt optimization and performance evaluation."
55
license = {text = "Apache-2.0"}
66
readme = "README.md"

tests/integration/test_log.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ async def test_log_artifact():
5252
versions = exp._runtime._artifact.list_versions(exp_obj.uuid)
5353
assert len(versions) == 0
5454

55-
trial.cancel()
55+
trial.complete()
5656

5757
got_exp = exp._runtime._metadb.get_exp(exp_id=exp._id)
5858
assert got_exp is not None
@@ -61,7 +61,7 @@ async def test_log_artifact():
6161
got_trial = exp._runtime._metadb.get_trial(trial_id=trial._id)
6262
assert got_trial is not None
6363
assert got_trial.name == "first-trial"
64-
assert got_trial.status == TrialStatus.FINISHED
64+
assert got_trial.status == TrialStatus.COMPLETED
6565

6666

6767
@pytest.mark.asyncio
@@ -84,11 +84,11 @@ async def test_log_params():
8484
assert new_trial.status == TrialStatus.RUNNING
8585
assert current_trial_id.get() == trial.id
8686

87-
trial.cancel()
87+
trial.complete()
8888

8989
trial = exp.start_trial(name="second-trial", params={"param1": 0.1})
9090
assert current_trial_id.get() == trial.id
91-
trial.cancel()
91+
trial.complete()
9292

9393

9494
@pytest.mark.asyncio
@@ -135,7 +135,7 @@ async def log_metric(metrics: dict):
135135
assert run_id_2 is not None
136136
assert run_id_2 != run_id_1
137137

138-
trial.cancel()
138+
trial.complete()
139139

140140

141141
@pytest.mark.asyncio
@@ -201,7 +201,7 @@ async def log_metric(value: float):
201201
versions = exp._runtime._artifact.list_versions(exp.id)
202202
assert len(versions) == 3
203203

204-
trial.cancel()
204+
trial.complete()
205205

206206

207207
@pytest.mark.asyncio
@@ -266,7 +266,7 @@ async def log_metric(value: float):
266266
versions = exp._runtime._artifact.list_versions(exp.id)
267267
assert len(versions) == 3
268268

269-
trial.cancel()
269+
trial.complete()
270270

271271

272272
@pytest.mark.asyncio
@@ -356,7 +356,7 @@ async def fake_work(value: float):
356356
max_runs_per_trial=5,
357357
),
358358
) as trial:
359-
while not trial.cancelled():
359+
while not trial.done():
360360
run = trial.start_run(lambda: fake_work(1))
361361
await run.wait()
362362

0 commit comments

Comments
 (0)