Skip to content

Commit ec14600

Browse files
committed
Implement async scheduling for single card
1 parent 96f9d3e commit ec14600

7 files changed

Lines changed: 487 additions & 85 deletions

File tree

gllm/async_worker.py

Lines changed: 0 additions & 63 deletions
This file was deleted.

gllm/entrypoints/api_server.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,11 @@ async def run_server(args):
169169
help="Experimental feature for worker implemented by async",
170170
action="store_true",
171171
)
172+
parser.add_argument(
173+
"--async-scheduling",
174+
help="Overlap CPU input preparation for the next batch with GPU execution of the current batch",
175+
action="store_true",
176+
)
172177
parser.add_argument(
173178
"--gpu-memory-util",
174179
type=float,
@@ -301,6 +306,7 @@ async def run_server(args):
301306
assigned_layers=args.assigned_layers,
302307
schedule_method=args.schedule_method,
303308
use_async_worker=args.use_async_worker,
309+
async_scheduling=args.async_scheduling,
304310
use_thinking=not args.disable_thinking,
305311
disable_cuda_graph=args.disable_cuda_graph,
306312
max_cuda_graph_bs=args.max_cuda_graph_bs,

gllm/layers/sampler.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,15 @@
66
class Sampler:
77

88
def forward(self, logits: torch.Tensor, input_data: InputData):
9+
return self.forward_gpu(logits, input_data).cpu().numpy().tolist()
10+
11+
def forward_gpu(self, logits: torch.Tensor, input_data: InputData) -> torch.Tensor:
12+
"""Same as forward() but returns a GPU tensor without D2H copy.
13+
14+
Used by async scheduling so the D2H transfer can be initiated with
15+
non_blocking=True on a dedicated copy stream, overlapping with
16+
CPU-side scheduling work for the next batch.
17+
"""
918
# repetition_penalty
1019
logits /= torch.where(logits > 0, input_data.repetition_penalty, 1.0)
1120
logits *= torch.where(logits <= 0, 1.0, input_data.repetition_penalty)
@@ -17,8 +26,7 @@ def forward(self, logits: torch.Tensor, input_data: InputData):
1726

1827
q = torch.empty_like(probs)
1928
q.exponential_()
20-
return probs.div_(q).argmax(dim=1).cpu().numpy().tolist()
21-
# return torch.multinomial(probs, 1).squeeze(1).cpu().numpy().tolist()
29+
return probs.div_(q).argmax(dim=1)
2230

2331
def _apply_top_k_top_p(
2432
self,

gllm/llm_engine.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,12 @@
66
import tqdm
77
from logger import logger
88

9-
from gllm.async_worker import AsyncWorker, run_worker_async
109
from gllm.comm import IPCPackage, zmqComm
1110
from gllm.id_allocator import IDAllocator
12-
from gllm.model_runner import ModelRunner
11+
from gllm.model_runner import ModelRunner, AsyncModelRunner
1312
from gllm.sequence import Sequence
1413
from gllm.utils import get_model_load_pbar, init_logger, random_uuid
15-
from gllm.worker import Worker, run_worker
14+
from gllm.worker import Worker, AsyncWorker, run_worker
1615

1716

1817
class LLM:
@@ -40,6 +39,7 @@ def __init__(
4039
assigned_layers=None,
4140
schedule_method="chunked_prefill",
4241
use_async_worker=False,
42+
async_scheduling=False,
4343
use_thinking=True,
4444
disable_cuda_graph=False,
4545
max_cuda_graph_bs=32,
@@ -50,7 +50,8 @@ def __init__(
5050
init_logger()
5151
self.model_path = model_path
5252
self.load_format = load_format
53-
self.model_runner = ModelRunner(
53+
model_runner_cls = AsyncModelRunner if async_scheduling else ModelRunner
54+
self.model_runner = model_runner_cls(
5455
load_format=load_format,
5556
model_path=model_path,
5657
gpu_memory_util=gpu_memory_util,
@@ -90,8 +91,11 @@ def __init__(
9091
self.assigned_layers = assigned_layers
9192
self.schedule_method = schedule_method
9293
self.use_async_worker = use_async_worker
94+
self.async_scheduling = async_scheduling
9395

9496
logger.info(f"Schedule method: {schedule_method}")
97+
if async_scheduling:
98+
logger.info("Async scheduling enabled")
9599

96100
# Interact with workers
97101
self.wait_lists: List[Sequence] = []
@@ -168,7 +172,7 @@ def init_workers(self):
168172
self.load_progress()
169173

170174
def start_worker(self, local_rank, pp_rank, tp_rank):
171-
worker_cls = Worker if not self.use_async_worker else AsyncWorker
175+
worker_cls = Worker if not self.async_scheduling else AsyncWorker
172176
comm = zmqComm(
173177
self.host,
174178
self.zmq_port_base,
@@ -195,7 +199,7 @@ def start_worker(self, local_rank, pp_rank, tp_rank):
195199
self.schedule_method,
196200
)
197201
process = self.ctx.Process(
198-
target=run_worker if not self.use_async_worker else run_worker_async,
202+
target=run_worker,
199203
args=(worker,),
200204
daemon=True,
201205
)

0 commit comments

Comments
 (0)