CLIP模型的使用和训练-利用CLIP实现zero-shot的分类任务

CLIP模型

文章目录

  • CLIP模型
  • @[toc]
  • 1 论文介绍
  • 1.1 训练阶段
  • 1.2 测试阶段
  • 1.3 优缺点
  • 1.4 官方给定的实验结果
  • 2 利用CLIP做分类任务
  • 2.1 识别杯子的二分类任务
  • 2.2 人脸分类(celebface)
  • 3 CLIP的再训练
  • 1 论文介绍

    官方网站

    1.1 训练阶段

    image.png-94.6kB

    模型架构分为两部分,图像编码器和文本编码器,图像编码器可以是比如 resnet50,然后文本编码器可以是 transformer。

    训练数据是网络社交媒体上搜集的图像文本对。在训练阶段,对于一个batch 的数据,首先通过文本编码器和图像编码器,得到文本和图像的特征,接着将所有的文本和图像特征分别计算内积,就能得到一个矩阵,然后从图像的角度看,行方向就是一个分类器,从文本角度看,列方向也是一个分类器。

    而由于我们已经知道一个batch中的文本和图像的匹配关系,所以目标函数就是最大化同一对图像和文本特征的内积,也就是矩阵对角线上的元素,而最小化与不相关特征的内积。文章的作者从社交媒体上搜集了有大约4亿对的数据

    1.2 测试阶段

    image.png-88.4kB

    在测试阶段,可以直接将训练好的CLIP用于其他数据集而不需要finetune。和训练阶段类似,首先将需要分类的图像经过编码器得到特征,然后对于目标任务数据集的每一个标签,或者你自己定义的标签,都构造一段对应的文本,如上图中的 dog 会改造成 “A photo of a dog”,以此类推。然后经过编码器得到文本和图像特征,接着将文本特征与图像特征做内积,内积最大对应的标签就是图像的分类结果。这就完成了目标任务上的 zero-shot 分类。

    1.3 优缺点

  • 千万不要被它zero-shot的能力吓到,这不是真正的zero-shot!在400M个文本图像配对的训练中,模型肯定看到了大量打着相关文本标签的图像,而且图像的应用范围比ImageNet要广得多——这也是为什么方法能够在一些高级场景(如clipart)轻松超越ImageNet预训练模型。但是要说这种方法碾压了有监督方法,就有点震惊体哗众取宠的意味了。
  • 另一个耐人寻味的地方,是方法同时训练了图像和文本特征(感谢评论区 @llll 的提醒,一开始我看成只训练图像了)。我直觉地认为文本预训练特征比视觉预训练特征更可靠,但是作者却放弃了OpenAI祖传的超大的文本预训练模型,令人略感意外。尤其是,NLP的预训练模型体量远超视觉预训练模型,所以固定文本模型,也许是更实用的方法?
  • 最让我感兴趣的问题,是图像和文本之间的交互方式。直接用文本的encoding结果做为图像的监督信号,显然噪声太大了;能否借鉴captioning等方向的做法,允许图像和文本在encoding过程中多次交互,从而提升效果?当然,这里还是涉及到语言模型太大,无法高效训练。不过,OpenAI也可以选择暴力出奇迹,直接从头训练大规模的跨模态预训练模型。只是这样做的话,400M的数据集可能就太小了。
  • 再往深了说,NLP的预训练之所以能做得好,关键是pretext任务比较好。相比起来,CV还在苦苦寻找合适的pretext任务。当前我对跨模态的最大预期,就是能够在NLP的辅助下,定义CV的pretext任务。CLIP迈出了第一步,前面的路还长得很。
  • 1.4 官方给定的实验结果

    image.png-99.1kB

    2 利用CLIP做分类任务

    2.1 识别杯子的二分类任务

    import os
    import clip
    import torch
    from torchvision.datasets import CIFAR100
    from PIL import Image
    
    img_pah = 'cup3.jpg'
    classes = ['cup', 'not_cup']
    
    #加载模型
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model, preprocess = clip.load('ViT-B/32', device)
    
    
    #准备输入集
    image = Image.open(img_pah)
    image_input = preprocess(image).unsqueeze(0).to(device)
    text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in classes]).to(device) #生成文字描述
    
    #特征编码
    with torch.no_grad():
        image_features = model.encode_image(image_input)
        text_features = model.encode_text(text_inputs)
    
    #选取参数最高的标签
    image_features /= image_features.norm(dim=-1, keepdim=True)
    text_features /= text_features.norm(dim=-1, keepdim=True)
    similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1) #对图像描述和图像特征  
    values, indices = similarity[0].topk(1)
    
    #输出结果
    print("\nTop predictions:\n")
    print('classes:{} score:{:.2f}'.format(classes[indices.item()], values.item()))
    
    

    针对与其他分类任务,只需要更改classes即可

    2.2 人脸分类(celebface)

    import os
    from torch.utils.data import DataLoader
    import clip
    import torch
    import torchvision
    import time
    
    device = "cuda" if torch.cuda.is_available() else "cpu"
    
    def model_load(model_name):
        # 加载模型
        model, preprocess = clip.load(model_name, device) #ViT-B/32 RN50x16
        return model, preprocess
    
    def data_load(data_path):
        #加载数据集和文字描述
        celeba = torchvision.datasets.CelebA(root='CELEBA', split='test', download=True)
        text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in celeba.attr_names]).to(device)
        return celeba, text_inputs
    
    
    def test_model(start, end, celeba, text_inputs, model, preprocess):
        #测试模型
        length = end - start + 1
        face_accuracy = 0
        face_score = 0
    
        for i, data in enumerate(celeba):
            face_result = 0
            if i < start:
                continue
            image, target = data
            image_input = preprocess(image).unsqueeze(0).to(device)
    
            with torch.no_grad():
                image_features = model.encode_image(image_input)
                text_features = model.encode_text(text_inputs)
    
            image_features /= image_features.norm(dim=-1, keepdim=True)
            text_features /= text_features.norm(dim=-1, keepdim=True)
    
            text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
            top_score, top_label = text_probs.topk(6, dim=-1)
            for k, score in zip(top_label[0], top_score[0]):
                if k.item() < 40 and target[k.item()] == 1:
                    face_result = 1
                    face_score += score.item()
                    print('Predict right! The predicted is {}'.format(celeba.attr_names[k.item()]))
                else:
                    print('Predict flase! The predicted is {}'.format(celeba.attr_names[k.item()]))
            face_accuracy += face_result
    
            if i == end:
                break
        face_score = face_score / length
        face_accuracy = face_accuracy / length
    
        return face_score, face_accuracy
    
    def main():
        start = 0
        end = 1000
        model_name = 'ViT-B/32' #ViT-B/32 RN50x16
        data_path = 'CELEBA'
    
        time_start = time.time()
        model, preprocess = model_load(model_name)
        celeba, text_inputs = data_load(data_path)
        face_score, face_accuracy = test_model(start, end, celeba, text_inputs, model, preprocess)
        time_end = time.time()
    
        print('The prediction:')
        print('face_accuracy: {:.2f} face_score: {}%'.format(face_accuracy, face_score*100))
        print('runing time: %.4f'%(time_end - time_start))
    
    if __name__ == '__main__':
        main()
    

    3 CLIP的再训练

    from torch.utils.data import Dataset, DataLoader
    import torch
    import clip
    from torch import nn, optim
    import pandas as pd
    from PIL import Image
    import os
    
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    class image_caption_dataset(Dataset):
        def __init__(self, df, preprocess):
            self.images = df["image"]
            self.caption = df["caption"]
            self.preprocess = preprocess
    
        def __len__(self):
            return len(self.caption)
    
        def __getitem__(self, idx):
            images = self.preprocess(Image.open(self.images[idx]))
            caption = self.caption[idx]
            return images, caption
    
    
    
    def load_data(cup_path, cupnot_path, batch_size, preprocess):
        df = {'image': [], 'caption':[]}
        cup_list = os.listdir(cup_path)
        cupnot_list = os.listdir(cupnot_path)
    
        caption = cup_path.split('/')[-1]
        for img in cup_list:
            img_path = os.path.join(cup_path, img)
            df['image'].append(img_path)
            df['caption'].append(caption)
    
        caption = cupnot_path.split('/')[-1]
        for img in cupnot_list:
            img_path = os.path.join(cupnot_path, img)
            df['image'].append(img_path)
            df['caption'].append(caption)
    
        dataset = image_caption_dataset(df, preprocess)
        train_dataloader = DataLoader(dataset, batch_size=batch_size)
        return train_dataloader
    
    
    def convert_models_to_fp32(model):
        for p in model.parameters():
            p.data = p.data.float()
            p.grad.data = p.grad.data.float()
    
    
    def load_pretrian_model(model_path):
        model, preprocess = clip.load(model_path, device=device, jit=False)  # 训练时 jit必须设置为false
        if device == "cpu":
            model.float()
        else:
            clip.model.convert_weights(model)
        return model, preprocess
    
    def train(epoch, batch_size, learning_rate, cup_path, cupnot_path):
        # 加载模型
        model, preprocess = load_pretrian_model('ViT-B/32')
    
        #加载数据集
        train_dataloader = load_data(cup_path, cupnot_path, batch_size, preprocess)
    
        #设置参数
        loss_img = nn.CrossEntropyLoss().to(device)
        loss_txt = nn.CrossEntropyLoss().to(device)
        optimizer = optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.98), eps=1e-6, weight_decay=0.2)
    
        for i in range(epoch):
            for batch in train_dataloader:
                list_image, list_txt = batch  # list_images is list of image in numpy array(np.uint8), or list of PIL images
    
                #list_image = list_image.to(device)
    
                texts = clip.tokenize(list_txt).to(device)
                images = list_image.to(device)
    
                logits_per_image, logits_per_text = model(images, texts)
                if device == "cpu":
                    ground_truth = torch.arange(batch_size).long().to(device)
                else:
                    #ground_truth = torch.arange(batch_size).half().to(device)
                    ground_truth = torch.arange(batch_size, dtype=torch.long, device=device)
    
    
                #反向传播
                total_loss = (loss_img(logits_per_image, ground_truth) + loss_txt(logits_per_text, ground_truth)) / 2
                optimizer.zero_grad()
                total_loss.backward()
                if device == "cpu":
                    optimizer.step()
                else:
                    convert_models_to_fp32(model)
                    optimizer.step()
                    clip.model.convert_weights(model)
    
            print('[%d] loss: %.3f' %(i + 1, total_loss))
        torch.save(model, './model/model1.pkl')
    
    def main():
        epoch = 100
        batch_size = 6
        learning_rate = 5e-5
        cup_path = './data/It is photo with cup'
        cupnot_path = './data/It is photo without cup'
        train(epoch, batch_size, learning_rate, cup_path, cupnot_path)
    
    if __name__ == '__main__':
        main()
    

    更新工程文件:

    「CLIP」https://www.aliyundrive.com/s/mM8n836Km5M 提取码: te40
    点击链接保存,或者复制本段内容,打开「阿里云盘」APP ,无需下载极速在线查看,视频原画倍速播放。

    来源:浅草夏洛洛

    物联沃分享整理
    物联沃-IOTWORD物联网 » CLIP模型的使用和训练-利用CLIP实现zero-shot的分类任务

    发表评论