Skip to content

Commit

Permalink
Limit number of sequences (#220)
Browse files Browse the repository at this point in the history
* Add max_num_seq

* add to engine

* Limit max_num_seq when grows the batch

* Add max num seq to args

* Add max_num_seq_per_sequence

* Expose gpu memory utilization to engine config

* Fix dataclass

* Remove the default value of gpu_memory_utilization

* Fix params

* Apply to torch model
  • Loading branch information
yelite authored Feb 23, 2024
1 parent a0e680c commit 8ee6aaa
Show file tree
Hide file tree
Showing 7 changed files with 45 additions and 3 deletions.
3 changes: 3 additions & 0 deletions serve/mlc_serve/engine/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,14 @@ class MLCServeEngineConfig:
# TODO(@sunggg): figure out better defaults
use_staging_engine: bool = True
max_num_batched_tokens: int = 4096
max_num_seq: int = 256
max_num_seq_per_request: Optional[int] = None # default to `max_num_seq / 4`
min_decode_steps: int = 32
max_decode_steps: int = 48
init_timeout: int = 120
model_type: str = "tvm" # "tvm", "torch"
num_shards: Optional[int] = None # Need to be specified for if model_type is "torch"
gpu_memory_utilization: float = 0.9

@classmethod
def _from_json(config_cls, json_obj: Dict[Any, Any]):
Expand Down
14 changes: 14 additions & 0 deletions serve/mlc_serve/engine/engine_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,8 @@ class EngineBase:
model_artifact_config: ModelArtifactConfig
max_context_length: int
max_num_batched_tokens: int
max_num_seq: int
max_num_seq_per_request: int
max_decode_steps: int
min_decode_steps: int
kv_cache_size: int
Expand All @@ -426,6 +428,10 @@ def __init__(self, model_module: ModelModule):
), "max_context_length must not be zero"
self.max_context_length = self.model_artifact_config.max_context_length
self.max_num_batched_tokens = model_module.engine_config.max_num_batched_tokens
self.max_num_seq = model_module.engine_config.max_num_seq
self.max_num_seq_per_request = model_module.engine_config.max_num_seq_per_request
if self.max_num_seq_per_request is None:
self.max_num_seq_per_request = self.max_num_seq // 4
self.max_decode_steps = min(
self.cache_manager.get_kv_cache_size(),
model_module.engine_config.max_decode_steps,
Expand Down Expand Up @@ -592,6 +598,14 @@ def try_grow_batch(self, num_new_batched_tokens) -> Optional[int]:
)
return None

current_num_seq = sum(len(s.generation_sequences) for s in self.current_batch.values())
if current_num_seq + len(state.generation_sequences) > self.max_num_seq:
LOG.debug(
"Stop growing the batch due to max number of sequences.",
)
return None


self.queue.popleft()
self.cache_manager.allocate(state.request_id, num_tokens, state.num_sequences)
self.current_batch[state.request_id] = state
Expand Down
6 changes: 6 additions & 0 deletions serve/mlc_serve/engine/staging_engine_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,12 @@ def add(self, request_states: list[RequestState]):
"The prompt is too long for the given set of engine"
" parameters."
)
elif state.num_sequences > self.max_num_seq_per_request:
self.cancelled_requests.append(state)
state.validation_err = ValidationError(
f"The number of sequences ({state.num_sequences}) is greater"
f"than the maximum allowed value ({self.max_num_seq_per_request})"
)
else:
valid_states.append(state)

Expand Down
2 changes: 1 addition & 1 deletion serve/mlc_serve/model/model_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def get_num_cache_blocks(
num_layers,
num_kv_heads,
head_size,
gpu_memory_utilization=0.9, # the default used by vllm
gpu_memory_utilization,
):
cache_block_size = CacheManager.get_cache_block_size(
block_size, num_layers, num_kv_heads, head_size
Expand Down
11 changes: 10 additions & 1 deletion serve/mlc_serve/model/torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,8 @@ def profile_and_init_cache(
hf_config,
num_shards,
max_num_batched_tokens,
max_num_seq,
gpu_memory_utilization,
):
num_kv_heads = hf_config.num_key_value_heads // num_shards
num_hidden_layers = hf_config.num_hidden_layers
Expand All @@ -177,7 +179,9 @@ def profile_and_init_cache(

if max_num_batched_tokens > 0:
LOG.info("Running memory profiling.")
seq_lens = [1] * max_num_batched_tokens
seq_len = max_num_batched_tokens // max_num_seq
seq_lens = [seq_len] * max_num_seq
seq_lens[-1] += max_num_batched_tokens % max_num_seq
used_memory_bytes = profile_memory_usage(
pt_model, seq_lens, num_hidden_layers, hf_config.vocab_size
)
Expand All @@ -187,6 +191,7 @@ def profile_and_init_cache(
hf_config.num_hidden_layers,
num_kv_heads,
head_size,
gpu_memory_utilization,
)
else:
num_blocks = 500
Expand Down Expand Up @@ -423,6 +428,8 @@ def exposed_init_model(
hf_config,
num_shards,
engine_config.max_num_batched_tokens,
engine_config.max_num_seq,
engine_config.gpu_memory_utilization,
)

return num_blocks
Expand Down Expand Up @@ -593,6 +600,8 @@ def __init__(
hf_config,
1,
engine_config.max_num_batched_tokens,
engine_config.max_num_seq,
engine_config.gpu_memory_utilization,
)
self.model_rpc = None

Expand Down
8 changes: 7 additions & 1 deletion serve/mlc_serve/model/tvm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,14 +588,20 @@ def init_tvm_model(
if engine_config.max_num_batched_tokens > 0:
LOG.info("Running memory profiling.")
try:
seq_lens = [1] * engine_config.max_num_batched_tokens
max_num_seq = engine_config.max_num_seq
max_num_batched_tokens = engine_config.max_num_batched_tokens
seq_len = max_num_batched_tokens // max_num_seq
seq_lens = [seq_len] * max_num_seq
seq_lens[-1] += max_num_batched_tokens % max_num_seq

used_memory_bytes = model.profile_memory_usage(seq_lens)
num_blocks = get_num_cache_blocks(
used_memory_bytes,
block_size,
model_artifact_config.num_hidden_layers,
num_kv_heads,
head_size,
engine_config.gpu_memory_utilization,
)
except tvm.error.InternalError:
raise RuntimeError(
Expand Down
4 changes: 4 additions & 0 deletions serve/mlc_serve/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,10 @@ def get_default_mlc_serve_argparser(description="", allow_override=False):
parser.add_argument("--use-sync-engine", action="store_true")
parser.add_argument("--num-sequences-to-sample", type=int, default=1)
parser.add_argument("--max-num-batched-tokens", type=int, default=4096)
parser.add_argument("--max-num-seq", type=int, default=256)
parser.add_argument("--min-decode-steps", type=int, default=32)
parser.add_argument("--max-decode-steps", type=int, default=56)
parser.add_argument("--gpu-memory-utilization", type=float, default=0.9)
parser.add_argument("--debug-logging", action="store_true")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--num-shards", type=int, default=1) # Needed for PT models
Expand Down Expand Up @@ -73,10 +75,12 @@ def create_mlc_engine(args: argparse.Namespace, start_engine=True) -> InferenceE
{
"use_staging_engine": args.use_staging_engine,
"max_num_batched_tokens": args.max_num_batched_tokens,
"max_num_seq": args.max_num_seq,
"min_decode_steps": args.min_decode_steps,
"max_decode_steps": args.max_decode_steps,
"model_type": model_type,
"num_shards": num_shards,
"gpu_memory_utilization": args.gpu_memory_utilization,
}
)

Expand Down

0 comments on commit 8ee6aaa

Please sign in to comment.