Skip to content

Commit

Permalink
Adding vllm backend for inference (#7)
Browse files Browse the repository at this point in the history
  • Loading branch information
peteriz authored Aug 19, 2024
1 parent 7a37d1a commit 22d61e5
Show file tree
Hide file tree
Showing 7 changed files with 137 additions and 0 deletions.
17 changes: 17 additions & 0 deletions configs/inference-vllm.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
model:
_target_: ragfoundry.models.vllm.VLLMInference
model_name_or_path: "facebook/opt-125m"
llm_params:
dtype: auto
generation:
temperature: 0.5
top_p: 0.95
seed: 1911
num_gpus: 1

data_file: my-processed-data.jsnol
generated_file: model-predictions.jsonl
input_key: prompt
generation_key: output
target_key: answers
limit:
32 changes: 32 additions & 0 deletions docs/inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,35 @@ python inference.py -cp configs/paper -cn inference-asqa \
model.lora_path=./path/to/lora/checkpoint
```

## Running Inference with vLLM Backend

To achieve potentially faster inference speeds, you can run inference using the vLLM backend. The functionality of the inference process remains similar to the previously defined process, with the addition of extra arguments that can be used with the vLLM engine.

Here is an example of an inference configuration using the vLLM engine:

```yaml
model:
_target_: ragfoundry.models.vllm.VLLMInference
model_name_or_path: "facebook/opt-125m"
llm_params:
dtype: auto
generation:
temperature: 0.5
top_p: 0.95
seed: 1911
num_gpus: 1
data_file: my-processed-data.jsnol
generated_file: model-predictions.jsonl
input_key: prompt
generation_key: output
target_key: answers
limit:
```

The main differences in this configuration are as follows:

- `ragfoundry.models.vllm.VLLMInference`: This class is used to utilize the vLLM-based engine.
- `llm_params`: These are optional vLLM arguments that can be passed to the LLM class.
- `generation`: These are optional arguments that define the generation policy. The supported arguments are compatible with vLLM's `SamplingParams`.
- `num_gpus`: This specifies the number of GPUs to use during inference.
1 change: 1 addition & 0 deletions docs/reference/models/vllm.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
::: ragfoundry.models.vllm
1 change: 1 addition & 0 deletions docs/reference/utils.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
::: ragfoundry.utils
2 changes: 2 additions & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,10 @@ nav:
- Models:
- Transformers: "reference/models/hf.md"
- OpenAI: "reference/models/openai_executor.md"
- vLLM: "reference/models/vllm.md"
- Evaluation:
- Base: "reference/evaluation/base.md"
- Metrics: "reference/evaluation/metrics.md"
- DeepEval: "reference/evaluation/deep.md"
- Utils: "reference/utils.md"

75 changes: 75 additions & 0 deletions ragfoundry/models/vllm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import logging
from pathlib import Path
from typing import Dict
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer, AutoConfig

from ragfoundry.utils import check_package_installed


logger = logging.getLogger(__name__)


class VLLMInference:
"""
Initializes a vLLM-based inference engine.
Args:
model_name_or_path (str): The name or path of the model.
instruction (Path): path to the instruction file.
instruct_in_prompt (bool): whether to include the instruction in the prompt for models without system role.
template (Path): path to a prompt template file if tokenizer does not include chat template. Optional.
num_gpus (int, optional): The number of GPUs to use. Defaults to 1.
llm_params (Dict, optional): Additional parameters for the LLM model. Supports all parameters define by vLLM LLM engine. Defaults to an empty dictionary.
generation (Dict, optional): Additional parameters for text generation. Supports all the keywords of `SamplingParams` of vLLM. Defaults to an empty dictionary.
"""

def __init__(
self,
model_name_or_path: str,
instruction: Path,
instruct_in_prompt: False,
template: Path = None,
num_gpus: int = 1,
llm_params: Dict = {},
generation: Dict = {},
):
check_package_installed(
"vllm",
"please refer to vLLM website for installation instructions, or run: pip install vllm",
)
logger.info(f"Using the following instruction: {self.instruction}")

self.instruct_in_prompt = instruct_in_prompt
self.template = open(template).read() if template else None
self.instruction = open(instruction).read()

self.sampling_params = SamplingParams(**generation)
self.llm = LLM(model=model_name_or_path, tensor_parallel_size=num_gpus, **llm_params)
if self.instruct_in_prompt:
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
self.config = AutoConfig.from_pretrained(self.model_name)

def generate(self, prompt: str) -> str:
"""
Generates text based on the given prompt.
"""
if self.template:
prompt = self.template.format(instruction=self.instruction, query=prompt)
elif self.instruct_in_prompt:
prompt = self.instruction + "\n" + prompt
messages = [
{"role": "system", "content": self.instruction},
{"role": "user", "content": prompt},
]

prompt = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
truncation=True,
max_length=(self.config.max_position_embeddings - self.sampling_param.max_tokens),
)

output = self.llm.generate(prompt, self.sampling_params)
return output[0].outputs[0].text
9 changes: 9 additions & 0 deletions ragfoundry/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
def check_package_installed(package_name: str, optional_msg: str = ""):
"""
Check if a package is installed.
"""

import importlib.util

if importlib.util.find_spec(package_name) is None:
raise ImportError(f"{package_name} package is not installed; {optional_msg}")

0 comments on commit 22d61e5

Please sign in to comment.