{ "cells": [ { "cell_type": "markdown", "id": "8fe59b6f-3505-4dd6-9078-bd85243fcf24", "metadata": {}, "source": [ "# Train CycleGAN Notebook" ] }, { "cell_type": "markdown", "id": "c13b7b64-578c-4faa-b90c-3da08dde7164", "metadata": {}, "source": "This notebook realizes a quick training of a CycleGAN (stabilityAI/sd-turbo) using merlin's BosonSampler on day and night car images (small part of dataset BDD-100k). The CycleGAN was pretrained on the whole dataset and the idea was then to add a BosonSampler, before the decoder, and to train only the decoder. This method has shown some results on a challenge organized by BMW and Airbus to turn driving in the day images into driving in the night images. The original architecture was based on this github https://github.com/GaParmar/img2img-turbo. We did a first training without any quantum layers and then a second one, starting from the pretained weights and adding a quantum layer. The notebook here is a small reproduction of the second training, with a dataset of only 200 images. The dataset is not directly in the MerLin repository, however you can specify where to download it with the _args.dataset_folder_ variable. It goes the same for the outputs of the model with the _args.output_dir_ variable. Be careful: it is necessary to have access to a cuda device to run this model. *To run this model you'll need to install the requirements present in the img2img-turbo-annotations folder.*" }, { "cell_type": "code", "execution_count": 1, "id": "77f978e3-90c3-4005-99de-1cb35096eb67", "metadata": {}, "outputs": [], "source": [ "import torch\n", "import gc\n", "import copy\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "from glob import glob\n", "import numpy as np\n", "import lpips\n", "from accelerate import Accelerator\n", "from accelerate.ut+ils import set_seed\n", "from PIL import Image\n", "from torchvision import transforms\n", "from tqdm.auto import tqdm\n", "from transformers import AutoTokenizer, CLIPTextModel\n", "from diffusers.optimization import get_scheduler\n", "from peft.utils import get_peft_model_state_dict\n", "from cleanfid.fid import get_folder_features, build_feature_extractor, frechet_distance\n", "import vision_aided_loss\n", "from img2img_turbo_annotations.src_quantum.model import make_1step_sched\n", "from img2img_turbo_annotations.src_quantum.cyclegan_turbo import CycleGAN_Turbo, VAE_encode, VAE_decode, initialize_unet, initialize_vae\n", "from img2img_turbo_annotations.src_quantum.my_utils.training_utils import UnpairedDataset, build_transform, parse_args_unpaired_training, \\\n", " UnpairedDataset_Quantum, get_next_id, read_from_emb16, image_fail, read_from_emb32, load_small_dataset\n", "from img2img_turbo_annotations.src_quantum.my_utils.dino_struct import DinoStructureLoss\n", "import h5py\n", "import torch.nn.init as init\n", "import pandas as pd\n", "from img2img_turbo_annotations.src_quantum.BosonSampler import BosonSampler\n", "import time\n", "import shutil\n", "import argparse\n", "from types import SimpleNamespace\n", "import os\n", "from huggingface_hub import hf_hub_download\n", "import matplotlib.pyplot as plt\n" ] }, { "cell_type": "markdown", "id": "959c43c5-4500-4230-9de3-75aca5698016", "metadata": {}, "source": [ "### Adjust key parameters and variables" ] }, { "cell_type": "code", "execution_count": 2, "id": "f99fb95c-c17f-4183-a2fc-d91e7406a432", "metadata": {}, "outputs": [], "source": [ "def parse_args_unpaired_training():\n", " \"\"\"\n", " Parses command-line arguments used for configuring an unpaired session (CycleGAN-Turbo).\n", " This function sets up an argument parser to handle various training options.\n", "\n", " Returns:\n", " argparse.Namespace: The parsed command-line arguments.\n", " \"\"\"\n", "\n", " \n", " args = {}\n", " # fixed random seed\n", " args[\"seed\"] = 42\n", " # args for the loss function\n", " args[\"gan_disc_type\"] = \"vagan_clip\"\n", " args[\"gan_loss_type\"] = \"multilevel_sigmoid\"\n", " args[\"lambda_gan\"] = 0.5\n", " args[\"lambda_idt\"] = 1\n", " args[\"lambda_cycle\"] = 1\n", " args[\"lambda_cycle_lpips\"] = 10.0\n", " args[\"lambda_idt_lpips\"] = 1.0\n", "\n", " # args for dataset and dataloader options\n", " args[\"dataset_folder\"] = 'img2img_turbo_annotations/dataset'\n", " args[\"train_img_prep\"] = 'resize_128' \n", " args[\"val_img_prep\"] = 'resize_128' \n", " args[\"dataloader_num_workers\"] = 0 \n", " args[\"train_batch_size\"] = 2 \n", " args[\"max_train_epochs\"] = 20 \n", " args[\"max_train_steps\"] = 20 \n", " \n", " args[\"revision\"] = None \n", " args[\"variant\"] = None \n", " args[\"lora_rank_unet\"] = 128 \n", " args[\"lora_rank_vae\"] = 4 \n", " # args for validation and logging \n", " args[\"viz_freq\"] = 20 \n", " args[\"output_dir\"] = 'img2img_turbo_annotations/outputs'\n", " args[\"validation_steps\"] = 50 \n", " args[\"validation_num_images\"] = -1\n", " args[\"checkpointing_steps\"] = 1000\n", "\n", " # args for the optimization options\n", " args[\"learning_rate\"] = 1e-5\n", " args[\"adam_beta1\"] = 0.9\n", " args[\"adam_beta2\"] = 0.999\n", " args[\"adam_weight_decay\"] = 1e-2\n", " args[\"adam_epsilon\"] = 1e-08 \n", " args[\"max_grad_norm\"] = 10.0\n", " args[\"lr_scheduler\"] = \"constant\" \n", " args[\"lr_warmup_steps\"] = 500\n", " args[\"lr_num_cycles\"] = 1 \n", " args[\"lr_power\"] = 1.0 \n", " args[\"gradient_accumulation_steps\"] = 1 # memory saving options\n", " args[\"allow_tf32\"] = False\n", " args[\"gradient_checkpointing\"] = False\n", " args[\"enable_xformers_memory_efficient_attention\"] = True\n", "\n", "\n", " # dynamic conditional quantum embeddings with the VAE frozen\n", " args[\"quantum_dynamic\"] = True\n", " args[\"cl_comp\"] = False\n", " args[\"pretrained_model_path\"] = \"model_251.pkl\"\n", " args[\"quantum_dims\"] = (4, 16, 16)\n", " args[\"quantum_processes\"] = 2\n", " args[\"training_images\"] = 1.\n", " return SimpleNamespace(**args)" ] }, { "cell_type": "code", "execution_count": 3, "id": "8b91772f-a84e-4eb4-b184-31017274e380", "metadata": {}, "outputs": [], "source": [ "args = parse_args_unpaired_training()\n" ] }, { "cell_type": "markdown", "id": "ca883c16-dafe-49a8-918f-a8cf62b9aa47", "metadata": {}, "source": [ "### Loading Boson Sampler" ] }, { "cell_type": "code", "execution_count": 4, "id": "1d96af3f-a1ca-47e7-a5d0-96216830a708", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "boson sampler dims (2, 4, 16, 16)\n", "time to load model 3.349485158920288\n", "-- Boson Sampler defined --\n" ] } ], "source": [ "if args.quantum_dynamic: \n", " # Set the random seeds\n", " torch.manual_seed(args.seed)\n", " print(\"boson sampler dims\", (args.quantum_processes, ) + args.quantum_dims)\n", " boson_sampler = BosonSampler((args.quantum_processes, ) + args.quantum_dims)\n", " print(\"-- Boson Sampler defined --\")" ] }, { "cell_type": "code", "execution_count": 36, "id": "352580db-da51-4c09-af8c-332b357af67d", "metadata": {}, "outputs": [], "source": [ " accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps)\n", "\n", "\n", "# set up id and name of experiment\n", "exp_id = get_next_id()\n", "os.makedirs(args.output_dir, exist_ok=True)\n", "args.exp_id = exp_id\n", "if accelerator.is_main_process:\n", " os.makedirs(os.path.join(args.output_dir, \"checkpoints\"), exist_ok=True)\n", "\n", "tokenizer = AutoTokenizer.from_pretrained(\"stabilityai/sd-turbo\", subfolder=\"tokenizer\", revision=args.revision,\n", " use_fast=False, )\n", "noise_scheduler_1step = make_1step_sched()\n", "text_encoder = CLIPTextModel.from_pretrained(\"stabilityai/sd-turbo\", subfolder=\"text_encoder\").cuda()\n", "\n", "weight_dtype = torch.float32\n", "\n", "text_encoder.to(accelerator.device, dtype=weight_dtype)\n", "text_encoder.requires_grad_(False)\n", "if args.gan_disc_type == \"vagan_clip\":\n", " net_disc_a = vision_aided_loss.Discriminator(cv_type='clip', loss_type=args.gan_loss_type, device=\"cuda\")\n", " net_disc_a.cv_ensemble.requires_grad_(False) # Freeze feature extractor\n", " net_disc_b = vision_aided_loss.Discriminator(cv_type='clip', loss_type=args.gan_loss_type, device=\"cuda\")\n", " net_disc_b.cv_ensemble.requires_grad_(False) # Freeze feature extractor\n", "\n", "crit_cycle, crit_idt = torch.nn.L1Loss(), torch.nn.L1Loss()" ] }, { "cell_type": "markdown", "id": "db80c57f-f8ce-4855-b251-8bd6eb128fdd", "metadata": {}, "source": [ "### Load pretrained weights" ] }, { "cell_type": "code", "execution_count": 37, "id": "b2cc1780-455c-4f3b-ba98-ea1d5faeffdd", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "- building the model\n", "-----> Number of parameters in UNet: 865910724\n", "-- Loading from pretrained weights for quantum training\n", "Defining cyclegan_d on process 0\n", "-- Loading from pretrained path not None --\n", "--load_ckpt_from_state_dict--\n", "--- CycleGAN defined\n", "-- Conv2d: requires_grad=\n", "-- TOTAL parameters = 1165770722\n", "-- TOTAL trainable parameters = 129794664\n" ] } ], "source": [ "\n", "if args.quantum_dynamic or args.cl_comp:\n", " print(\"- building the model\")\n", " unet, l_modules_unet_encoder, l_modules_unet_decoder, l_modules_unet_others = initialize_unet(\n", " args.lora_rank_unet, return_lora_module_names=True)\n", " vae_a2b, vae_lora_target_modules = initialize_vae(args.lora_rank_vae, return_lora_module_names=True,\n", " dynamic=args.quantum_dynamic)\n", " print(\"-- Loading from pretrained weights for quantum training\")\n", " \n", " print(f\"Defining cyclegan_d on process {accelerator.process_index}\")\n", " model_path = hf_hub_download(repo_id=\"quandelagl/img-to-img-turbo-pretrained\", filename=args.pretrained_model_path)\n", " cyclegan_d = CycleGAN_Turbo(accelerator=accelerator, pretrained_path=model_path) #pretrained_path=args.quantum_start_path,\n", " print(\"--- CycleGAN defined\")\n", " vae_enc = cyclegan_d.vae_enc\n", " vae_dec = cyclegan_d.vae_dec\n", " vae_a2b = cyclegan_d.vae\n", " vae_b2a = cyclegan_d.vae_b2a\n", " unet = cyclegan_d.unet\n", "\n", "\n", " unet.conv_in.requires_grad_(True)\n", "\n", " weight_dtype = torch.float32\n", " \"\"\"vae_a2b.to(accelerator.device, dtype=weight_dtype)\n", " unet.to(accelerator.device, dtype=weight_dtype)\n", " vae_b2a = copy.deepcopy(vae_a2b)\"\"\"\n", "\n", " # freeze the VAE enc and detach it from the gradient\n", " vae_enc.requires_grad_(False) \n", " # train the post_quant_conv layer at the \n", " vae_a2b.post_quant_conv.requires_grad_(True) \n", " vae_b2a.post_quant_conv.requires_grad_(True) \n", "# get the trainable parameters (vae_a2b.encoder and vae_b2a.encoder should have 0 trainable parameters) \n", " params_gen = cyclegan_d.get_traininable_params(unet, vae_a2b, vae_b2a, boson_sampler, dynamic=True, quantum_training=True) \n", " print(f\"-- {vae_a2b.encoder.conv_in.__class__.__name__}: requires_grad={vae_a2b.encoder.conv_in.weight.requires_grad_}\") \n", " \n", "\n", "\n", " \n", "\n", "\n", " # CLASSICAL NETWORKS with no initialization\n", "else:\n", " print(\"--- Classical CycleGAN training ---\")\n", " print(\"- building the model\")\n", " unet, l_modules_unet_encoder, l_modules_unet_decoder, l_modules_unet_others = initialize_unet(\n", " args.lora_rank_unet, return_lora_module_names=True)\n", " vae_a2b, vae_lora_target_modules = initialize_vae(args.lora_rank_vae, return_lora_module_names=True,\n", " dynamic=args.quantum_dynamic)\n", " vae_a2b.to(accelerator.device, dtype=weight_dtype)\n", " unet.to(accelerator.device, dtype=weight_dtype)\n", " unet.conv_in.requires_grad_(True)\n", " vae_b2a = copy.deepcopy(vae_a2b)\n", " params_gen = CycleGAN_Turbo.get_traininable_params(unet, vae_a2b, vae_b2a, boson_sampler)\n", " vae_enc = VAE_encode(vae_a2b, vae_b2a=vae_b2a)\n", " vae_dec = VAE_decode(vae_a2b, vae_b2a=vae_b2a)\n", "\n", " print(\n", " f\"-- {vae_a2b.encoder.conv_in.__class__.__name__}: requires_grad={vae_a2b.encoder.conv_in.weight.requires_grad_}\")\n", " print(\n", " f\"-- For vae_enc = {sum(p.numel() for p in vae_enc.parameters() if p.requires_grad)}, - a2b = {sum(p.numel() for p in vae_a2b.encoder.parameters() if p.requires_grad)} and -b2a = {sum(p.numel() for p in vae_b2a.encoder.parameters() if p.requires_grad)}\")\n", " print(f\"-- For unet = {sum(p.numel() for p in unet.parameters() if p.requires_grad)}\")\n", " print(\n", " f\"-- For unet.conv_in = {sum(p.numel() for p in unet.conv_in.parameters() if p.requires_grad)} - unet.conv_out = {sum(p.numel() for p in unet.conv_out.parameters() if p.requires_grad)} \")\n", "print(f\"-- TOTAL parameters = {sum(p.numel() for p in unet.parameters())+sum(p.numel() for p in vae_a2b.parameters())+sum(p.numel() for p in vae_b2a.parameters())}\")\n", "print(\n", " f\"-- TOTAL trainable parameters = {sum(p.numel() for p in unet.parameters() if p.requires_grad) + sum(p.numel() for p in vae_a2b.parameters() if p.requires_grad) + sum(p.numel() for p in vae_b2a.parameters() if p.requires_grad)}\")\n", "if args.enable_xformers_memory_efficient_attention: \n", " unet.enable_xformers_memory_efficient_attention()\n", "\n", "if args.gradient_checkpointing:\n", " unet.enable_gradient_checkpointing()\n", "\n", "if args.allow_tf32:\n", " torch.backends.cuda.matmul.allow_tf32 = True\n", " " ] }, { "cell_type": "markdown", "id": "fe1a88aa-90f6-4015-8fb5-7588ee46562f", "metadata": {}, "source": [ "### Load and prepare dataset" ] }, { "cell_type": "code", "execution_count": 38, "id": "1aebae97-8546-4aee-9995-d0f80047d11f", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "loading small dataset\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n", "Saving train_a: 0%| | 0/200 [00:00 Length of dataset_train = 400\n", "- Data loaded\n", "- FOR TEST PURPOSE: working with 80 and 80\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "optimizer_gen = torch.optim.AdamW(params_gen, lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2),\n", " weight_decay=args.adam_weight_decay, eps=args.adam_epsilon, )\n", "\n", "params_disc = list(net_disc_a.parameters()) + list(net_disc_b.parameters())\n", "optimizer_disc = torch.optim.AdamW(params_disc, lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2),\n", " weight_decay=args.adam_weight_decay, eps=args.adam_epsilon, )\n", "\n", "_ = load_small_dataset(\"quandelagl/small-bdd-100k\", args.dataset_folder)\n", "print(args.dataset_folder)\n", "dataset_train = UnpairedDataset(dataset_folder=args.dataset_folder, image_prep=args.train_img_prep,\n", " split=\"train\", tokenizer=tokenizer, part = args.training_images)\n", "#here, we could lower the number of images used\n", "print(f\"- Dataset loaded\")\n", "\n", "print(f\"--> Length of dataset_train = {len(dataset_train)}\")\n", "train_dataloader = torch.utils.data.DataLoader(dataset_train, batch_size=args.train_batch_size, shuffle=True,\n", " num_workers=args.dataloader_num_workers)\n", "print(\"- Data loaded\")\n", "T_val = build_transform(args.val_img_prep)\n", "fixed_caption_src = dataset_train.fixed_caption_src\n", "fixed_caption_tgt = dataset_train.fixed_caption_tgt\n", "l_images_src_test = []\n", "for ext in [\"*.jpg\", \"*.jpeg\", \"*.png\", \"*.bmp\"]:\n", " l_images_src_test.extend(glob(os.path.join(args.dataset_folder, \"test_a\", ext)))\n", "l_images_src_test_2 = [el for el in l_images_src_test if os.path.basename(el)]# not in list_failures]\n", "l_images_src_test = l_images_src_test_2\n", "l_images_tgt_test = []\n", "for ext in [\"*.jpg\", \"*.jpeg\", \"*.png\", \"*.bmp\"]:\n", " l_images_tgt_test.extend(glob(os.path.join(args.dataset_folder, \"test_b\", ext)))\n", "\n", "l_images_src_test, l_images_tgt_test = sorted(l_images_src_test), sorted(l_images_tgt_test)\n", "#l_images_src_test, l_images_tgt_test = l_images_src_test[:100], l_images_tgt_test[:100]\n", "print(f\"- FOR TEST PURPOSE: working with {len(l_images_src_test)} and {len(l_images_tgt_test)}\")\n" ] }, { "cell_type": "code", "execution_count": 8, "id": "9b04631f-f5d6-4430-bdc1-42141fd6050a", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "995ca218f5c14ace8594c49f74d29e80", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/80 [00:00 B translation\n", " \"\"\"\n", " output_dir_ref = os.path.join(args.output_dir, \"fid_reference_a2b\")\n", " os.makedirs(output_dir_ref, exist_ok=True)\n", " # transform all images according to the validation transform and save them\n", " for _path in tqdm(l_images_tgt_test):\n", " _img = T_val(Image.open(_path).convert(\"RGB\"))\n", " outf = os.path.join(output_dir_ref, os.path.basename(_path)).replace(\".jpg\", \".png\")\n", " if not os.path.exists(outf):\n", " _img.save(outf)\n", " # compute the features for the reference images\n", " ref_features = get_folder_features(output_dir_ref, model=feat_model, num_workers=0, num=None,\n", " shuffle=False, seed=0, batch_size=8, device=torch.device(\"cuda\"),\n", " mode=\"clean\", custom_fn_resize=None, description=\"\", verbose=True,\n", " custom_image_tranform=None)\n", " a2b_ref_mu, a2b_ref_sigma = np.mean(ref_features, axis=0), np.cov(ref_features, rowvar=False)\n", " \"\"\"\n", " FID reference statistics for B -> A translation\n", " \"\"\"\n", " # transform all images according to the validation transform and save them\n", " output_dir_ref = os.path.join(args.output_dir, \"fid_reference_b2a\")\n", " os.makedirs(output_dir_ref, exist_ok=True)\n", " for _path in tqdm(l_images_src_test):\n", " _img = T_val(Image.open(_path).convert(\"RGB\"))\n", " outf = os.path.join(output_dir_ref, os.path.basename(_path)).replace(\".jpg\", \".png\")\n", " if not os.path.exists(outf):\n", " _img.save(outf)\n", " # compute the features for the reference images\n", " ref_features = get_folder_features(output_dir_ref, model=feat_model, num_workers=0, num=None,\n", " shuffle=False, seed=0, batch_size=8, device=torch.device(\"cuda\"),\n", " mode=\"clean\", custom_fn_resize=None, description=\"\", verbose=True,\n", " custom_image_tranform=None)\n", " b2a_ref_mu, b2a_ref_sigma = np.mean(ref_features, axis=0), np.cov(ref_features, rowvar=False)" ] }, { "cell_type": "code", "execution_count": 9, "id": "8f3a293d-9816-4376-88a3-566a414237f2", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/ubuntu/miniconda3/envs/merlin_notebook/lib/python3.10/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.\n", " warnings.warn(\n", "/home/ubuntu/miniconda3/envs/merlin_notebook/lib/python3.10/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=VGG16_Weights.IMAGENET1K_V1`. You can also use `weights=VGG16_Weights.DEFAULT` to get the most up-to-date weights.\n", " warnings.warn(msg)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "- Defining the scheduler for generators + discriminator\n", "Setting up [LPIPS] perceptual loss: trunk [vgg], v[0.1], spatial [off]\n", "Loading model from: /home/ubuntu/miniconda3/envs/merlin_notebook/lib/python3.10/site-packages/lpips/weights/v0.1/vgg.pth\n" ] }, { "data": { "text/plain": [ "LPIPS(\n", " (scaling_layer): ScalingLayer()\n", " (net): vgg16(\n", " (slice1): Sequential(\n", " (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (1): ReLU(inplace=True)\n", " (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (3): ReLU(inplace=True)\n", " )\n", " (slice2): Sequential(\n", " (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", " (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (6): ReLU(inplace=True)\n", " (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (8): ReLU(inplace=True)\n", " )\n", " (slice3): Sequential(\n", " (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", " (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (11): ReLU(inplace=True)\n", " (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (13): ReLU(inplace=True)\n", " (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (15): ReLU(inplace=True)\n", " )\n", " (slice4): Sequential(\n", " (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", " (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (18): ReLU(inplace=True)\n", " (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (20): ReLU(inplace=True)\n", " (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (22): ReLU(inplace=True)\n", " )\n", " (slice5): Sequential(\n", " (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", " (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (25): ReLU(inplace=True)\n", " (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (27): ReLU(inplace=True)\n", " (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", " (29): ReLU(inplace=True)\n", " )\n", " )\n", " (lin0): NetLinLayer(\n", " (model): Sequential(\n", " (0): Dropout(p=0.5, inplace=False)\n", " (1): Conv2d(64, 1, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " )\n", " )\n", " (lin1): NetLinLayer(\n", " (model): Sequential(\n", " (0): Dropout(p=0.5, inplace=False)\n", " (1): Conv2d(128, 1, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " )\n", " )\n", " (lin2): NetLinLayer(\n", " (model): Sequential(\n", " (0): Dropout(p=0.5, inplace=False)\n", " (1): Conv2d(256, 1, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " )\n", " )\n", " (lin3): NetLinLayer(\n", " (model): Sequential(\n", " (0): Dropout(p=0.5, inplace=False)\n", " (1): Conv2d(512, 1, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " )\n", " )\n", " (lin4): NetLinLayer(\n", " (model): Sequential(\n", " (0): Dropout(p=0.5, inplace=False)\n", " (1): Conv2d(512, 1, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " )\n", " )\n", " (lins): ModuleList(\n", " (0): NetLinLayer(\n", " (model): Sequential(\n", " (0): Dropout(p=0.5, inplace=False)\n", " (1): Conv2d(64, 1, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " )\n", " )\n", " (1): NetLinLayer(\n", " (model): Sequential(\n", " (0): Dropout(p=0.5, inplace=False)\n", " (1): Conv2d(128, 1, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " )\n", " )\n", " (2): NetLinLayer(\n", " (model): Sequential(\n", " (0): Dropout(p=0.5, inplace=False)\n", " (1): Conv2d(256, 1, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " )\n", " )\n", " (3-4): 2 x NetLinLayer(\n", " (model): Sequential(\n", " (0): Dropout(p=0.5, inplace=False)\n", " (1): Conv2d(512, 1, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " )\n", " )\n", " )\n", ")" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "print(\"- Defining the scheduler for generators + discriminator\")\n", "lr_scheduler_gen = get_scheduler(args.lr_scheduler, optimizer=optimizer_gen,\n", " num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,\n", " num_training_steps=args.max_train_steps * accelerator.num_processes,\n", " num_cycles=args.lr_num_cycles, power=args.lr_power)\n", "lr_scheduler_disc = get_scheduler(args.lr_scheduler, optimizer=optimizer_disc,\n", " num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,\n", " num_training_steps=args.max_train_steps * accelerator.num_processes,\n", " num_cycles=args.lr_num_cycles, power=args.lr_power)\n", "\n", "net_lpips = lpips.LPIPS(net='vgg')\n", "net_lpips.cuda()\n", "net_lpips.requires_grad_(False)" ] }, { "cell_type": "markdown", "id": "b9279d98-8069-47e5-8b2f-da181dd4e15d", "metadata": {}, "source": [ "### Defining the text embeddings" ] }, { "cell_type": "code", "execution_count": 10, "id": "34b802b5-971b-4b10-a1c8-3a7bc5ee90e3", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "- Defining the text embeddings\n" ] } ], "source": [ "\n", "fixed_a2b_tokens = \\\n", "tokenizer(fixed_caption_tgt, max_length=tokenizer.model_max_length, padding=\"max_length\", truncation=True,\n", " return_tensors=\"pt\").input_ids[0]\n", "fixed_a2b_emb_base = text_encoder(fixed_a2b_tokens.cuda().unsqueeze(0))[0].detach()\n", "fixed_b2a_tokens = \\\n", "tokenizer(fixed_caption_src, max_length=tokenizer.model_max_length, padding=\"max_length\", truncation=True,\n", " return_tensors=\"pt\").input_ids[0]\n", "fixed_b2a_emb_base = text_encoder(fixed_b2a_tokens.cuda().unsqueeze(0))[0].detach()\n", "del text_encoder, tokenizer # free up some memory\n", "\n", "unet, vae_enc, vae_dec, net_disc_a, net_disc_b = accelerator.prepare(unet, vae_enc, vae_dec, net_disc_a, net_disc_b) \n", "net_lpips, optimizer_gen, optimizer_disc, train_dataloader, lr_scheduler_gen, lr_scheduler_disc = accelerator.prepare(\n", " net_lpips, optimizer_gen, optimizer_disc, train_dataloader, lr_scheduler_gen, lr_scheduler_disc\n", ")" ] }, { "cell_type": "markdown", "id": "dcaaed4d-33af-436d-b5de-edda8ce54ba1", "metadata": {}, "source": [ "### Training the model with BosonSampler" ] }, { "cell_type": "code", "execution_count": null, "id": "0a33e5c2-cb70-4347-be4e-45b7946866e7", "metadata": {}, "outputs": [], "source": [ "if accelerator.is_main_process:\n", " config = dict(vars(args))\n", " config[\"run_name\"] = f\"qCycleGAN-{args.exp_id}\"\n", "# checking if the VAE_encoder is frozen on the pretrained weights\n", "first_epoch = 0\n", "global_step = 0\n", "progress_bar = tqdm(range(0, args.max_train_steps), initial=global_step, desc=\"Steps\",\n", " disable=not accelerator.is_local_main_process, )\n", "# turn off eff. attn for the disc\n", "for name, module in net_disc_a.named_modules():\n", " if \"attn\" in name:\n", " module.fused_attn = False\n", "for name, module in net_disc_b.named_modules():\n", " if \"attn\" in name:\n", " module.fused_attn = False\n", "\n", "for epoch in range(first_epoch, args.max_train_epochs):\n", "\n", " for step, batch in enumerate(train_dataloader):\n", " \n", " t_start = time.time()\n", " l_acc = [unet, net_disc_a, net_disc_b, vae_enc, vae_dec]\n", " with accelerator.accumulate(*l_acc):\n", " img_a = batch[\"pixel_values_src\"].to(dtype=weight_dtype)\n", " img_b = batch[\"pixel_values_tgt\"].to(dtype=weight_dtype)\n", "\n", " bsz = img_a.shape[0]\n", " fixed_a2b_emb = fixed_a2b_emb_base.repeat(bsz, 1, 1).to(dtype=weight_dtype)\n", " fixed_b2a_emb = fixed_b2a_emb_base.repeat(bsz, 1, 1).to(dtype=weight_dtype)\n", " timesteps = torch.tensor([noise_scheduler_1step.config.num_train_timesteps - 1] * bsz,\n", " device=img_a.device).long()\n", " \n", " # A -> fake B -> rec A\n", " if args.quantum_dynamic:\n", " boson_sampler.unitaries = None\n", " #print(\"Real A -> Fake B\")\n", " cyc_fake_b, q_emb_a = cyclegan_d.forward_with_networks_dynamic(img_a, \"a2b\", vae_enc, unet, vae_dec, \n", " noise_scheduler_1step, timesteps,\n", " fixed_a2b_emb, bs=boson_sampler, \n", " device = accelerator.device, accelerator = accelerator)\n", " if torch.isnan(cyc_fake_b).any(): \n", " global_step += 1\n", " continue \n", " cyc_rec_a, q_emb_cyc_fake_b = cyclegan_d.forward_with_networks_dynamic(cyc_fake_b, \"b2a\", vae_enc,\n", " unet, vae_dec,\n", " noise_scheduler_1step,\n", " timesteps,\n", " text_emb=fixed_b2a_emb,\n", " bs=boson_sampler,\n", " device = accelerator.device, accelerator = accelerator)\n", " if torch.isnan(cyc_rec_a).any():\n", " print(\"!!!!! JUMPING TO NEXT ITER during Fake B -> Rec A!!!!!\")\n", " global_step += 1\n", " continue\n", " else:\n", " cyc_fake_b = CycleGAN_Turbo.forward_with_networks(img_a, \"a2b\", vae_enc, unet, vae_dec,\n", " noise_scheduler_1step, timesteps, fixed_a2b_emb)\n", " cyc_rec_a = CycleGAN_Turbo.forward_with_networks(cyc_fake_b, \"b2a\", vae_enc, unet, vae_dec,\n", " noise_scheduler_1step, timesteps, fixed_b2a_emb)\n", " #print(f\"--loss cycle a-- (Process {accelerator.process_index})\")\n", " loss_cycle_a = crit_cycle(cyc_rec_a, img_a) * args.lambda_cycle\n", "\n", "\n", " loss_cycle_a += net_lpips(cyc_rec_a, img_a).mean() * args.lambda_cycle_lpips\n", " if args.quantum_dynamic:\n", " cyc_fake_a, q_emb_b = cyclegan_d.forward_with_networks_dynamic(img_b, \"b2a\", vae_enc, unet, vae_dec,\n", " noise_scheduler_1step, timesteps,\n", " fixed_b2a_emb, bs=boson_sampler,\n", " device = accelerator.device, accelerator = accelerator)\n", " if torch.isnan(cyc_fake_a).any():\n", " print(\"!!!!! JUMPING TO NEXT ITER during img B -> fake A !!!!!\")\n", " global_step += 1\n", " continue\n", " #print(\"fake A -> rec B\")\n", " cyc_rec_b, q_emb_cyc_fake_a = cyclegan_d.forward_with_networks_dynamic(cyc_fake_a, \"a2b\", vae_enc,\n", " unet, vae_dec,\n", " noise_scheduler_1step,\n", " timesteps,\n", " text_emb=fixed_a2b_emb,\n", " bs=boson_sampler,\n", " device = accelerator.device, accelerator = accelerator)\n", " if torch.isnan(cyc_rec_b).any():\n", " print(\"!!!!! JUMPING TO NEXT ITER during fake A -> rec B!!!!!\")\n", " global_step += 1\n", " continue\n", " else:\n", " cyc_fake_a = CycleGAN_Turbo.forward_with_networks(img_b, \"b2a\", vae_enc, unet, vae_dec,\n", " noise_scheduler_1step, timesteps, fixed_b2a_emb)\n", " cyc_rec_b = CycleGAN_Turbo.forward_with_networks(cyc_fake_a, \"a2b\", vae_enc, unet, vae_dec,\n", " noise_scheduler_1step, timesteps, fixed_a2b_emb)\n", " loss_cycle_b = crit_cycle(cyc_rec_b, img_b) * args.lambda_cycle\n", "\n", " loss_cycle_b += net_lpips(cyc_rec_b, img_b).mean() * args.lambda_cycle_lpips\n", " accelerator.backward(loss_cycle_a + loss_cycle_b, retain_graph=False)\n", " if accelerator.sync_gradients:\n", " accelerator.clip_grad_norm_(params_gen, args.max_grad_norm)\n", "\n", " optimizer_gen.step()\n", " lr_scheduler_gen.step()\n", " optimizer_gen.zero_grad()\n", "\n", " if args.quantum_dynamic:\n", " fake_a = cyclegan_d.forward_with_networks_dynamic(img_b, \"b2a\", vae_enc, unet, vae_dec,\n", " noise_scheduler_1step, timesteps, fixed_b2a_emb,\n", " q_emb=q_emb_b,\n", " device = accelerator.device)\n", " fake_b = cyclegan_d.forward_with_networks_dynamic(img_a, \"a2b\", vae_enc, unet, vae_dec,\n", " noise_scheduler_1step, timesteps, fixed_a2b_emb,\n", " q_emb=q_emb_a,\n", " device = accelerator.device)\n", " else:\n", " fake_a = CycleGAN_Turbo.forward_with_networks(img_b, \"b2a\", vae_enc, unet, vae_dec,\n", " noise_scheduler_1step, timesteps, fixed_b2a_emb)\n", " fake_b = CycleGAN_Turbo.forward_with_networks(img_a, \"a2b\", vae_enc, unet, vae_dec,\n", " noise_scheduler_1step, timesteps, fixed_a2b_emb)\n", "\n", " loss_gan_a = net_disc_a(fake_b, for_G=True).mean() * args.lambda_gan\n", " loss_gan_b = net_disc_b(fake_a, for_G=True).mean() * args.lambda_gan\n", " accelerator.backward(loss_gan_a + loss_gan_b, retain_graph=False)\n", " if accelerator.sync_gradients:\n", " accelerator.clip_grad_norm_(params_gen, args.max_grad_norm)\n", " optimizer_gen.step()\n", " lr_scheduler_gen.step()\n", " optimizer_gen.zero_grad()\n", " optimizer_disc.zero_grad()\n", "\n", " \"\"\"\n", " Identity Objective\n", " \"\"\"\n", " if args.quantum_dynamic:\n", " idt_a = cyclegan_d.forward_with_networks_dynamic(img_b, \"a2b\", vae_enc, unet, vae_dec,\n", " noise_scheduler_1step, timesteps, fixed_a2b_emb,\n", " q_emb=q_emb_b,\n", " device = accelerator.device)\n", " else:\n", " idt_a = CycleGAN_Turbo.forward_with_networks(img_b, \"a2b\", vae_enc, unet, vae_dec,\n", " noise_scheduler_1step, timesteps, fixed_a2b_emb)\n", " loss_idt_a = crit_idt(idt_a, img_b) * args.lambda_idt\n", "\n", " loss_idt_a += net_lpips(idt_a, img_b).mean() * args.lambda_idt_lpips\n", "\n", " if args.quantum_dynamic:\n", " idt_b = cyclegan_d.forward_with_networks_dynamic(img_a, \"b2a\", vae_enc, unet, vae_dec,\n", " noise_scheduler_1step, timesteps, fixed_b2a_emb,\n", " q_emb=q_emb_a,\n", " device = accelerator.device)\n", " else:\n", " idt_b = CycleGAN_Turbo.forward_with_networks(img_a, \"b2a\", vae_enc, unet, vae_dec,\n", " noise_scheduler_1step, timesteps, fixed_b2a_emb)\n", " loss_idt_b = crit_idt(idt_b, img_a) * args.lambda_idt\n", " loss_idt_b += net_lpips(idt_b, img_a).mean() * args.lambda_idt_lpips\n", " loss_g_idt = loss_idt_a + loss_idt_b\n", " accelerator.backward(loss_g_idt, retain_graph=False)\n", " if accelerator.sync_gradients:\n", " accelerator.clip_grad_norm_(params_gen, args.max_grad_norm)\n", " optimizer_gen.step()\n", " lr_scheduler_gen.step()\n", " optimizer_gen.zero_grad()\n", " \"\"\"\n", " Discriminator for task a->b and b->a (fake inputs)\n", " \"\"\"\n", " loss_D_A_fake = net_disc_a(fake_b.detach(), for_real=False).mean() * args.lambda_gan\n", " loss_D_B_fake = net_disc_b(fake_a.detach(), for_real=False).mean() * args.lambda_gan\n", "\n", " loss_D_fake = (loss_D_A_fake + loss_D_B_fake) * 0.5\n", " accelerator.backward(loss_D_fake, retain_graph=False)\n", " if accelerator.sync_gradients:\n", " params_to_clip = list(net_disc_a.parameters()) + list(net_disc_b.parameters())\n", " accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)\n", " optimizer_disc.step()\n", " lr_scheduler_disc.step()\n", " optimizer_disc.zero_grad()\n", " \"\"\"\n", " Discriminator for task a->b and b->a (real inputs)\n", " \"\"\"\n", " loss_D_A_real = net_disc_a(img_b, for_real=True).mean() * args.lambda_gan\n", " loss_D_B_real = net_disc_b(img_a, for_real=True).mean() * args.lambda_gan\n", " loss_D_real = (loss_D_A_real + loss_D_B_real) * 0.5\n", " accelerator.backward(loss_D_real, retain_graph=False)\n", " if accelerator.sync_gradients:\n", " params_to_clip = list(net_disc_a.parameters()) + list(net_disc_b.parameters())\n", " accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)\n", " optimizer_disc.step()\n", " lr_scheduler_disc.step()\n", " optimizer_disc.zero_grad()\n", " print(f\"For this step {global_step}: time = {time.time() - t_start}\")\n", " logs = {}\n", " logs[\"cycle_a\"] = loss_cycle_a.detach().item()\n", " logs[\"cycle_b\"] = loss_cycle_b.detach().item()\n", " logs[\"gan_a\"] = loss_gan_a.detach().item()\n", " logs[\"gan_b\"] = loss_gan_b.detach().item()\n", " logs[\"disc_a\"] = loss_D_A_fake.detach().item() + loss_D_A_real.detach().item()\n", " logs[\"disc_b\"] = loss_D_B_fake.detach().item() + loss_D_B_real.detach().item()\n", " logs[\"idt_a\"] = loss_idt_a.detach().item()\n", " logs[\"idt_b\"] = loss_idt_b.detach().item()\n", "\n", " if accelerator.sync_gradients:\n", " progress_bar.update(1)\n", " global_step += 1\n", " if accelerator.is_main_process:\n", " eval_unet = accelerator.unwrap_model(unet)\n", " eval_vae_enc = accelerator.unwrap_model(vae_enc)\n", " eval_vae_dec = accelerator.unwrap_model(vae_dec)\n", "\n", " if global_step % args.checkpointing_steps == 1 or global_step == args.max_train_steps:\n", " outf = os.path.join(args.output_dir, \"checkpoints\", f\"model_{global_step}.pkl\")\n", " sd = {}\n", " sd[\"l_target_modules_encoder\"] = l_modules_unet_encoder\n", " sd[\"l_target_modules_decoder\"] = l_modules_unet_decoder\n", " sd[\"l_modules_others\"] = l_modules_unet_others\n", " sd[\"rank_unet\"] = args.lora_rank_unet\n", " sd[\"sd_encoder\"] = get_peft_model_state_dict(eval_unet, adapter_name=\"default_encoder\")\n", " sd[\"sd_decoder\"] = get_peft_model_state_dict(eval_unet, adapter_name=\"default_decoder\")\n", " sd[\"sd_other\"] = get_peft_model_state_dict(eval_unet, adapter_name=\"default_others\")\n", " sd[\"rank_vae\"] = args.lora_rank_vae\n", " sd[\"vae_lora_target_modules\"] = vae_lora_target_modules\n", " sd[\"sd_vae_enc\"] = eval_vae_enc.state_dict()\n", " sd[\"sd_vae_dec\"] = eval_vae_dec.state_dict()\n", " sd[\"quantum_params\"] = boson_sampler.model.state_dict()\n", " torch.save(sd, outf, _use_new_zipfile_serialization=False)\n", " gc.collect()\n", " torch.cuda.empty_cache()\n", "\n", " # compute val FID and DINO-Struct scores\n", " if global_step % args.validation_steps == 1:\n", " _timesteps = torch.tensor([noise_scheduler_1step.config.num_train_timesteps - 1] * 1, device=\"cuda\").long() \n", " net_dino = DinoStructureLoss() \n", " \"\"\"Evaluate \"A->B\" \"\"\" \n", " fid_output_dir = os.path.join(args.output_dir, f\"fid-{global_step}/samples_a2b\") \n", " os.makedirs(fid_output_dir, exist_ok=True) \n", " l_dino_scores_a2b = [] \n", " \n", " for idx, input_img_path in enumerate(tqdm(l_images_src_test)): \n", " if idx > args.validation_num_images and args.validation_num_images > 0: \n", " break \n", " outf = os.path.join(fid_output_dir, f\"{idx}.png\") \n", " with torch.no_grad(): \n", " # print(f\"input_img_path = {input_img_path}\") \n", " input_img = T_val(Image.open(input_img_path).convert(\"RGB\"))\n", " img_a = transforms.ToTensor()(input_img)\n", " img_a = transforms.Normalize([0.5], [0.5])(img_a).unsqueeze(0).cuda()\n", " src_name = os.path.basename(input_img_path)\n", " if args.quantum_dynamic:\n", " eval_fake_b, q_emb_a_val = cyclegan_d.forward_with_networks_dynamic(img_a, \"a2b\",\n", " eval_vae_enc,\n", " eval_unet,\n", " eval_vae_dec,\n", " noise_scheduler_1step,\n", " _timesteps,\n", " fixed_a2b_emb[\n", " 0:1],\n", " bs=boson_sampler,\n", " device = accelerator.device)\n", "\n", " else:\n", " eval_fake_b = CycleGAN_Turbo.forward_with_networks(img_a, \"a2b\", eval_vae_enc,\n", " eval_unet,\n", " eval_vae_dec,\n", " noise_scheduler_1step,\n", " _timesteps, fixed_a2b_emb[0:1])\n", " \n", " eval_fake_b_pil = transforms.ToPILImage()(eval_fake_b[0] * 0.5 + 0.5)\n", " \n", " eval_fake_b_pil.save(outf)\n", " a = net_dino.preprocess(input_img).unsqueeze(0).cuda()\n", " b = net_dino.preprocess(eval_fake_b_pil).unsqueeze(0).cuda()\n", " dino_ssim = net_dino.calculate_global_ssim_loss(a, b).item()\n", " l_dino_scores_a2b.append(dino_ssim)\n", " dino_score_a2b = np.mean(l_dino_scores_a2b)\n", " gen_features = get_folder_features(fid_output_dir, model=feat_model, num_workers=0, num=None,\n", " shuffle=False, seed=0, batch_size=8,\n", " device=torch.device(\"cuda\"),\n", " mode=\"clean\", custom_fn_resize=None, description=\"\",\n", " verbose=True,\n", " custom_image_tranform=None)\n", " ed_mu, ed_sigma = np.mean(gen_features, axis=0), np.cov(gen_features, rowvar=False)\n", " score_fid_a2b = frechet_distance(a2b_ref_mu, a2b_ref_sigma, ed_mu, ed_sigma)\n", " print(f\"step={global_step}, fid(a2b)={score_fid_a2b:.2f}, dino(a2b)={dino_score_a2b:.3f}\")\n", " # remove folder\n", " shutil.rmtree(fid_output_dir)\n", " \"\"\"\n", " compute FID for \"B->A\"\n", " \"\"\"\n", " fid_output_dir = os.path.join(args.output_dir, f\"fid-{global_step}/samples_b2a\")\n", " os.makedirs(fid_output_dir, exist_ok=True)\n", " l_dino_scores_b2a = []\n", " for idx, input_img_path in enumerate(tqdm(l_images_tgt_test)):\n", " if idx > args.validation_num_images and args.validation_num_images > 0:\n", " break\n", " outf = os.path.join(fid_output_dir, f\"{idx}.png\")\n", " with torch.no_grad():\n", " input_img = T_val(Image.open(input_img_path).convert(\"RGB\"))\n", " img_b = transforms.ToTensor()(input_img)\n", " img_b = transforms.Normalize([0.5], [0.5])(img_b).unsqueeze(0).cuda()\n", " src_name = os.path.basename(input_img_path)\n", " if args.quantum_dynamic:\n", " eval_fake_a, q_emb_b_val = cyclegan_d.forward_with_networks_dynamic(img_b, \"b2a\",\n", " eval_vae_enc,\n", " eval_unet,\n", " eval_vae_dec,\n", " noise_scheduler_1step,\n", " _timesteps,\n", " fixed_b2a_emb[\n", " 0:1],\n", " bs=boson_sampler,\n", " device = accelerator.device)\n", "\n", " else:\n", " eval_fake_a = CycleGAN_Turbo.forward_with_networks(img_b, \"b2a\", eval_vae_enc,\n", " eval_unet,\n", " eval_vae_dec,\n", " noise_scheduler_1step,\n", " _timesteps, fixed_b2a_emb[0:1])\n", "\n", "\n", " eval_fake_a_pil = transforms.ToPILImage()(eval_fake_a[0] * 0.5 + 0.5)\n", "\n", " # only save images for which idx is 0-100\n", " #if idx < 100:\n", " eval_fake_a_pil.save(outf)\n", " a = net_dino.preprocess(input_img).unsqueeze(0).cuda()\n", " b = net_dino.preprocess(eval_fake_a_pil).unsqueeze(0).cuda()\n", " dino_ssim = net_dino.calculate_global_ssim_loss(a, b).item()\n", " l_dino_scores_b2a.append(dino_ssim)\n", " dino_score_b2a = np.mean(l_dino_scores_b2a)\n", " gen_features = get_folder_features(fid_output_dir, model=feat_model, num_workers=0, num=None,\n", " shuffle=False, seed=0, batch_size=8,\n", " device=torch.device(\"cuda\"),\n", " mode=\"clean\", custom_fn_resize=None, description=\"\",\n", " verbose=True,\n", " custom_image_tranform=None)\n", " ed_mu, ed_sigma = np.mean(gen_features, axis=0), np.cov(gen_features, rowvar=False)\n", " score_fid_b2a = frechet_distance(b2a_ref_mu, b2a_ref_sigma, ed_mu, ed_sigma)\n", " print(f\"step={global_step}, fid(b2a)={score_fid_b2a}, dino(b2a)={dino_score_b2a:.3f}\")\n", " # remove folder\n", " shutil.rmtree(fid_output_dir)\n", " logs[\"val/fid_a2b\"], logs[\"val/fid_b2a\"] = score_fid_a2b, score_fid_b2a\n", " logs[\"val/dino_struct_a2b\"], logs[\"val/dino_struct_b2a\"] = dino_score_a2b, dino_score_b2a\n", " del net_dino # free up memory\n", "\n", " progress_bar.set_postfix(**logs)\n", " accelerator.log(logs, step=global_step)\n", " if global_step >= args.max_train_steps:\n", " break\n", " " ] }, { "cell_type": "markdown", "id": "31e6a8dd-7679-43d6-8977-adb76d6a10fe", "metadata": {}, "source": [ "### Example of a Base Image" ] }, { "cell_type": "code", "execution_count": 28, "id": "4bedb9bd-e7fd-4d38-b9bf-6fa231444eab", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-1.0..1.0].\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "base_image = img_a.detach().cpu().numpy()[0].transpose(1, 2, 0)\n", "plt.imshow(base_image)" ] }, { "cell_type": "markdown", "id": "abddb17a-fe58-4972-9283-ddee310b36bf", "metadata": {}, "source": [ "### Image generated by the model" ] }, { "cell_type": "code", "execution_count": null, "id": "fa90f233-f48e-499b-aad9-55266a4f7676", "metadata": {}, "outputs": [], "source": [ "fake_image = fake_b.detach().cpu().numpy()[0].transpose(1, 2, 0)\n", "plt.imshow(fake_image)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.16" } }, "nbformat": 4, "nbformat_minor": 5 }