Skip to content

Commit

Permalink
[Serving] Remove draft tokens after finishing request (#2953)
Browse files Browse the repository at this point in the history
Fixed a leak of draft token slots because they are not released when
requests are finished.
  • Loading branch information
vinx13 authored Sep 30, 2024
1 parent d2cd68e commit 188e68e
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 9 deletions.
5 changes: 4 additions & 1 deletion cpp/serve/engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,7 @@ class EngineImpl : public Engine {
n->actions_ = CreateEngineActions(
n->models_, engine_config, model_configs, n->model_workspaces_, logit_processor, sampler,
draft_token_workspace_manager, n->tokenizer_, n->trace_recorder_);
n->draft_token_workspace_manager_ = draft_token_workspace_manager;
// - Automatically set the threading backend max concurrency.
n->engine_config_ = engine_config;
n->SetThreadMaxConcurrency();
Expand Down Expand Up @@ -595,7 +596,7 @@ class EngineImpl : public Engine {
if (!processed_requests.empty()) {
ActionStepPostProcess(processed_requests, estate_, models_, tokenizer_,
request_stream_callback_, engine_config_->max_single_sequence_length,
trace_recorder_);
draft_token_workspace_manager_, trace_recorder_);
return;
}
}
Expand Down Expand Up @@ -844,6 +845,8 @@ class EngineImpl : public Engine {
FRequestStreamCallback request_stream_callback_;
// Engine actions.
Array<EngineAction> actions_;
// Draft token workspace manager for speculative decoding.
Optional<DraftTokenWorkspaceManager> draft_token_workspace_manager_;
// Event trace recorder.
Optional<EventTraceRecorder> trace_recorder_;
};
Expand Down
27 changes: 19 additions & 8 deletions cpp/serve/engine_actions/action_commons.cc
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,15 @@ void RemoveRequestFromModel(EngineState estate, int64_t req_internal_id,
* \param rsentry The request state entry to remove.
*/
void RemoveRequestStateEntry(EngineState estate, const Array<Model>& models,
RequestStateEntry rsentry) {
RequestStateEntry rsentry,
Optional<DraftTokenWorkspaceManager> draft_token_workspace_manager) {
if (draft_token_workspace_manager.defined()) {
std::vector<int> draft_token_slots;
for (const RequestModelState& mstate : rsentry->mstates) {
mstate->RemoveAllDraftTokens(&draft_token_slots);
draft_token_workspace_manager.value()->FreeSlots(draft_token_slots);
}
}
if (estate->prefix_cache->HasSequence(rsentry->mstates[0]->internal_id)) {
// If the sequence is stored in prefix cache, call prefix cache to remove.
if (!(rsentry->request->generation_cfg->debug_config.pinned_system_prompt)) {
Expand All @@ -141,10 +149,11 @@ void RemoveRequestStateEntry(EngineState estate, const Array<Model>& models,
}
}

void ProcessFinishedRequestStateEntries(const std::vector<RequestStateEntry>& finished_rsentries,
EngineState estate, const Array<Model>& models,
int max_single_sequence_length,
Array<RequestStreamOutput>* callback_delta_outputs) {
void ProcessFinishedRequestStateEntries(
const std::vector<RequestStateEntry>& finished_rsentries, EngineState estate,
const Array<Model>& models, int max_single_sequence_length,
Optional<DraftTokenWorkspaceManager> draft_token_workspace_manager,
Array<RequestStreamOutput>* callback_delta_outputs) {
NVTXScopedRange nvtx_scope("Process finished requests");
// - Remove the finished request state entries.
for (const RequestStateEntry& rsentry : finished_rsentries) {
Expand All @@ -153,7 +162,7 @@ void ProcessFinishedRequestStateEntries(const std::vector<RequestStateEntry>& fi
// Mark the status of this entry as finished.
rsentry->status = RequestStateStatus::kFinished;
// Remove the request state entry from all the models.
RemoveRequestStateEntry(estate, models, rsentry);
RemoveRequestStateEntry(estate, models, rsentry, draft_token_workspace_manager);

RequestState rstate = estate->GetRequestState(rsentry->request);
int parent_idx = rsentry->parent_idx;
Expand All @@ -174,7 +183,8 @@ void ProcessFinishedRequestStateEntries(const std::vector<RequestStateEntry>& fi
rstate->entries[parent_idx]->status = RequestStateStatus::kFinished;
// Remove the request state entry from all the models.

RemoveRequestStateEntry(estate, models, rstate->entries[parent_idx]);
RemoveRequestStateEntry(estate, models, rstate->entries[parent_idx],
draft_token_workspace_manager);
// Climb up to the parent.
parent_idx = rstate->entries[parent_idx]->parent_idx;
}
Expand Down Expand Up @@ -206,6 +216,7 @@ void ActionStepPostProcess(Array<Request> requests, EngineState estate, const Ar
const Tokenizer& tokenizer,
FRequestStreamCallback request_stream_callback,
int64_t max_single_sequence_length,
Optional<DraftTokenWorkspaceManager> draft_token_workspace_manager,
Optional<EventTraceRecorder> trace_recorder) {
NVTXScopedRange nvtx_scope("EngineAction postproc");
int num_requests = requests.size();
Expand Down Expand Up @@ -272,7 +283,7 @@ void ActionStepPostProcess(Array<Request> requests, EngineState estate, const Ar
}

ProcessFinishedRequestStateEntries(estate->postproc_workspace.finished_rsentries, estate, models,
max_single_sequence_length,
max_single_sequence_length, draft_token_workspace_manager,
&estate->postproc_workspace.callback_delta_outputs);

if (!estate->postproc_workspace.callback_delta_outputs.empty()) {
Expand Down
2 changes: 2 additions & 0 deletions cpp/serve/engine_actions/action_commons.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,15 @@ void RemoveRequestFromModel(EngineState estate, int64_t req_internal_id,
* \param tokenizer The tokenizer for logprob process.
* \param request_stream_callback The request stream callback function.
* \param max_single_sequence_length The max single sequence length to help decide
* \param draft_token_workspace_manager The draft token workspace manager.
* \param trace_recorder The event trace recorder for requests.
* if a request is finished.
*/
void ActionStepPostProcess(Array<Request> requests, EngineState estate, const Array<Model>& models,
const Tokenizer& tokenizer,
FRequestStreamCallback request_stream_callback,
int64_t max_single_sequence_length,
Optional<DraftTokenWorkspaceManager> draft_token_workspace_manager,
Optional<EventTraceRecorder> trace_recorder);

/*!
Expand Down

0 comments on commit 188e68e

Please sign in to comment.