by Grayson Adkins, updated February 6, 2024
This notebook demonstrates fine-tuning of an open-source model (Mistral 7B). It leverages the transformers
and PEFT
libraries from Hugging Face for quantization, LoRA, and training, and a custom-built data set for function calling.
This notebook builds on the basic fine-tuning example by introducing the following innovations:
Notes
Recommended Reading
Attribution
You want to learn how to:
The source code for this notebook is available in the ai-cookbook
repo on my GitHub.
Typically, a model is graded on its prediction of the next token in both the question and answer. However, our primary goal is for the model to give thoughtful attention to the question, while its performance should be graded based soley on how it predicts the answer; this is achieved by attention and loss masks, respectively.
Attention is a mechanism used during training to instruct the model on what parts of the input text (e.g., a question or a context) it should pay attention to. It helps the model focus on the relevant information and ignore irrelevant portions of the input. An attention mask is simply a sequence of 1s and 0s that is multiplied by the input sequence IDs—resulting in a new input sequence where irrelevant tokens are zeroed out (i.e. masked).
{'input_ids': tensor([[9204, 18, 3763, 456, 222, 13563, 22580, 584]]),
'attention_mask': tensor([[1, 1, 1, 0, 1, 1, 0, 1]])}
{'result': tensor([[9204, 18, 3763, 0, 222, 13563, 0, 584]])}
As an example, we usually want to make sure that PAD
tokens are masked.
A loss mask is used to calculate the loss or error during training. It specifies which parts of the model's output should be considered when computing the loss. When training a model, we take the losses and multiply them by the loss mask.
To improve model performance, in this notebook we mask the losses associated with prompt to ensure the model focuses on answering the question, not predicting the next sequence of tokens in the question.
Have you every noticed how verbose some models are? By fine-tuning with stop sequence, such as USER:
, we can teach the model to be more concise:
{
prompt: "Where is the stock price of Apple?\n\nBOT:",
completion: "Apple stock price is $188.04.\n\nUSER: ",
},
...
Mistral 7B Instruct is an instruction fine-tuned version of Mistral 7B available on Hugging Face.
Per the HF model card:
Instruction format¶
The template used to build a prompt for the Instruct model is defined as follows:
<s> [INST] Instruction [/INST] Model answer</s> [INST] Follow-up instruction [/INST]
Model architecture¶
This instruction model is based on Mistral-7B-v0.1, a transformer model with the following architecture choices:
- Grouped-Query Attention
- Sliding-Window Attention
- Byte-fallback BPE tokenizer
write
permissions.# # Print GPU info
# gpu_info = !nvidia-smi
# gpu_info = '\n'.join(gpu_info)
# if gpu_info.find('failed') >= 0:
# print('Not connected to a GPU')
# else:
# print(gpu_info)
# # Print VRAM
# from psutil import virtual_memory
# ram_gb = virtual_memory().total / 1e9
# print('Your runtime has {:.1f} gigabytes of available RAM\n'.format(ram_gb))
# if ram_gb < 20:
# print('Not using a high-RAM runtime')
# else:
# print('You are using a high-RAM runtime!')
# Authenticate to Hugging Face to pull and push models
!pip install huggingface_hub -q
from huggingface_hub import notebook_login
notebook_login()
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…
# (Optional) Configure Weights & Biases (wandb) to track training runs
!pip install wandb -q -U
import wandb
wandb.login()
wandb: Currently logged in as: gadkins. Use `wandb login --relogin` to force relogin
True
# base_model = "./Mistral-7B-Instruct-v0.1-function-calling-v2"
# base_model = "meta-llama/Llama-2-7b-hf"
# base_model = "meta-llama/Llama-2-7b-chat-hf"
# base_model = "meta-llama/Llama-2-13b-chat-hf"
# base_model = "codellama/CodeLlama-34b-Instruct-hf"
# base_model = "meta-llama/Llama-2-70b-chat-hf"
base_model = "mistralai/Mistral-7B-Instruct-v0.1"
# base_model = "deepseek-ai/deepseek-coder-1.3b-instruct"
# base_model = "deepseek-ai/deepseek-coder-6.7b-instruct"
# base_model = "deepseek-ai/deepseek-coder-33b-instruct"
# base_model = "larryvrh/Yi-34B-200K-Llamafied"
# base_model = "./Yi-34B-200K-Llamafied-chat-SFT"
# base_model = "openchat/openchat_3.5"
# base_model = "SUSTech/SUS-Chat-34B"
# base_model = "mistralai/Mixtral-8x7B-Instruct-v0.1"
# base_model = "microsoft/phi-2"
cache_dir = '' # Initialise the cache_dir to null.
# (Optionally, you can set Google Drive as the cache_dir below)
# stable versions
!python -m pip install --upgrade pip
!pip install -U -q transformers
!pip install -q -U bitsandbytes
!pip install -q -U peft
!pip install -q -U accelerate
!pip install -q datasets
!pip install -q -U scipy
!pip install -q -U trl
!pip install -U flash-attn -q
Requirement already satisfied: pip in /usr/local/lib/python3.10/dist-packages (23.1.2) Collecting pip Downloading pip-24.0-py3-none-any.whl (2.1 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.1/2.1 MB 31.6 MB/s eta 0:00:00 Installing collected packages: pip Attempting uninstall: pip Found existing installation: pip 23.1.2 Uninstalling pip-23.1.2: Successfully uninstalled pip-23.1.2 Successfully installed pip-24.0 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 129.4/129.4 kB 3.4 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 8.4/8.4 MB 95.7 MB/s eta 0:00:00 WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 105.0/105.0 MB 19.9 MB/s eta 0:00:00 WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 183.4/183.4 kB 6.8 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 270.9/270.9 kB 13.9 MB/s eta 0:00:00 WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 507.1/507.1 kB 17.6 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 115.3/115.3 kB 10.9 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 134.8/134.8 kB 12.7 MB/s eta 0:00:00 WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 60.4/60.4 kB 2.0 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 38.4/38.4 MB 55.6 MB/s eta 0:00:00 ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts. lida 0.0.10 requires fastapi, which is not installed. lida 0.0.10 requires kaleido, which is not installed. lida 0.0.10 requires python-multipart, which is not installed. lida 0.0.10 requires uvicorn, which is not installed. WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 150.9/150.9 kB 6.6 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 79.8/79.8 kB 7.3 MB/s eta 0:00:00 WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.5/2.5 MB 31.6 MB/s eta 0:00:00 Preparing metadata (setup.py) ... done ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 44.6/44.6 kB 3.8 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 307.2/307.2 kB 22.6 MB/s eta 0:00:00 Building wheel for flash-attn (setup.py) ... done WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline, AutoConfig
import transformers
import torch
from torch.utils.data import DataLoader, Dataset
from google.colab import drive
drive.mount('/content/drive')
Mounted at /content/drive
import os
cache_dir = "/content/drive/My Drive/huggingface_cache"
os.makedirs(cache_dir, exist_ok=True) # Ensure the directory exists
# https://stackoverflow.com/questions/56081324/why-are-google-colab-shell-commands-not-working
import locale
def getpreferredencoding(do_setlocale = True):
return "UTF-8"
locale.getpreferredencoding = getpreferredencoding
Note about quantization:
In this section, we have the option to load a quantized version of the model (see the QLoRA notebook for quantization details) to reduce the computation requirements such that it will fit on a free T4 GPU in Google Colab. If cost is most important to you, then I recommend this option—just uncomment the quantization_config
option below.
However, I've observed slightly better performance in function-calling fine-tunes when using models at full precision. Note that if you use full precision, you'll need a larger GPU such as an A100. If you're using Google Colab, you'll need to upgrade to Pro or use another service like RunPod or Lambda Labs (which are a bit cheaper).
# QLoRA config
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
# Instantiate model
model = AutoModelForCausalLM.from_pretrained(
base_model,
# quantization_config=bnb_config, # Uncomment to use quantized version
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
#attn_implementation="flash_attention_2", # Supported in Ampere GPUs or newer
cache_dir=cache_dir
)
/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_token.py:88: UserWarning: The secret `HF_TOKEN` does not exist in your Colab secrets. To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session. You will be able to reuse this secret in all of your notebooks. Please note that authentication is recommended but still optional to access public models or datasets. warnings.warn(
Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]
# # Required for certain tokenizers like Yi
# !pip install sentencepiece -q -U
tokenizer = AutoTokenizer.from_pretrained(base_model, cache_dir=cache_dir, trust_remote_code=True)
# tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1", cache_dir=cache_dir)
print("EOS token:", tokenizer.eos_token)
print("EOS token id:", tokenizer.eos_token_id)
EOS token: </s> EOS token id: 2
# If pad token is None, we'll need to set one in the next section
print("Pad token: ", tokenizer.pad_token)
print("Pad token ID: ", tokenizer.pad_token_id)
Pad token: None Pad token ID: None
# Padding to the right (i.e. after) the prompt and response has better results
tokenizer.padding_side='right'
print(tokenizer)
LlamaTokenizerFast(name_or_path='mistralai/Mistral-7B-Instruct-v0.1', vocab_size=32000, model_max_length=1000000000000000019884624838656, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>'}, clean_up_tokenization_spaces=False), added_tokens_decoder={ 0: AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True), 1: AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True), 2: AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True), }
Some models already have a pad token set. You can see whether they do or don't from the tokenizer print statement above. If that's the case, then you don't need to do anything further.
If no pad token exists, then you have three options:
Options
<unk>
token (i.e. "unknown") to pad—note that this assumes the <unk>
token exists in the vocab. ## (Recommended) OPTION 1
# If <unk> is in the tokenizer, set the pad token to <unk>
# Else, set pad token to EOS token
if '<unk>' in tokenizer.get_vocab():
print('Found \'<unk>\' token in tokenizer. Using \'<unk>\' for pad.')
# Set the pad token
tokenizer.pad_token = '<unk>'
else:
print(f'Using EOS token, \'{tokenizer.eos_token}\', for padding')
tokenizer.pad_token = tokenizer.eos_token
## OPTION 2
# # Check if the pad token is already in the tokenizer vocabulary
# if '<pad>' not in tokenizer.get_vocab():
# print('pad token not in the tokenizer')
# # Add the pad token
# tokenizer.add_tokens(['<pad>'])
# # Set the pad token
# tokenizer.pad_token = '<pad>'
# # Resize token embeddings
# model.resize_token_embeddings(len(tokenizer))
Found '<unk>' token in tokenizer. Using '<unk>' for pad.
# Update pad token id in model and its config
model.pad_token_id = tokenizer.pad_token_id
model.config.pad_token_id = tokenizer.pad_token_id
# Check if they are equal
assert model.pad_token_id == tokenizer.pad_token_id, "The model's pad token ID \
does not match the tokenizer's pad token ID!"
# Print the pad token ids
print('Tokenizer pad token ID:', tokenizer.pad_token_id)
print('Model pad token ID:', model.pad_token_id)
print('Model config pad token ID:', model.config.pad_token_id)
print('Number of tokens now in tokenizer:', len(tokenizer))
Tokenizer pad token ID: 0 Model pad token ID: 0 Model config pad token ID: 0 Number of tokens now in tokenizer: 32000
# Print model configuration
print(model.config)
MistralConfig { "_name_or_path": "mistralai/Mistral-7B-Instruct-v0.1", "architectures": [ "MistralForCausalLM" ], "attention_dropout": 0.0, "bos_token_id": 1, "eos_token_id": 2, "hidden_act": "silu", "hidden_size": 4096, "initializer_range": 0.02, "intermediate_size": 14336, "max_position_embeddings": 32768, "model_type": "mistral", "num_attention_heads": 32, "num_hidden_layers": 32, "num_key_value_heads": 8, "pad_token_id": 0, "rms_norm_eps": 1e-05, "rope_theta": 10000.0, "sliding_window": 4096, "tie_word_embeddings": false, "torch_dtype": "bfloat16", "transformers_version": "4.37.2", "use_cache": true, "vocab_size": 32000 }
# Sample string
# sample_string = ['hello [/INST]', 'my good friend</s>']
sample_string = ['Caio!']
# Tokenize the stringified JSON object
encoded_sample = tokenizer(sample_string, truncation=True, padding=True, max_length=1024, return_tensors='pt', add_special_tokens=True)
BOS_token_id = tokenizer.bos_token_id
EOS_token_id = tokenizer.eos_token_id
BOS_token = tokenizer.decode([BOS_token_id])
EOS_token = tokenizer.decode([EOS_token_id])
print(f"Beginning of the sequence: {sample_string[0]} (BOS token: {BOS_token}, id: {BOS_token_id})")
print(f"End of the sequence: {sample_string[-1]} (EOS token: {EOS_token}, id: {EOS_token_id})")
token_count = len(encoded_sample)
print(f"Tokens in the string: {token_count}")
print(f"Token IDs: {encoded_sample}")
# Decode the input_ids
decoded_sample = tokenizer.decode(encoded_sample['input_ids'][0], skip_special_tokens=False)
# Print the decoded string
print(f"Decoded string: {decoded_sample}")
# Print the attention mask
print(f"Attention mask: {encoded_sample['attention_mask']}")
Beginning of the sequence: Caio! (BOS token: <s>, id: 1) End of the sequence: Caio! (EOS token: </s>, id: 2) Tokens in the string: 2 Token IDs: {'input_ids': tensor([[ 1, 11013, 691, 28808]]), 'attention_mask': tensor([[1, 1, 1, 1]])} Decoded string: <s> Caio! Attention mask: tensor([[1, 1, 1, 1]])
# # If loading with adapters
# # Note: Instead, it's often faster to download base model then add adapters
# from peft import PeftModel
# # adapter_model = f'{base_model}' + '-function-calling-adapters' # replace
# # Load peft model with adapters
# model = PeftModel.from_pretrained(
# model,
# adapter_model,
# )
# To reduce VRAM usage (supported by most models)
model.gradient_checkpointing_enable()
# If using quantized model
# from peft import prepare_model_for_kbit_training
# model = prepare_model_for_kbit_training(model)
# Print list of modules
print(model.state_dict().keys())
odict_keys(['model.embed_tokens.weight', 'model.layers.0.self_attn.q_proj.weight', 'model.layers.0.self_attn.k_proj.weight', 'model.layers.0.self_attn.v_proj.weight', 'model.layers.0.self_attn.o_proj.weight', 'model.layers.0.mlp.gate_proj.weight', 'model.layers.0.mlp.up_proj.weight', 'model.layers.0.mlp.down_proj.weight', 'model.layers.0.input_layernorm.weight', 'model.layers.0.post_attention_layernorm.weight', 'model.layers.1.self_attn.q_proj.weight', 'model.layers.1.self_attn.k_proj.weight', 'model.layers.1.self_attn.v_proj.weight', 'model.layers.1.self_attn.o_proj.weight', 'model.layers.1.mlp.gate_proj.weight', 'model.layers.1.mlp.up_proj.weight', 'model.layers.1.mlp.down_proj.weight', 'model.layers.1.input_layernorm.weight', 'model.layers.1.post_attention_layernorm.weight', 'model.layers.2.self_attn.q_proj.weight', 'model.layers.2.self_attn.k_proj.weight', 'model.layers.2.self_attn.v_proj.weight', 'model.layers.2.self_attn.o_proj.weight', 'model.layers.2.mlp.gate_proj.weight', 'model.layers.2.mlp.up_proj.weight', 'model.layers.2.mlp.down_proj.weight', 'model.layers.2.input_layernorm.weight', 'model.layers.2.post_attention_layernorm.weight', 'model.layers.3.self_attn.q_proj.weight', 'model.layers.3.self_attn.k_proj.weight', 'model.layers.3.self_attn.v_proj.weight', 'model.layers.3.self_attn.o_proj.weight', 'model.layers.3.mlp.gate_proj.weight', 'model.layers.3.mlp.up_proj.weight', 'model.layers.3.mlp.down_proj.weight', 'model.layers.3.input_layernorm.weight', 'model.layers.3.post_attention_layernorm.weight', 'model.layers.4.self_attn.q_proj.weight', 'model.layers.4.self_attn.k_proj.weight', 'model.layers.4.self_attn.v_proj.weight', 'model.layers.4.self_attn.o_proj.weight', 'model.layers.4.mlp.gate_proj.weight', 'model.layers.4.mlp.up_proj.weight', 'model.layers.4.mlp.down_proj.weight', 'model.layers.4.input_layernorm.weight', 'model.layers.4.post_attention_layernorm.weight', 'model.layers.5.self_attn.q_proj.weight', 'model.layers.5.self_attn.k_proj.weight', 'model.layers.5.self_attn.v_proj.weight', 'model.layers.5.self_attn.o_proj.weight', 'model.layers.5.mlp.gate_proj.weight', 'model.layers.5.mlp.up_proj.weight', 'model.layers.5.mlp.down_proj.weight', 'model.layers.5.input_layernorm.weight', 'model.layers.5.post_attention_layernorm.weight', 'model.layers.6.self_attn.q_proj.weight', 'model.layers.6.self_attn.k_proj.weight', 'model.layers.6.self_attn.v_proj.weight', 'model.layers.6.self_attn.o_proj.weight', 'model.layers.6.mlp.gate_proj.weight', 'model.layers.6.mlp.up_proj.weight', 'model.layers.6.mlp.down_proj.weight', 'model.layers.6.input_layernorm.weight', 'model.layers.6.post_attention_layernorm.weight', 'model.layers.7.self_attn.q_proj.weight', 'model.layers.7.self_attn.k_proj.weight', 'model.layers.7.self_attn.v_proj.weight', 'model.layers.7.self_attn.o_proj.weight', 'model.layers.7.mlp.gate_proj.weight', 'model.layers.7.mlp.up_proj.weight', 'model.layers.7.mlp.down_proj.weight', 'model.layers.7.input_layernorm.weight', 'model.layers.7.post_attention_layernorm.weight', 'model.layers.8.self_attn.q_proj.weight', 'model.layers.8.self_attn.k_proj.weight', 'model.layers.8.self_attn.v_proj.weight', 'model.layers.8.self_attn.o_proj.weight', 'model.layers.8.mlp.gate_proj.weight', 'model.layers.8.mlp.up_proj.weight', 'model.layers.8.mlp.down_proj.weight', 'model.layers.8.input_layernorm.weight', 'model.layers.8.post_attention_layernorm.weight', 'model.layers.9.self_attn.q_proj.weight', 'model.layers.9.self_attn.k_proj.weight', 'model.layers.9.self_attn.v_proj.weight', 'model.layers.9.self_attn.o_proj.weight', 'model.layers.9.mlp.gate_proj.weight', 'model.layers.9.mlp.up_proj.weight', 'model.layers.9.mlp.down_proj.weight', 'model.layers.9.input_layernorm.weight', 'model.layers.9.post_attention_layernorm.weight', 'model.layers.10.self_attn.q_proj.weight', 'model.layers.10.self_attn.k_proj.weight', 'model.layers.10.self_attn.v_proj.weight', 'model.layers.10.self_attn.o_proj.weight', 'model.layers.10.mlp.gate_proj.weight', 'model.layers.10.mlp.up_proj.weight', 'model.layers.10.mlp.down_proj.weight', 'model.layers.10.input_layernorm.weight', 'model.layers.10.post_attention_layernorm.weight', 'model.layers.11.self_attn.q_proj.weight', 'model.layers.11.self_attn.k_proj.weight', 'model.layers.11.self_attn.v_proj.weight', 'model.layers.11.self_attn.o_proj.weight', 'model.layers.11.mlp.gate_proj.weight', 'model.layers.11.mlp.up_proj.weight', 'model.layers.11.mlp.down_proj.weight', 'model.layers.11.input_layernorm.weight', 'model.layers.11.post_attention_layernorm.weight', 'model.layers.12.self_attn.q_proj.weight', 'model.layers.12.self_attn.k_proj.weight', 'model.layers.12.self_attn.v_proj.weight', 'model.layers.12.self_attn.o_proj.weight', 'model.layers.12.mlp.gate_proj.weight', 'model.layers.12.mlp.up_proj.weight', 'model.layers.12.mlp.down_proj.weight', 'model.layers.12.input_layernorm.weight', 'model.layers.12.post_attention_layernorm.weight', 'model.layers.13.self_attn.q_proj.weight', 'model.layers.13.self_attn.k_proj.weight', 'model.layers.13.self_attn.v_proj.weight', 'model.layers.13.self_attn.o_proj.weight', 'model.layers.13.mlp.gate_proj.weight', 'model.layers.13.mlp.up_proj.weight', 'model.layers.13.mlp.down_proj.weight', 'model.layers.13.input_layernorm.weight', 'model.layers.13.post_attention_layernorm.weight', 'model.layers.14.self_attn.q_proj.weight', 'model.layers.14.self_attn.k_proj.weight', 'model.layers.14.self_attn.v_proj.weight', 'model.layers.14.self_attn.o_proj.weight', 'model.layers.14.mlp.gate_proj.weight', 'model.layers.14.mlp.up_proj.weight', 'model.layers.14.mlp.down_proj.weight', 'model.layers.14.input_layernorm.weight', 'model.layers.14.post_attention_layernorm.weight', 'model.layers.15.self_attn.q_proj.weight', 'model.layers.15.self_attn.k_proj.weight', 'model.layers.15.self_attn.v_proj.weight', 'model.layers.15.self_attn.o_proj.weight', 'model.layers.15.mlp.gate_proj.weight', 'model.layers.15.mlp.up_proj.weight', 'model.layers.15.mlp.down_proj.weight', 'model.layers.15.input_layernorm.weight', 'model.layers.15.post_attention_layernorm.weight', 'model.layers.16.self_attn.q_proj.weight', 'model.layers.16.self_attn.k_proj.weight', 'model.layers.16.self_attn.v_proj.weight', 'model.layers.16.self_attn.o_proj.weight', 'model.layers.16.mlp.gate_proj.weight', 'model.layers.16.mlp.up_proj.weight', 'model.layers.16.mlp.down_proj.weight', 'model.layers.16.input_layernorm.weight', 'model.layers.16.post_attention_layernorm.weight', 'model.layers.17.self_attn.q_proj.weight', 'model.layers.17.self_attn.k_proj.weight', 'model.layers.17.self_attn.v_proj.weight', 'model.layers.17.self_attn.o_proj.weight', 'model.layers.17.mlp.gate_proj.weight', 'model.layers.17.mlp.up_proj.weight', 'model.layers.17.mlp.down_proj.weight', 'model.layers.17.input_layernorm.weight', 'model.layers.17.post_attention_layernorm.weight', 'model.layers.18.self_attn.q_proj.weight', 'model.layers.18.self_attn.k_proj.weight', 'model.layers.18.self_attn.v_proj.weight', 'model.layers.18.self_attn.o_proj.weight', 'model.layers.18.mlp.gate_proj.weight', 'model.layers.18.mlp.up_proj.weight', 'model.layers.18.mlp.down_proj.weight', 'model.layers.18.input_layernorm.weight', 'model.layers.18.post_attention_layernorm.weight', 'model.layers.19.self_attn.q_proj.weight', 'model.layers.19.self_attn.k_proj.weight', 'model.layers.19.self_attn.v_proj.weight', 'model.layers.19.self_attn.o_proj.weight', 'model.layers.19.mlp.gate_proj.weight', 'model.layers.19.mlp.up_proj.weight', 'model.layers.19.mlp.down_proj.weight', 'model.layers.19.input_layernorm.weight', 'model.layers.19.post_attention_layernorm.weight', 'model.layers.20.self_attn.q_proj.weight', 'model.layers.20.self_attn.k_proj.weight', 'model.layers.20.self_attn.v_proj.weight', 'model.layers.20.self_attn.o_proj.weight', 'model.layers.20.mlp.gate_proj.weight', 'model.layers.20.mlp.up_proj.weight', 'model.layers.20.mlp.down_proj.weight', 'model.layers.20.input_layernorm.weight', 'model.layers.20.post_attention_layernorm.weight', 'model.layers.21.self_attn.q_proj.weight', 'model.layers.21.self_attn.k_proj.weight', 'model.layers.21.self_attn.v_proj.weight', 'model.layers.21.self_attn.o_proj.weight', 'model.layers.21.mlp.gate_proj.weight', 'model.layers.21.mlp.up_proj.weight', 'model.layers.21.mlp.down_proj.weight', 'model.layers.21.input_layernorm.weight', 'model.layers.21.post_attention_layernorm.weight', 'model.layers.22.self_attn.q_proj.weight', 'model.layers.22.self_attn.k_proj.weight', 'model.layers.22.self_attn.v_proj.weight', 'model.layers.22.self_attn.o_proj.weight', 'model.layers.22.mlp.gate_proj.weight', 'model.layers.22.mlp.up_proj.weight', 'model.layers.22.mlp.down_proj.weight', 'model.layers.22.input_layernorm.weight', 'model.layers.22.post_attention_layernorm.weight', 'model.layers.23.self_attn.q_proj.weight', 'model.layers.23.self_attn.k_proj.weight', 'model.layers.23.self_attn.v_proj.weight', 'model.layers.23.self_attn.o_proj.weight', 'model.layers.23.mlp.gate_proj.weight', 'model.layers.23.mlp.up_proj.weight', 'model.layers.23.mlp.down_proj.weight', 'model.layers.23.input_layernorm.weight', 'model.layers.23.post_attention_layernorm.weight', 'model.layers.24.self_attn.q_proj.weight', 'model.layers.24.self_attn.k_proj.weight', 'model.layers.24.self_attn.v_proj.weight', 'model.layers.24.self_attn.o_proj.weight', 'model.layers.24.mlp.gate_proj.weight', 'model.layers.24.mlp.up_proj.weight', 'model.layers.24.mlp.down_proj.weight', 'model.layers.24.input_layernorm.weight', 'model.layers.24.post_attention_layernorm.weight', 'model.layers.25.self_attn.q_proj.weight', 'model.layers.25.self_attn.k_proj.weight', 'model.layers.25.self_attn.v_proj.weight', 'model.layers.25.self_attn.o_proj.weight', 'model.layers.25.mlp.gate_proj.weight', 'model.layers.25.mlp.up_proj.weight', 'model.layers.25.mlp.down_proj.weight', 'model.layers.25.input_layernorm.weight', 'model.layers.25.post_attention_layernorm.weight', 'model.layers.26.self_attn.q_proj.weight', 'model.layers.26.self_attn.k_proj.weight', 'model.layers.26.self_attn.v_proj.weight', 'model.layers.26.self_attn.o_proj.weight', 'model.layers.26.mlp.gate_proj.weight', 'model.layers.26.mlp.up_proj.weight', 'model.layers.26.mlp.down_proj.weight', 'model.layers.26.input_layernorm.weight', 'model.layers.26.post_attention_layernorm.weight', 'model.layers.27.self_attn.q_proj.weight', 'model.layers.27.self_attn.k_proj.weight', 'model.layers.27.self_attn.v_proj.weight', 'model.layers.27.self_attn.o_proj.weight', 'model.layers.27.mlp.gate_proj.weight', 'model.layers.27.mlp.up_proj.weight', 'model.layers.27.mlp.down_proj.weight', 'model.layers.27.input_layernorm.weight', 'model.layers.27.post_attention_layernorm.weight', 'model.layers.28.self_attn.q_proj.weight', 'model.layers.28.self_attn.k_proj.weight', 'model.layers.28.self_attn.v_proj.weight', 'model.layers.28.self_attn.o_proj.weight', 'model.layers.28.mlp.gate_proj.weight', 'model.layers.28.mlp.up_proj.weight', 'model.layers.28.mlp.down_proj.weight', 'model.layers.28.input_layernorm.weight', 'model.layers.28.post_attention_layernorm.weight', 'model.layers.29.self_attn.q_proj.weight', 'model.layers.29.self_attn.k_proj.weight', 'model.layers.29.self_attn.v_proj.weight', 'model.layers.29.self_attn.o_proj.weight', 'model.layers.29.mlp.gate_proj.weight', 'model.layers.29.mlp.up_proj.weight', 'model.layers.29.mlp.down_proj.weight', 'model.layers.29.input_layernorm.weight', 'model.layers.29.post_attention_layernorm.weight', 'model.layers.30.self_attn.q_proj.weight', 'model.layers.30.self_attn.k_proj.weight', 'model.layers.30.self_attn.v_proj.weight', 'model.layers.30.self_attn.o_proj.weight', 'model.layers.30.mlp.gate_proj.weight', 'model.layers.30.mlp.up_proj.weight', 'model.layers.30.mlp.down_proj.weight', 'model.layers.30.input_layernorm.weight', 'model.layers.30.post_attention_layernorm.weight', 'model.layers.31.self_attn.q_proj.weight', 'model.layers.31.self_attn.k_proj.weight', 'model.layers.31.self_attn.v_proj.weight', 'model.layers.31.self_attn.o_proj.weight', 'model.layers.31.mlp.gate_proj.weight', 'model.layers.31.mlp.up_proj.weight', 'model.layers.31.mlp.down_proj.weight', 'model.layers.31.input_layernorm.weight', 'model.layers.31.post_attention_layernorm.weight', 'model.norm.weight', 'lm_head.weight'])
print(model)
MistralForCausalLM( (model): MistralModel( (embed_tokens): Embedding(32000, 4096) (layers): ModuleList( (0-31): 32 x MistralDecoderLayer( (self_attn): MistralAttention( (q_proj): Linear(in_features=4096, out_features=4096, bias=False) (k_proj): Linear(in_features=4096, out_features=1024, bias=False) (v_proj): Linear(in_features=4096, out_features=1024, bias=False) (o_proj): Linear(in_features=4096, out_features=4096, bias=False) (rotary_emb): MistralRotaryEmbedding() ) (mlp): MistralMLP( (gate_proj): Linear(in_features=4096, out_features=14336, bias=False) (up_proj): Linear(in_features=4096, out_features=14336, bias=False) (down_proj): Linear(in_features=14336, out_features=4096, bias=False) (act_fn): SiLU() ) (input_layernorm): MistralRMSNorm() (post_attention_layernorm): MistralRMSNorm() ) ) (norm): MistralRMSNorm() ) (lm_head): Linear(in_features=4096, out_features=32000, bias=False) )
# # If extending model context
# def set_added_trainable_params(model):
# """
# Sets the parameters with names containing "embed" or "norm" as trainable.
# """
# trainable_params_dict = {}
# for name, param in model.named_parameters():
# if "embed" in name or "norm" in name: #for most models
# # if "ln" in name or "embd" in name: #for Phi-2
# param.requires_grad_()
# trainable_params_dict[name] = param
# return trainable_params_dict
def print_trainable_parameters(model):
"""
Prints the number of trainable parameters in the model.
"""
trainable_params = 0
all_param = 0
for _, param in model.named_parameters():
all_param += param.numel()
if param.requires_grad:
trainable_params += param.numel()
print(
f"trainable params: {trainable_params} || all params: {all_param} || trainable %: {100 * trainable_params / all_param}"
)
from peft import LoraConfig, get_peft_model
# Initialize LoRA configuration
config = LoraConfig(
# Lower rank results in smaller update matrices with fewer trainable params
r=8, # Use 8 for models >=7B or larger, else 128
lora_alpha=32,
target_modules=[
# "Wqkv", #for Phi-2
# "fc1", #for Phi-2
# "fc2" #for Phi-2
"self_attn.q_proj",
"self_attn.k_proj",
"self_attn.v_proj",
"self_attn.o_proj",
# "self_attn.rotary_emb.inv_freq",
"mlp.gate_proj",
"mlp.up_proj",
"mlp.down_proj",
# "input_layernorm.weight",
# "post_attention_layernorm.weight",
# "model.norm.weight",
# "lm_head.weight"
],
lora_dropout=0.1,
bias="none",
task_type="CAUSAL_LM"
)
# Apply LoRA to the model
model = get_peft_model(model, config)
# # Set added parameters with names containing "embed" or "norm" as trainable.
# # Recommended if you are extending an LLM's context window.
# set_added_trainable_params(model)
# Print out the number of trainable parameters
print_trainable_parameters(model)
trainable params: 20971520 || all params: 7262703616 || trainable %: 0.2887563792882719
Each function in the data set is stored as JSON in its own file. All functions follow OpenAI's metadata format.
{
"type": "function",
"function": {
"name": "function_name",
"description": "function description",
"parameters": {
"type": "object",
"properties": {
"property_1": {
"type": "property_type", //#e.g. string
"description": "property description"
},
"property_2": {
"type": "property_type", //#e.g. string
"description": "property description"
}
},
"required": ["property_1","property_2"]
}
},
"samplePromptResponsePairs": [
{
"prompt": "sample_prompt",
"response": {
"name": "generate_password",
"arguments": {
"property_1": "property_value",
"property_2": "property_value"
}
}
},
...
]
}
!pip install -q -U datasets
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
from datasets import load_dataset
# From Hugging Face Hub
data = load_dataset(
"Trelis/function_calling_v3"
)
Downloading readme: 0%| | 0.00/8.93k [00:00<?, ?B/s]
Downloading data: 0%| | 0.00/104k [00:00<?, ?B/s]
Downloading data: 0%| | 0.00/7.83k [00:00<?, ?B/s]
Downloading data: 0%| | 0.00/32.3k [00:00<?, ?B/s]
Downloading data: 0%| | 0.00/11.6k [00:00<?, ?B/s]
Generating train split: 0 examples [00:00, ? examples/s]
Generating validation split: 0 examples [00:00, ? examples/s]
Generating test split: 0 examples [00:00, ? examples/s]
print(data)
DatasetDict({ train: Dataset({ features: ['functionList', 'userPrompt', 'assistantResponse'], num_rows: 66 }) validation: Dataset({ features: ['functionList', 'userPrompt', 'assistantResponse'], num_rows: 19 }) test: Dataset({ features: ['functionList', 'userPrompt', 'assistantResponse'], num_rows: 7 }) })
class TextDataset(Dataset):
def __init__(self, encodings, response_lengths, input_lengths):
self.encodings = encodings
self.response_lengths = response_lengths
self.input_lengths = input_lengths
def __getitem__(self, idx):
item = {key: val[idx].clone().detach() for key, val in self.encodings.items()}
# Set labels to be the same as input_ids
item["labels"] = item["input_ids"].clone()
# Calculate the start and end positions of the response
response_start_position = self.input_lengths[idx]
response_end_position = self.input_lengths[idx] + self.response_lengths[idx]
# Create a loss mask that covers only the response tokens
item["loss_mask"] = torch.zeros_like(item["input_ids"])
item["loss_mask"][response_start_position:response_end_position] = 1
# Shift the loss mask to the left by one position
shifted_loss_mask = torch.cat([item["loss_mask"][1:], torch.tensor([0])])
item["loss_mask"] = shifted_loss_mask
# Shift the labels to the left by one position
item["labels"][:-1] = item["input_ids"][1:]
# Replace the token after the response with an EOS token
item["labels"][response_end_position - 1] = tokenizer.eos_token_id
# Replace the token after the response with an 1 in the loss mask
item["loss_mask"][response_end_position - 1] = 1
return item
def __len__(self):
return len(self.encodings["input_ids"])
# Define the function start and end strings
# \n\n is added at the end during training to avoid different tokenizations of
# the E_INST string with whatever follows.
B_FUNC, E_FUNC = "You have access to the following functions. Use them if required:\n\n", "\n\n"
# Define the user prompt start and end strings
# B_INST, E_INST = "GPT4 Correct User: ", "<|end_of_turn|>GPT4 Correct Assistant:" # OpenChat style
B_INST, E_INST = "[INST] ", " [/INST]" # Llama 2 or Mistral style
# B_INST, E_INST = "Instruct:", "\nOutput:" # Phi 2
# B_INST, E_INST = "\n### Instruction:\n", "\n### Response:\n" # DeepSeek Coder style
# B_INST, E_INST = "Human: ", " Assistant:" # Yi style for function calling, no training space
# B_INST, E_INST = "### Human: ", "\n\n### Assistant: " # SUSChat
def prepare_dataset(dataset, tokenizer):
# Create the formatted text with the correct roles for each part of the dialogue
formatted_dataset = dataset.map(
lambda x: {
"input_text": "".join([
f"{B_INST}{B_FUNC}{x['functionList'].strip()}{E_FUNC}",
f"{x['userPrompt'].strip()}{E_INST}\n\n",
f"{x['assistantResponse'].strip()}", # append EOS token in TextData...
]),
"response_text": "".join([
f"{x['assistantResponse'].strip()}", # append EOS token in TextData...
]),
}
)
# Tokenize the datasets
encodings = tokenizer([dialogue["input_text"] for dialogue in \
formatted_dataset], truncation=True, padding=True, \
max_length=1024, return_tensors='pt', \
add_special_tokens=True)
# Tokenize the response one by one without padding and special tokens for
# the purpose of calculating length
response_lengths = [len(tokenizer.encode(dialogue["response_text"], \
truncation=True, max_length=1024, \
padding=False, \
add_special_tokens=False)) \
for dialogue in formatted_dataset]
# Tokenize the input one by one without padding and with the initial
# special token for the purpose of calculating length
total_lengths = [len(tokenizer.encode(dialogue["input_text"], \
truncation=True, max_length=1024, \
padding=False, \
add_special_tokens=True)) \
for dialogue in formatted_dataset]
input_lengths = [total_length - response_length \
for total_length, response_length in \
zip(total_lengths, response_lengths)]
# Create TextDataset
text_dataset = TextDataset(encodings, response_lengths, input_lengths)
return text_dataset
# Apply function to your datasets
train_dataset = prepare_dataset(data['train'], tokenizer)
test_dataset = prepare_dataset(data['test'], tokenizer)
validation_dataset = prepare_dataset(data['validation'], tokenizer)
Map: 0%| | 0/66 [00:00<?, ? examples/s]
Map: 0%| | 0/7 [00:00<?, ? examples/s]
Map: 0%| | 0/19 [00:00<?, ? examples/s]
# Print the number of items in the dataset
print(f"Number of samples in the dataset: {len(train_dataset)}")
# Get a sample item
sample_item = train_dataset[1] # replace with the index of any sample
# Print the dimensions of the sample item
print(f"Dimensions of input_ids: {sample_item['input_ids'].shape}")
print(f"Dimensions of attention_mask: {sample_item['attention_mask'].shape}")
print(f"Dimensions of loss_mask: {sample_item['loss_mask'].shape}")
print(f"Dimensions of labels: {sample_item['labels'].shape}")
# Print some tokens from the start and end of the sample
num_tokens_to_print = 336 # replace with the number of tokens you want to print
print("\nTokens at the start of the sample:")
print(sample_item['input_ids'][:num_tokens_to_print].tolist())
print(tokenizer.convert_ids_to_tokens(sample_item['input_ids'][:num_tokens_to_print].tolist()))
print("\nLabels at the start of the sample:")
print(sample_item['labels'][:num_tokens_to_print].tolist())
print(tokenizer.convert_ids_to_tokens(sample_item['labels'][:num_tokens_to_print].tolist()))
print("Attention mask at the start of the sample:")
print(sample_item['attention_mask'][:num_tokens_to_print].tolist())
print("Loss mask at the start of the sample:")
print(sample_item['loss_mask'][:num_tokens_to_print].tolist())
print("\nTokens at the end of the sample:")
print(sample_item['input_ids'][-num_tokens_to_print:].tolist())
print(tokenizer.convert_ids_to_tokens(sample_item['input_ids'][-num_tokens_to_print:].tolist()))
print("\nLabels at the end of the sample:")
print(sample_item['labels'][-num_tokens_to_print:].tolist())
print(tokenizer.convert_ids_to_tokens(sample_item['labels'][-num_tokens_to_print:].tolist()))
print("Attention mask at the end of the sample:")
print(sample_item['attention_mask'][-num_tokens_to_print:].tolist())
print("Loss mask at the end of the sample:")
print(sample_item['loss_mask'][-num_tokens_to_print:].tolist())
Number of samples in the dataset: 66 Dimensions of input_ids: torch.Size([677]) Dimensions of attention_mask: torch.Size([677]) Dimensions of loss_mask: torch.Size([677]) Dimensions of labels: torch.Size([677]) Tokens at the start of the sample: [1, 733, 16289, 28793, 995, 506, 2735, 298, 272, 2296, 5572, 28723, 5938, 706, 513, 3030, 28747, 13, 13, 28792, 13, 2287, 371, 13, 5390, 345, 1123, 1264, 345, 2628, 548, 13, 5390, 345, 2628, 1264, 371, 13, 17422, 345, 861, 1264, 345, 2360, 28730, 283, 28744, 449, 548, 13, 17422, 345, 6518, 1264, 345, 7009, 354, 3332, 10374, 356, 1010, 28814, 449, 28723, 6746, 938, 302, 5771, 28725, 3994, 304, 5457, 12765, 390, 7658, 298, 5175, 3471, 2373, 272, 5709, 9191, 13, 17422, 345, 11438, 1264, 371, 13, 1417, 28705, 345, 1123, 1264, 345, 2814, 548, 13, 1417, 28705, 345, 10723, 1264, 371, 13, 359, 2287, 345, 3385, 1264, 371, 13, 359, 5390, 345, 1123, 1264, 345, 1427, 548, 13, 359, 5390, 345, 6518, 1264, 345, 1014, 3472, 5709, 1423, 28739, 13, 359, 2287, 443, 13, 1417, 28705, 1630, 13, 1417, 28705, 345, 10893, 1264, 733, 13, 359, 2287, 345, 3385, 28739, 13, 1417, 28705, 4709, 13, 17422, 443, 13, 5390, 443, 13, 2287, 1630, 13, 2287, 371, 13, 5390, 345, 1123, 1264, 345, 2628, 548, 13, 5390, 345, 2628, 1264, 371, 13, 17422, 345, 861, 1264, 345, 527, 28730, 3022, 28730, 769, 1223, 548, 13, 17422, 345, 6518, 1264, 345, 1458, 272, 1868, 8086, 297, 264, 2078, 4723, 548, 13, 17422, 345, 11438, 1264, 371, 13, 1417, 28705, 345, 1123, 1264, 345, 2814, 548, 13, 1417, 28705, 345, 10723, 1264, 371, 13, 359, 2287, 345, 2733, 1264, 371, 13, 359, 5390, 345, 1123, 1264, 345, 1427, 548, 13, 359, 5390, 345, 6518, 1264, 345, 1014, 2990, 304, 2939, 28725, 317, 28723, 28721, 28723, 22263, 28725, 11170, 28739, 13, 359, 2287, 1630, 13, 359, 2287, 345, 5306, 1264, 371, 13, 359, 5390, 345, 1123, 1264, 345, 1427, 548, 13, 359, 5390, 345, 6518, 1264, 345, 25241, 466, 5028, 354, 272, 8086, 28723, 19641, 28747, 464, 28717, 1190, 3170, 647, 464, 28722, 18657, 12307, 21236, 13, 359, 2287, 443, 13, 1417, 28705, 1630, 13, 1417, 28705, 345, 10893, 1264, 733, 13, 359, 2287, 345, 2733, 28739] ['<s>', '▁[', 'INST', ']', '▁You', '▁have', '▁access', '▁to', '▁the', '▁following', '▁functions', '.', '▁Use', '▁them', '▁if', '▁required', ':', '<0x0A>', '<0x0A>', '[', '<0x0A>', '▁▁▁', '▁{', '<0x0A>', '▁▁▁▁▁▁▁', '▁"', 'type', '":', '▁"', 'function', '",', '<0x0A>', '▁▁▁▁▁▁▁', '▁"', 'function', '":', '▁{', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁', '▁"', 'name', '":', '▁"', 'search', '_', 'ar', 'x', 'iv', '",', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁', '▁"', 'description', '":', '▁"', 'Search', '▁for', '▁research', '▁papers', '▁on', '▁Ar', 'X', 'iv', '.', '▁Make', '▁use', '▁of', '▁AND', ',', '▁OR', '▁and', '▁NOT', '▁operators', '▁as', '▁appropriate', '▁to', '▁join', '▁terms', '▁within', '▁the', '▁query', '.",', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁', '▁"', 'parameters', '":', '▁{', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁', '▁"', 'type', '":', '▁"', 'object', '",', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁', '▁"', 'properties', '":', '▁{', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁▁▁', '▁"', 'query', '":', '▁{', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁▁▁▁▁▁▁', '▁"', 'type', '":', '▁"', 'string', '",', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁▁▁▁▁▁▁', '▁"', 'description', '":', '▁"', 'The', '▁search', '▁query', '▁string', '"', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁▁▁', '▁}', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁', '▁},', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁', '▁"', 'required', '":', '▁[', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁▁▁', '▁"', 'query', '"', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁', '▁]', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁', '▁}', '<0x0A>', '▁▁▁▁▁▁▁', '▁}', '<0x0A>', '▁▁▁', '▁},', '<0x0A>', '▁▁▁', '▁{', '<0x0A>', '▁▁▁▁▁▁▁', '▁"', 'type', '":', '▁"', 'function', '",', '<0x0A>', '▁▁▁▁▁▁▁', '▁"', 'function', '":', '▁{', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁', '▁"', 'name', '":', '▁"', 'get', '_', 'current', '_', 'we', 'ather', '",', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁', '▁"', 'description', '":', '▁"', 'Get', '▁the', '▁current', '▁weather', '▁in', '▁a', '▁given', '▁location', '",', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁', '▁"', 'parameters', '":', '▁{', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁', '▁"', 'type', '":', '▁"', 'object', '",', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁', '▁"', 'properties', '":', '▁{', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁▁▁', '▁"', 'location', '":', '▁{', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁▁▁▁▁▁▁', '▁"', 'type', '":', '▁"', 'string', '",', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁▁▁▁▁▁▁', '▁"', 'description', '":', '▁"', 'The', '▁city', '▁and', '▁country', ',', '▁e', '.', 'g', '.', '▁Dublin', ',', '▁Ireland', '"', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁▁▁', '▁},', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁▁▁', '▁"', 'unit', '":', '▁{', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁▁▁▁▁▁▁', '▁"', 'type', '":', '▁"', 'string', '",', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁▁▁▁▁▁▁', '▁"', 'description', '":', '▁"', 'Measure', 'ment', '▁unit', '▁for', '▁the', '▁weather', '.', '▁Options', ':', "▁'", 'c', 'els', 'ius', "',", "▁'", 'f', 'ahren', 'heit', '\'"', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁▁▁', '▁}', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁', '▁},', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁', '▁"', 'required', '":', '▁[', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁▁▁', '▁"', 'location', '"'] Labels at the start of the sample: [733, 16289, 28793, 995, 506, 2735, 298, 272, 2296, 5572, 28723, 5938, 706, 513, 3030, 28747, 13, 13, 28792, 13, 2287, 371, 13, 5390, 345, 1123, 1264, 345, 2628, 548, 13, 5390, 345, 2628, 1264, 371, 13, 17422, 345, 861, 1264, 345, 2360, 28730, 283, 28744, 449, 548, 13, 17422, 345, 6518, 1264, 345, 7009, 354, 3332, 10374, 356, 1010, 28814, 449, 28723, 6746, 938, 302, 5771, 28725, 3994, 304, 5457, 12765, 390, 7658, 298, 5175, 3471, 2373, 272, 5709, 9191, 13, 17422, 345, 11438, 1264, 371, 13, 1417, 28705, 345, 1123, 1264, 345, 2814, 548, 13, 1417, 28705, 345, 10723, 1264, 371, 13, 359, 2287, 345, 3385, 1264, 371, 13, 359, 5390, 345, 1123, 1264, 345, 1427, 548, 13, 359, 5390, 345, 6518, 1264, 345, 1014, 3472, 5709, 1423, 28739, 13, 359, 2287, 443, 13, 1417, 28705, 1630, 13, 1417, 28705, 345, 10893, 1264, 733, 13, 359, 2287, 345, 3385, 28739, 13, 1417, 28705, 4709, 13, 17422, 443, 13, 5390, 443, 13, 2287, 1630, 13, 2287, 371, 13, 5390, 345, 1123, 1264, 345, 2628, 548, 13, 5390, 345, 2628, 1264, 371, 13, 17422, 345, 861, 1264, 345, 527, 28730, 3022, 28730, 769, 1223, 548, 13, 17422, 345, 6518, 1264, 345, 1458, 272, 1868, 8086, 297, 264, 2078, 4723, 548, 13, 17422, 345, 11438, 1264, 371, 13, 1417, 28705, 345, 1123, 1264, 345, 2814, 548, 13, 1417, 28705, 345, 10723, 1264, 371, 13, 359, 2287, 345, 2733, 1264, 371, 13, 359, 5390, 345, 1123, 1264, 345, 1427, 548, 13, 359, 5390, 345, 6518, 1264, 345, 1014, 2990, 304, 2939, 28725, 317, 28723, 28721, 28723, 22263, 28725, 11170, 28739, 13, 359, 2287, 1630, 13, 359, 2287, 345, 5306, 1264, 371, 13, 359, 5390, 345, 1123, 1264, 345, 1427, 548, 13, 359, 5390, 345, 6518, 1264, 345, 25241, 466, 5028, 354, 272, 8086, 28723, 19641, 28747, 464, 28717, 1190, 3170, 647, 464, 28722, 18657, 12307, 21236, 13, 359, 2287, 443, 13, 1417, 28705, 1630, 13, 1417, 28705, 345, 10893, 1264, 733, 13, 359, 2287, 345, 2733, 28739, 13] ['▁[', 'INST', ']', '▁You', '▁have', '▁access', '▁to', '▁the', '▁following', '▁functions', '.', '▁Use', '▁them', '▁if', '▁required', ':', '<0x0A>', '<0x0A>', '[', '<0x0A>', '▁▁▁', '▁{', '<0x0A>', '▁▁▁▁▁▁▁', '▁"', 'type', '":', '▁"', 'function', '",', '<0x0A>', '▁▁▁▁▁▁▁', '▁"', 'function', '":', '▁{', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁', '▁"', 'name', '":', '▁"', 'search', '_', 'ar', 'x', 'iv', '",', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁', '▁"', 'description', '":', '▁"', 'Search', '▁for', '▁research', '▁papers', '▁on', '▁Ar', 'X', 'iv', '.', '▁Make', '▁use', '▁of', '▁AND', ',', '▁OR', '▁and', '▁NOT', '▁operators', '▁as', '▁appropriate', '▁to', '▁join', '▁terms', '▁within', '▁the', '▁query', '.",', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁', '▁"', 'parameters', '":', '▁{', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁', '▁"', 'type', '":', '▁"', 'object', '",', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁', '▁"', 'properties', '":', '▁{', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁▁▁', '▁"', 'query', '":', '▁{', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁▁▁▁▁▁▁', '▁"', 'type', '":', '▁"', 'string', '",', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁▁▁▁▁▁▁', '▁"', 'description', '":', '▁"', 'The', '▁search', '▁query', '▁string', '"', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁▁▁', '▁}', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁', '▁},', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁', '▁"', 'required', '":', '▁[', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁▁▁', '▁"', 'query', '"', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁', '▁]', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁', '▁}', '<0x0A>', '▁▁▁▁▁▁▁', '▁}', '<0x0A>', '▁▁▁', '▁},', '<0x0A>', '▁▁▁', '▁{', '<0x0A>', '▁▁▁▁▁▁▁', '▁"', 'type', '":', '▁"', 'function', '",', '<0x0A>', '▁▁▁▁▁▁▁', '▁"', 'function', '":', '▁{', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁', '▁"', 'name', '":', '▁"', 'get', '_', 'current', '_', 'we', 'ather', '",', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁', '▁"', 'description', '":', '▁"', 'Get', '▁the', '▁current', '▁weather', '▁in', '▁a', '▁given', '▁location', '",', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁', '▁"', 'parameters', '":', '▁{', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁', '▁"', 'type', '":', '▁"', 'object', '",', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁', '▁"', 'properties', '":', '▁{', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁▁▁', '▁"', 'location', '":', '▁{', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁▁▁▁▁▁▁', '▁"', 'type', '":', '▁"', 'string', '",', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁▁▁▁▁▁▁', '▁"', 'description', '":', '▁"', 'The', '▁city', '▁and', '▁country', ',', '▁e', '.', 'g', '.', '▁Dublin', ',', '▁Ireland', '"', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁▁▁', '▁},', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁▁▁', '▁"', 'unit', '":', '▁{', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁▁▁▁▁▁▁', '▁"', 'type', '":', '▁"', 'string', '",', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁▁▁▁▁▁▁', '▁"', 'description', '":', '▁"', 'Measure', 'ment', '▁unit', '▁for', '▁the', '▁weather', '.', '▁Options', ':', "▁'", 'c', 'els', 'ius', "',", "▁'", 'f', 'ahren', 'heit', '\'"', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁▁▁', '▁}', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁', '▁},', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁', '▁"', 'required', '":', '▁[', '<0x0A>', '▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁', '▁▁▁', '▁"', 'location', '"', '<0x0A>'] Attention mask at the start of the sample: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] Loss mask at the start of the sample: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] Tokens at the end of the sample: [17422, 443, 13, 5390, 443, 13, 2287, 443, 13, 28793, 13, 13, 9607, 863, 19808, 15648, 2172, 4296, 28804, 733, 28748, 16289, 28793, 13, 13, 1014, 17008, 5016, 302, 19808, 15648, 349, 521, 6206, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] ['▁▁▁▁▁▁▁▁▁▁▁', '▁}', '<0x0A>', '▁▁▁▁▁▁▁', '▁}', '<0x0A>', '▁▁▁', '▁}', '<0x0A>', ']', '<0x0A>', '<0x0A>', 'Where', '▁did', '▁fortune', '▁cookies', '▁orig', 'inate', '?', '▁[', '/', 'INST', ']', '<0x0A>', '<0x0A>', 'The', '▁precise', '▁origin', '▁of', '▁fortune', '▁cookies', '▁is', '▁un', 'clear', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>'] Labels at the end of the sample: [443, 13, 5390, 443, 13, 2287, 443, 13, 28793, 13, 13, 9607, 863, 19808, 15648, 2172, 4296, 28804, 733, 28748, 16289, 28793, 13, 13, 1014, 17008, 5016, 302, 19808, 15648, 349, 521, 6206, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] ['▁}', '<0x0A>', '▁▁▁▁▁▁▁', '▁}', '<0x0A>', '▁▁▁', '▁}', '<0x0A>', ']', '<0x0A>', '<0x0A>', 'Where', '▁did', '▁fortune', '▁cookies', '▁orig', 'inate', '?', '▁[', '/', 'INST', ']', '<0x0A>', '<0x0A>', 'The', '▁precise', '▁origin', '▁of', '▁fortune', '▁cookies', '▁is', '▁un', 'clear', '</s>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>'] Attention mask at the end of the sample: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] Loss mask at the end of the sample: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
import textwrap
wrapper = textwrap.TextWrapper(width=80)
import re # import regular expressions module
import gc # import Python's garbage collection module
def generate(index,data_split="test"):
functionList = data[data_split][index]['functionList']
user_prompt = data[data_split][index]['userPrompt']
correct_answer = data[data_split][index]['assistantResponse']
# model.config.use_cache = True # Unsure this is needed
# Format your prompt template
prompt = f"{B_INST}{B_FUNC}{functionList.strip()}\
{E_FUNC}{user_prompt.strip()}{E_INST}\n\n"
print(f"Using the {data_split} data split.\n\nPrompt:")
print(prompt)
inputs = tokenizer([prompt], return_tensors="pt").to("cuda")
if "token_type_ids" in inputs:
del inputs["token_type_ids"]
# print(f'model is on: {next(model.parameters()).device}') # Debug
# print(f'input_ids is on: {inputs["input_ids"].device}') # Debug
output = model.generate(**inputs,
max_new_tokens=200,
# do_sample=False,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
# temperature=0.01,
# top_k=0
)
print()
# Subtract the length of input_ids from output to get only the model response
output_text = tokenizer.decode(output[0, len(inputs.input_ids[0]):], \
skip_special_tokens=False)
output_text = re.sub('\n+', '\n', output_text) # remove excessive newlines
print("**Generated Assistant Response:**")
print(output_text)
print()
print("**Correct Assistant Response:**")
print(correct_answer)
print()
# Clear GPU cache and run garbage collection
torch.cuda.empty_cache() # Clear GPU cache
gc.collect() # Run garbage collection
Before fine-tuning the model, let's take a look at how the model responds to the validation set prompts.
Notice that the model should respond with a function name and query params, yet it tries writing code itself and includes lots of extra words.
# Run validation before training
for index in range(len(test_dataset)):
print(f'---Running index {index}---')
generate(index, "test")
---Running index 0--- Using the test data split. Prompt: [INST] You have access to the following functions. Use them if required: [ { "type": "function", "function": { "name": "get_stock_price", "description": "Get the stock price of an array of stocks", "parameters": { "type": "object", "properties": { "names": { "type": "array", "items": { "type": "string" }, "description": "An array of stocks" } }, "required": [ "names" ] } } }, { "type": "function", "function": { "name": "get_big_stocks", "description": "Get the names of the largest N stocks by market cap", "parameters": { "type": "object", "properties": { "number": { "type": "integer", "description": "The number of largest stocks to get the names of, e.g. 25" }, "region": { "type": "string", "description": "The region to consider, can be \"US\" or \"World\"." } }, "required": [ "number" ] } } } ] Get the names of the five largest stocks by market cap [/INST] **Generated Assistant Response:** To get the names of the five largest stocks by market cap, we can use the `get_big_stocks` function with the following parameters: * `number`: 5 (to get the names of the five largest stocks) * `region`: "World" (to consider all stocks in the world) Here's the code to do that: ```python import requests def get_big_stocks(number, region): url = f"https://api.alphavantage.co/query?function=TIME_SERIES_DAILY&symbol={region}&apikey=YOUR_API_KEY" response = requests.get(url) data = response.json() market_cap = {} for symbol in data["Time Series (Daily)"].keys(): market_cap[symbol] = data["Time Series ( **Correct Assistant Response:** { "name": "get_big_stocks", "arguments": { "number": "5" } } ---Running index 1--- Using the test data split. Prompt: [INST] You have access to the following functions. Use them if required: [ { "type": "function", "function": { "name": "get_big_stocks", "description": "Get the names of the largest N stocks by market cap", "parameters": { "type": "object", "properties": { "number": { "type": "integer", "description": "The number of largest stocks to get the names of, e.g. 25" }, "region": { "type": "string", "description": "The region to consider, can be \"US\" or \"World\"." } }, "required": [ "number" ] } } }, { "type": "function", "function": { "name": "get_stock_price", "description": "Get the stock price of an array of stocks", "parameters": { "type": "object", "properties": { "names": { "type": "array", "items": { "type": "string" }, "description": "An array of stocks" } }, "required": [ "names" ] } } } ] Get the names of the five largest stocks in the US by market cap [/INST] **Generated Assistant Response:** Here's the code to get the names of the five largest stocks in the US by market cap: ```python import requests def get_big_stocks(params): url = "https://api.alphavantage.co/query?function=TIME_SERIES_DAILY_ADJUSTED&symbol={}&apikey=YOUR_API_KEY".format(params["names"][0]) response = requests.get(url) data = response.json() return data["Time Series (Daily)"] def get_stock_price(params): url = "https://api.alphavantage.co/query?function=GLOBAL_QUOTE&symbol={}&apikey=YOUR_API_KEY".format(params["names"][0]) response = requests.get(url) data = response **Correct Assistant Response:** { "name": "get_big_stocks", "arguments": { "number": "5", "region": "US" } } ---Running index 2--- Using the test data split. Prompt: [INST] You have access to the following functions. Use them if required: [ { "type": "function", "function": { "name": "get_big_stocks", "description": "Get the names of the largest N stocks by market cap", "parameters": { "type": "object", "properties": { "number": { "type": "integer", "description": "The number of largest stocks to get the names of, e.g. 25" }, "region": { "type": "string", "description": "The region to consider, can be \"US\" or \"World\"." } }, "required": [ "number" ] } } }, { "type": "function", "function": { "name": "get_stock_price", "description": "Get the stock price of an array of stocks", "parameters": { "type": "object", "properties": { "names": { "type": "array", "items": { "type": "string" }, "description": "An array of stocks" } }, "required": [ "names" ] } } } ] Get the stock prices of the ten largest stocks in the world [/INST] **Generated Assistant Response:** To get the stock prices of the ten largest stocks in the world, you can use the following steps: 1. Get the names of the ten largest stocks in the world using the `get_big_stocks` function. 2. Use the `get_stock_price` function to get the stock prices of the names obtained in step 1. Here's the code to implement this: ```python import requests def get_big_stocks(params): url = "https://api.alphavantage.co/query?function=TIME_SERIES_DAILY_ADJUSTED&symbol={}&apikey=YOUR_API_KEY".format(params["names"][0]) response = requests.get(url) data = response.json() return data["Time Series (Daily)"] def get_stock_price( **Correct Assistant Response:** { "name": "get_big_stocks", "arguments": { "number": "10", "region": "World" } } ---Running index 3--- Using the test data split. Prompt: [INST] You have access to the following functions. Use them if required: [ { "type": "function", "function": { "name": "get_stock_price", "description": "Get the stock price of an array of stocks", "parameters": { "type": "object", "properties": { "names": { "type": "array", "items": { "type": "string" }, "description": "An array of stocks" } }, "required": [ "names" ] } } }, { "type": "function", "function": { "name": "get_big_stocks", "description": "Get the names of the largest N stocks by market cap", "parameters": { "type": "object", "properties": { "number": { "type": "integer", "description": "The number of largest stocks to get the names of, e.g. 25" }, "region": { "type": "string", "description": "The region to consider, can be \"US\" or \"World\"." } }, "required": [ "number" ] } } } ] Get the price of Apple's stock [/INST] **Generated Assistant Response:** To get the price of Apple's stock, you can use the `get_stock_price` function. Here's an example of how you can use it: ``` const stockNames = ["AAPL"]; const stockPrices = get_stock_price({ names: stockNames }); console.log(stockPrices); ``` This will output the current price of Apple's stock. Note that you will need to have an API key to use this function.</s> **Correct Assistant Response:** { "name": "get_stock_price", "arguments": { "names": [ "Apple" ] } } ---Running index 4--- Using the test data split. Prompt: [INST] You have access to the following functions. Use them if required: [ { "type": "function", "function": { "name": "get_stock_price", "description": "Get the stock price of an array of stocks", "parameters": { "type": "object", "properties": { "names": { "type": "array", "items": { "type": "string" }, "description": "An array of stocks" } }, "required": [ "names" ] } } }, { "type": "function", "function": { "name": "get_big_stocks", "description": "Get the names of the largest N stocks by market cap", "parameters": { "type": "object", "properties": { "number": { "type": "integer", "description": "The number of largest stocks to get the names of, e.g. 25" }, "region": { "type": "string", "description": "The region to consider, can be \"US\" or \"World\"." } }, "required": [ "number" ] } } } ] Greetings! [/INST] **Generated Assistant Response:** Hello! How can I assist you today?</s> **Correct Assistant Response:** Greetings to you too! ---Running index 5--- Using the test data split. Prompt: [INST] You have access to the following functions. Use them if required: [ { "type": "function", "function": { "name": "get_big_stocks", "description": "Get the names of the largest N stocks by market cap", "parameters": { "type": "object", "properties": { "number": { "type": "integer", "description": "The number of largest stocks to get the names of, e.g. 25" }, "region": { "type": "string", "description": "The region to consider, can be \"US\" or \"World\"." } }, "required": [ "number" ] } } }, { "type": "function", "function": { "name": "get_stock_price", "description": "Get the stock price of an array of stocks", "parameters": { "type": "object", "properties": { "names": { "type": "array", "items": { "type": "string" }, "description": "An array of stocks" } }, "required": [ "names" ] } } } ] shop [/INST] **Generated Assistant Response:** To get the names of the largest 25 stocks by market cap in the US region, you can use the following code: ``` const bigStocks = await get_big_stocks({ number: 25, region: "US" }); console.log(bigStocks); ``` To get the stock price of an array of stocks, you can use the following code: ``` const stockPrices = await get_stock_price({ names: ["AAPL", "GOOG", "AMZN"] }); console.log(stockPrices); ``` Note that the `get_stock_price` function requires an array of stock names as input, and returns an object with the stock prices as keys and their corresponding values.</s> **Correct Assistant Response:** What would you like to know about shops? ---Running index 6--- Using the test data split. Prompt: [INST] You have access to the following functions. Use them if required: [ { "type": "function", "function": { "name": "get_stock_price", "description": "Get the stock price of an array of stocks", "parameters": { "type": "object", "properties": { "names": { "type": "array", "items": { "type": "string" }, "description": "An array of stocks" } }, "required": [ "names" ] } } }, { "type": "function", "function": { "name": "get_big_stocks", "description": "Get the names of the largest N stocks by market cap", "parameters": { "type": "object", "properties": { "number": { "type": "integer", "description": "The number of largest stocks to get the names of, e.g. 25" }, "region": { "type": "string", "description": "The region to consider, can be \"US\" or \"World\"." } }, "required": [ "number" ] } } } ] What are the planets in our solar system? [/INST] **Generated Assistant Response:** I'm sorry, but the provided functions do not seem to be related to the question about the planets in our solar system. Can you please provide more context or clarify your question?</s> **Correct Assistant Response:** The planets are Mercury, Venus, Earth, Mars, Jupiter, Saturn, Uranus and Neptune.
import torch.nn as nn
class CustomTrainer(transformers.Trainer):
def compute_loss(self, model, inputs, return_outputs=False):
# Define number of tokens to display
# Displays actual and predicted token info at end of each sequence
num_tokens = 25
labels = inputs.pop("labels")
# # Get first hundred label IDs for each sequence in the batch
# first_hundred_label_ids = labels[:, :200]
# # Convert to tokens
# first_hundred_tokens = [tokenizer.convert_ids_to_tokens(label_ids) \
# for label_ids in first_hundred_label_ids]
# # Print them
# for batch_idx, tokens in enumerate(first_hundred_tokens):
# print(f"First 200 decoded tokens for sequence {batch_idx + 1}: {tokens}")
loss_mask = inputs.pop("loss_mask")
# Forward pass
outputs = model(**inputs)
logits = outputs.logits
# Check for NaN in logits and labels
if torch.isnan(logits).any():
print("NaN detected in logits")
print(logits)
# Convert logits to probabilities using softmax function
probs = nn.functional.softmax(logits, dim=-1)
# Get the most probable tokens
predicted_token_ids = torch.argmax(probs, dim=-1)
# Compute the loss
loss_fct = nn.CrossEntropyLoss(reduction='none')
losses = loss_fct(logits.view(-1, self.model.config.vocab_size), labels.view(-1))
# Reshaping the losses to have dimensions [batch_size, seq_length]
losses = losses.view(-1, inputs['input_ids'].size(1))
# Apply the loss mask
masked_loss = losses * loss_mask
# Check for NaN in losses and zero in loss_mask.sum()
if torch.isnan(losses).any():
print("NaN detected in losses")
# print(losses)
if loss_mask.sum() == 0:
print("Sum of loss_mask is zero")
return (torch.tensor(0).to(loss_mask.device), outputs) \
if return_outputs else torch.tensor(0).to(loss_mask.device) # Early return
# Aggregate the masked losses
# Normalize by the number of tokens considered + epsilon to prevent
# division by zero
loss = masked_loss.sum() / (loss_mask.sum() + 1e-9)
# Print formatted tokens
batch_size, seq_length = inputs['input_ids'].size()
# num_tokens = len(inputs['input_ids'][0])
# # Useful for debugging training
# # Recommend training a small number of steps
# print("-" * 120)
# print(f"Token analysis for last {num_tokens} tokens:")
# header_format = "{:<10}{:<20}{:<20}{:<20}{:<20}{:<30}{:<30}".format("Index", "Input Token", "Predicted Token", "True Token", "Loss Mask", "Raw Loss", "Masked Loss")
# for batch_idx in range(batch_size):
# input_tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][batch_idx]) # Using batch_idx
# predicted_tokens = tokenizer.convert_ids_to_tokens(predicted_token_ids[batch_idx]) # Using batch_idx
# true_tokens = tokenizer.convert_ids_to_tokens(labels[batch_idx]) # Using batch_idx
# print(f"\nBatch {batch_idx + 1} of {batch_size}:")
# print(header_format)
# for i in range(-num_tokens, 0, 1):
# index = seq_length + i # Correct index based on sequence length
# print("{:<10}{:<20}{:<20}{:<20}{:<20.1f}{:<30.6f}{:<30.6f}".format(index, input_tokens[index], predicted_tokens[index], true_tokens[index], loss_mask[batch_idx, i].item(), losses[batch_idx, i], masked_loss[batch_idx, i]))
# print("-" * 120)
return (loss, outputs) if return_outputs else loss
def get_train_dataloader(self):
train_dataset = self.train_dataset
data_collator = self.data_collator
dataloader_params = {
"batch_size": self.args.train_batch_size,
"collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
}
if not isinstance(train_dataset, torch.utils.data.IterableDataset):
dataloader_params["sampler"] = self._get_train_sampler()
dataloader_params["drop_last"] = self.args.dataloader_drop_last
return DataLoader(train_dataset, **dataloader_params)
def get_eval_dataloader(self, eval_dataset=None):
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
if eval_dataset is None:
raise ValueError("Trainer: evaluation requires an eval_dataset.")
data_collator = self.data_collator
# Parameters for the DataLoader
dataloader_params = {
"batch_size": self.args.eval_batch_size,
"collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
}
# If your dataset isn't an instance of torch's IterableDataset, you can
# provide sampler and drop_last
if not isinstance(eval_dataset, torch.utils.data.IterableDataset):
dataloader_params["sampler"] = self._get_eval_sampler(eval_dataset)
# Typically we don't drop the last batch for evaluation
dataloader_params["drop_last"] = False
return DataLoader(eval_dataset, **dataloader_params)
class CustomDataCollator: # Needed if the EOS token is included in training
def __init__(self, tokenizer):
self.tokenizer = tokenizer
def __call__(self, batch):
input_ids = torch.stack([item['input_ids'] for item in batch])
attention_mask = torch.stack([item['attention_mask'] for item in batch])
labels = torch.stack([item['labels'] for item in batch])
loss_mask = torch.stack([item['loss_mask'] for item in batch])
# # Debugging: print details of the first sequence in the batch
# num_elements_to_view = 20 # Number of last elements to view
# # Decoding the input_ids
# decoded_input_tokens = self.tokenizer.convert_ids_to_tokens(input_ids[0].tolist())
# print("Debugging last", num_elements_to_view, "elements of the first sequence in the batch:")
# print("{:<20}{:<20}{:<20}{:<20}".format("Token", "Input ID", "Label", "Loss Mask"))
# for i in range(-num_elements_to_view, 0, 1):
# print("{:<20}{:<20}{:<20}{:<20}".format(decoded_input_tokens[i], input_ids[0, i].item(), labels[0, i].item(), loss_mask[0, i].item()))
return {
'input_ids': input_ids,
'attention_mask': attention_mask,
'labels': labels,
'loss_mask': loss_mask
}
data_collator = CustomDataCollator(tokenizer)
trainer = CustomTrainer(
model=model,
train_dataset=train_dataset,
eval_dataset=validation_dataset,
args=transformers.TrainingArguments(
# max_steps=1,
num_train_epochs=1, # Larger models typically only need 1 epoch
per_device_train_batch_size=1,
gradient_accumulation_steps=1,
evaluation_strategy="steps",
max_grad_norm=1,
warmup_ratio=0.1,
eval_steps=0.2,
learning_rate=1e-4, # 1e-4 for LoRA
# learning_rate=1e-5, # 1e-5 for full fine-tuning
# fp16=True, # If not using an Ampere series (i.e. not H100, A100, A6000)
bf16=True,
logging_steps=1,
output_dir="outputs",
# optim="paged_adamw_8bit", # For training in 4bit (quantized)
optim="adamw_torch", # For training in full fp16/bf16 precision
lr_scheduler_type='constant',
hub_private_repo=True
),
data_collator=data_collator,
# data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
)
model.config.use_cache = False # Silence warnings (Set to True for inference!)
trainer.train()
torch.cuda.empty_cache()
/content/wandb/run-20240206_060238-j0g979ti
/usr/local/lib/python3.10/dist-packages/torch/utils/checkpoint.py:429: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants. warnings.warn(
Step | Training Loss | Validation Loss |
---|---|---|
14 | 0.109000 | 0.818482 |
28 | 0.030600 | 0.809609 |
42 | 0.001300 | 0.767155 |
56 | 0.151700 | 0.759370 |
model.config.use_cache = True
model.eval()
PeftModelForCausalLM( (base_model): LoraModel( (model): MistralForCausalLM( (model): MistralModel( (embed_tokens): Embedding(32000, 4096) (layers): ModuleList( (0-31): 32 x MistralDecoderLayer( (self_attn): MistralAttention( (q_proj): lora.Linear( (base_layer): Linear(in_features=4096, out_features=4096, bias=False) (lora_dropout): ModuleDict( (default): Dropout(p=0.1, inplace=False) ) (lora_A): ModuleDict( (default): Linear(in_features=4096, out_features=8, bias=False) ) (lora_B): ModuleDict( (default): Linear(in_features=8, out_features=4096, bias=False) ) (lora_embedding_A): ParameterDict() (lora_embedding_B): ParameterDict() ) (k_proj): lora.Linear( (base_layer): Linear(in_features=4096, out_features=1024, bias=False) (lora_dropout): ModuleDict( (default): Dropout(p=0.1, inplace=False) ) (lora_A): ModuleDict( (default): Linear(in_features=4096, out_features=8, bias=False) ) (lora_B): ModuleDict( (default): Linear(in_features=8, out_features=1024, bias=False) ) (lora_embedding_A): ParameterDict() (lora_embedding_B): ParameterDict() ) (v_proj): lora.Linear( (base_layer): Linear(in_features=4096, out_features=1024, bias=False) (lora_dropout): ModuleDict( (default): Dropout(p=0.1, inplace=False) ) (lora_A): ModuleDict( (default): Linear(in_features=4096, out_features=8, bias=False) ) (lora_B): ModuleDict( (default): Linear(in_features=8, out_features=1024, bias=False) ) (lora_embedding_A): ParameterDict() (lora_embedding_B): ParameterDict() ) (o_proj): lora.Linear( (base_layer): Linear(in_features=4096, out_features=4096, bias=False) (lora_dropout): ModuleDict( (default): Dropout(p=0.1, inplace=False) ) (lora_A): ModuleDict( (default): Linear(in_features=4096, out_features=8, bias=False) ) (lora_B): ModuleDict( (default): Linear(in_features=8, out_features=4096, bias=False) ) (lora_embedding_A): ParameterDict() (lora_embedding_B): ParameterDict() ) (rotary_emb): MistralRotaryEmbedding() ) (mlp): MistralMLP( (gate_proj): lora.Linear( (base_layer): Linear(in_features=4096, out_features=14336, bias=False) (lora_dropout): ModuleDict( (default): Dropout(p=0.1, inplace=False) ) (lora_A): ModuleDict( (default): Linear(in_features=4096, out_features=8, bias=False) ) (lora_B): ModuleDict( (default): Linear(in_features=8, out_features=14336, bias=False) ) (lora_embedding_A): ParameterDict() (lora_embedding_B): ParameterDict() ) (up_proj): lora.Linear( (base_layer): Linear(in_features=4096, out_features=14336, bias=False) (lora_dropout): ModuleDict( (default): Dropout(p=0.1, inplace=False) ) (lora_A): ModuleDict( (default): Linear(in_features=4096, out_features=8, bias=False) ) (lora_B): ModuleDict( (default): Linear(in_features=8, out_features=14336, bias=False) ) (lora_embedding_A): ParameterDict() (lora_embedding_B): ParameterDict() ) (down_proj): lora.Linear( (base_layer): Linear(in_features=14336, out_features=4096, bias=False) (lora_dropout): ModuleDict( (default): Dropout(p=0.1, inplace=False) ) (lora_A): ModuleDict( (default): Linear(in_features=14336, out_features=8, bias=False) ) (lora_B): ModuleDict( (default): Linear(in_features=8, out_features=4096, bias=False) ) (lora_embedding_A): ParameterDict() (lora_embedding_B): ParameterDict() ) (act_fn): SiLU() ) (input_layernorm): MistralRMSNorm() (post_attention_layernorm): MistralRMSNorm() ) ) (norm): MistralRMSNorm() ) (lm_head): Linear(in_features=4096, out_features=32000, bias=False) ) ) )
# Run validation
for index in range(len(test_dataset)):
print(f'---Running index {index}---')
generate(index, "test")
---Running index 0--- Using the test data split. Prompt: [INST] You have access to the following functions. Use them if required: [ { "type": "function", "function": { "name": "get_stock_price", "description": "Get the stock price of an array of stocks", "parameters": { "type": "object", "properties": { "names": { "type": "array", "items": { "type": "string" }, "description": "An array of stocks" } }, "required": [ "names" ] } } }, { "type": "function", "function": { "name": "get_big_stocks", "description": "Get the names of the largest N stocks by market cap", "parameters": { "type": "object", "properties": { "number": { "type": "integer", "description": "The number of largest stocks to get the names of, e.g. 25" }, "region": { "type": "string", "description": "The region to consider, can be \"US\" or \"World\"." } }, "required": [ "number" ] } } } ] Get the names of the five largest stocks by market cap [/INST] **Generated Assistant Response:** { "name": "get_big_stocks", "arguments": { "number": 5 } }</s> **Correct Assistant Response:** { "name": "get_big_stocks", "arguments": { "number": "5" } } ---Running index 1--- Using the test data split. Prompt: [INST] You have access to the following functions. Use them if required: [ { "type": "function", "function": { "name": "get_big_stocks", "description": "Get the names of the largest N stocks by market cap", "parameters": { "type": "object", "properties": { "number": { "type": "integer", "description": "The number of largest stocks to get the names of, e.g. 25" }, "region": { "type": "string", "description": "The region to consider, can be \"US\" or \"World\"." } }, "required": [ "number" ] } } }, { "type": "function", "function": { "name": "get_stock_price", "description": "Get the stock price of an array of stocks", "parameters": { "type": "object", "properties": { "names": { "type": "array", "items": { "type": "string" }, "description": "An array of stocks" } }, "required": [ "names" ] } } } ] Get the names of the five largest stocks in the US by market cap [/INST] **Generated Assistant Response:** { "name": "get_big_stocks", "arguments": { "number": 5 } }</s> **Correct Assistant Response:** { "name": "get_big_stocks", "arguments": { "number": "5", "region": "US" } } ---Running index 2--- Using the test data split. Prompt: [INST] You have access to the following functions. Use them if required: [ { "type": "function", "function": { "name": "get_big_stocks", "description": "Get the names of the largest N stocks by market cap", "parameters": { "type": "object", "properties": { "number": { "type": "integer", "description": "The number of largest stocks to get the names of, e.g. 25" }, "region": { "type": "string", "description": "The region to consider, can be \"US\" or \"World\"." } }, "required": [ "number" ] } } }, { "type": "function", "function": { "name": "get_stock_price", "description": "Get the stock price of an array of stocks", "parameters": { "type": "object", "properties": { "names": { "type": "array", "items": { "type": "string" }, "description": "An array of stocks" } }, "required": [ "names" ] } } } ] Get the stock prices of the ten largest stocks in the world [/INST] **Generated Assistant Response:** { "name": "get_big_stocks", "arguments": { "number": 10 } }</s> **Correct Assistant Response:** { "name": "get_big_stocks", "arguments": { "number": "10", "region": "World" } } ---Running index 3--- Using the test data split. Prompt: [INST] You have access to the following functions. Use them if required: [ { "type": "function", "function": { "name": "get_stock_price", "description": "Get the stock price of an array of stocks", "parameters": { "type": "object", "properties": { "names": { "type": "array", "items": { "type": "string" }, "description": "An array of stocks" } }, "required": [ "names" ] } } }, { "type": "function", "function": { "name": "get_big_stocks", "description": "Get the names of the largest N stocks by market cap", "parameters": { "type": "object", "properties": { "number": { "type": "integer", "description": "The number of largest stocks to get the names of, e.g. 25" }, "region": { "type": "string", "description": "The region to consider, can be \"US\" or \"World\"." } }, "required": [ "number" ] } } } ] Get the price of Apple's stock [/INST] **Generated Assistant Response:** { "name": "get_stock_price", "arguments": { "names": [ "AAPL" ] } }</s> **Correct Assistant Response:** { "name": "get_stock_price", "arguments": { "names": [ "Apple" ] } } ---Running index 4--- Using the test data split. Prompt: [INST] You have access to the following functions. Use them if required: [ { "type": "function", "function": { "name": "get_stock_price", "description": "Get the stock price of an array of stocks", "parameters": { "type": "object", "properties": { "names": { "type": "array", "items": { "type": "string" }, "description": "An array of stocks" } }, "required": [ "names" ] } } }, { "type": "function", "function": { "name": "get_big_stocks", "description": "Get the names of the largest N stocks by market cap", "parameters": { "type": "object", "properties": { "number": { "type": "integer", "description": "The number of largest stocks to get the names of, e.g. 25" }, "region": { "type": "string", "description": "The region to consider, can be \"US\" or \"World\"." } }, "required": [ "number" ] } } } ] Greetings! [/INST] **Generated Assistant Response:** Hello! How can I assist you today?</s> **Correct Assistant Response:** Greetings to you too! ---Running index 5--- Using the test data split. Prompt: [INST] You have access to the following functions. Use them if required: [ { "type": "function", "function": { "name": "get_big_stocks", "description": "Get the names of the largest N stocks by market cap", "parameters": { "type": "object", "properties": { "number": { "type": "integer", "description": "The number of largest stocks to get the names of, e.g. 25" }, "region": { "type": "string", "description": "The region to consider, can be \"US\" or \"World\"." } }, "required": [ "number" ] } } }, { "type": "function", "function": { "name": "get_stock_price", "description": "Get the stock price of an array of stocks", "parameters": { "type": "object", "properties": { "names": { "type": "array", "items": { "type": "string" }, "description": "An array of stocks" } }, "required": [ "names" ] } } } ] shop [/INST] **Generated Assistant Response:** { "name": "get_big_stocks", "arguments": { "number": 25 } }</s> **Correct Assistant Response:** What would you like to know about shops? ---Running index 6--- Using the test data split. Prompt: [INST] You have access to the following functions. Use them if required: [ { "type": "function", "function": { "name": "get_stock_price", "description": "Get the stock price of an array of stocks", "parameters": { "type": "object", "properties": { "names": { "type": "array", "items": { "type": "string" }, "description": "An array of stocks" } }, "required": [ "names" ] } } }, { "type": "function", "function": { "name": "get_big_stocks", "description": "Get the names of the largest N stocks by market cap", "parameters": { "type": "object", "properties": { "number": { "type": "integer", "description": "The number of largest stocks to get the names of, e.g. 25" }, "region": { "type": "string", "description": "The region to consider, can be \"US\" or \"World\"." } }, "required": [ "number" ] } } } ] What are the planets in our solar system? [/INST] **Generated Assistant Response:** The planets in our solar system are: Mercury, Venus, Earth, Mars, Jupiter, Saturn, Uranus, and Neptune</s> **Correct Assistant Response:** The planets are Mercury, Venus, Earth, Mars, Jupiter, Saturn, Uranus and Neptune.
# Extract the last portion of the base_model
base_model_name = base_model.split("/")[-1]
adapter_model = f"gadkins/{base_model_name}-function-calling-adapters"
new_model = f"gadkins/{base_model_name}-function-calling" # Your HF account
print(f"Adapter Model: {adapter_model}\nNew Model: {new_model}")
Adapter Model: gadkins/Mistral-7B-Instruct-v0.1-function-calling-adapters New Model: gadkins/Mistral-7B-Instruct-v0.1-function-calling
# (Optional) Create repo + branch for gguf and awq
from huggingface_hub import HfApi, create_branch, create_repo
# Initialize the HfApi class
api = HfApi()
create_repo(new_model, private=False)
create_branch(new_model, repo_type="model", branch="gguf")
# create_branch(new_model, repo_type="model", branch="awq")
# create_branch(new_model, repo_type="model", branch="gptq")
# model.config._name_or_path="gadkins/Yi-34B-200K-Llamafied-chat-SFT"
# print(model.config._name_or_path)
# Save the model
model.save_pretrained(adapter_model, push_to_hub=True, use_auth_token=True)
# Push the model to the hub
# model.push_to_hub(adapter_model, use_auth_token=True)
# # ## reload the base model (you might need a pro subscription for this because you may need a high RAM environment since this is loading the full original model, not quantized)
# # ## you may prefer to use auto instead of cpu if you have a gpu
# # ## if you are training in full precision (not quantized), you may not need to reload the model, you can directly merge and unload.
# # ## if you are training very large models you may need to restart the kernel and reload the base model as there may not be enough space on gpu.
# # from transformers import AutoModelForCausalLM, PretrainedConfig
# # import torch
# # model = AutoModelForCausalLM.from_pretrained(base_model, device_map='auto', trust_remote_code=True, torch_dtype=torch.float16, cache_dir=cache_dir)
# from peft import PeftModel
# # load perf model with new adapters
# model = PeftModel.from_pretrained(
# model,
# './gadkins/Yi-34B-200K-Llamafied-chat-SFT-function-calling-adapters-v2',
# )
model = model.merge_and_unload() # merge adapters with the base model.
# (Optional) Allows you to save the model locally to do inference without downloading
model.save_pretrained(f"gadkins/{base_model_name}-function-calling-v3")
model.push_to_hub(new_model, token=True, max_shard_size="10GB",safe_serialization=True)
README.md: 0%| | 0.00/5.18k [00:00<?, ?B/s]
Upload 2 LFS files: 0%| | 0/2 [00:00<?, ?it/s]
model-00002-of-00002.safetensors: 0%| | 0.00/4.54G [00:00<?, ?B/s]
model-00001-of-00002.safetensors: 0%| | 0.00/9.94G [00:00<?, ?B/s]
CommitInfo(commit_url='https://huggingface.co/gadkins/Mistral-7B-Instruct-v0.1-function-calling/commit/ca16aa19012184d3f5721a3a4ba7829876a385d1', commit_message='Upload MistralForCausalLM', commit_description='', oid='ca16aa19012184d3f5721a3a4ba7829876a385d1', pr_url=None, pr_revision=None, pr_num=None)
import os
import requests
from huggingface_hub import HfApi
def download_file_from_huggingface(model_id, filename, save_path):
url = f"https://huggingface.co/{model_id}/resolve/main/{filename}"
r = requests.get(url)
if r.status_code != 200:
print(f"Failed to download {filename}. HTTP Status Code: {r.status_code}")
return False
with open(os.path.join(save_path, filename), 'wb') as f:
f.write(r.content)
return True
def main():
# Files to download and upload
files_to_process = ["tokenizer.model", "README.md"]
# Directory to save the downloaded files
save_path = "./models"
if not os.path.exists(save_path):
os.makedirs(save_path)
# Initialize HfApi class
api = HfApi()
# Specify the repository where you want to upload the files
repo_id = new_model # Assuming new_model is in the format "username/repo"
for filename in files_to_process:
# Download the file
success = download_file_from_huggingface(base_model, filename, save_path)
if success:
print(f"Successfully downloaded {filename}")
else:
print(f"Failed to download {filename}")
continue # Skip uploading if download failed
# File path to upload
local_file_path = os.path.join(save_path, filename)
# Upload the file
api.upload_file(
path_or_fileobj=local_file_path,
path_in_repo=filename, # Using filename directly, adjust as needed
repo_id=repo_id,
repo_type="model", # Assuming it's a model; can be "dataset" or "space" as well
)
print(f"Uploaded {filename} to {repo_id}")
if __name__ == "__main__":
main()
Successfully downloaded tokenizer.model Uploaded tokenizer.model to gadkins/Mistral-7B-Instruct-v0.1-function-calling Successfully downloaded README.md Uploaded README.md to gadkins/Mistral-7B-Instruct-v0.1-function-calling
This is a more advanced step that allows you to customize a chat template for function calling.
Typically you need to start by grabbing the chat_template
from tokenizer_config.json
of the base file and pasting that into the box below. You then need to customize that template to include function_metadata
, function_response
and function_call
roles. You can see one example below but it won't be correct for all models.
print(tokenizer.chat_template)
{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token + ' ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}
print(tokenizer.bos_token)
print(tokenizer.eos_token)
<s> </s>
import json
function_metadata = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "This function gets the current weather in a given city",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "The city, e.g., San Francisco"
},
"format": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
"description": "The temperature unit to use."
}
},
"required": ["city"]
}
}
},
{
"type": "function",
"function": {
"name": "get_clothes",
"description": "This function provides a suggestion of clothes to wear based on the current weather",
"parameters": {
"type": "object",
"properties": {
"temperature": {
"type": "string",
"description": "The temperature, e.g., 15 C or 59 F"
},
"condition": {
"type": "string",
"description": "The weather condition, e.g., 'Cloudy', 'Sunny', 'Rainy'"
}
},
"required": ["temperature", "condition"]
}
}
}
]
# Comment out later messages to test various stages of generation.
sample_messages = [
# {
# "role": "system",
# "content": "you are a helpful assistant"
# },
{
"role": "function_metadata",
"content": "FUNCTION_METADATA"
},
{
"role": "user",
"content": "What is the current weather in London?"
},
# {
# "role": "function_call",
# "content": "{\n \"name\": \"get_current_weather\",\n \"arguments\": {\n \"city\": \"London\"\n }\n}</s>"
# },
# {
# "role": "function_response",
# "content": "{\n \"temperature\": \"15 C\",\n \"condition\": \"Cloudy\"\n}"
# },
# {
# "role": "assistant",
# "content": "The current weather in London is Cloudy with a temperature of 15 Celsius.</s>"
# },
# {
# "role": "user",
# "content": "That's great. Now say hello."
# },
# {
# "role": "assistant",
# "content": "Hello!</s>"
# }
]
# Iterate through each message in the list
for message in sample_messages:
if message['role'] == 'function_metadata':
# Replace 'FUNCTION_METADATA' with 'function_metadata' in the content
message['content'] = message['content'].replace('FUNCTION_METADATA', json.dumps(function_metadata, indent=4))
# Llama 2 templates / Mistral
tokenizer.chat_template = """{{ bos_token }} [INST] {% for message in messages %}{% if message['role'] == 'system' %}<<SYS>>\n{{ message['content'] }}\n<</SYS>>\n\n{% elif message['role'] == 'function_metadata' %}You have access to the following functions. Use them if required:\n\n{{ message['content'] }}\n\n{% elif message['role'] == 'user' %}{{ message['content'] }} [/INST]\n\n{% elif message['role'] == 'assistant' %}{{ message['content'] }} [INST] {% elif message['role'] == 'function_call' %}{{ message['content'] }} [INST] {% elif message['role'] == 'function_response' %}Here is the response to the function call. If helpful, use it to respond to my question:\n\n{{ message['content'] }} [/INST]\n\n{% endif %}{% endfor %}"""
print(tokenizer.chat_template)
{{ bos_token }} [INST] {% for message in messages %}{% if message['role'] == 'system' %}<<SYS>> {{ message['content'] }} <</SYS>> {% elif message['role'] == 'function_metadata' %}You have access to the following functions. Use them if required: {{ message['content'] }} {% elif message['role'] == 'user' %}{{ message['content'] }} [/INST] {% elif message['role'] == 'assistant' %}{{ message['content'] }} [INST] {% elif message['role'] == 'function_call' %}{{ message['content'] }} [INST] {% elif message['role'] == 'function_response' %}Here is the response to the function call. If helpful, use it to respond to my question: {{ message['content'] }} [/INST] {% endif %}{% endfor %}
# View the template applied without tokenization
prompt = tokenizer.apply_chat_template(sample_messages, tokenize=False)
print(prompt)
<s> [INST] You have access to the following functions. Use them if required: [ { "type": "function", "function": { "name": "get_current_weather", "description": "This function gets the current weather in a given city", "parameters": { "type": "object", "properties": { "city": { "type": "string", "description": "The city, e.g., San Francisco" }, "format": { "type": "string", "enum": [ "celsius", "fahrenheit" ], "description": "The temperature unit to use." } }, "required": [ "city" ] } } }, { "type": "function", "function": { "name": "get_clothes", "description": "This function provides a suggestion of clothes to wear based on the current weather", "parameters": { "type": "object", "properties": { "temperature": { "type": "string", "description": "The temperature, e.g., 15 C or 59 F" }, "condition": { "type": "string", "description": "The weather condition, e.g., 'Cloudy', 'Sunny', 'Rainy'" } }, "required": [ "temperature", "condition" ] } } } ] What is the current weather in London? [/INST]
## Test generation
inputs = tokenizer([prompt], return_tensors="pt").to("cuda")
if "token_type_ids" in inputs:
del inputs["token_type_ids"]
# print(f'model is on: {next(model.parameters()).device}') # Debug line
# print(f'input_ids is on: {inputs["input_ids"].device}') # Debug line
output = model.generate(**inputs,
max_new_tokens=200,
do_sample=False,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
# temperature=0.01,
# top_k=0
)
print()
# Subtract the length of input_ids from output to get only the model's response
output_text = tokenizer.decode(output[0, len(inputs.input_ids[0]):], skip_special_tokens=False)
print(output_text)
{ "name": "get_current_weather", "arguments": { "city": "London" } }</s>
# optional, but allows you to save the model locally so you can immediately inference without downloading
tokenizer.save_pretrained(f"gadkins/{base_model_name}-function-calling-v3")
('gadkins/Mistral-7B-Instruct-v0.1-function-calling-v3/tokenizer_config.json', 'gadkins/Mistral-7B-Instruct-v0.1-function-calling-v3/special_tokens_map.json', 'gadkins/Mistral-7B-Instruct-v0.1-function-calling-v3/tokenizer.model', 'gadkins/Mistral-7B-Instruct-v0.1-function-calling-v3/added_tokens.json', 'gadkins/Mistral-7B-Instruct-v0.1-function-calling-v3/tokenizer.json')
# #Push the tokenizer
tokenizer.push_to_hub(new_model, token=True)
## RELOAD IF NEEDED (NOT RECOMMENDED IF tokenizer.chat_template was updated.
# from transformers import AutoTokenizer
# # reload the tokenizer because you don't want to have an off-size tokenizer with pad tokens.
# tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
tokenizer.model: 0%| | 0.00/493k [00:00<?, ?B/s]
CommitInfo(commit_url='https://huggingface.co/gadkins/Mistral-7B-Instruct-v0.1-function-calling/commit/dfe2015a6083826389d11212b55d530816a0e0c6', commit_message='Upload tokenizer', commit_description='', oid='dfe2015a6083826389d11212b55d530816a0e0c6', pr_url=None, pr_revision=None, pr_num=None)