From 188e68ea237f679977f2a266cf3931c649f6b01b Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Mon, 30 Sep 2024 16:54:12 -0700 Subject: [PATCH] [Serving] Remove draft tokens after finishing request (#2953) Fixed a leak of draft token slots because they are not released when requests are finished. --- cpp/serve/engine.cc | 5 +++- cpp/serve/engine_actions/action_commons.cc | 27 +++++++++++++++------- cpp/serve/engine_actions/action_commons.h | 2 ++ 3 files changed, 25 insertions(+), 9 deletions(-) diff --git a/cpp/serve/engine.cc b/cpp/serve/engine.cc index 81972346a2..67b312a0bf 100644 --- a/cpp/serve/engine.cc +++ b/cpp/serve/engine.cc @@ -406,6 +406,7 @@ class EngineImpl : public Engine { n->actions_ = CreateEngineActions( n->models_, engine_config, model_configs, n->model_workspaces_, logit_processor, sampler, draft_token_workspace_manager, n->tokenizer_, n->trace_recorder_); + n->draft_token_workspace_manager_ = draft_token_workspace_manager; // - Automatically set the threading backend max concurrency. n->engine_config_ = engine_config; n->SetThreadMaxConcurrency(); @@ -595,7 +596,7 @@ class EngineImpl : public Engine { if (!processed_requests.empty()) { ActionStepPostProcess(processed_requests, estate_, models_, tokenizer_, request_stream_callback_, engine_config_->max_single_sequence_length, - trace_recorder_); + draft_token_workspace_manager_, trace_recorder_); return; } } @@ -844,6 +845,8 @@ class EngineImpl : public Engine { FRequestStreamCallback request_stream_callback_; // Engine actions. Array actions_; + // Draft token workspace manager for speculative decoding. + Optional draft_token_workspace_manager_; // Event trace recorder. Optional trace_recorder_; }; diff --git a/cpp/serve/engine_actions/action_commons.cc b/cpp/serve/engine_actions/action_commons.cc index 01520520af..0e8c3eab55 100644 --- a/cpp/serve/engine_actions/action_commons.cc +++ b/cpp/serve/engine_actions/action_commons.cc @@ -126,7 +126,15 @@ void RemoveRequestFromModel(EngineState estate, int64_t req_internal_id, * \param rsentry The request state entry to remove. */ void RemoveRequestStateEntry(EngineState estate, const Array& models, - RequestStateEntry rsentry) { + RequestStateEntry rsentry, + Optional draft_token_workspace_manager) { + if (draft_token_workspace_manager.defined()) { + std::vector draft_token_slots; + for (const RequestModelState& mstate : rsentry->mstates) { + mstate->RemoveAllDraftTokens(&draft_token_slots); + draft_token_workspace_manager.value()->FreeSlots(draft_token_slots); + } + } if (estate->prefix_cache->HasSequence(rsentry->mstates[0]->internal_id)) { // If the sequence is stored in prefix cache, call prefix cache to remove. if (!(rsentry->request->generation_cfg->debug_config.pinned_system_prompt)) { @@ -141,10 +149,11 @@ void RemoveRequestStateEntry(EngineState estate, const Array& models, } } -void ProcessFinishedRequestStateEntries(const std::vector& finished_rsentries, - EngineState estate, const Array& models, - int max_single_sequence_length, - Array* callback_delta_outputs) { +void ProcessFinishedRequestStateEntries( + const std::vector& finished_rsentries, EngineState estate, + const Array& models, int max_single_sequence_length, + Optional draft_token_workspace_manager, + Array* callback_delta_outputs) { NVTXScopedRange nvtx_scope("Process finished requests"); // - Remove the finished request state entries. for (const RequestStateEntry& rsentry : finished_rsentries) { @@ -153,7 +162,7 @@ void ProcessFinishedRequestStateEntries(const std::vector& fi // Mark the status of this entry as finished. rsentry->status = RequestStateStatus::kFinished; // Remove the request state entry from all the models. - RemoveRequestStateEntry(estate, models, rsentry); + RemoveRequestStateEntry(estate, models, rsentry, draft_token_workspace_manager); RequestState rstate = estate->GetRequestState(rsentry->request); int parent_idx = rsentry->parent_idx; @@ -174,7 +183,8 @@ void ProcessFinishedRequestStateEntries(const std::vector& fi rstate->entries[parent_idx]->status = RequestStateStatus::kFinished; // Remove the request state entry from all the models. - RemoveRequestStateEntry(estate, models, rstate->entries[parent_idx]); + RemoveRequestStateEntry(estate, models, rstate->entries[parent_idx], + draft_token_workspace_manager); // Climb up to the parent. parent_idx = rstate->entries[parent_idx]->parent_idx; } @@ -206,6 +216,7 @@ void ActionStepPostProcess(Array requests, EngineState estate, const Ar const Tokenizer& tokenizer, FRequestStreamCallback request_stream_callback, int64_t max_single_sequence_length, + Optional draft_token_workspace_manager, Optional trace_recorder) { NVTXScopedRange nvtx_scope("EngineAction postproc"); int num_requests = requests.size(); @@ -272,7 +283,7 @@ void ActionStepPostProcess(Array requests, EngineState estate, const Ar } ProcessFinishedRequestStateEntries(estate->postproc_workspace.finished_rsentries, estate, models, - max_single_sequence_length, + max_single_sequence_length, draft_token_workspace_manager, &estate->postproc_workspace.callback_delta_outputs); if (!estate->postproc_workspace.callback_delta_outputs.empty()) { diff --git a/cpp/serve/engine_actions/action_commons.h b/cpp/serve/engine_actions/action_commons.h index 6791602fba..e4c30dd8eb 100644 --- a/cpp/serve/engine_actions/action_commons.h +++ b/cpp/serve/engine_actions/action_commons.h @@ -52,6 +52,7 @@ void RemoveRequestFromModel(EngineState estate, int64_t req_internal_id, * \param tokenizer The tokenizer for logprob process. * \param request_stream_callback The request stream callback function. * \param max_single_sequence_length The max single sequence length to help decide + * \param draft_token_workspace_manager The draft token workspace manager. * \param trace_recorder The event trace recorder for requests. * if a request is finished. */ @@ -59,6 +60,7 @@ void ActionStepPostProcess(Array requests, EngineState estate, const Ar const Tokenizer& tokenizer, FRequestStreamCallback request_stream_callback, int64_t max_single_sequence_length, + Optional draft_token_workspace_manager, Optional trace_recorder); /*!