Visionmodellen fine-tunen

Multi-modale modellen met Hugging Face

James Chapman

Curriculum Manager, DataCamp

Doel van het fine-tunen van visionmodellen

 

  • Nieuwe klassen, bijv. binaire classificatie van echte vs. AI-gegenereerde afbeeldingen
  • Nieuwe domeinen, bijv. röntgenfoto's

Algemene afbeelding voor pretraining van een zeehond

Voorbeeld van een AI-gegenereerde afbeelding

1 https://image-net.org/index.php
Multi-modale modellen met Hugging Face

Visionmodellen fine-tunen

Voorbeeld van een AI-gegenereerde afbeelding

 

  1. Pas modeloutput aan op de nieuwe voorspellingen
  2. Bereid de dataset voor op training
  3. Stel trainingsopties in
  4. Trainen!
Multi-modale modellen met Hugging Face

Modelupdates

from datasets import load_dataset
dataset = load_dataset("ideepankarsharma2003/Midjourney_v6_Classification_small_shuf
fled")['train']

data_splits = dataset.train_test_split(test_size=0.2, seed=42)
labels = data_splits["train"].features["label"].names
label2id, id2label = dict(), dict() for i, label in enumerate(labels): label2id[label] = str(i) id2label[str(i)] = label
Multi-modale modellen met Hugging Face

Modelupdates

from transformers import AutoModelForImageClassification
checkpoint = "google/mobilenet_v2_1.0_224"
model = AutoModelForImageClassification.from_pretrained(
    checkpoint,
    num_labels=len(labels),

id2label=id2label, label2id=label2id,
ignore_mismatched_sizes=True
)
Multi-modale modellen met Hugging Face

Datasetvoorbereiding

from transformers import AutoImageProcessor
image_processor = AutoImageProcessor.from_pretrained(checkpoint)


from torchvision.transforms import Compose, Normalize, ToTensor
normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
transform = Compose([ToTensor(), normalize])
def transforms(examples): examples["pixel_values"] = [transform(img.convert("RGB")) for img in examples["image"]] del examples["image"] return examples
dataset = dataset.with_transform(transforms)
Multi-modale modellen met Hugging Face

Getransformeerde data plotten

import matplotlib.pyplot as plt
plt.imshow(dataset["train"][0]["pixel_values"].permute(1, 2, 0))
plt.show()

Getransformeerde afbeelding uit de nieuwe dataset

Multi-modale modellen met Hugging Face

Training

from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir="dataset_finetune",

learning_rate=6e-5,
gradient_accumulation_steps=4,
num_train_epochs=3,
push_to_hub=False )
from transformers import Trainer,
    DefaultDataCollator

data_collator = DefaultDataCollator()

trainer = Trainer(

model=model,
args=training_args,
train_dataset=dataset["train"], eval_dataset=dataset["test"],
processing_class=image_processor,
data_collator=data_collator
)
Multi-modale modellen met Hugging Face

Evaluatie

predictions = trainer.predict(dataset["test"])
predictions.metrics["test_accuracy"]
0.455
trainer.train()
{..., 'eval_accuracy': 0.93, ...}
Multi-modale modellen met Hugging Face

Laten we oefenen!

Multi-modale modellen met Hugging Face

Preparing Video For Download...