{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Python executable: /workspace/synthetic-derm/.venv/bin/python\n", "Virtual environment: True\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/workspace/synthetic-derm/.venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "✓ synderm package is available\n", " Package location: /workspace/synthetic-derm/synderm/__init__.py\n" ] } ], "source": [ "# Verify environment setup\n", "import sys\n", "print(f\"Python executable: {sys.executable}\")\n", "print(f\"Virtual environment: {'.venv' in sys.executable}\")\n", "\n", "# Test synderm import\n", "try:\n", " import synderm\n", " print(\"✓ synderm package is available\")\n", " print(f\" Package location: {synderm.__file__}\")\n", "except ImportError as e:\n", " print(f\"✗ synderm import failed: {e}\")\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Vignette: Augmenting Your Classifier with Synthetic Images" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/workspace/synthetic-derm/.venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.optim as optim\n", "from torch.utils.data import DataLoader\n", "from torchvision import datasets, transforms, models\n", "from synderm.generation.generate import generate_synthetic_dataset\n", "from synderm.utils.utils import synthetic_train_val_split\n", "from webdataset import WebDataset, RandomMix\n", "from huggingface_hub import get_token\n", "from torch.utils.data import DataLoader\n", "from torch.utils.data import Dataset\n", "from huggingface_hub import HfApi\n", "import matplotlib.pyplot as plt\n", "import webdataset as wds\n", "from pathlib import Path\n", "from PIL import Image\n", "import pandas as pd\n", "import random\n", "import os\n", "import json\n", "import io\n", "import re" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "/workspace/synthetic-derm\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/workspace/synthetic-derm/.venv/lib/python3.10/site-packages/IPython/core/magics/osm.py:417: UserWarning: This is now an optional IPython functionality, setting dhist requires you to install the `pickleshare` library.\n", " self.shell.db['dhist'] = compress_dhist(dhist)[-100:]\n" ] } ], "source": [ "# Set path to root directory of package\n", "%cd ../../../" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Introduction\n", "\n", "This notebook will demonstrate how to augment a small dermatology dataset with our large collection of synthetic images. We will start by loading in a sample dermatology dataset. These are also synthetic images, but we are pretending they are real images for the purposes of this vignette. You should replace this dataset with your own dataset (adjusting the labels/format as necessary).\n", "\n", "After we load in these images, we will select the desired labels from the synthetic-derm training dataset hosted on [HuggingFace](https://huggingface.co/datasets/tbuckley/synthetic-derm-1M-train). We will then mix these in with our real images, and use a subset of images for validation." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. Load your dataset\n", "\n", "First, create a Torch dataset based on the structure of your data. We provide a sample dataset called \"sample_derm_dataset,\" with a folder for \"train\" and \"val.\" Each folder is organized into subfolders for each label (similar to ImageNet). For use with this package, it is standard to return dictionary entries containing a \"label\" and \"image\" (PIL) field." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "class SampleDataset(Dataset):\n", " def __init__(self, dataset_dir, split=\"train\"):\n", " self.dataset_dir = Path(dataset_dir)\n", " self.image_paths = []\n", " self.labels = []\n", " self.split = split\n", "\n", " # Walk through class folders\n", " data_dir = self.dataset_dir / self.split\n", " for class_name in os.listdir(data_dir):\n", " class_dir = data_dir / class_name\n", " if not class_dir.is_dir():\n", " continue\n", " \n", " # Get all png images in this class folder\n", " for img_name in os.listdir(class_dir):\n", " if img_name.lower().endswith('.png'):\n", " self.image_paths.append(class_dir / img_name)\n", " self.labels.append(class_name)\n", " \n", " # Shuffle the dataset\n", " indices = list(range(len(self.image_paths)))\n", " random.shuffle(indices)\n", " self.image_paths = [self.image_paths[i] for i in indices]\n", " self.labels = [self.labels[i] for i in indices]\n", "\n", " def __len__(self):\n", " return len(self.image_paths)\n", "\n", " def __getitem__(self, idx):\n", " image_path = self.image_paths[idx]\n", " label = self.labels[idx]\n", " \n", " # Load and convert image to RGB\n", " image = Image.open(image_path).convert('RGB')\n", " image_name = image_path.stem\n", "\n", " return {\"id\": image_name, \"image\": image, \"label\": label}\n" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "train_data = SampleDataset(\"sample_derm_dataset\")\n", "test_data = SampleDataset(\"sample_derm_dataset\", split=\"val\")" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'id': '0064', 'image': , 'label': 'squamous-cell-carcinoma'}\n" ] } ], "source": [ "# Print a sample entry\n", "for item in train_data:\n", " print(item)\n", " break" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. Loading the synthetic images from HuggingFace\n", "\n", "We will now load synthetic images using the train version of the dataset hosted on HuggingFace. This dataset contains 1 million images, seperated into four generation methods: finetune-inpaint, finetune-text-to-image, pretrained-inpaint, and pretrained-text-to-image. These have already been shuffled, then broken into shards. This is ideal for training as no reshuffling needs to be done, and shards only need to be loaded one at a time (saving lots of memory).\n", "\n", "Based on the results in our paper, images produced from finetune-text-to-image perform the best, and this is also the largest split in the dataset. So, we will select this split of the dataset, and all shards (133 is the last numbered shard, shards can be viewed at [this link](https://huggingface.co/datasets/tbuckley/synthetic-derm-1M-train/tree/main/data))\n", "\n", "This is selected using the following URL:" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "# Select finetune-text-to-image shards\n", "url = \"https://huggingface.co/datasets/tbuckley/synthetic-derm-1M-train/resolve/main/data/shard-finetune-text-to-image-{00000..00133}.tar\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, we will list the labels we would like to include, and create a WebDataset pipeline to filter and format each entry as the dataset is iterated." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/workspace/synthetic-derm/.venv/lib/python3.10/site-packages/webdataset/compat.py:381: UserWarning: set WebDataset(shardshuffle=...) to a positive integer or 0 or False\n", " warnings.warn(\"set WebDataset(shardshuffle=...) to a positive integer or 0 or False\")\n" ] } ], "source": [ "LABELS = [\n", " \"allergic-contact-dermatitis\",\n", " \"basal-cell-carcinoma\",\n", " \"folliculitis\",\n", " \"lichen-planus\",\n", " \"lupus-erythematosus\",\n", " \"neutrophilic-dermatoses\",\n", " \"photodermatoses\",\n", " \"psoriasis\",\n", " \"sarcoidosis\",\n", " \"squamous-cell-carcinoma\"\n", "]\n", "\n", "def to_dict(sample):\n", " return {\n", " \"id\": sample[\"json\"][\"md5hash\"], \n", " \"image\": sample[\"png\"],\n", " \"label\": sample[\"json\"][\"label\"]\n", " }\n", "\n", "def select_label(sample):\n", " if sample[\"label\"] in LABELS:\n", " return sample\n", " else:\n", " return None\n", "\n", "# Create a WebDataset\n", "synthetic_data = (\n", " wds.WebDataset(url, shardshuffle=True)\n", " .shuffle(40000)\n", " .decode(\"pil\")\n", " .map(to_dict)\n", " .map(select_label)\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3. Mixing the real and synthetic training images\n", "\n", "Now, we need to somehow combine our real and synthetic images for model training. We can use the convenient `RandomMix` function from the WebDataset package. This function allows us to combine two Pytorch datasets and specify the sampling probabilitiy for each one. \n", "\n", "We are going to create a dataset with a 1.5 probability of sampling real data, and 1.0 for synthetic data. This will give us on average more real images than synthetic. **We encourage you to try different mixing ratios for the best performance with your data.**" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "mixed_dataset = RandomMix([train_data, synthetic_data], [1.5, 1.0]) " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 4. Model training and validation\n", "\n", "Finally, now that we have our dataset of real and synthetic images, we will train a Pytorch EfficientNet_V2_M model to classify these images. We will validate our model on the held-out set of real images." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "unique_labels = sorted(set(LABELS))\n", "label_to_idx = {label: i for i, label in enumerate(unique_labels)}\n", "idx_to_label = {i: label for label, i in label_to_idx.items()}\n", "\n", "def collate_fn(batch):\n", " tfms = transforms.Compose([\n", " transforms.Resize((224, 224)), # Ensure all images have same size\n", " transforms.ToTensor(),\n", " ])\n", "\n", " images, labels = [], []\n", "\n", " for sample in batch:\n", " img = tfms(sample['image'])\n", " lbl = label_to_idx[sample['label']]\n", " images.append(img)\n", " labels.append(lbl)\n", "\n", " images = torch.stack(images, dim=0)\n", " labels = torch.tensor(labels)\n", "\n", " return images, labels\n", "\n", "train_loader = torch.utils.data.DataLoader(mixed_dataset, batch_size=32, collate_fn=collate_fn)\n", "val_loader = torch.utils.data.DataLoader(test_data, batch_size=32, collate_fn=collate_fn)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Downloading: \"https://download.pytorch.org/models/efficientnet_v2_m-dc08266a.pth\" to /root/.cache/torch/hub/checkpoints/efficientnet_v2_m-dc08266a.pth\n", "100%|██████████| 208M/208M [00:47<00:00, 4.61MB/s] \n" ] } ], "source": [ "model = models.efficientnet_v2_m(weights=models.EfficientNet_V2_M_Weights.IMAGENET1K_V1)\n", "\n", "num_ftrs = model.classifier[1].in_features\n", "model.classifier[1] = nn.Linear(num_ftrs, 10)\n", "\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "model = model.to(device)\n", "\n", "criterion = nn.CrossEntropyLoss()\n", "optimizer = optim.Adam(model.parameters(), lr=0.001)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 1/10, Batch: 1, Train Loss: 2.3465139865875244\n", "Epoch: 2/10, Batch: 1, Train Loss: 1.8198271989822388\n", "Epoch: 3/10, Batch: 1, Train Loss: 1.6083544492721558\n", "Epoch: 4/10, Batch: 1, Train Loss: 1.3128741979599\n", "Epoch: 5/10, Batch: 1, Train Loss: 0.9874697923660278\n", "Epoch: 6/10, Batch: 1, Train Loss: 0.8346286416053772\n", "Epoch: 7/10, Batch: 1, Train Loss: 0.7049002051353455\n", "Epoch: 8/10, Batch: 1, Train Loss: 0.33139896392822266\n", "Epoch: 9/10, Batch: 1, Train Loss: 0.6521413326263428\n", "Epoch: 10/10, Batch: 1, Train Loss: 0.4039866030216217\n" ] } ], "source": [ "num_epochs = 10\n", "for epoch in range(num_epochs):\n", " # Training\n", " model.train()\n", " train_loss = 0.0\n", " for i, data in enumerate(train_loader):\n", " imgs, lbls = data\n", " imgs, lbls = imgs.to(device), lbls.to(device)\n", "\n", " optimizer.zero_grad()\n", " outputs = model(imgs)\n", " loss = criterion(outputs, lbls)\n", " loss.backward()\n", " optimizer.step()\n", "\n", " train_loss += loss.item()\n", "\n", " if i % 100 == 0:\n", " print(f'Epoch: {epoch+1}/{num_epochs}, Batch: {i+1}, Train Loss: {loss.item()}')" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "Validation Set Evaluation:\n", "Validation Loss: 2.9727\n", "Validation Accuracy: 37.19%\n", "\n", "Detailed Validation Metrics:\n", " precision recall f1-score support\n", "\n", "allergic-contact-dermatitis 0.3600 0.2812 0.3158 32\n", " basal-cell-carcinoma 0.8750 0.2188 0.3500 32\n", " folliculitis 0.5517 0.5000 0.5246 32\n", " lichen-planus 0.2474 0.7500 0.3721 32\n", " lupus-erythematosus 0.3095 0.4062 0.3514 32\n", " neutrophilic-dermatoses 0.5294 0.2812 0.3673 32\n", " photodermatoses 0.3469 0.5312 0.4198 32\n", " psoriasis 0.3333 0.0938 0.1463 32\n", " sarcoidosis 0.3103 0.2812 0.2951 32\n", " squamous-cell-carcinoma 0.8000 0.3750 0.5106 32\n", "\n", " accuracy 0.3719 320\n", " macro avg 0.4664 0.3719 0.3653 320\n", " weighted avg 0.4664 0.3719 0.3653 320\n", "\n", "\n", "Validation Confusion Matrix:\n", "[[ 9 0 2 7 4 1 7 1 1 0]\n", " [ 0 7 0 22 0 1 1 0 0 1]\n", " [ 2 0 16 3 2 0 2 1 6 0]\n", " [ 1 0 1 24 1 0 3 1 1 0]\n", " [ 1 0 2 5 13 1 3 1 6 0]\n", " [ 1 0 0 9 6 9 4 0 3 0]\n", " [ 3 0 0 5 4 0 17 1 2 0]\n", " [ 2 0 6 5 7 2 5 3 1 1]\n", " [ 2 0 1 7 5 2 5 0 9 1]\n", " [ 4 1 1 10 0 1 2 1 0 12]]\n" ] } ], "source": [ "# Evaluation\n", "from sklearn.metrics import classification_report, confusion_matrix\n", "import numpy as np\n", "\n", "def evaluate_model(model, data_loader, device, criterion):\n", " model.eval()\n", " all_predictions = []\n", " all_labels = []\n", " running_loss = 0.0\n", " correct = 0\n", " total = 0\n", " \n", " with torch.no_grad():\n", " for data in data_loader:\n", " imgs, lbls = data\n", " imgs, lbls = imgs.to(device), lbls.to(device)\n", " \n", " outputs = model(imgs)\n", " loss = criterion(outputs, lbls)\n", " running_loss += loss.item()\n", " \n", " _, predicted = torch.max(outputs.data, 1)\n", " total += lbls.size(0)\n", " correct += (predicted == lbls).sum().item()\n", " \n", " all_predictions.extend(predicted.cpu().numpy())\n", " all_labels.extend(lbls.cpu().numpy())\n", " \n", " # Map numeric labels back to string labels\n", " all_predictions_str = [idx_to_label[pred] for pred in all_predictions]\n", " all_labels_str = [idx_to_label[lbl] for lbl in all_labels]\n", " \n", " # Calculate metrics\n", " loss = running_loss / len(data_loader)\n", " accuracy = 100 * correct / total\n", " \n", " # Generate detailed classification report\n", " report = classification_report(all_labels_str, all_predictions_str, digits=4)\n", " conf_matrix = confusion_matrix(all_labels_str, all_predictions_str)\n", " \n", " return loss, accuracy, report, conf_matrix\n", "\n", "print(\"\\nValidation Set Evaluation:\")\n", "val_loss, val_accuracy, val_report, val_conf_matrix = evaluate_model(model, val_loader, device, criterion)\n", "print(f\"Validation Loss: {val_loss:.4f}\")\n", "print(f\"Validation Accuracy: {val_accuracy:.2f}%\")\n", "print(\"\\nDetailed Validation Metrics:\")\n", "print(val_report)\n", "print(\"\\nValidation Confusion Matrix:\")\n", "print(val_conf_matrix)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Given that this is a sample dataset, this model appears to perform OK. We encourage you to use your own data, augmented with our large collection of synthetic images. For next steps, you can try training your model with and without data augmentation, trying different mixing ratios, and different models. Best of luck!" ] }, { "cell_type": "markdown", "metadata": {}, "source": [] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 2 }