Skip to content

Commit

Permalink
Extend sampler tests (#200)
Browse files Browse the repository at this point in the history
* Add test cases for temperature

* Fix temperature parameter verification

* Extend tests on temperature

* Fix import

* Fix test for logprobs after rebase

* Add tests for repetition penalty

* Refactor test for penalties

* Fix types + update tests for logprobs after rebase

* Fix type of error in tests

* Fix tests for temperature

* Update test for penalties

* Extend tests for top_p, top_k

* Remove irrelevant todo

* Extend test for logit_bias

* Remove underlines to execute with pytest

* Add test for num_sequences

* format + lint

* Add test for inspecting behaviour of logprobs depending on temperature

* Test with mixed greedy and random sampling requests + fix after merge

* Fix get_sampling_state

* Redesign test penalties

* Set top_k as vocab_size when -1 + simplify test for penalties

* Add pytest parametrization

* Update mixed test

* Skip broken test

* Remove debug print

* Corrections according to review comments + add simple test for logit_bias with engine

* Fix indices in logit_bias since it starts from 0

* Update test for penalties

* Add greedy sampling case in test for penalties

* Remove debug lines
  • Loading branch information
Ailurus1 authored Feb 22, 2024
1 parent a377c3c commit d66880c
Show file tree
Hide file tree
Showing 4 changed files with 542 additions and 245 deletions.
18 changes: 13 additions & 5 deletions serve/mlc_serve/engine/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ def __post_init__(self):
self._verify_greedy_sampling()
if not self.logprobs:
self.top_logprobs = 0
if self.top_k == -1:
self.top_k = self.vocab_size

def verify(self) -> None:
if not -2.0 <= self.presence_penalty <= 2.0:
Expand All @@ -99,15 +101,15 @@ def verify(self) -> None:
"frequency_penalty must be in [-2, 2], got "
f"{self.frequency_penalty}."
)
if self.temperature < 0.0:
if not 0.0 <= self.temperature <= 2.0:
raise ValueError(
f"temperature must be non-negative, got {self.temperature}."
f"temperature must be in [0, 2], got {self.temperature}."
)
if not 0.0 < self.top_p <= 1.0:
raise ValueError(f"top_p must be in (0, 1], got {self.top_p}.")

if not isinstance(self.top_k, int):
raise ValueError(f"top_k must be integer.")
raise TypeError(f"top_k must be integer.")

if self.top_k < -1 or self.top_k == 0:
raise ValueError(
Expand All @@ -119,8 +121,10 @@ def verify(self) -> None:
raise ValueError(
f"logit bias must be in [-100, 100], got {bias} for token {token}."
)
if not 1 <= token <= self.vocab_size:
raise ValueError(f"token id must be in [1, vocab_size]")
if not isinstance(token, int):
raise ValueError(f"token id must be an integer")
if not 0 <= token < self.vocab_size:
raise ValueError(f"token id must be in [0, vocab_size)")

if self.repetition_penalty <= 0:
raise ValueError(
Expand All @@ -132,6 +136,10 @@ def verify(self) -> None:
raise ValueError(
f"top_logprobs must be between 0 and {LOGPROB_TOP_K_MAX}, got {self.top_logprobs}."
)
if not isinstance(self.top_logprobs, int):
raise TypeError(
"top_logprobs must be integer"
)

def _verify_greedy_sampling(self) -> None:
if self.top_p < 1.0 - _SAMPLING_EPS:
Expand Down
6 changes: 2 additions & 4 deletions serve/mlc_serve/model/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,6 @@ def from_lists(
device="cpu",
pin_memory=True,
)
# Convert 1-based index to 0-based
logit_bias_indices -= 1
logit_bias_values = torch.tensor(
list_logit_bias_values,
dtype=dtype,
Expand Down Expand Up @@ -546,8 +544,8 @@ def _is_safe_to_sample(prob_like):
assert sampling_state.sampling_params[batch_idx].logprobs
top_k = sampling_state.sampling_params[batch_idx].top_logprobs
logprob_infos[batch_idx] = RawLogprobsInfo(
current_token_id=next_token,
current_logprob=logprobs[batch_idx][next_token],
current_token_id=int(next_token),
current_logprob=float(logprobs[batch_idx][next_token]),
top_token_ids=top_tokens[idx][:top_k],
top_logprobs=top_logprobs[idx][:top_k],
)
Expand Down
129 changes: 128 additions & 1 deletion serve/tests/unittest/test_engine_with_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from mlc_serve.utils import get_default_mlc_serve_argparser, postproc_mlc_serve_args, create_mlc_engine
import random
from pydantic import BaseModel
from typing import List
from typing import List, Callable


def create_request(
Expand All @@ -22,6 +22,7 @@ def create_request(
pre_pen,
max_tokens,
stop,
num_sequences=1,
ignore_eos=False,
top_logprobs=0,
logprobs=False,
Expand All @@ -41,6 +42,7 @@ def create_request(
json_schema=json_schema,
),
stopping_criteria=StoppingCriteria(max_tokens=max_tokens, stop_sequences=stop),
num_sequences=num_sequences,
debug_options=DebugOptions(ignore_eos=ignore_eos),
)

Expand Down Expand Up @@ -210,6 +212,45 @@ def _test_stop(
)
assert found == 1, f"{gen_txt!r}, matches: {found}"

def _test_logit_bias(
engine,
num_requests=10
):
prompt = "Repeat only one of the following words: hi, hello"
requests = []
for n in range(num_requests):
requests.append(
create_request(
idx=str(n),
prompt=prompt,
temp=0.8,
freq_pen=0,
pre_pen=0,
max_tokens=10,
stop="\n",
logit_bias={
engine.tokenizer.encode("hi")[0]: -100.0,
engine.tokenizer.encode("Hi")[0]: -100.0
}
)
)

engine.add(requests)
generated = ["" for _ in range(num_requests)]

while engine.has_pending_requests():
results = engine.step()
for res in results.outputs:
assert len(res.sequences) == 1
seq = res.sequences[0]
req_id = int(res.request_id)

if seq.delta:
generated[int(res.request_id)] += seq.delta

if seq.is_finished:
gen_txt = generated[req_id]
assert "hi" not in gen_txt and "Hi" not in gen_txt

def _test_logprobs(
engine,
Expand Down Expand Up @@ -257,6 +298,46 @@ def _test_logprobs(
)
generated[int(res.request_id)] += seq.delta

# If temperature is increasing then difference between
# boundaries of range of top logprobs in response must decrease
temperatures = [0.2, 1.1, 2.0]
mean_bounds_diff = [0 for _ in range(num_requests * len(temperatures))]
for idx, temp in enumerate(temperatures):
requests = [
create_request(
idx=str(n),
prompt=random.choice(prompts),
temp=temp,
freq_pen=0,
pre_pen=0,
max_tokens=300,
stop=None,
ignore_eos=True,
logprobs=True,
top_logprobs=5
)
for n in range(num_requests)
]
engine.add(requests)

while engine.has_pending_requests():
results = engine.step()
for res in results.outputs:
seq = res.sequences[0]
req = requests[int(res.request_id)]

if not seq.is_finished:
mean_bounds_diff[idx * num_requests + int(res.request_id)] += \
seq.logprob_info[0].top_logprobs[0].logprob \
- seq.logprob_info[0].top_logprobs[4].logprob
else:
mean_bounds_diff[idx * num_requests + int(res.request_id)] /= seq.num_generated_tokens

for num_req_batch in range(num_requests):
for idx in range(1, len(temperatures)):
assert mean_bounds_diff[idx * num_requests + num_req_batch] < \
mean_bounds_diff[(idx - 1) * num_requests + num_req_batch]


def _test_logprobs_mixed_requests(
engine,
Expand Down Expand Up @@ -301,6 +382,48 @@ def _test_logprobs_mixed_requests(
assert len(seq.logprob_info) == 0
generated[int(res.request_id)] += seq.delta

def _test_num_sequences(
engine,
num_requests=5,
):
prompt = "Write a merge sort program in Python."
requests = []
num_sequences = [2 * i for i in range(1, num_requests + 1)]
for n, num_seq in enumerate(num_sequences):
requests.append(
create_request(
idx=str(n),
prompt=prompt,
temp=0.6,
freq_pen=0,
pre_pen=0,
stop=None,
max_tokens=300,
ignore_eos=False,
num_sequences=num_seq
)
)
engine.add(requests)

generated = [[""] * num_seq for _, num_seq in zip(range(num_requests), num_sequences)]
unique_sequences = [set() for _ in range(num_requests)]
while engine.has_pending_requests():
results = engine.step()
for idx, res in enumerate(results.outputs):
assert len(res.sequences) == num_sequences[idx]
for seq_id, seq in enumerate(res.sequences):
req_id = int(res.request_id)

if seq.delta:
generated[int(req_id)][seq_id] += seq.delta

if seq.is_finished:
unique_sequences[req_id].add(generated[req_id][seq_id])

for idx, response in enumerate(unique_sequences):
assert num_sequences[idx] == len(response)



# These three models are used in _test_json_mode
class France(BaseModel):
Expand Down Expand Up @@ -407,6 +530,8 @@ def _test_json_mode(
# _test_stop(staging_engine)
_test_logprobs(staging_engine)
_test_logprobs_mixed_requests(staging_engine)
_test_num_sequences(staging_engine)
_test_logit_bias(staging_engine)
_test_json_mode(staging_engine)
# These tests are broken since we are now imposing no length limit
# if max_tokens = None. The tests do not finish in a reasonable time.
Expand All @@ -422,6 +547,8 @@ def _test_json_mode(
_test_stop(sync_engine)
_test_logprobs(sync_engine)
_test_logprobs_mixed_requests(sync_engine)
_test_num_sequences(sync_engine)
_test_logit_bias(sync_engine)
_test_json_mode(sync_engine)
# These tests are broken since we are now imposing no length limit
# if max_tokens = None. The tests do not finish in a reasonable time.
Expand Down
Loading

0 comments on commit d66880c

Please sign in to comment.