diff --git a/serve/mlc_serve/model/tvm_model.py b/serve/mlc_serve/model/tvm_model.py index 0c28ff7003..c467e5be7f 100644 --- a/serve/mlc_serve/model/tvm_model.py +++ b/serve/mlc_serve/model/tvm_model.py @@ -420,13 +420,19 @@ def init_tvm_model( if engine_config.max_num_batched_tokens > 0: LOG.info("Running memory profiling.") - num_blocks = get_num_cache_blocks( - model, - [1] * engine_config.max_num_batched_tokens, - model_artifact_config.num_hidden_layers, - num_kv_heads, - head_size, - ) + try: + num_blocks = get_num_cache_blocks( + model, + [1] * engine_config.max_num_batched_tokens, + model_artifact_config.num_hidden_layers, + num_kv_heads, + head_size, + ) + except tvm.error.InternalError: + raise RuntimeError( + f"Memory profiling failed with max_num_batched_tokens = " + "{engine_config.max_num_batched_tokens}." + ) else: num_blocks = 500 @@ -450,13 +456,16 @@ def init_tvm_model( else: init_cache_func = tvm.get_global_func("tvm.contrib.vllm.allocate_kv_cache") - model.cache_blocks = init_cache_func( - head_size, - model_artifact_config.num_hidden_layers, - num_kv_heads, - CacheManager.block_size, - num_blocks, - ) + try: + model.cache_blocks = init_cache_func( + head_size, + model_artifact_config.num_hidden_layers, + num_kv_heads, + CacheManager.block_size, + num_blocks, + ) + except tvm.error.InternalError: + raise RuntimeError(f"Failed to allocate {num_blocks} cache blocks.") cache_manager = CacheManager( num_blocks, diff --git a/serve/tests/unittest/test_engine_init.py b/serve/tests/unittest/test_engine_init.py index d79b1b7116..ea50aee49b 100644 --- a/serve/tests/unittest/test_engine_init.py +++ b/serve/tests/unittest/test_engine_init.py @@ -4,6 +4,8 @@ from mlc_serve.engine import get_engine_config from mlc_serve.model.paged_cache_model import PagedCacheModelModule +from mlc_serve.model.base import get_model_artifact_config +from mlc_serve.model.tvm_model import init_tvm_model def _test_insufficient_cache_blocks_fail(artifact_path): @@ -34,9 +36,30 @@ def try_init(max_num_seqs): assert "Try reducing" in str(e_info.value) +def _test_catch_cache_alloc_oom(artifact_path): + model_artifact_path = os.path.join(artifact_path, "llama-2-13b-chat-hf-q0f16") + + if not os.path.exists(os.path.join(model_artifact_path)): + return + + model_artifact_config = get_model_artifact_config(model_artifact_path) + + engine_config = get_engine_config( + { + "max_num_batched_tokens": 40960 + } + ) + + with pytest.raises(RuntimeError) as e_info: + init_tvm_model(model_artifact_config, engine_config) + + assert "Failed to allocate" in str(e_info.value) + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--artifact-path", type=str, default="dist") args = parser.parse_args() _test_insufficient_cache_blocks_fail(args.artifact_path) + _test_catch_cache_alloc_oom(args.artifact_path)