by Grayson Adkins, updated March 21, 2024
This notebook focuses on serving custom language models on a variety of GPUs. I explore quantization considerations, GPU selection, popular inference toolkits TGI and vLLM, and include the python code I use for querying models. I also evaluate performance, particularly tokens per second, and discuss other metrics including latency and throughput.
You want to:
The Python scripts used in this notebook are available in the ai-cookbook repo on my GitHub.
I've also made available free, one-click Runpod templates for both TGI (here) and vLLM (here).
Model size here is the number of parameters of the neural network. This number correlates to the amount of GPU memory needed to load the model, which in turn affects your hosting costs (discussed in the next section).
Bigger ≠ Better - Generally speaking, the bigger the model, the "smarter" it is—but bigger is not always better. That's because big models will also be more expensive to serve to your users and response times are slower. And depending on your task, you may not need all those extra smarts anyway, or low latency may be more important for your use case. For example, routine tasks like intelligent auto-complete or transforming data from one format into another are well suited for a smaller model like Llama 2 7B; whereas tasks involving complex logic, such as extracting data from multiple sources and synthesizing novel observations are better for a model like Llama 2 70B.
Quantization or Full-Precision? - It's possible to get the smarts of a big model in a smaller package (i.e. require less GPU memory) by quantizing the model. With quantization, you can reduce a full-precision model, e.g. in 32 or 16-bit, to a lower precision like 8 or 4-bit, so that the model fits on a single GPU or even a laptop. Keep in mind that by doing this, your model will lose precision and (possibly) performance, but, again, depending on your task, it may not matter or even be noticable at all. See the QLoRA notebook for quantized fine-tuning examples.
# Llama 2 at 16-bit:
7 billion parameters x 32 bits / (8 bits per byte) = 14 GB
vs.
# Llama 2 quantized (4-bit):
7 billion parameters x 4 bits / (8 bits per byte) = 3.5 GB
Inference toolkits give you a way to deploy a "production-ready" inference API, without having to implement a lot of the nitty-gritty details yourself.
TGI vs. vLLM - If you choose to quantize, I currently recommend using either Text Generation Inference (TGI) from Hugging Face or vLLM to configure your model API, depending on which quantization method you want to use. TGI supports EETQ for 8-bit quantization or bitsandbytes for 4-bit, both of which allow you to perform quantization on the fly. vLLM supports AWQ quantization, among others, but your model will need to be pre-quantized before loading it—that's because AWQ is data dependednt, i.e. it uses a dataset to know which model weights are most important, instead of quantizing them all equally. You can find prequantized AWQ models on Hugging Face, particularly from the user TheBloke.
Concurrency - Serving many concurrent users will increase your inference latency and decrease your tokens per second. By quantizing the model, e.g. with 4-bit AWQ, can also be slower than bf16
. If inference speed is important, consider using full bf16
precision, instead of quantization. Similarly, adding speculative decoding (i.e. using the --speculate
option in TGI) takes more memory and, therefore will add latency for many concurrent requests.
The following specs should be considered when selecting a GPU, as they will affect which models are supported, inference speed, and cost.
Memory (VRAM) - Higher GPU memory lets you fit in a bigger model. However, while full-precision Llama 7B technically fits in 14 GB of memory, it will practically need ~15-16 GB to actually perform inference. That's because the GPU must also store the input sequence, sequence history (i.e. kv cache), current layer activations, and other data. The longer your model context is, the more VRAM head room you'll need.
Computational Speed (FLOPS) - For text generation applications like chatbots, higher FLOPS will allow your system to process more tokens per second. This is often the bottleneck in deployed systems. As a rule of thumb, 10-15 tokens per second is considered average reading speed, so that's typically sufficient for streaming text back to your users. Other application, though, may require more FLOPS.
Memory Bandwidth (GB/s) - The speed at which the GPU is able to read from VRAM into the deep computational units of the GPU. This is also a common bottleneck during inference. Higher memory bandwidth will allow your model to support more concurrent requests.
Don't just default to the Big 3 cloud providers (AWS, Azure, Google Cloud). There is tight competition right now among MLOps startups flush with VC dollars. That means low prices for you and me. For example, in my experience, specialized providers like Runpod or Lambda Labs are cheaper than the hyperscale clouds. Still, there are other factors to consider:
GPU Availability - The current workhorses of deep learning are NVIDIA H100, A100, A6000, and A4000 series, with the A100 being the most popular. Beacuse of the popularity, however, their availability in some clouds is often limited. You might also consider using NVIDIA L40S or T4.
Auxiliary Cost - You should also consider auxiliary hosting costs such as data egress, storage, networking, logging, etc. For some clouds, these expenses are included in the serve costs, while others bill them à la carte. These will be a fraction of your total spend, but they can add up.
Pre-Configuration - Many cloud providers are increasingly offering pre-configured environments tailored to model hosting. For example, in this notebook, I use Runpod, which handles the networking and provides a custom domain and TLS certificate to provide an out-of-the-box HTTPS endpoint to query my deployed model.
Uptime - There is high variance among cloud providers in terms of hardware quality, uptime, connectivity, UX, etc. For example, not all will provide uptime gaurantees (known as Service-Level Agreements (SLAs). Persistent downtime is unfortunately common among some providers. Reputation here is key. Many cloud providers offer uptime garauntees for an added cost or with a contract.
In this section, we'll use the open-source inference toolkit TGI to serve our model. TGI implements several advanced features to improve inference speed and make our API production ready. In fact, Hugging Face uses TGI to serve several of their own production APIs.
TGI is a toolkit from Hugging Face that simplifies creating production-ready inference APIs for any public or private models available on Hugging Face or from a local repository. It implements a bunch of features, notably:
This is what we're going to deploy:
Model: Mixtral 8x7b Instruct
Quantization: AWQ
Minimum GPU memory (VRAM): 48 GB
Recommended GPUs: A6000, A100, H100
Inference toolkit: TGI
If you're using Runpod.io, you can use this one-click template to deploy your API in seconds. The sections below discuss configuration details, regardless of which cloud provider you choose.
I'll use Runpod.io to serve my model, but the general steps outlined here will work for any cloud provider.
Mixtral 8x7b requires at least 48 GB RAM, so you'll need an A6000, A100, or H100.
I'll test in four different configurations using the following GPUs:
1x RTX A6000 48GB VRAM ($0.79/hr from Runpod)
1x A100 80GB VRAM ($1.89/hr from Runpod)
1x H100 PCIe 80GB VRAM ($3.89/hr from Runpod)
1x H100 SXM5 80GB VRAM ($4.69/hr from Runpod)
The easiest way to use TGI is via the official Docker image, available on GitHub Container Registry (ghcr). This image bundles the model with all its dependencies and ensure it is compatible with your chosen cloud provider's runtime environment. It takes as an input the model you want to serve, along with any options you want like quantization method, max tokens, speculative decoding etc.
Important
To enable GPU access for your container, you'll need to install thenvidia-container-toolkit
on your server and set the--gpus
flag in yourdocker run
command, as shown below. Some providers (like Runpod), will pre-install this dependency so you don't have to.
Container image: ghcr.io/huggingface/text-generation-inference:1.4
Container disk (temporary): 5 GB
Volume disk (persistent):50 GB
Volume mount path: /workspace
Port: 8080
(omit to use ssh instead)
Environment variables:
HUGGINGFACE_HUB_CACHE=/workspace
HF_HUB_ENABLE_HF_TRANSFER=1
HUGGING_FACE_HUB_TOKEN=<replace with your access token>
(required for private models)
Container start command:
docker run --gpus all --shm-size 1g --trust-remote-code \
--port 8080 --max-input-length 2048 --max-total-tokens 4096 \
--max-batch-prefill-tokens 4096 --quantize awq --speculate 3 \
--volume $PWD/data:data --model-id TheBloke/Llama-2-70B-chat-AWQ \
ghcr.io/huggingface/text-generation-inference:1.4
--gpus all
- Enables GPU access; specify number orall
--shm-size 1g
- Size of the shared memory device alloted to the container
--trust-remote-code
- Serve a custom model with weights and implementation available on Hugging Face Hub
--port 8080:80
- Maps container port to server port where your app is listening
--max-input-length 2048
- Max tokens in user prompt
--max-total-tokens
4096 - "Memory budget" of client requests. Max of prompt + generated output
--max-batch-prefill-tokens 4096
- Max tokens in prefill (kv caching) operation
--quantize awq
- Quantization method (if desired)
--speculate 3
- For speculative decoding, number of inputs to speculate
--volume $PWD/data:data
- share a volume with the container to avoid downloading weights every run
--model-id TheBloke/Llama-2-70B-chat-AWQ
- local model or from Hugging Face
For a full list of flags see Hugging Face docs.
Runpod handles all the networking, domains, and TLS certs for you, so that you can immediately make queries to your model once the container is up, using the Pod ID:
https://{YOUR_POD_ID}-8080.proxy.runpod.net
Note
If you're not using Runpod or a service that similarly handles the networking for you, you may have to additionally configure a static IP address, load balancer, TLS certificate, update your DNS, etc. before quering your model. As these steps vary greatly by cloud provider, they are not covered in this article.
However, the server will not have curl
installed, so you'll need to first connect via ssh
and install it:
apt update && apt install -y curl
Then, you can make queries to the api as follows:
curl https://{YOUR_POD_ID}-8080.proxy.runpod.net/generate \
-X POST \
-d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}}' \
-H 'Content-Type: application/json'
Or use the /generate_stream
endpoint for streaming. You can also write python scripts and use python to make requests, as demonstrated later in this notebook.
This script tests the response time of our model API. Note that this script assumes you are using TGI to configure your API as described above.
import os
import subprocess
import json
import time
from termcolor import colored
from tenacity import retry, wait_random_exponential, stop_after_attempt
from transformers import AutoTokenizer
# Model should be available on Hugging Face or stored locally
# If private, be sure to add HUGGING_FACE_ACCESS_TOKEN to environment variables
model = 'casperhansen/mixtral-instruct-awq'
# For Runpod with TGI. Replace <POD_ID> with your Runpod Pod ID
api_endpoint = "https://<POD-ID>-8080.proxy.runpod.net"
tgi_api_base = api_endpoint + '/generate'
tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)
# # ## Manually adjust the prompt. Not Recommended. Here is Vicuna 1.1 prompt
## format. System messages not supported.
# tokenizer.chat_template = "{% set sep = ' ' %}{% set sep2 = '</s>' %}{{ 'A chat between a curious user and an artificial intelligence assistant.\n\nThe assistant gives helpful, detailed, and polite answers to user questions.\n\n' }}{% if messages[0]['role'] == 'system' %}{{ '' }}{% set start_index = 1 %}{% else %}{% set start_index = 0 %}{% endif %}{% for i in range(start_index, messages|length) %}{% if messages[i]['role'] == 'user' %}{{ 'USER:\n' + messages[i]['content'].strip() + (sep if i % 2 == start_index else sep2) }}{% elif messages[i]['role'] == 'assistant' %}{{ 'ASSISTANT:\n' + messages[i]['content'].strip() + (sep if i % 2 == start_index else sep2) }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'ASSISTANT:\n' }}{% endif %}"
# # OPTION TO MANUALLY FORMAT MESSAGES (INSTEAD OF USING tokenizer.apply_chat_template)
# B_SYS = "<<SYS>>\n"
# E_SYS = "\n<</SYS>>\n\n"
# B_INST = "[INST] "
# E_INST = " [/INST]\n\n"
# BOS_token = "<s>"
# EOS_token = "</s>"
# def format_messages(messages):
# formatted_string = ''
# formatted_string += BOS_token
# formatted_string += B_INST
# for message in messages:
# if message['role'] == 'system':
# formatted_string += B_SYS
# formatted_string += message['content']
# formatted_string += E_SYS
# elif message['role'] in ['user']:
# formatted_string += message['content']
# formatted_string += E_INST
# elif message['role'] in ['assistant']:
# formatted_string += message['content']
# formatted_string += EOS_token
# formatted_string += BOS_token
# formatted_string += B_INST
# return formatted_string
@retry(wait=wait_random_exponential(multiplier=1, max=40), stop=stop_after_attempt(3))
def chat_completion_request_runpod(messages):
# formatted_messages = format_messages(messages)
formatted_messages = tokenizer.apply_chat_template(messages, tokenize=False, \
add_generation_prompt=True)
# print(formatted_messages)
# Properly escape the string for JSON
json_payload = json.dumps({
"inputs": formatted_messages,
"parameters": {
"max_new_tokens": 500,
"do_sample": False,
# "stop": ["<step>"] #required for codellama 70b
}})
start_time = time.time() # Start timing
try:
# Execute the curl command
curl_command = f"""
curl -s {tgi_api_base} \
-X POST \
-d '{json_payload}' \
-H 'Content-Type: application/json'
"""
response = subprocess.run(curl_command, shell=True, check=True, \
stdout=subprocess.PIPE)
response_time = time.time() - start_time # Calculate response time
response = response.stdout.decode()
response = json.loads(response).get("generated_text", "No generated text found")
# # Log the first and last 25 characters and the response time
# print(f"Response Time: {response_time} seconds")
# print(f"Start of Response: {response[:25]}")
# print(f"End of Response: {response[-25:]}")
# Calculate tokens per second
tokens_generated = len(response)/4 # Assuming each word is a token
tokens_per_second = tokens_generated / response_time if response_time > 0 else 0
prompt_tokens = chat_response.usage.prompt_tokens if completion_text else 0
# Print promt and generated tokens, time taken and tokens per second
print(f"Total Time: {response_time:.2f} seconds")
print(f"Prompt Tokens: {prompt_tokens:.2f}")
print(f"Tokens Generated: {tokens_generated:.2f}")
print(f"Tokens per Second: {tokens_per_second:.2f}")
return response
except subprocess.CalledProcessError as e:
print("Unable to generate ChatCompletion response")
print(f"Exception: {e}")
return str(e)
def pretty_print_conversation(messages):
role_to_color = {
"system": "red",
"user": "green",
"assistant": "blue",
"tool": "magenta",
}
for message in messages:
if message["role"] == "system":
print(colored(f"system: {message['content']}\n", role_to_color[message["role"]]))
elif message["role"] == "user":
print(colored(f"user: {message['content']}\n", role_to_color[message["role"]]))
elif message["role"] == "assistant" and message.get("function_call"):
print(colored(f"assistant: {message['function_call']}\n", \
role_to_color[message["role"]]))
elif message["role"] == "assistant" and not message.get("function_call"):
print(colored(f"assistant: {message['content']}\n", role_to_color[message["role"]]))
elif message["role"] == "tool":
print(colored(f"function ({message['name']}): {message['content']}\n", \
role_to_color[message["role"]]))
# Chat
messages = []
# messages.append({"role": "system", "content": "You are a helpful assistant."})
messages.append({"role": "user", "content": "Write a long essay on the topic of spring."})
# messages.append({"role": "user", "content": "Write a short piece of python code to add up the first 10 prime fibonacci numbers."})
chat_response = chat_completion_request_runpod(messages)
messages.append({"role": "assistant", "content": chat_response})
pretty_print_conversation(messages)
Prompt:
"Write a long essay on the topic of spring."
Max New Tokens:
500
Concurrent Requests:
1
GPU | Latency (s) | Tokens Per Second |
---|---|---|
1x RTX A6000 (48GB VRAM) | 29.05 | 17.21 |
1x A100 (80GB VRAM) | 25.20 | 19.84 |
1x H100 PCIe (80GB VRAM) | 25.76 | 19.41 |
1x H100 SXM5 (80GB VRAM) | 24.64 | 20.29 |
Response:
user:
Write a long essay on the topic of spring.
assistant:
Title: The Wonders of Spring: A Season of Renewal and Growth
Spring, the season of renewal and growth, is a time of great beauty
and transformation in the natural world. It is a time when the earth
seems to come back to life after the long, cold winter, and when the
first signs of new life begin to emerge. In this essay, we will
explore the many wonders of spring, from the blooming of flowers and
the return of migratory birds, to the changing behavior of animals
and the rebirth of the landscape.[...]
This script test the response time of our model API for a given number of requests per second. Note that this script assumes you are using TGI to configure your API as described above.
Prompt:
"Write a long essay on the topic of spring."
Max New Tokens:
500
Concurrent Requests:
25, 50, 100
import threading
import os
import json
import time
import requests
from tenacity import retry, wait_random_exponential, stop_after_attempt
from transformers import AutoTokenizer
# Model should be available on Hugging Face
# If private, be sure to add HUGGING_FACE_ACCESS_TOKEN to environment variables
model = 'casperhansen/mixtral-instruct-awq'
# For Runpod with TGI. Replace <POD_ID> with your Runpod Pod ID
api_endpoint = "https://<POD-ID>-8080.proxy.runpod.net"
tgi_api_base = api_endpoint + '/generate'
tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)
# # Manually adjust the prompt. Not Recommended. Here is Vicuna 1.1 prompt format. System messages not supported.
# tokenizer.chat_template = "{% set sep = ' ' %}{% set sep2 = '</s>' %}{{ 'A chat between a curious user and an artificial intelligence assistant.\n\nThe assistant gives helpful, detailed, and polite answers to user questions.\n\n' }}{% if messages[0]['role'] == 'system' %}{{ '' }}{% set start_index = 1 %}{% else %}{% set start_index = 0 %}{% endif %}{% for i in range(start_index, messages|length) %}{% if messages[i]['role'] == 'user' %}{{ 'USER:\n' + messages[i]['content'].strip() + (sep if i % 2 == start_index else sep2) }}{% elif messages[i]['role'] == 'assistant' %}{{ 'ASSISTANT:\n' + messages[i]['content'].strip() + (sep if i % 2 == start_index else sep2) }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'ASSISTANT:\n' }}{% endif %}"
# # OPTION TO MANUALLY FORMAT MESSAGES (INSTEAD OF USING tokenizer.apply_chat_template)
# B_SYS = "<<SYS>>\n"
# E_SYS = "\n<</SYS>>\n\n"
# B_INST = "[INST] "
# E_INST = " [/INST]\n\n"
# BOS_token = "<s>"
# EOS_token = "</s>"
# def format_messages(messages):
# formatted_string = ''
# formatted_string += BOS_token
# formatted_string += B_INST
# for message in messages:
# if message['role'] == 'system':
# formatted_string += B_SYS
# formatted_string += message['content']
# formatted_string += E_SYS
# elif message['role'] in ['user']:
# formatted_string += message['content']
# formatted_string += E_INST
# elif message['role'] in ['assistant']:
# formatted_string += message['content']
# formatted_string += EOS_token
# formatted_string += BOS_token
# formatted_string += B_INST
# return formatted_string
# @retry(wait=wait_random_exponential(multiplier=1, max=40), stop=stop_after_attempt(3))
def chat_completion_request_threaded(messages, request_number):
# formatted_messages = format_messages(messages)
formatted_messages = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
json_payload = {"inputs": formatted_messages, "parameters": {"max_new_tokens": 500, "do_sample": False}}
start_time = time.time() # Start timing
try:
response = requests.post(tgi_api_base, json=json_payload)
response_time = time.time() - start_time # Calculate response time
if response.status_code == 200:
response_content = response.json().get("generated_text", "No generated text found")
else:
raise Exception(f"Request failed with status code {response.status_code}")
# print(response_content)
# Calculate tokens per second
tokens_generated = len(response_content) / 4
tokens_per_second = tokens_generated / response_time if response_time > 0 else 0
# Print time taken and tokens per second for each request
print(f"Request #{request_number}: Total Time: {response_time:.2f} seconds, Tokens per Second: {tokens_per_second:.2f}")
return response_content
except Exception as e:
print(f"Unable to generate ChatCompletion response for Request #{request_number}")
print(f"Exception: {e}")
return str(e)
def send_request_every_x_seconds(interval, total_requests):
for i in range(total_requests):
threading.Timer(interval * i, send_request, args=(i+1,)).start()
def send_request(request_number):
messages = [
{"role": "user", "content": "Write a long essay on the topic of spring."}
]
chat_completion_request_threaded(messages, request_number)
# Start sending requests every x seconds
send_request_every_x_seconds(0.125, 12) # Modify as needed for your use case
GPU | Concurrent Requests | Average Latency (s) | Average Tokens Per Second | Cost ($/hr) |
---|---|---|---|---|
1x RTX A6000 (48GB) | 25 | 46.51 | 11.87 | 0.79 |
1x RTX A6000 (48GB) | 50 | 82.87 | 6.71 | 0.79 |
1x RTX A6000 (48GB) | 100 | Timeout Error | Timeout Error | 0.79 |
1x A100 (80GB) | 25 | 43.67 | 12.84 | 1.89 |
1x A100 (80GB) | 50 | 57.15 | 9.66 | 1.89 |
1x A100 (80GB) | 100 | 85.80 | 6.51 | 1.89 |
1x H100 PCIe (80GB) | 25 | 45.69 | 12.04 | 3.89 |
1x H100 PCIe (80GB) | 50 | 62.77 | 8.76 | 3.89 |
1x H100 PCIe (80GB) | 100 | 99.53 | 5.56 | 3.89 |
1x H100 SXM5 (80GB) | 25 | 35.43 | 15.71 | 4.69 |
1x H100 SXM5 (80GB) | 50 | 48.66 | 11.38 | 4.69 |
1x H100 SXM5 (80GB) | 100 | 72.40 | 7.59 | 4.69 |
Summary of Results:
In this section, we'll use the open-source inference toolkit vLLM to serve our model.
Like TGI, vLLM implements several advanced features to improve inference speed and make our API production ready:
This is what we're going to deploy:
Model: Mixtral 8x7b Instruct
Quantization: AWQ
Minimum GPU memory (VRAM): 48 GB
Recommended GPUs: A6000, A100, H100
Inference toolkit: vLLM
If you're using Runpod.io, you can use this one-click template to deploy your API in seconds. The sections below discuss configuration details, regardless of which cloud provider you choose.
We'll use the same GPUs as we did with the TGI tests, except for the NVIDIA H100 PCIe, which had comparable result to the A100, but for more than 2x more money. Therefore, I skip the H100 PCIe test in this section.
The easiest way to use vLLM is via Docker image, available on Docker Hub. This image bundles the model with all its dependencies and ensure it is compatible with your chosen cloud provider's runtime environment. It takes as an input the model you want to serve, along with any options you want like quantization method, max tokens, speculative decoding etc.
Important
To enable GPU access for your container, you'll need to install thenvidia-container-toolkit
on your server and set the--gpus
flag in yourdocker run
command, as shown below. Some providers (like Runpod), will pre-install this dependency so you don't have to.
Container image: vllm/vllm-openai:latest
Container disk (temporary): 10 GB
Volume disk (persistent):50 GB
Volume mount path: /root/.cache/huggingface
Port: 8000
(omit to use ssh instead)
Environment variables:
HUGGING_FACE_HUB_TOKEN=<replace with your access token>
(required for private models)
Container start command:
docker run --max-model-len 4096 --quantize awq --dtype half --enforce-eager \
--model casperhansen/mixtral-instruct-awq --port 8000 \
vllm/vllm-openai:latest
--max-model-len 4096
- Model context length.
--quantize awq
- Quantization method
--dtype
- Data type for model weights and activations
--port 8000
- The port where your app is listening--model-id TheBloke/Llama-2-70B-chat-AWQ
- Name or path of the Hugging Face model to use
For a full list of flags see vLLM docs.
Runpod handles all the networking, domains, and TLS certs for you, so that you can immediately make queries to your model once the container is up, using the Pod ID:
https://{YOUR_POD_ID}-8000.proxy.runpod.net
Note
If you're not using Runpod or a service that similarly handles the networking for you, you may have to additionally configure a static IP address, load balancer, TLS certificate, update your DNS, etc. before quering your model. As these steps vary greatly by cloud provider, they are not covered in this article.
However, the server will not have curl
installed, so you'll need to first connect via ssh
and install it:
apt update && apt install -y curl
Then, you can make queries to the api as follows:
curl https://{YOUR_POD_ID}-8000.proxy.runpod.net/generate \
-X POST \
-d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}}' \
-H 'Content-Type: application/json'
Or use python to make requests, as demonstrated later in the next section.
This script tests the response time of our model API. Note that this script assumes you are using vLLM to configure your API as described above.
from openai import OpenAI
import os
import time
from dotenv import load_dotenv
from termcolor import colored
model = 'casperhansen/mixtral-instruct-awq'
# For Runpod with vLLM. Replace <POD_ID> with your Runpod Pod ID
api_endpoint = "https://<POD-ID>-8000.proxy.runpod.net"
openai_api_base = api_endpoint + '/v1'
# Initialize the OpenAI client
client = OpenAI(
api_key="EMPTY", # Replace with your actual API key if required
base_url=openai_api_base,
)
def chat_completion_request_openai(messages, client):
start_time = time.time() # Start timing
# Create chat completions using the OpenAI client
chat_response = client.chat.completions.create(
model=model,
messages=messages,
temperature=0,
max_tokens=500
)
response_time = time.time() - start_time # Calculate response time
# Extract the completion text from the response
if chat_response.choices:
completion_text = chat_response.choices[0].message.content
else:
completion_text = None
# Calculate tokens per second
prompt_tokens = chat_response.usage.prompt_tokens if completion_text else 0
tokens_generated = chat_response.usage.completion_tokens if completion_text else 0
tokens_per_second = tokens_generated / response_time if response_time > 0 else 0
# print(chat_response)
# Print time taken and tokens per second
print(f"Total Time: {response_time:.2f} seconds")
print(f"Prompt Tokens: {prompt_tokens:.2f}")
print(f"Tokens Generated: {tokens_generated:.2f}")
print(f"Tokens per Second: {tokens_per_second:.2f}")
return completion_text
def pretty_print_conversation(messages):
role_to_color = {
"system": "red",
"user": "green",
"assistant": "blue",
"tool": "magenta",
}
for message in messages:
color = role_to_color.get(message["role"], "grey")
print(colored(f"{message['role']}: {message['content']}\n", color))
# Test the function
messages = [
{"role": "user", "content": "Write a long essay on the topic of spring."}
]
chat_response = chat_completion_request_openai(messages, client)
messages.append({"role": "assistant", "content": chat_response})
pretty_print_conversation(messages)
Prompt:
"Write a long essay on the topic of spring."
Max New Tokens:
500
Concurrent Requests:
1
Note: I did not test vLLM with the H100 PCIe 80GB VRAM after it performed on par with the much cheaper A100 in the TGI tests.
GPU | Latency (s) | Tokens Per Second |
---|---|---|
1x RTX A6000 (48GB VRAM) | 30.09 | 16.62 |
1x A100 (80GB VRAM) | 29.12 | 17.17 |
1x H100 SXM5 (80GB VRAM) | 27.96 | 17.89 |
Response:
user:
Write a long essay on the topic of spring.
assistant:
Title: The Wonders of Spring: A Season of Renewal and Growth
Spring, the season of renewal and growth, is a time of great beauty
and transformation in the natural world. It is a time when the earth
seems to come back to life after the long, cold winter, and when the
first signs of new life begin to emerge. In this essay, we will
explore the many wonders of spring, from the blooming of flowers and
the return of migratory birds, to the changing behavior of animals
and the rebirth of the landscape.[...]
vLLM supports the OpenAI format/client
from openai import OpenAI
import os
import time
import threading
# from termcolor import colored # Uncomment if you wish to use colored output
# Model should be available on Hugging Face or stored locally
# If private, be sure to add HUGGING_FACE_ACCESS_TOKEN to environment variables
model = 'casperhansen/mixtral-instruct-awq'
# For Runpod with TGI. Replace <POD_ID> with your Runpod Pod ID
api_endpoint = "https://<POD-ID>-8080.proxy.runpod.net"
openai_api_base = api_endpoint + '/v1'
# Initialize the OpenAI client
client = OpenAI(
api_key="EMPTY", # Replace with your actual API key if required
base_url=openai_api_base,
)
def chat_completion_request_openai(messages, client, request_number):
start_time = time.time() # Start timing
# Create chat completions using the OpenAI client
chat_response = client.chat.completions.create(
model=model,
messages=messages,
temperature=0,
max_tokens=500
)
response_time = time.time() - start_time # Calculate response time
# Extract the completion text from the response
if chat_response.choices:
completion_text = chat_response.choices[0].message.content
else:
completion_text = None
# Calculate tokens per second
prompt_tokens = chat_response.usage.prompt_tokens if completion_text else 0
tokens_generated = chat_response.usage.completion_tokens if completion_text else 0
tokens_per_second = tokens_generated / response_time if response_time > 0 else 0
# Print header and response details
print(f"\n---------- Request #{request_number} ----------")
print(f"Total Time Taken: {response_time:.2f} seconds")
print(f"Prompt tokens: {prompt_tokens:.2f}")
print(f"Tokens generated: {tokens_generated:.2f}")
print(f"Tokens per Second: {tokens_per_second:.2f}\n")
return completion_text
def send_request_every_x_seconds():
for i in range(12):
threading.Timer(0.125 * i, send_request, args=(i+1,)).start()
def send_request(request_number):
messages = [
{"role": "user", "content": "Write a long essay on the topic of spring."}
]
chat_completion_request_openai(messages, client, request_number)
# Start sending requests every x seconds
send_request_every_x_seconds()
GPU | Concurrent Requests | Average Latency (s) | Average Tokens Per Second | Cost ($/hr) |
---|---|---|---|---|
1x RTX A6000 (48GB) | 25 | 48.63 | 10.32 | 0.79 |
1x RTX A6000 (48GB) | 50 | 82.87 | 6.71 | 0.79 |
1x RTX A6000 (48GB) | 100 | 502 Error | 502 Error | 0.79 |
1x A100 (80GB) | 25 | 45.44 | 10.86 | 1.89 |
1x A100 (80GB) | 50 | 72.12 | 6.89 | 1.89 |
1x A100 (80GB) | 100 | 502 Error | 502 Error | 1.89 |
1x H100 SXM5 (80GB) | 25 | 39.25 | 12.57 | 4.69 |
1x H100 SXM5 (80GB) | 50 | 56.42 | 8.83 | 4.69 |
1x H100 SXM5 (80GB) | 100 | 91.05 | 5.55 | 4.69 |