Skip to content

Commit

Permalink
Import after verification, pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
danielfleischer committed Aug 22, 2024
1 parent 3de7693 commit d9b9c67
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 12 deletions.
3 changes: 1 addition & 2 deletions evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,8 @@ def setup_wandb(args: dict):
"""
WANDB integration for tracking evaluations.
"""
from wandb.wandb_run import Run

import wandb
from wandb.wandb_run import Run

env = {key: os.getenv(key) for key in os.environ}
run: Run = wandb.init(
Expand Down
15 changes: 10 additions & 5 deletions ragfoundry/models/vllm.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import logging
from pathlib import Path
from typing import Dict
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer, AutoConfig

from ragfoundry.utils import check_package_installed
from transformers import AutoConfig, AutoTokenizer

from ragfoundry.utils import check_package_installed

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -38,14 +37,18 @@ def __init__(
"vllm",
"please refer to vLLM website for installation instructions, or run: pip install vllm",
)
from vllm import LLM, SamplingParams

logger.info(f"Using the following instruction: {self.instruction}")

self.instruct_in_prompt = instruct_in_prompt
self.template = open(template).read() if template else None
self.instruction = open(instruction).read()

self.sampling_params = SamplingParams(**generation)
self.llm = LLM(model=model_name_or_path, tensor_parallel_size=num_gpus, **llm_params)
self.llm = LLM(
model=model_name_or_path, tensor_parallel_size=num_gpus, **llm_params
)
if self.instruct_in_prompt:
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
self.config = AutoConfig.from_pretrained(self.model_name)
Expand All @@ -68,7 +71,9 @@ def generate(self, prompt: str) -> str:
tokenize=False,
add_generation_prompt=True,
truncation=True,
max_length=(self.config.max_position_embeddings - self.sampling_param.max_tokens),
max_length=(
self.config.max_position_embeddings - self.sampling_param.max_tokens
),
)

output = self.llm.generate(prompt, self.sampling_params)
Expand Down
6 changes: 3 additions & 3 deletions ragfoundry/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ def check_package_installed(package_name: str, optional_msg: str = ""):
"""
Check if a package is installed.
"""

import importlib.util

if importlib.util.find_spec(package_name) is None:
raise ImportError(f"{package_name} package is not installed; {optional_msg}")
raise ImportError(f"{package_name} package is not installed; {optional_msg}")
3 changes: 1 addition & 2 deletions training.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,13 @@
from pathlib import Path

import hydra
import wandb
from datasets import load_dataset
from hydra.utils import to_absolute_path
from omegaconf import OmegaConf
from transformers import TrainingArguments
from trl import DataCollatorForCompletionOnlyLM, SFTTrainer

import wandb

logger = logging.getLogger(__name__)


Expand Down

0 comments on commit d9b9c67

Please sign in to comment.