diff --git a/model_service/model_service.py b/model_service/model_service.py index 36997bc..d466e27 100644 --- a/model_service/model_service.py +++ b/model_service/model_service.py @@ -27,6 +27,7 @@ class ModelService(): def __init__(self, model_instance_id: str, model_type: str, model_variant: str, model_path: str, gateway_host: str, gateway_port: int, master_host: str, master_port: int) -> None: """ Args: + B model_instance_id (str): Unique identifier for model instance model_type (str): Type of model to load model_variant (str): Variant of model to load @@ -59,7 +60,7 @@ def run(self): if model.rank == 0: logger.info(f"Starting model service for {self.model_type} on rank {model.rank}") - #Placeholder static triton config for now + # Placeholder static triton config for now triton_config = TritonConfig(http_address="0.0.0.0", http_port=self.master_port, log_verbose=4) triton_workspace = Path("/tmp") / Path("pytriton") / Path("".join(random.choices(string.ascii_uppercase + string.ascii_lowercase + string.digits, k=16))) with Triton(config=triton_config, workspace=triton_workspace) as triton: diff --git a/model_service/models/falcon/accelerate_config.yaml b/model_service/models/falcon/accelerate_config.yaml new file mode 100644 index 0000000..363b5ae --- /dev/null +++ b/model_service/models/falcon/accelerate_config.yaml @@ -0,0 +1,21 @@ +compute_environment: LOCAL_MACHINE +deepspeed_config: + gradient_accumulation_steps: 1 + offload_optimizer_device: cpu + offload_param_device: cpu + zero3_init_flag: true + zero3_save_16bit_model: true + zero_stage: 3 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +machine_rank: 0 +main_training_function: main +mixed_precision: 'no' +num_machines: 1 +num_processes: 2 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/model_service/models/falcon/launch_falcon-40b.slurm b/model_service/models/falcon/launch_falcon-40b.slurm index 8f622d6..811eeac 100644 --- a/model_service/models/falcon/launch_falcon-40b.slurm +++ b/model_service/models/falcon/launch_falcon-40b.slurm @@ -1,14 +1,22 @@ #!/bin/bash #SBATCH --mail-type=END,FAIL -#SBATCH --mem=167G +#SBATCH --mem=0 #SBATCH --partition=a40 #SBATCH --qos=llm #SBATCH --nodes=1 -#SBATCH --gpus-per-node=2 +#SBATCH --ntasks=1 +#SBATCH --gpus-per-node=4 #SBATCH --cpus-per-task=32 #SBATCH --output=falcon-40b_service.%j.out #SBATCH --error=falcon-40b_service.%j.err +echo "SLURM_JOB_NODELIST"=$SLURM_JOB_NODELIST +echo "SLURM_JOB_PARTITION"=$SLURM_JOB_PARTITION +echo "SLURM_NNODES"=$SLURM_NNODES +echo "SLURM_GPUS_ON_NODE"=$SLURM_GPUS_ON_NODE +echo "SLURM_CPUS_ON_NODE"=$SLURM_CPUS_ON_NODE +echo "SLURM_PROCID"=$SLURM_PROCID + model_service_dir=$1 gateway_host=$2 gateway_port=$3 @@ -18,14 +26,27 @@ model_path="/model_checkpoint" source /opt/lmod/lmod/init/profile module load singularity-ce/3.8.2 -export MASTER_ADDR=$(hostname -I | awk '{print $1}') -export CUDA_VISIBLE_DEVICES=0,1 #,2,3 # Add 2 and 3 if changed to 4 gpus -# export NCCL_IB_DISABLE=1 + +MASTER_ADDR=$(hostname -I | awk '{print $1}') +MASTER_PORT=8855 + +export CUDA_VISIBLE_DEVICES=0,1,2,3 + +GPUS_PER_NODE=4 +NNODES=$SLURM_NNODES +NUM_PROCESSES=$(expr $NNODES \* $GPUS_PER_NODE) +echo "NNODES"=$NNODES +echo "NUM_PROCESSES"=$NUM_PROCESSES # Send registration request to gateway curl -X POST -H "Content-Type: application/json" -d '{"host": "'"$MASTER_ADDR"':50112"}' http://$gateway_host:$gateway_port/models/instances/$SLURM_JOB_NAME/register echo $MASTER_ADDR -singularity exec --nv --bind /checkpoint,/scratch,/ssd003,/ssd005,$model_chkp_dir:$model_path /ssd005/projects/llm/falcon-hf.sif \ -/usr/bin/python3 -s \ - $model_service_dir/model_service.py --model_type falcon --model_variant 40b --model_path $model_path --model_instance_id $SLURM_JOB_NAME --gateway_host $gateway_host --gateway_port $gateway_port --master_host $MASTER_ADDR --master_port 50112 +# Use --containall flag to avoid using host python env +singularity exec --containall --nv --bind /checkpoint,/scratch,/ssd003,/ssd005,$model_chkp_dir:$model_path /ssd005/projects/llm/falcon-hf-ds.sif \ + accelerate launch --config_file $model_service_dir/models/falcon/accelerate_config.yaml \ + $model_service_dir/model_service.py \ + --model_type falcon --model_variant 40b --model_path $model_path \ + --model_instance_id $SLURM_JOB_NAME \ + --gateway_host $gateway_host --gateway_port $gateway_port \ + --master_host $MASTER_ADDR --master_port 50112 diff --git a/model_service/models/falcon/launch_falcon-7b.slurm b/model_service/models/falcon/launch_falcon-7b.slurm index eedf312..4ceea20 100644 --- a/model_service/models/falcon/launch_falcon-7b.slurm +++ b/model_service/models/falcon/launch_falcon-7b.slurm @@ -1,15 +1,22 @@ #!/bin/bash #SBATCH --mail-type=END,FAIL -#SBATCH --mem=128G +#SBATCH --mem=0 #SBATCH --partition=a40 #SBATCH --qos=llm #SBATCH --nodes=1 #SBATCH --ntasks=1 -#SBATCH --gpus-per-node=1 +#SBATCH --gpus-per-node=2 #SBATCH --cpus-per-task=8 #SBATCH --output=falcon-7b_service.%j.out #SBATCH --error=falcon-7b_service.%j.err +echo "SLURM_JOB_NODELIST"=$SLURM_JOB_NODELIST +echo "SLURM_JOB_PARTITION"=$SLURM_JOB_PARTITION +echo "SLURM_NNODES"=$SLURM_NNODES +echo "SLURM_GPUS_ON_NODE"=$SLURM_GPUS_ON_NODE +echo "SLURM_CPUS_ON_NODE"=$SLURM_CPUS_ON_NODE +echo "SLURM_PROCID"=$SLURM_PROCID + model_service_dir=$1 gateway_host=$2 gateway_port=$3 @@ -19,13 +26,27 @@ model_path="/model_checkpoint" source /opt/lmod/lmod/init/profile module load singularity-ce/3.8.2 -export MASTER_ADDR=$(hostname -I | awk '{print $1}') -export PYTHONPATH=/usr/bin/python3 + +MASTER_ADDR=$(hostname -I | awk '{print $1}') +MASTER_PORT=8855 + +export CUDA_VISIBLE_DEVICES=0,1 + +GPUS_PER_NODE=2 +NNODES=$SLURM_NNODES +NUM_PROCESSES=$(expr $NNODES \* $GPUS_PER_NODE) +echo "NNODES"=$NNODES +echo "NUM_PROCESSES"=$NUM_PROCESSES # Send registration request to gateway curl -X POST -H "Content-Type: application/json" -d '{"host": "'"$MASTER_ADDR"':50116"}' http://$gateway_host:$gateway_port/models/instances/$SLURM_JOB_NAME/register echo $MASTER_ADDR -singularity exec --nv --bind /checkpoint,/scratch,/ssd003,/ssd005,$model_chkp_dir:$model_path /ssd005/projects/llm/falcon-hf.sif \ - /usr/bin/python3 -s \ - $model_service_dir/model_service.py --model_type falcon --model_variant 7b --model_path $model_path --model_instance_id $SLURM_JOB_NAME --gateway_host $gateway_host --gateway_port $gateway_port --master_host $MASTER_ADDR --master_port 50116 +# Use --containall flag to avoid using host python env +singularity exec --containall --nv --bind /checkpoint,/ssd003,/ssd005,$model_chkp_dir:$model_path /ssd005/projects/llm/falcon-hf-ds.sif \ + accelerate launch --config_file $model_service_dir/models/falcon/accelerate_config.yaml \ + $model_service_dir/model_service.py \ + --model_type falcon --model_variant 7b --model_path $model_path \ + --model_instance_id $SLURM_JOB_NAME \ + --gateway_host $gateway_host --gateway_port $gateway_port \ + --master_host $MASTER_ADDR --master_port 50116 diff --git a/model_service/models/falcon/model.py b/model_service/models/falcon/model.py index 2690c70..0c4f8e3 100644 --- a/model_service/models/falcon/model.py +++ b/model_service/models/falcon/model.py @@ -10,10 +10,11 @@ from ..abstract_model import AbstractModel from pytriton.decorators import batch -from pytriton.model_config import ModelConfig, Tensor +from pytriton.model_config import ModelConfig, Tensor # DynamicBatcher, QueuePolicy from accelerate import init_empty_weights, load_checkpoint_and_dispatch, infer_auto_device_map from accelerate.utils.modeling import get_balanced_memory from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, GenerationConfig +from accelerate import Accelerator logger = logging.getLogger("kaleidoscope.model_service.falcon") @@ -43,40 +44,38 @@ def __init__(self, model_type, model_variant): def load(self, model_path): - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + world_rank = int(os.getenv("RANK")) + local_rank = int(os.getenv("LOCAL_RANK")) + world_size = int(os.getenv("WORLD_SIZE")) + logger.info(f"Rank: {world_rank}") + logger.info(f"Local rank: {local_rank}") + logger.info(f"World size: {world_size}") + + self.device = torch.device("cuda:{}".format(local_rank) if torch.cuda.is_available() else "cpu") self.load_model_cfg(os.path.join(self.model_cfg_path, "model_config.json")) - if self.model_variant == "40b": - local_rank = int(os.getenv("LOCAL_RANK", "0")) - world_size = torch.cuda.device_count() - logger.info(f"Rank: {local_rank}") - logger.info(f"World size: {world_size}") + # Load model on main device for each node + if local_rank == 0: - logger.debug(f"Torch dtype: {self.model_cfg['torch_dtype']}") config = AutoConfig.from_pretrained( - model_path, trust_remote_code=self.model_cfg["trust_remote_code"], torch_dtype=self.model_cfg["torch_dtype"]) + model_path, trust_remote_code=self.model_cfg["trust_remote_code"], torch_dtype=self.model_cfg["torch_dtype"]) with init_empty_weights(): - model = self.model_class.from_config(config, trust_remote_code=self.model_cfg["trust_remote_code"], torch_dtype=self.model_cfg["torch_dtype"]) + model = self.model_class.from_config(config, trust_remote_code=self.model_cfg["trust_remote_code"], torch_dtype=self.model_cfg["torch_dtype"]) model.tie_weights() - # Configure memory per device and get device map - max_memory = {idx: "40GiB" for idx in range(world_size)} - max_memory.update({"cpu": "120GiB"}) - device_map = infer_auto_device_map(model, max_memory, no_split_module_classes=["MLP", "DecoderLayer"]) - logging.debug(f"Max memory: {max_memory}") - logging.debug(f"Device map: {device_map}") + device_map = "balanced_low_0" self.model = load_checkpoint_and_dispatch( - model, model_path, device_map=device_map, dtype=self.model_cfg["torch_dtype"]) - else: - self.model = self.model_class.from_pretrained(model_path, **self.model_cfg) # TODO: .eval()? - self.model.to(self.device) - + model, model_path, device_map=device_map, dtype=self.model_cfg["torch_dtype"], no_split_module_classes=["MLP", "DecoderLayer"]) + self.tokenizer = self.tokenizer_class.from_pretrained(model_path, **self.tokenizer_cfg) self.tokenizer.pad_token_id = self.tokenizer.eos_token_id self.model_path = model_path + logger.debug(f"Deepspeed?: {os.environ.get('ACCELERATE_USE_DEEPSPEED', False)}") + logger.debug(f"Deepspeed Zero Stage?: {os.environ.get('ACCELERATE_DEEPSPEED_ZERO_STAGE', None)}") + def load_model_cfg(self, cfg_file): """Load model and tokenzer config""" @@ -131,7 +130,7 @@ def bind(self, triton): @property def rank(self): - return 0 + return int(os.getenv("RANK")) @batch @@ -172,9 +171,9 @@ def generate(self, inputs): transition_scores = self.model.compute_transition_scores( outputs.sequences, outputs.scores, normalize_logits=True) generated_ids = outputs.sequences - # remove input tokens + # Remove input tokens generated_ids = generated_ids[:, input_tokens_size:] - # replace token_id 0 with a special token so that it is removed while decoding - EOS + # Replace token_id 0 with a special token so that it is removed while decoding - EOS generated_ids[generated_ids==0] = int(self.tokenizer.eos_token_id) generations = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True) diff --git a/model_service/models/falcon/model_config.json b/model_service/models/falcon/model_config.json index 836fcff..22c7a9d 100644 --- a/model_service/models/falcon/model_config.json +++ b/model_service/models/falcon/model_config.json @@ -3,7 +3,7 @@ "model": { "torch_dtype": "bfloat16", "trust_remote_code": true, - "device_map": "auto" + "device_map": null }, "tokenizer": { "use_fast": false, @@ -15,7 +15,7 @@ "model": { "torch_dtype": "bfloat16", "trust_remote_code": true, - "device_map": "auto" + "device_map": null }, "tokenizer": { "use_fast": false, diff --git a/model_service/triton/falcon_client.py b/model_service/triton/falcon_client.py index 799412f..29a1ce1 100755 --- a/model_service/triton/falcon_client.py +++ b/model_service/triton/falcon_client.py @@ -21,6 +21,9 @@ from pytriton.client import ModelClient +from web.utils.triton import TritonClient + + logger = logging.getLogger("triton.falcon_client") @@ -58,75 +61,30 @@ def main(): log_level = logging.DEBUG if args.verbose else logging.INFO logging.basicConfig(level=log_level, format="%(asctime)s - %(levelname)s - %(name)s: %(message)s") - - # sequence = np.array( - # [ - # ["Show me the meaning of "], - # ["I would love to learn cook the Asian street food"], - # ["Carnival in Rio de Janeiro"], - # ["William Shakespeare was a great writer"], - # ] - # ) - - sequence = np.array( - [ - ["William Shakespeare was a great writer"], - ["William Shakespeare was a great writer"], - ["William Shakespeare was a great writer"], - ["William Shakespeare was a great writer"], - ["William Shakespeare was a great writer"], - ["William Shakespeare was a great writer"], - ["William Shakespeare was a great writer"], - ["William Shakespeare was a great writer"], - ] - ) - - sequence = np.char.encode(sequence, "utf-8") - logger.info(f"Sequence: {sequence}") - - batch_size = sequence.shape[0] - def _param(dtype, value): - if bool(value): - return np.ones((batch_size, 1), dtype=dtype) * value - else: - return np.zeros((batch_size, 1), dtype=dtype) + + num_tokens = 32 - num_tokens = 8 - gen_params = { - "max_tokens": _param(np.int64, num_tokens), - "do_sample": _param(np.bool_, False), - "temperature": _param(np.float64, 0.7), - } + # Using TritonClient + prompts = ["William Shakespeare was a great writer"]*8 + inputs = { + "prompts": prompts, + "max_tokens": num_tokens + } + model_name = "falcon-7b" + batch_size = len(prompts) + host = args.url.lstrip("http://") - model_name = "falcon-40b_generation" - - # logger.info(f"Waiting for response...") - # start_time = time.time() - # with ModelClient(args.url, model_name, init_timeout_s=args.init_timeout_s) as client: - # for req_idx in range(1, args.iterations + 1): - # logger.info(f"Sending request ({req_idx}).") - # result_dict = client.infer_batch( - # prompts=sequence, **gen_params) - # logger.info(f"Result: {result_dict} for request ({req_idx}).") - # time_taken = time.time() - start_time - # logger.info(f"Total time taken: {time_taken:.2f} secs") - # token_per_sec = (num_tokens*batch_size)/time_taken - # logger.info(f"tokens/sec: {token_per_sec:.2f}") + triton_client = TritonClient(host) + start_time = time.time() + generation = triton_client.infer(model_name, inputs, task="generation") + print(generation) + time_taken = time.time() - start_time + # Common logging for both methods + logger.info(f"Total time taken: {time_taken:.2f} secs") + token_per_sec = (num_tokens*batch_size)/time_taken + logger.info(f"tokens/sec: {token_per_sec:.2f}") - # benchmark - n_runs = 5 - run_times = [] - for run_idx in range(n_runs): - start_time = time.time() - with ModelClient(args.url, model_name, init_timeout_s=args.init_timeout_s) as client: - for req_idx in range(1, args.iterations + 1): - logger.info(f"Sending request ({req_idx}).") - result_dict = client.infer_batch( - prompts=sequence, **gen_params) - run_times.append(time.time() - start_time) - mean_token_per_sec = np.mean([(num_tokens*batch_size)/elm for elm in run_times]) - logger.info(f"seq_len: {num_tokens}, tokens/sec: {mean_token_per_sec:.2f}") if __name__ == "__main__": main() diff --git a/model_service/triton/falcon_client.sh b/model_service/triton/falcon_client.sh index c7fa923..c675619 100755 --- a/model_service/triton/falcon_client.sh +++ b/model_service/triton/falcon_client.sh @@ -13,10 +13,10 @@ if [ ! -z "$2" ]; then server_port=$2 fi +export PYTHONPATH=$PYTHONPATH:/h/odige/triton_multi_node/kaleidoscope + source /opt/lmod/lmod/init/profile module load singularity-ce/3.8.2 echo "Sending request to http://$server_host:$server_port" singularity exec --bind /ssd005 /ssd005/projects/llm/falcon-hf.sif /usr/bin/python3 -s ~/triton_multi_node/kaleidoscope/model_service/triton/falcon_client.py --url http://$server_host:$server_port -# singularity exec --bind /scratch /scratch/ssd002/projects/opt_test/triton/pytriton_falcon/pytriton_falcon.sif /usr/bin/python3 -s ~/triton_multi_node/kaleidoscope/model_service/triton/falcon_client.py --url http://$server_host:$server_port - diff --git a/web/utils/triton.py b/web/utils/triton.py index 4fd6040..5cdc9b9 100644 --- a/web/utils/triton.py +++ b/web/utils/triton.py @@ -1,4 +1,3 @@ -from flask import current_app import numpy as np import tritonclient.http as httpclient from tritonclient.utils import np_to_triton_dtype, triton_to_np_dtype @@ -29,7 +28,6 @@ def prepare_prompts_tensor(prompts): triton_dtype = "BYTES" input = _str_list2numpy(value) - # np.array(value, dtype=bytes) tensor = httpclient.InferInput(name, input.shape, triton_dtype) tensor.set_data_from_numpy(input) @@ -59,7 +57,6 @@ def prepare_inputs(inputs, inputs_config): inputs_wrapped = [prepare_prompts_tensor(prompts)] - current_app.logger.info(f"Input args: {inputs}") for input in inputs.items(): try: inputs_wrapped.append(prepare_param_tensor(input, inputs_config, batch_size))