by Grayson Adkins, updated May 13, 2024
This notebook demonstrates visual captioning (i.e. generating descriptions of images) using a multi-modal model called moondream. I fine-tune moondream using LoRA on a custom Chess piece dataset to improve performance for a downstream chess task.
This notebook is based on Moondream's GitHub fine-tuning example and this video on fine-tuning multi-modal LLaVA vision and language models by Trelis Research.
!python -m pip install --upgrade pip
%pip install torch transformers timm einops datasets bitsandbytes accelerate -q
!pip install peft -q
!pip install flash-attn -q
!pip install wandb -q
Requirement already satisfied: pip in /usr/local/lib/python3.10/dist-packages (23.1.2) Collecting pip Downloading pip-24.0-py3-none-any.whl (2.1 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.1/2.1 MB 14.7 MB/s eta 0:00:00 Installing collected packages: pip Attempting uninstall: pip Found existing installation: pip 23.1.2 Uninstalling pip-23.1.2: Successfully uninstalled pip-23.1.2 Successfully installed pip-24.0 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 43.6/43.6 kB 2.1 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 410.6/410.6 MB 4.0 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 14.1/14.1 MB 100.4 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 23.7/23.7 MB 87.8 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 823.6/823.6 kB 54.6 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 731.7/731.7 MB 2.6 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 121.6/121.6 MB 7.5 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 56.5/56.5 MB 10.3 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 124.2/124.2 MB 7.6 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 196.0/196.0 MB 5.6 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 176.2/176.2 MB 6.6 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 99.1/99.1 kB 8.6 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.3/2.3 MB 82.7 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 43.2/43.2 kB 3.8 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 542.1/542.1 kB 33.4 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 119.8/119.8 MB 7.7 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 309.4/309.4 kB 24.9 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 116.3/116.3 kB 10.7 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 64.9/64.9 kB 5.8 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 134.8/134.8 kB 12.3 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 194.1/194.1 kB 14.8 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 21.3/21.3 MB 80.6 MB/s eta 0:00:00 ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts. google-colab 1.0.0 requires requests==2.31.0, but you have requests 2.32.3 which is incompatible. WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 251.6/251.6 kB 7.2 MB/s eta 0:00:00 WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.6/2.6 MB 16.3 MB/s eta 0:00:00 Preparing metadata (setup.py) ... done Building wheel for flash-attn (setup.py) ... done WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 6.8/6.8 MB 72.9 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 207.3/207.3 kB 18.7 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 289.6/289.6 kB 25.7 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 62.7/62.7 kB 6.2 MB/s eta 0:00:00 WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
!pip install hf_transfer -q
# NOW RESTART THE RUNTIME IN ORDER TO BENEFIT FROM INCREASED DOWNLOAD SPEEDS.
# RUN THE CELLS BELOW.
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 4.4/4.4 MB 21.8 MB/s eta 0:00:00 WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
import os
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
# Change DEVICE to 'mps' if you're on an M1 Mac, or 'cpu' if you don't have a
# GPU. Note that fine-tuning on CPU will be very slow.
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
model_slug="vikhyatk/moondream2"
DEVICE = "cuda"
DTYPE = torch.float32 if DEVICE == "cpu" else torch.bfloat16 # CPU doesn't support float16. Also, switch to bfloat16 for Ampere architectures.
MD_REVISION = "2024-04-02"
use_4bit = False
use_lora = True # must be true if using 4_bit and training.
set_other_trainable = True # to set embed layers trainable (fully trainable, not LoRA)
quantization_config = None
if use_4bit:
from transformers import BitsAndBytesConfig
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=DTYPE
)
quantization_config = bnb_config
tokenizer = AutoTokenizer.from_pretrained("vikhyatk/moondream2", revision=MD_REVISION)
model = AutoModelForCausalLM.from_pretrained(
model_slug,
revision=MD_REVISION,
trust_remote_code=True,
attn_implementation="flash_attention_2" if DEVICE == "cuda" else None,
torch_dtype=DTYPE,
device_map={"": DEVICE},
cache_dir='',
quantization_config=quantization_config
)
tokenizer_config.json: 0%| | 0.00/7.34k [00:00<?, ?B/s]
vocab.json: 0%| | 0.00/798k [00:00<?, ?B/s]
merges.txt: 0%| | 0.00/456k [00:00<?, ?B/s]
tokenizer.json: 0%| | 0.00/2.11M [00:00<?, ?B/s]
added_tokens.json: 0%| | 0.00/1.08k [00:00<?, ?B/s]
special_tokens_map.json: 0%| | 0.00/99.0 [00:00<?, ?B/s]
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
config.json: 0%| | 0.00/318 [00:00<?, ?B/s]
configuration_moondream.py: 0%| | 0.00/3.39k [00:00<?, ?B/s]
moondream.py: 0%| | 0.00/5.17k [00:00<?, ?B/s]
vision_encoder.py: 0%| | 0.00/4.94k [00:00<?, ?B/s]
modeling_phi.py: 0%| | 0.00/49.4k [00:00<?, ?B/s]
model.safetensors: 0%| | 0.00/3.72G [00:00<?, ?B/s]
generation_config.json: 0%| | 0.00/69.0 [00:00<?, ?B/s]
Here we're using a dataset of created by Trelis Research. The dataset consists of image-text pairs of chess pieces.
# Import necessary libraries
from datasets import load_dataset
from IPython.display import display
# Load the dataset from Hugging Face
dataset = load_dataset("Trelis/chess_pieces")
sample = dataset['train'][1]
img = sample['image']
img = img.resize((224, 224))
# Display the image and caption
display(img)
print(f"Caption: {sample['caption']}")
/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_token.py:89: UserWarning: The secret `HF_TOKEN` does not exist in your Colab secrets. To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session. You will be able to reuse this secret in all of your notebooks. Please note that authentication is recommended but still optional to access public models or datasets. warnings.warn(
Downloading readme: 0%| | 0.00/413 [00:00<?, ?B/s]
Downloading data: 0%| | 0.00/52.3M [00:00<?, ?B/s]
Downloading data: 0%| | 0.00/3.41M [00:00<?, ?B/s]
Generating train split: 0%| | 0/48 [00:00<?, ? examples/s]
Generating test split: 0%| | 0/3 [00:00<?, ? examples/s]
Caption: A white rook.
from torch.utils.data import Dataset
from datasets import load_dataset
class ChessDataset(Dataset):
def __init__(self, split='train'):
self.data = load_dataset(
"Trelis/chess_pieces",
# revision="refs/convert/parquet",
)[split]
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
sample = self.data[idx]
return {
"image": sample["image"], # Should be a PIL image
"qa": [
{
"question": "What do you see?",
"answer": sample["caption"],
}
]
}
datasets = {
"train": ChessDataset("train"),
# "val": ChessDataset("validation"),
"test": ChessDataset("test"),
}
Downloading readme: 0%| | 0.00/413 [00:00<?, ?B/s]
Downloading data: 0%| | 0.00/52.3M [00:00<?, ?B/s]
Downloading data: 0%| | 0.00/3.41M [00:00<?, ?B/s]
Generating train split: 0%| | 0/48 [00:00<?, ? examples/s]
Generating test split: 0%| | 0/3 [00:00<?, ? examples/s]
print(datasets['train'][0])
{'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=3024x4032 at 0x7F85764FE740>, 'qa': [{'question': 'What do you see?', 'answer': 'A single white rook.'}]}
# Display model details
print(model)
Moondream( (vision_encoder): VisionEncoder( (encoder): EncoderWrapper( (model): ModuleDict( (visual): VisionTransformer( (patch_embed): LinearPatchEmbedding( (linear): Linear(in_features=588, out_features=1152, bias=True) ) (blocks): Sequential( (0): VitBlock( (attn): Attention( (qkv): Linear(in_features=1152, out_features=3456, bias=True) (proj): Linear(in_features=1152, out_features=1152, bias=True) ) (mlp): MLP( (fc1): Linear(in_features=1152, out_features=4304, bias=True) (act): GELU(approximate='tanh') (fc2): Linear(in_features=4304, out_features=1152, bias=True) ) (norm1): LayerNorm((1152,), eps=1e-05, elementwise_affine=True) (norm2): LayerNorm((1152,), eps=1e-05, elementwise_affine=True) ) (1): VitBlock( (attn): Attention( (qkv): Linear(in_features=1152, out_features=3456, bias=True) (proj): Linear(in_features=1152, out_features=1152, bias=True) ) (mlp): MLP( (fc1): Linear(in_features=1152, out_features=4304, bias=True) (act): GELU(approximate='tanh') (fc2): Linear(in_features=4304, out_features=1152, bias=True) ) (norm1): LayerNorm((1152,), eps=1e-05, elementwise_affine=True) (norm2): LayerNorm((1152,), eps=1e-05, elementwise_affine=True) ) (2): VitBlock( (attn): Attention( (qkv): Linear(in_features=1152, out_features=3456, bias=True) (proj): Linear(in_features=1152, out_features=1152, bias=True) ) (mlp): MLP( (fc1): Linear(in_features=1152, out_features=4304, bias=True) (act): GELU(approximate='tanh') (fc2): Linear(in_features=4304, out_features=1152, bias=True) ) (norm1): LayerNorm((1152,), eps=1e-05, elementwise_affine=True) (norm2): LayerNorm((1152,), eps=1e-05, elementwise_affine=True) ) (3): VitBlock( (attn): Attention( (qkv): Linear(in_features=1152, out_features=3456, bias=True) (proj): Linear(in_features=1152, out_features=1152, bias=True) ) (mlp): MLP( (fc1): Linear(in_features=1152, out_features=4304, bias=True) (act): GELU(approximate='tanh') (fc2): Linear(in_features=4304, out_features=1152, bias=True) ) (norm1): LayerNorm((1152,), eps=1e-05, elementwise_affine=True) (norm2): LayerNorm((1152,), eps=1e-05, elementwise_affine=True) ) (4): VitBlock( (attn): Attention( (qkv): Linear(in_features=1152, out_features=3456, bias=True) (proj): Linear(in_features=1152, out_features=1152, bias=True) ) (mlp): MLP( (fc1): Linear(in_features=1152, out_features=4304, bias=True) (act): GELU(approximate='tanh') (fc2): Linear(in_features=4304, out_features=1152, bias=True) ) (norm1): LayerNorm((1152,), eps=1e-05, elementwise_affine=True) (norm2): LayerNorm((1152,), eps=1e-05, elementwise_affine=True) ) (5): VitBlock( (attn): Attention( (qkv): Linear(in_features=1152, out_features=3456, bias=True) (proj): Linear(in_features=1152, out_features=1152, bias=True) ) (mlp): MLP( (fc1): Linear(in_features=1152, out_features=4304, bias=True) (act): GELU(approximate='tanh') (fc2): Linear(in_features=4304, out_features=1152, bias=True) ) (norm1): LayerNorm((1152,), eps=1e-05, elementwise_affine=True) (norm2): LayerNorm((1152,), eps=1e-05, elementwise_affine=True) ) (6): VitBlock( (attn): Attention( (qkv): Linear(in_features=1152, out_features=3456, bias=True) (proj): Linear(in_features=1152, out_features=1152, bias=True) ) (mlp): MLP( (fc1): Linear(in_features=1152, out_features=4304, bias=True) (act): GELU(approximate='tanh') (fc2): Linear(in_features=4304, out_features=1152, bias=True) ) (norm1): LayerNorm((1152,), eps=1e-05, elementwise_affine=True) (norm2): LayerNorm((1152,), eps=1e-05, elementwise_affine=True) ) (7): VitBlock( (attn): Attention( (qkv): Linear(in_features=1152, out_features=3456, bias=True) (proj): Linear(in_features=1152, out_features=1152, bias=True) ) (mlp): MLP( (fc1): Linear(in_features=1152, out_features=4304, bias=True) (act): GELU(approximate='tanh') (fc2): Linear(in_features=4304, out_features=1152, bias=True) ) (norm1): LayerNorm((1152,), eps=1e-05, elementwise_affine=True) (norm2): LayerNorm((1152,), eps=1e-05, elementwise_affine=True) ) (8): VitBlock( (attn): Attention( (qkv): Linear(in_features=1152, out_features=3456, bias=True) (proj): Linear(in_features=1152, out_features=1152, bias=True) ) (mlp): MLP( (fc1): Linear(in_features=1152, out_features=4304, bias=True) (act): GELU(approximate='tanh') (fc2): Linear(in_features=4304, out_features=1152, bias=True) ) (norm1): LayerNorm((1152,), eps=1e-05, elementwise_affine=True) (norm2): LayerNorm((1152,), eps=1e-05, elementwise_affine=True) ) (9): VitBlock( (attn): Attention( (qkv): Linear(in_features=1152, out_features=3456, bias=True) (proj): Linear(in_features=1152, out_features=1152, bias=True) ) (mlp): MLP( (fc1): Linear(in_features=1152, out_features=4304, bias=True) (act): GELU(approximate='tanh') (fc2): Linear(in_features=4304, out_features=1152, bias=True) ) (norm1): LayerNorm((1152,), eps=1e-05, elementwise_affine=True) (norm2): LayerNorm((1152,), eps=1e-05, elementwise_affine=True) ) (10): VitBlock( (attn): Attention( (qkv): Linear(in_features=1152, out_features=3456, bias=True) (proj): Linear(in_features=1152, out_features=1152, bias=True) ) (mlp): MLP( (fc1): Linear(in_features=1152, out_features=4304, bias=True) (act): GELU(approximate='tanh') (fc2): Linear(in_features=4304, out_features=1152, bias=True) ) (norm1): LayerNorm((1152,), eps=1e-05, elementwise_affine=True) (norm2): LayerNorm((1152,), eps=1e-05, elementwise_affine=True) ) (11): VitBlock( (attn): Attention( (qkv): Linear(in_features=1152, out_features=3456, bias=True) (proj): Linear(in_features=1152, out_features=1152, bias=True) ) (mlp): MLP( (fc1): Linear(in_features=1152, out_features=4304, bias=True) (act): GELU(approximate='tanh') (fc2): Linear(in_features=4304, out_features=1152, bias=True) ) (norm1): LayerNorm((1152,), eps=1e-05, elementwise_affine=True) (norm2): LayerNorm((1152,), eps=1e-05, elementwise_affine=True) ) (12): VitBlock( (attn): Attention( (qkv): Linear(in_features=1152, out_features=3456, bias=True) (proj): Linear(in_features=1152, out_features=1152, bias=True) ) (mlp): MLP( (fc1): Linear(in_features=1152, out_features=4304, bias=True) (act): GELU(approximate='tanh') (fc2): Linear(in_features=4304, out_features=1152, bias=True) ) (norm1): LayerNorm((1152,), eps=1e-05, elementwise_affine=True) (norm2): LayerNorm((1152,), eps=1e-05, elementwise_affine=True) ) (13): VitBlock( (attn): Attention( (qkv): Linear(in_features=1152, out_features=3456, bias=True) (proj): Linear(in_features=1152, out_features=1152, bias=True) ) (mlp): MLP( (fc1): Linear(in_features=1152, out_features=4304, bias=True) (act): GELU(approximate='tanh') (fc2): Linear(in_features=4304, out_features=1152, bias=True) ) (norm1): LayerNorm((1152,), eps=1e-05, elementwise_affine=True) (norm2): LayerNorm((1152,), eps=1e-05, elementwise_affine=True) ) (14): VitBlock( (attn): Attention( (qkv): Linear(in_features=1152, out_features=3456, bias=True) (proj): Linear(in_features=1152, out_features=1152, bias=True) ) (mlp): MLP( (fc1): Linear(in_features=1152, out_features=4304, bias=True) (act): GELU(approximate='tanh') (fc2): Linear(in_features=4304, out_features=1152, bias=True) ) (norm1): LayerNorm((1152,), eps=1e-05, elementwise_affine=True) (norm2): LayerNorm((1152,), eps=1e-05, elementwise_affine=True) ) (15): VitBlock( (attn): Attention( (qkv): Linear(in_features=1152, out_features=3456, bias=True) (proj): Linear(in_features=1152, out_features=1152, bias=True) ) (mlp): MLP( (fc1): Linear(in_features=1152, out_features=4304, bias=True) (act): GELU(approximate='tanh') (fc2): Linear(in_features=4304, out_features=1152, bias=True) ) (norm1): LayerNorm((1152,), eps=1e-05, elementwise_affine=True) (norm2): LayerNorm((1152,), eps=1e-05, elementwise_affine=True) ) (16): VitBlock( (attn): Attention( (qkv): Linear(in_features=1152, out_features=3456, bias=True) (proj): Linear(in_features=1152, out_features=1152, bias=True) ) (mlp): MLP( (fc1): Linear(in_features=1152, out_features=4304, bias=True) (act): GELU(approximate='tanh') (fc2): Linear(in_features=4304, out_features=1152, bias=True) ) (norm1): LayerNorm((1152,), eps=1e-05, elementwise_affine=True) (norm2): LayerNorm((1152,), eps=1e-05, elementwise_affine=True) ) (17): VitBlock( (attn): Attention( (qkv): Linear(in_features=1152, out_features=3456, bias=True) (proj): Linear(in_features=1152, out_features=1152, bias=True) ) (mlp): MLP( (fc1): Linear(in_features=1152, out_features=4304, bias=True) (act): GELU(approximate='tanh') (fc2): Linear(in_features=4304, out_features=1152, bias=True) ) (norm1): LayerNorm((1152,), eps=1e-05, elementwise_affine=True) (norm2): LayerNorm((1152,), eps=1e-05, elementwise_affine=True) ) (18): VitBlock( (attn): Attention( (qkv): Linear(in_features=1152, out_features=3456, bias=True) (proj): Linear(in_features=1152, out_features=1152, bias=True) ) (mlp): MLP( (fc1): Linear(in_features=1152, out_features=4304, bias=True) (act): GELU(approximate='tanh') (fc2): Linear(in_features=4304, out_features=1152, bias=True) ) (norm1): LayerNorm((1152,), eps=1e-05, elementwise_affine=True) (norm2): LayerNorm((1152,), eps=1e-05, elementwise_affine=True) ) (19): VitBlock( (attn): Attention( (qkv): Linear(in_features=1152, out_features=3456, bias=True) (proj): Linear(in_features=1152, out_features=1152, bias=True) ) (mlp): MLP( (fc1): Linear(in_features=1152, out_features=4304, bias=True) (act): GELU(approximate='tanh') (fc2): Linear(in_features=4304, out_features=1152, bias=True) ) (norm1): LayerNorm((1152,), eps=1e-05, elementwise_affine=True) (norm2): LayerNorm((1152,), eps=1e-05, elementwise_affine=True) ) (20): VitBlock( (attn): Attention( (qkv): Linear(in_features=1152, out_features=3456, bias=True) (proj): Linear(in_features=1152, out_features=1152, bias=True) ) (mlp): MLP( (fc1): Linear(in_features=1152, out_features=4304, bias=True) (act): GELU(approximate='tanh') (fc2): Linear(in_features=4304, out_features=1152, bias=True) ) (norm1): LayerNorm((1152,), eps=1e-05, elementwise_affine=True) (norm2): LayerNorm((1152,), eps=1e-05, elementwise_affine=True) ) (21): VitBlock( (attn): Attention( (qkv): Linear(in_features=1152, out_features=3456, bias=True) (proj): Linear(in_features=1152, out_features=1152, bias=True) ) (mlp): MLP( (fc1): Linear(in_features=1152, out_features=4304, bias=True) (act): GELU(approximate='tanh') (fc2): Linear(in_features=4304, out_features=1152, bias=True) ) (norm1): LayerNorm((1152,), eps=1e-05, elementwise_affine=True) (norm2): LayerNorm((1152,), eps=1e-05, elementwise_affine=True) ) (22): VitBlock( (attn): Attention( (qkv): Linear(in_features=1152, out_features=3456, bias=True) (proj): Linear(in_features=1152, out_features=1152, bias=True) ) (mlp): MLP( (fc1): Linear(in_features=1152, out_features=4304, bias=True) (act): GELU(approximate='tanh') (fc2): Linear(in_features=4304, out_features=1152, bias=True) ) (norm1): LayerNorm((1152,), eps=1e-05, elementwise_affine=True) (norm2): LayerNorm((1152,), eps=1e-05, elementwise_affine=True) ) (23): VitBlock( (attn): Attention( (qkv): Linear(in_features=1152, out_features=3456, bias=True) (proj): Linear(in_features=1152, out_features=1152, bias=True) ) (mlp): MLP( (fc1): Linear(in_features=1152, out_features=4304, bias=True) (act): GELU(approximate='tanh') (fc2): Linear(in_features=4304, out_features=1152, bias=True) ) (norm1): LayerNorm((1152,), eps=1e-05, elementwise_affine=True) (norm2): LayerNorm((1152,), eps=1e-05, elementwise_affine=True) ) (24): VitBlock( (attn): Attention( (qkv): Linear(in_features=1152, out_features=3456, bias=True) (proj): Linear(in_features=1152, out_features=1152, bias=True) ) (mlp): MLP( (fc1): Linear(in_features=1152, out_features=4304, bias=True) (act): GELU(approximate='tanh') (fc2): Linear(in_features=4304, out_features=1152, bias=True) ) (norm1): LayerNorm((1152,), eps=1e-05, elementwise_affine=True) (norm2): LayerNorm((1152,), eps=1e-05, elementwise_affine=True) ) (25): VitBlock( (attn): Attention( (qkv): Linear(in_features=1152, out_features=3456, bias=True) (proj): Linear(in_features=1152, out_features=1152, bias=True) ) (mlp): MLP( (fc1): Linear(in_features=1152, out_features=4304, bias=True) (act): GELU(approximate='tanh') (fc2): Linear(in_features=4304, out_features=1152, bias=True) ) (norm1): LayerNorm((1152,), eps=1e-05, elementwise_affine=True) (norm2): LayerNorm((1152,), eps=1e-05, elementwise_affine=True) ) (26): VitBlock( (attn): Attention( (qkv): Linear(in_features=1152, out_features=3456, bias=True) (proj): Linear(in_features=1152, out_features=1152, bias=True) ) (mlp): MLP( (fc1): Linear(in_features=1152, out_features=4304, bias=True) (act): GELU(approximate='tanh') (fc2): Linear(in_features=4304, out_features=1152, bias=True) ) (norm1): LayerNorm((1152,), eps=1e-05, elementwise_affine=True) (norm2): LayerNorm((1152,), eps=1e-05, elementwise_affine=True) ) ) (norm): LayerNorm((1152,), eps=1e-05, elementwise_affine=True) ) ) ) (projection): VisionProjection( (mlp): MLP( (fc1): Linear(in_features=1152, out_features=8192, bias=True) (act): GELU(approximate='tanh') (fc2): Linear(in_features=8192, out_features=2048, bias=True) ) ) (preprocess): Compose( Resize(size=[378, 378], interpolation=InterpolationMode.BICUBIC, antialias=warn) ToImage() ToDtype(scale=True) Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=False) ) ) (text_model): PhiForCausalLM( (transformer): PhiModel( (embd): Embedding( (wte): Embedding(51200, 2048) ) (embed_dropout): Dropout(p=0.0, inplace=False) (h): ModuleList( (0-23): 24 x PhiDecoderLayer( (mixer): PhiFlashAttention2( (Wqkv): Linear(in_features=2048, out_features=6144, bias=True) (out_proj): Linear(in_features=2048, out_features=2048, bias=True) (rotary_emb): PhiRotaryEmbedding() ) (mlp): PhiMLP( (activation_fn): NewGELUActivation() (fc1): Linear(in_features=2048, out_features=8192, bias=True) (fc2): Linear(in_features=8192, out_features=2048, bias=True) ) (ln): LayerNorm((2048,), eps=1e-05, elementwise_affine=True) (resid_dropout): Dropout(p=0.0, inplace=False) ) ) ) (lm_head): CausalLMHead( (ln): LayerNorm((2048,), eps=1e-05, elementwise_affine=True) (linear): Linear(in_features=2048, out_features=51200, bias=True) ) ) )
# if use_4bit:
# from peft import prepare_model_for_kbit_training
# model.gradient_checkpointing_enable()
# model = prepare_model_for_kbit_training(model)
lora_alpha = 32
lora_rank = 64
## Apply LoRA (if use_lora is True in the config)
if use_lora:
from peft import LoraConfig
lora_config = LoraConfig(
r=lora_rank,
lora_alpha=lora_alpha,
target_modules=[
'proj','fc1','fc2',
'Wqkv','out_proj'
],
lora_dropout=0.1, # Example value, adjust as needed
bias="none", # Example setting, adjust as needed
task_type="CAUSAL_LM",
# modules_to_save=['lm_head','embd'], #won't work with the trainer unless using a hf trainer, not custom.
)
if use_lora:
from peft import get_peft_model
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
trainable params: 74,422,272 || all params: 1,931,904,880 || trainable%: 3.8522741347389733
print(model.config)
MoondreamConfig { "_name_or_path": "vikhyatk/moondream2", "architectures": [ "Moondream" ], "auto_map": { "AutoConfig": "vikhyatk/moondream2--configuration_moondream.MoondreamConfig", "AutoModelForCausalLM": "vikhyatk/moondream2--moondream.Moondream" }, "model_type": "moondream1", "phi_config": { "model_type": "phi" }, "torch_dtype": "bfloat16", "transformers_version": "4.40.1" }
# List to hold the names of the trainable parameters
if set_other_trainable:
trainable_params_names = ['lm_head','embd']
# trainable_params_names = None
# Set modules to be trainable
for n, p in model.named_parameters():
if any(k in n for k in trainable_params_names):
p.requires_grad_(True)
# else:
# p.requires_grad_(False) # Optional: Set the rest to be not trainable
# Make a dictionary of trainable parameters
trainable_params = {n: p for n, p in model.named_parameters() if p.requires_grad}
# Convert trainable_params to state_dict format
trainable_params_state_dict = {n: p.data for n, p in trainable_params.items()}
from IPython.display import display
model.eval()
sample = dataset['train'][1]
img = img.resize((224, 224))
display(img)
for qa in sample['qa']:
print('Question:', qa['question'])
print('Ground Truth:', qa['answer'])
print('model:', model.answer_question(
model.encode_image(sample['image']),
qa['question'],
tokenizer=tokenizer,
))
Let's start setting up hyperparameters for finetuning.
# Number of times to repeat the training dataset
## Too large and the model may overfit or experience catastrophic forgetting
## Too small and the model may underfit and have poor performance
EPOCHS = 3
# Number of samples to process in each batch. Set this to the highest value that doesn't cause an
# out-of-memory error. Decrease it if you're running out of memory. Batch size 8 currently uses around
# 15 GB of GPU memory during fine-tuning.
BATCH_SIZE = 4
# Number of batches to process before updating the model. You can use this to simulate a higher batch
# size than your GPU can handle. Set this to 1 to disable gradient accumulation.
GRAD_ACCUM_STEPS = 1
# INTUITION IS THAT YOU CAN MOVE FASTER IF YOU use multiple batches.
# Learning rate for the Adam optimizer. Needs to be tuned on a case-by-case basis. As a general rule
# of thumb, increase it by 1.4 times each time you double the effective batch size.
#
# Source: https://www.cs.princeton.edu/~smalladi/blog/2024/01/22/SDEs-ScalingRules/
#
# Note that we linearly warm the learning rate up from 0.1 * LR to LR over the first 10% of the
# training run, and then decay it back to 0.1 * LR over the last 90% of the training run using a
# cosine schedule.
# LR = 3e-5 # default value
LR = 1.5e-5
if use_lora:
LR_scaling = lora_alpha / (lora_rank**0.5)
print("Using an LR scaling for LoRA adapters of: ", LR_scaling)
# Whether to use Weights and Biases for logging training metrics.
USE_WANDB = True
# Eval steps
eval_freq = 0.25 # means run every such fraction of total steps.
Using an LR scaling for LoRA adapters of: 4.0
This next block will start the training process.
# import os
# import getpass
# # Securely ask for the API key
# api_key = getpass.getpass("Please enter your wandb API key: ")
# # Use the API key to login to wandb
# os.system(f"wandb login --relogin {api_key}")
from torch.utils.data import DataLoader
from bitsandbytes.optim import Adam8bit
import math
from einops import rearrange
from tqdm import tqdm
ANSWER_EOS = "<|endoftext|>"
# Number of tokens used to represent each image.
IMG_TOKENS = 729
def collate_fn(batch):
images = [sample['image'] for sample in batch]
images = torch.stack(model.vision_encoder.preprocess(images))
images = rearrange(images,
"b c (h p1) (w p2) -> b (h w) (c p1 p2)",
p1=14, p2=14)
labels_acc = []
tokens_acc = []
for sample in batch:
toks = [tokenizer.bos_token_id]
labs = [-100] * (IMG_TOKENS + 1)
for qa in sample['qa']:
q_t = tokenizer(
f"\n\nQuestion: {qa['question']}\n\nAnswer:",
add_special_tokens=False
).input_ids
toks.extend(q_t)
labs.extend([-100] * len(q_t))
a_t = tokenizer(
f" {qa['answer']}{ANSWER_EOS}",
add_special_tokens=False
).input_ids
toks.extend(a_t)
labs.extend(a_t)
tokens_acc.append(toks)
labels_acc.append(labs)
max_len = -1
for labels in labels_acc:
max_len = max(max_len, len(labels))
attn_mask_acc = []
for i in range(len(batch)):
len_i = len(labels_acc[i])
pad_i = max_len - len_i
labels_acc[i].extend([-100] * pad_i)
tokens_acc[i].extend([tokenizer.eos_token_id] * pad_i)
attn_mask_acc.append([1] * len_i + [0] * pad_i)
return (
images.to(dtype=DTYPE),
torch.stack([torch.tensor(t, dtype=torch.long) for t in tokens_acc]),
torch.stack([torch.tensor(l, dtype=torch.long) for l in labels_acc]),
torch.stack([torch.tensor(a, dtype=torch.bool) for a in attn_mask_acc]),
)
dataloaders = {
"train": DataLoader(
datasets["train"],
batch_size=BATCH_SIZE,
shuffle=True,
collate_fn=collate_fn,
),
"test": DataLoader(
datasets["test"],
batch_size=3,
collate_fn=collate_fn,
),
}
total_steps = EPOCHS * len(dataloaders["train"]) // GRAD_ACCUM_STEPS
eval_steps=total_steps*eval_freq
def compute_loss(batch):
images, tokens, labels, attn_mask = batch
images = images.to(DEVICE)
tokens = tokens.to(DEVICE)
labels = labels.to(DEVICE)
attn_mask = attn_mask.to(DEVICE)
with torch.no_grad():
img_embs = model.vision_encoder.encoder(images)
img_embs = model.vision_encoder.projection(img_embs)
tok_embs = model.text_model.get_input_embeddings()(tokens)
inputs_embeds = torch.cat((tok_embs[:, 0:1, :], img_embs, tok_embs[:, 1:, :]), dim=1)
outputs = model.text_model(
inputs_embeds=inputs_embeds,
labels=labels,
attention_mask=attn_mask,
)
return outputs.loss
# Cosine learning rate schedule.
def lr_schedule(step, max_steps):
x = step / max_steps
if x < 0.1:
return 0.1 * LR + 0.9 * LR * x / 0.1
else:
return 0.1 * LR + 0.9 * LR * (1 + math.cos(math.pi * (x - 0.1))) / 2
# # Constant learning rate schedule.
# def lr_schedule(step, max_steps):
# x = step / max_steps
# if x < 0.1:
# return 0.1 * LR + 0.9 * LR * x / 0.1
# else:
# return LR
model.text_model.train()
model.text_model.transformer.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant":False},) #this fixes the no grad issues...
## For fine-tuning LoRA params
lora_params = []
for name, module in model.named_modules():
if "lora" in name:
lora_params.extend([p for p in module.parameters() if p.requires_grad])
# To fine-tune all lora params (which can include the vision model)
optimizer = Adam8bit(
[
{"params": lora_params},
],
lr=LR * 0.1,
betas=(0.9, 0.95),
eps=1e-6
)
# # For fine-tuning all text model params
# optimizer = Adam8bit(
# [
# {"params": model.text_model.parameters()},
# ],
# # [{"params": lora_params}],
# lr=LR * 0.1,
# betas=(0.9, 0.95),
# eps=1e-6
# )
if USE_WANDB:
import wandb
wandb.init(
project="model-ft",
config={
"EPOCHS": EPOCHS,
"BATCH_SIZE": BATCH_SIZE,
"GRAD_ACCUM_STEPS": GRAD_ACCUM_STEPS,
"LR": LR,
}
)
i = 0
for epoch in range(EPOCHS):
for batch in tqdm(dataloaders["train"], desc=f"Epoch {epoch + 1}/{EPOCHS}"):
i += 1
loss = compute_loss(batch)
loss.backward()
if i % GRAD_ACCUM_STEPS == 0:
optimizer.step()
optimizer.zero_grad()
lr = lr_schedule(i / GRAD_ACCUM_STEPS, total_steps)
for param_group in optimizer.param_groups:
if param_group['params'] == lora_params:
param_group['lr'] = lr * LR_scaling # Apply scaling only to lora_params
else:
param_group['lr'] = lr # Apply base lr to all other params
if i % eval_steps == 0 and USE_WANDB:
# Calculate validation loss
val_loss = 0
for val_batch in tqdm(dataloaders["test"], desc="Validation"):
with torch.no_grad():
val_loss += compute_loss(val_batch).item()
val_loss /= len(dataloaders["test"])
if USE_WANDB:
wandb.log({
"loss/train": loss.item(),
"lr": optimizer.param_groups[0]['lr']
} | ({"loss/val": val_loss} if i % eval_steps == 0 else {}))
if USE_WANDB:
wandb.finish()
/usr/local/lib/python3.10/dist-packages/torch/_compile.py:24: UserWarning: optimizer contains a parameter group with duplicate parameters; in future, this will cause an error; see github.com/pytorch/pytorch/issues/40967 for more information return torch._dynamo.disable(fn, recursive)(*args, **kwargs) wandb: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server) wandb: You can find your API key in your browser here: https://wandb.ai/authorize wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:
wandb: Appending key for api.wandb.ai to your netrc file: /root/.netrc
Epoch 1/3: 0%| | 0/12 [00:00<?, ?it/s]`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`... Epoch 1/3: 67%|██████▋ | 8/12 [00:08<00:04, 1.11s/it] Validation: 0%| | 0/1 [00:00<?, ?it/s] Validation: 100%|██████████| 1/1 [00:00<00:00, 1.79it/s] Epoch 1/3: 100%|██████████| 12/12 [00:14<00:00, 1.23s/it] Epoch 2/3: 42%|████▏ | 5/12 [00:05<00:08, 1.18s/it] Validation: 0%| | 0/1 [00:00<?, ?it/s] Validation: 100%|██████████| 1/1 [00:00<00:00, 1.89it/s] Epoch 2/3: 100%|██████████| 12/12 [00:14<00:00, 1.20s/it] Epoch 3/3: 17%|█▋ | 2/12 [00:02<00:12, 1.23s/it] Validation: 0%| | 0/1 [00:00<?, ?it/s] Validation: 100%|██████████| 1/1 [00:00<00:00, 1.86it/s] Epoch 3/3: 92%|█████████▏| 11/12 [00:13<00:01, 1.20s/it] Validation: 0%| | 0/1 [00:00<?, ?it/s] Validation: 100%|██████████| 1/1 [00:00<00:00, 1.84it/s] Epoch 3/3: 100%|██████████| 12/12 [00:15<00:00, 1.26s/it]
VBox(children=(Label(value='0.014 MB of 0.014 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))
loss/train | █▇▆▆▆▆▅▆▄▃▅▄▃▃▃▃▂▃▂▃▂▃▂▂▂▂▁▂▂▃▂▂▂▂▂▂ |
loss/val | █▂▁▁ |
lr | ▃▅▇██████▇▇▇▇▇▆▆▆▆▅▅▅▄▄▄▃▃▃▃▂▂▂▂▁▁▁▁ |
loss/train | 1.18496 |
loss/val | 1.53571 |
lr | 1e-05 |
model.save_pretrained(f"checkpoints/{model_slug}-ft")
Now that training has completed, let's inspect a few samples and calculate accuracy.
model.eval()
correct = 0
for i, sample in enumerate(datasets['test']):
md_answer = model.answer_question(
model.encode_image(sample['image']),
sample['qa'][0]['question'],
tokenizer=tokenizer,
)
# if md_answer == sample['qa'][0]['answer']:
# correct += 1
if i < 3:
display(sample['image'])
print('Question:', sample['qa'][0]['question'])
print('Ground Truth:', sample['qa'][0]['answer'])
print('model:', md_answer)
# print(f"\n\nAccuracy: {correct / len(datasets['test']) * 100:.2f}%")