深度学习系列38:Dalle2模型
1. 快速入门
1.1 diffusion模型
diffusion模型从原始图片出发增加噪声,然后再尝试重建
另外还用GLIDE模型来进行图像解码,与普通diffusion模型不同的是,它还加入了text embedding和clip embedding:
1.2 Dalle2模型
Dalle2模型基于CLIP模型,流程如下。其中Prior采用diffusion模型
为啥要这么设计呢?论文说是尝试出来的。
加入把“a hedgedog using a calculator”直接输入decoder,得到下图:
加上text embedding的话是这样:
加上diffusion模型和image embedding,得到下图:
Delle2生成的图像是否ok,是人工打标的,维度包括caption similarity、photorealism、sample diversity。
1.3 多样性
使用下面的模型生成多种图片:
2. 训练代码
安装:pip install dalle2-pytorch
2.1 一般流程
首先要训练clip:
import torch
from dalle2_pytorch import CLIP
clip = CLIP().cuda()
loss = clip(text,images,return_loss = True)
loss.backward()
然后训练解码器(基于CLIP的image embedding),这里使用一个Unet来作为解码器:
import torch
from dalle2_pytorch import Unet, Decoder, CLIP
unet = Unet().cuda()
decoder = Decoder(unet = unet,clip = clip).cuda()
images = torch.randn(4, 3, 256, 256).cuda()
loss = decoder(images)
loss.backward()
最后训练prior(基于CLIP的text embedding生成image embedding),这里使用Diffusion模型:
import torch
from dalle2_pytorch import DiffusionPriorNetwork, DiffusionPrior, CLIP
prior_network = DiffusionPriorNetwork().cuda()
diffusion_prior = DiffusionPrior(net = prior_network,clip = clip).cuda()
loss = diffusion_prior(text, images)
loss.backward()
2.2 生成图片
需要用到训练好的DiffusionPrior和Decoder:
from dalle2_pytorch import DALLE2
dalle2 = DALLE2(
prior = diffusion_prior,
decoder = decoder
)
texts = ['glistening morning dew on a flower petal']
images = dalle2(texts) # (1, 3, 256, 256)
3. 网上资源
3.1 使用现有CLIP
使用OpenAIClipAdapter类,并将其传给diffusion_prior和decoder进行训练:
import torch
from dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder, OpenAIClipAdapter
# openai pretrained clip - defaults to ViT-B/32
clip = OpenAIClipAdapter()
# mock data
text = torch.randint(0, 49408, (4, 256)).cuda()
images = torch.randn(4, 3, 256, 256).cuda()
# prior networks (with transformer)
prior_network = DiffusionPriorNetwork(
dim = 512,
depth = 6,
dim_head = 64,
heads = 8
).cuda()
diffusion_prior = DiffusionPrior(
net = prior_network,
clip = clip,
timesteps = 100,
cond_drop_prob = 0.2
).cuda()
loss = diffusion_prior(text, images)
loss.backward()
# do above for many steps ...
# decoder (with unet)
unet1 = Unet(
dim = 128,
image_embed_dim = 512,
cond_dim = 128,
channels = 3,
dim_mults=(1, 2, 4, 8)
).cuda()
unet2 = Unet(
dim = 16,
image_embed_dim = 512,
cond_dim = 128,
channels = 3,
dim_mults = (1, 2, 4, 8, 16)
).cuda()
decoder = Decoder(
unet = (unet1, unet2),
image_sizes = (128, 256),
clip = clip,
timesteps = 100,
image_cond_drop_prob = 0.1,
text_cond_drop_prob = 0.5,
condition_on_text_encodings = False # set this to True if you wish to condition on text during training and sampling
).cuda()
for unet_number in (1, 2):
loss = decoder(images, unet_number = unet_number) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much
loss.backward()
# do above for many steps
dalle2 = DALLE2(
prior = diffusion_prior,
decoder = decoder
)
images = dalle2(
['a butterfly trying to escape a tornado'],
cond_scale = 2. # classifier free guidance strength (> 1 would strengthen the condition)
)
# save your image (in this example, of size 256x256)
3.2 使用现成的prior模型
参考这里:https://huggingface.co/zenglishuci/conditioned-prior,这里有各种尺寸的模型。
下面是加载prior模型的代码
def load_diffusion_model(dprior_path, device, clip_choice):
loaded_obj = torch.load(str(dprior_path), map_location='cpu')
if clip_choice == "ViT-B/32":
dim = 512
else:
dim = 768
prior_network = DiffusionPriorNetwork(
dim=dim,
depth=12,
dim_head=64,
heads=12,
normformer=True
).to(device)
diffusion_prior = DiffusionPrior(
net=prior_network,
clip=OpenAIClipAdapter(clip_choice),
image_embed_dim=dim,
timesteps=1000,
cond_drop_prob=0.1,
loss_type="l2",
).to(device)
diffusion_prior.load_state_dict(loaded_obj["model"], strict=True)
diffusion_prior = DiffusionPriorTrainer(
diffusion_prior = diffusion_prior,
lr = 1.1e-4,
wd = 6.02e-2,
max_grad_norm = 0.5,
amp = False,
).to(device)
diffusion_prior.optimizer.load_state_dict(loaded_obj['optimizer'])
diffusion_prior.scaler.load_state_dict(loaded_obj['scaler'])
return diffusion_prior
来源:IE06