Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding chat completion task to endpoint models #281

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ nanotron = [
]
tensorboardX = ["tensorboardX"]
quality = ["ruff==v0.2.2","pre-commit"]
tests = ["pytest==7.4.0"]
tests = ["pytest==7.4.0", "docker"]
dev = ["lighteval[accelerate,quality,tests]"]
extended_tasks = [
"langdetect", # ifeval
Expand Down
56 changes: 38 additions & 18 deletions src/lighteval/models/abstract_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
from typing import Optional, Union

import torch
from transformers import AutoTokenizer, BatchEncoding
from huggingface_hub import ChatCompletionInputMessage
from transformers import BatchEncoding, PreTrainedTokenizerBase

from lighteval.models.model_output import (
GenerativeMultiturnResponse,
Expand All @@ -34,14 +35,15 @@
LoglikelihoodSingleTokenResponse,
)
from lighteval.tasks.requests import (
Conversation,
GreedyUntilMultiTurnRequest,
GreedyUntilRequest,
LoglikelihoodRequest,
LoglikelihoodRollingRequest,
LoglikelihoodSingleTokenRequest,
RequestType,
)
from lighteval.utils.utils import EnvConfig
from lighteval.utils.utils import EnvConfig, as_list


TokenSequence = Union[list[int], torch.LongTensor, torch.Tensor, BatchEncoding]
Expand Down Expand Up @@ -74,7 +76,7 @@ def cleanup(self):

@property
@abstractmethod
def tokenizer(self) -> AutoTokenizer:
def tokenizer(self) -> PreTrainedTokenizerBase:
raise NotImplementedError

@property
Expand Down Expand Up @@ -156,24 +158,42 @@ def loglikelihood_single_token(
return NotImplemented

# Tokenization utils
def tok_encode(self, str_to_encode: str | list[str], add_special_tokens: Optional[bool] = None) -> TokenSequence:
def tok_encode(
self,
input: str | list[str] | ChatCompletionInputMessage | Conversation | list[Conversation],
add_special_tokens: Optional[bool] = None,
) -> TokenSequence:
if add_special_tokens is None:
add_special_tokens = self.add_special_tokens
if isinstance(str_to_encode, str):
return self.tokenizer.encode(str_to_encode, add_special_tokens=add_special_tokens)
return self.tokenizer(
str_to_encode,
padding=True,
add_special_tokens=add_special_tokens,
return_tensors="pt",
)

def tok_encode_pair(self, context, continuation):
if isinstance(input, str):
return self.tokenizer.encode(input, add_special_tokens=add_special_tokens)
elif isinstance(input, ChatCompletionInputMessage) or isinstance(input[0], ChatCompletionInputMessage):
return self.tokenizer.apply_chat_template(as_list(input), add_special_tokens=add_special_tokens)
elif isinstance(input, list) and isinstance(input[0], str):
return self.tokenizer(
input,
padding=True,
add_special_tokens=add_special_tokens,
return_tensors="pt",
)
else:
return self.tokenizer.apply_chat_template(
input,
add_special_tokens=add_special_tokens,
padding=True,
return_tensors="pt",
return_dict=True,
)

def tok_encode_pair(self, context: str | Conversation, continuation: str | ChatCompletionInputMessage):
"""Encodes a context, continuation pair by taking care of the spaces in between."""
n_spaces = len(context) - len(context.rstrip())
if n_spaces > 0:
continuation = context[-n_spaces:] + continuation
context = context[:-n_spaces]
if isinstance(context, str):
n_spaces = len(context) - len(context.rstrip())
if n_spaces > 0:
continuation = context[-n_spaces:] + continuation
context = context[:-n_spaces]
else:
continuation = [continuation]
whole_enc = self.tok_encode(context + continuation)
context_enc = self.tok_encode(context)
context_enc_len = len(context_enc)
Expand Down
36 changes: 22 additions & 14 deletions src/lighteval/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,16 +544,27 @@ def greedy_until(

context = [c.context for c in batch]

# See doc https://huggingface.co/docs/transformers/v4.38.2/en/pad_truncation#padding-and-truncation
# Will do left truncation and padding, as defined when creating the tokenizer
tokenized = self.tokenizer(
context,
truncation="longest_first", # we truncate to the model max length if needed
padding="longest", # we pad to the longest sequence
return_tensors="pt",
max_length=self.max_length - 1, # we always allow minimum one token of generation
add_special_tokens=self.add_special_tokens,
).to(self.device)
if self.use_chat_template:
tokenized = self.tokenizer.apply_chat_template(
context,
truncation="longest_first",
padding="longest",
return_tensors="pt",
max_length=self.max_length - 1,
add_special_tokens=self.add_special_tokens,
return_dict=True,
).to(self.device)
else:
# See doc https://huggingface.co/docs/transformers/v4.38.2/en/pad_truncation#padding-and-truncation
# Will do left truncation and padding, as defined when creating the tokenizer
tokenized = self.tokenizer(
context,
truncation="longest_first", # we truncate to the model max length if needed
padding="longest", # we pad to the longest sequence
return_tensors="pt",
max_length=self.max_length - 1, # we always allow minimum one token of generation
add_special_tokens=self.add_special_tokens,
).to(self.device)

# The main question for this step is the following:
# Would we rather truncate the prompt to allow generation to go to max_new_tokens, at the risk
Expand All @@ -579,10 +590,7 @@ def greedy_until(
input_ids=tokenized["input_ids"],
input_lengths=[len(item == 1) for item in tokenized["attention_mask"]],
input_mask=tokenized["attention_mask"],
truncated=[
len(c) - tokenized["input_ids"].shape[1] if len(c) > tokenized["input_ids"].shape[1] else 0
for c in context
],
truncated=[max(len(c.tokenized_context) - tokenized["input_ids"].shape[1], 0) for c in batch],
padded=[sum(mask == 0) for mask in tokenized["attention_mask"]],
)

Expand Down
Loading