Skip to content

Commit

Permalink
[Fix][Bench] Improve determinism and log dump (#2960)
Browse files Browse the repository at this point in the history
This PR updates the benchmark with the following changes:

1. Fix the preprocessing of ShareGPT dataset. In ShareGPT dataset,
sometimes a conversation entry starts with role GPT instead of role
human, but previously the benchmark always selects the first message
as the input and second as output. Therefore it might cause sending
the GPT output in the dataset as a request input to benchmark servers.
This does not align with the expectation and we thus fix it.

2. We make sure that for a given seed, given tokenizer and a given
number of requests, the sampled requests are always the same across
different runs. This improves the determinism and reproducibility.

3. Introduce logging raw request records into file. The raw request
records include all information of a request, including its input
prompt, measured metrics, error message, etc.
  • Loading branch information
MasterJH5574 authored Oct 4, 2024
1 parent f51c0f6 commit 42aaa7f
Show file tree
Hide file tree
Showing 5 changed files with 124 additions and 36 deletions.
45 changes: 40 additions & 5 deletions python/mlc_llm/bench/__main__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""MLC LLM benchmark main entrance"""

import functools
import json
import random
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
import requests
Expand All @@ -17,6 +18,7 @@
create_pipelines,
)
from mlc_llm.bench.request_record import (
RequestRecord,
convert_reports_to_df,
generate_metrics_summary,
pretty_print_report,
Expand Down Expand Up @@ -88,7 +90,7 @@ def run_pipeline(
dataset: Dataset,
tokenizer: AutoTokenizer,
args: argparse.argparse.Namespace,
) -> Dict[str, Any]:
) -> Tuple[Dict[str, Any], List[RequestRecord]]:
"""Run the pipeline with the given dataset and args. Return the benchmark report dict."""
random.seed(args.seed)
np.random.seed(args.seed)
Expand All @@ -100,10 +102,15 @@ def run_pipeline(
)
request_records = pipeline(request_records)
assert len(request_records) == args.num_requests
sorted_requests: List[RequestRecord] = [None] * args.num_requests
for request_record in request_records:
assert request_record.request_id is not None
assert sorted_requests[request_record.request_id] is None
sorted_requests[request_record.request_id] = request_record

request_records = MetricAnalyzer(tokenizer)(request_records)
report = generate_metrics_summary(request_records, args.num_requests, args.num_gpus)
return report
return report, sorted_requests


def query_mlc_server_metrics(host: str, port: int):
Expand All @@ -130,8 +137,17 @@ def _main():
f_create_api_endpoint = functools.partial(create_api_endpoint, args)
pipelines = create_pipelines(args, f_create_api_endpoint)
reports = []
for pipeline in pipelines:
report = run_pipeline(pipeline, dataset, tokenizer, args)
alltime_records = {}
for i, pipeline in enumerate(pipelines):
report, request_records = run_pipeline(pipeline, dataset, tokenizer, args)
exec_feature = (
json.dumps(report["exec_feature"])
if report["exec_feature"] is not None
else f"pipeline{i}"
)
alltime_records[exec_feature] = [
request_record.model_dump() for request_record in request_records
]
reports.append(report)
pretty_print_report(report)
query_mlc_server_metrics(args.host, args.port)
Expand All @@ -141,6 +157,13 @@ def _main():
print(df)
df.to_csv(args.output, index=False)
logger.info("Benchmark results dumped to file %s", args.output)
if args.debug_dump:
debug_dump_filepath = (
args.output[:-4] if args.output.endswith(".csv") else args.output
) + "_debug_dump.log"
with open(debug_dump_filepath, "w", encoding="utf-8") as file:
json.dump(alltime_records, file, indent=4)
logger.info("Debug log dumped to file %s", debug_dump_filepath)

if mlc_server is not None:
with mlc_server:
Expand Down Expand Up @@ -288,6 +311,12 @@ def _main():
default=1.0,
help="The top-p value for sampling. Default to 1.",
)
parser.add_argument(
"--ignore-eos",
default=False,
action="store_true",
help='Whether to set the "ignore_eos" field.',
)
parser.add_argument(
"--num-process-workers",
type=int,
Expand Down Expand Up @@ -329,6 +358,12 @@ def _main():
help="Whether to enable cuda profile on server. "
"The --mlc-model-lib path should be provided when enabling this option.",
)
parser.add_argument(
"--debug-dump",
default=False,
action="store_true",
help="Whether to dump all request record raw data to file.",
)
parser.add_argument(
"--multi-round",
default=False,
Expand Down
47 changes: 37 additions & 10 deletions python/mlc_llm/bench/api_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ async def __aenter__(self) -> Self:
async def __aexit__(self, exc_type, exc_value, tb) -> None:
await self.client.close()

async def __call__( # pylint: disable=too-many-branches,too-many-statements
async def __call__( # pylint: disable=too-many-branches,too-many-statements,too-many-locals
self, request_record: RequestRecord
) -> RequestRecord:
payload = request_record.chat_cmpl.model_dump()
Expand All @@ -89,6 +89,7 @@ async def __call__( # pylint: disable=too-many-branches,too-many-statements

try:
async with self.client.post(self.url, json=payload, headers=self.headers) as response:
assert response.status == 200, await response.text()
if payload["stream"]:
async for chunk in response.content:
chunk = chunk.strip()
Expand Down Expand Up @@ -143,7 +144,8 @@ async def __call__( # pylint: disable=too-many-branches,too-many-statements
# pylint: enable=line-too-long
# fmt: on
except Exception: # pylint: disable=broad-except
logger.info("Error sending request: %s", traceback.format_exc())
error_msg = "API endpoint errored when sending request: " + traceback.format_exc()
logger.info(error_msg)
finish_time = time.monotonic()
request_record.output_str = generated_text
request_record.first_chunk_output_str = first_chunk_output_str
Expand All @@ -157,13 +159,19 @@ async def __call__( # pylint: disable=too-many-branches,too-many-statements
server_metrics=server_metrics,
exec_feature=request_record.metrics.exec_feature,
)
request_record.error_msg = error_msg
return request_record

finish_time = time.monotonic()
request_record.output_str = generated_text
request_record.first_chunk_output_str = first_chunk_output_str
success = True
error_msg = None
if len(generated_text) == 0:
success = False
error_msg = "Empty generated text."
request_record.metrics = Metrics(
success=len(generated_text) > 0,
success=success,
start_time=start_time,
finish_time=finish_time,
end_to_end_latency_s=finish_time - start_time,
Expand All @@ -172,6 +180,7 @@ async def __call__( # pylint: disable=too-many-branches,too-many-statements
server_metrics=server_metrics,
exec_feature=request_record.metrics.exec_feature,
)
request_record.error_msg = error_msg
return request_record


Expand Down Expand Up @@ -230,6 +239,7 @@ async def __call__( # pylint: disable=too-many-branches,too-many-statements
and request_record.chat_cmpl.debug_config.ignore_eos
):
payload["ignore_eos"] = True
payload["debug_config"] = {"ignore_eos": True}

generated_text = ""
first_chunk_output_str = ""
Expand All @@ -238,6 +248,7 @@ async def __call__( # pylint: disable=too-many-branches,too-many-statements

try:
async with self.client.post(self.url, json=payload, headers=self.headers) as response:
assert response.status == 200, await response.text()
if payload["stream"]:
async for chunk in response.content:
chunk = chunk.strip()
Expand All @@ -260,7 +271,8 @@ async def __call__( # pylint: disable=too-many-branches,too-many-statements
data = await response.json()
generated_text = data["choices"][0]["message"]["content"]
except Exception: # pylint: disable=broad-except
logger.info("Error sending request: %s", traceback.format_exc())
error_msg = "API endpoint errored when sending request: " + traceback.format_exc()
logger.info(error_msg)
finish_time = time.monotonic()
request_record.output_str = generated_text
request_record.first_chunk_output_str = first_chunk_output_str
Expand All @@ -274,13 +286,19 @@ async def __call__( # pylint: disable=too-many-branches,too-many-statements
server_metrics=None,
exec_feature=request_record.metrics.exec_feature,
)
request_record.error_msg = error_msg
return request_record

finish_time = time.monotonic()
request_record.output_str = generated_text
request_record.first_chunk_output_str = first_chunk_output_str
success = True
error_msg = None
if len(generated_text) == 0:
success = False
error_msg = "Empty generated text."
request_record.metrics = Metrics(
success=len(generated_text) > 0,
success=success,
start_time=start_time,
finish_time=finish_time,
end_to_end_latency_s=finish_time - start_time,
Expand All @@ -289,6 +307,7 @@ async def __call__( # pylint: disable=too-many-branches,too-many-statements
server_metrics=None,
exec_feature=request_record.metrics.exec_feature,
)
request_record.error_msg = error_msg
return request_record


Expand Down Expand Up @@ -316,7 +335,7 @@ async def __aenter__(self) -> Self:
async def __aexit__(self, exc_type, exc_value, tb) -> None:
await self.client.close()

async def __call__( # pylint: disable=too-many-branches
async def __call__( # pylint: disable=too-many-branches,too-many-locals,too-many-statements
self, request_record: RequestRecord
) -> RequestRecord:
assert len(request_record.chat_cmpl.messages) == 1
Expand All @@ -337,7 +356,7 @@ async def __call__( # pylint: disable=too-many-branches
request_record.chat_cmpl.debug_config is not None
and request_record.chat_cmpl.debug_config.ignore_eos
):
payload["ignore_eos"] = True
payload["min_length"] = payload["max_tokens"]
if self.timeout is not None and "timeout" not in payload:
payload["timeout"] = self.timeout

Expand All @@ -349,7 +368,7 @@ async def __call__( # pylint: disable=too-many-branches

try:
async with self.client.post(url, json=payload) as response:
assert response.status == 200, response.reason
assert response.status == 200, await response.text()
if payload["stream"]:
async for chunk in response.content:
chunk = chunk.strip()
Expand All @@ -370,7 +389,8 @@ async def __call__( # pylint: disable=too-many-branches
data = await response.json()
generated_text = data["text_output"]
except Exception: # pylint: disable=broad-except
logger.info("Error sending request: %s", traceback.format_exc())
error_msg = "API endpoint errored when sending request: " + traceback.format_exc()
logger.info(error_msg)
finish_time = time.monotonic()
request_record.output_str = generated_text
request_record.first_chunk_output_str = first_chunk_output_str
Expand All @@ -383,20 +403,27 @@ async def __call__( # pylint: disable=too-many-branches
time_to_first_token_s=time_to_first_token_s,
exec_feature=request_record.metrics.exec_feature,
)
request_record.error_msg = error_msg
return request_record

finish_time = time.monotonic()
request_record.output_str = generated_text
request_record.first_chunk_output_str = first_chunk_output_str
success = True
error_msg = None
if len(generated_text) == 0:
success = False
error_msg = "Empty generated text."
request_record.metrics = Metrics(
success=len(generated_text) > 0,
success=success,
start_time=start_time,
finish_time=finish_time,
end_to_end_latency_s=finish_time - start_time,
input_tokens=request_record.metrics.input_tokens,
time_to_first_token_s=time_to_first_token_s,
exec_feature=request_record.metrics.exec_feature,
)
request_record.error_msg = error_msg
return request_record


Expand Down
16 changes: 13 additions & 3 deletions python/mlc_llm/bench/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __init__(self, dataset_path: str, tokenizer: AutoTokenizer) -> None:
_dataset = [
(data["conversations"][0]["value"], data["conversations"][1]["value"])
for data in raw_dataset
if len(data["conversations"]) >= 2
if len(data["conversations"]) >= 2 and data["conversations"][0]["from"] == "human"
]
# Tokenize the prompts and completions.
self.tokenizer = tokenizer
Expand All @@ -56,16 +56,21 @@ def __init__(self, dataset_path: str, tokenizer: AutoTokenizer) -> None:
prompts,
truncation=True,
max_length=min(tokenizer.model_max_length, self.truncate_length),
add_special_tokens=False,
).input_ids
)
completions = [completion for _, completion in _dataset]
completion_token_ids = tokenizer(
completions,
truncation=True,
max_length=min(tokenizer.model_max_length, self.truncate_length),
add_special_tokens=False,
).input_ids
self._tokenized_dataset: List[Tuple[str, List[int], int]] = []
for i in range(len(_dataset)):
if len(prompt_token_ids[i]) < 4:
# Filter out sequences that are too short
continue
self._tokenized_dataset.append(
(prompts[i], prompt_token_ids[i], len(completion_token_ids[i]))
)
Expand Down Expand Up @@ -140,6 +145,7 @@ def __init__(self, dataset_path: str, num_requests: int, tokenizer: AutoTokenize
untokenized_data,
truncation=True,
max_length=min(tokenizer.model_max_length, self.truncate_length),
add_special_tokens=False,
).input_ids
tokenized_data_lengths = [len(tokens) for tokens in tokenized_data]
self.dataset: List[Tuple[str, List[int], int]] = list(
Expand Down Expand Up @@ -169,7 +175,9 @@ def generate_request_records( # pylint: disable=too-many-arguments,too-many-loc
"Don't generate eos tokens:\n\n"
)

remaining_token_length = input_length - len(self.tokenizer.encode(prompt))
remaining_token_length = input_length - len(
self.tokenizer.encode(prompt, add_special_tokens=False)
)

random.shuffle(self.dataset)

Expand Down Expand Up @@ -219,7 +227,9 @@ def __init__(self, tokenizer: AutoTokenizer) -> None:
}
num_tokens = 0
for message in messages:
num_tokens += len(self.tokenizer.encode(message["content"]))
num_tokens += len(
self.tokenizer.encode(message["content"], add_special_tokens=False)
)
self.dataset.append((messages, schema, num_tokens))

def generate_request_records(
Expand Down
Loading

0 comments on commit 42aaa7f

Please sign in to comment.