Skip to content

Commit

Permalink
[Bench] Support applying chat template (#2961)
Browse files Browse the repository at this point in the history
This PR supports applying tokenizer chat template in benchmarking. It
needs to be manually specified via flag `--apply-chat-template`.

In the following cases `--apply-chat-template` is not supported:

* when `--input-len` is also specified, which means the input text
needs truncation.
* when the dataset is not ShareGPT.
* when the tokenizer does not have `chat_template` defined.
  • Loading branch information
MasterJH5574 authored Oct 4, 2024
1 parent 42aaa7f commit 7c1f3c3
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 7 deletions.
7 changes: 7 additions & 0 deletions python/mlc_llm/bench/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,13 @@ def _main():
action="store_true",
help='Whether to set the "ignore_eos" field.',
)
parser.add_argument(
"--apply-chat-template",
default=False,
action="store_true",
help="Whether to apply chat template to the request input text. "
'It is not supported when "--input-len" is specified.',
)
parser.add_argument(
"--num-process-workers",
type=int,
Expand Down
32 changes: 30 additions & 2 deletions python/mlc_llm/bench/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,12 @@ class ShareGPTDataset(Dataset): # pylint: disable=too-few-public-methods
"""The dataset class for ShareGPT dataset."""

_tokenized_dataset: List[Tuple[str, List[int], int]]
apply_chat_template: bool

def __init__(self, dataset_path: str, tokenizer: AutoTokenizer) -> None:
def __init__(
self, dataset_path: str, tokenizer: AutoTokenizer, apply_chat_template: bool
) -> None:
self.apply_chat_template = apply_chat_template
with open(dataset_path, encoding="utf-8") as f:
raw_dataset = json.load(f)
# Filter out the conversations with less than 2 turns.
Expand All @@ -51,6 +55,19 @@ def __init__(self, dataset_path: str, tokenizer: AutoTokenizer) -> None:
# Tokenize the prompts and completions.
self.tokenizer = tokenizer
prompts = [prompt for prompt, _ in _dataset]
if apply_chat_template:
assert (
getattr(tokenizer, "chat_template", None) is not None
), '"--apply-chat-template" is set but the tokenizer does not have chat template.'
prompts = [
tokenizer.apply_chat_template(
[{"role": "user", "content": prompt}],
add_generation_prompt=True,
tokenize=False,
)
for prompt in prompts
]

prompt_token_ids = list(
tokenizer(
prompts,
Expand Down Expand Up @@ -82,6 +99,11 @@ def generate_request_records(
input_len_std: float = 0.0,
output_len_std: float = 0.0,
) -> List[RequestRecord]:
if self.apply_chat_template:
assert (
input_len is None
), '"--apply-chat-template" is not supported when "--input-len" is specified.'

request_records = []
for prompt, input_token_ids, output_length in self._tokenized_dataset:
input_length = len(input_token_ids)
Expand Down Expand Up @@ -479,9 +501,15 @@ def create_dataset(args: argparse.Namespace, tokenizer: AutoTokenizer) -> "Datas
'Please specify the dataset kind via "--dataset".'
)
if args.dataset == "sharegpt":
return ShareGPTDataset(args.dataset_path, tokenizer)
return ShareGPTDataset(args.dataset_path, tokenizer, args.apply_chat_template)
if args.dataset == "llmperf":
assert (
args.apply_chat_template is False
), "LLMPerf dataset does not support applying chat template"
return LLMPerfDataset(args.dataset_path, args.num_requests * 4, tokenizer)
if args.dataset == "json-mode-eval":
assert (
args.apply_chat_template is False
), "JSON mode evaluation does not support applying chat template"
return JSONModeEvalDataset(tokenizer)
raise ValueError(f"Unrecognized dataset {args.dataset}")
14 changes: 9 additions & 5 deletions python/mlc_llm/bench/request_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,8 @@ def _print(report: Dict[str, Any], server_metrics: bool): # pylint: disable=too

input_tokens = report["input_tokens"]
print(" Input Tokens ".center(50, "-"))
print(f"{'Mean:':<40} {input_tokens['mean']:<1}")
print(f"{'Stddev:':<40} {input_tokens['stddev']:<1}")
print(f"{'P25:':<40} {input_tokens['quantiles']['p25']:<1}")
print(f"{'P50:':<40} {input_tokens['quantiles']['p50']:<1}")
print(f"{'P95:':<40} {input_tokens['quantiles']['p95']:<1}")
Expand All @@ -244,11 +246,13 @@ def _print(report: Dict[str, Any], server_metrics: bool): # pylint: disable=too

output_tokens = report["output_tokens"]
print(" Output Tokens ".center(50, "-"))
print(f"{'P25:':<40} {output_tokens['quantiles']['p25']:<10}")
print(f"{'P50:':<40} {output_tokens['quantiles']['p50']:<10}")
print(f"{'P95:':<40} {output_tokens['quantiles']['p95']:<10}")
print(f"{'Min:':<40} {output_tokens['min']:<10}")
print(f"{'Max:':<40} {output_tokens['max']:<10}")
print(f"{'Mean:':<40} {output_tokens['mean']:<1}")
print(f"{'Stddev:':<40} {output_tokens['stddev']:<1}")
print(f"{'P25:':<40} {output_tokens['quantiles']['p25']:<1}")
print(f"{'P50:':<40} {output_tokens['quantiles']['p50']:<1}")
print(f"{'P95:':<40} {output_tokens['quantiles']['p95']:<1}")
print(f"{'Min:':<40} {output_tokens['min']:<1}")
print(f"{'Max:':<40} {output_tokens['max']:<1}")

print("=" * 50)

Expand Down

0 comments on commit 7c1f3c3

Please sign in to comment.