STANet_pytorch代码问题汇总、附上裁剪图片代码(有问留言必答)

STANet_pytorch代码问题汇总、附上裁剪图片代码

  • 一、 STANet
  • 二、问题汇总与解答(如列不全,请留言)
  • 1.安装虚拟环境与相关库的的条件:
  • 2. 查看代码中的readme.md文件(里面有跑代码的方式、模型与数据集链接)
  • 3. python demo.py 的问题
  • 4.跑BAM、PAM代码中常见的Out of memory 2 (64、256)GiB 问题
  • 三、裁剪代码实现
  • 说明:
  • 一、 STANet

    因有一部分实验用到STANet网络,在网上找到相应的代码,花了大概一周一步步跳入坑、填坑的过程,苦于将其跑通,遂记录如下心得,希望能够帮助有需要的小伙伴避开“雷区”!

  • 文章源于:

  • 代码源于:

  • 大致看了一下该篇文章,网上有很多解读的博客,不做过多介绍,简而言之,该文章通过利用自注意力机制模块(BAM)和(多个BAM集成的PAM块),对遥感影像进行特征提取与训练, 通过对比两张不同时期的遥感图像,以深度学习的方法训练模型,最后能够“自动比对”找出同一区域,不同时间的变化情况。下图是STANet文章的截图。

  • 文章能够显著检测出遥感影像中变化的建筑物,可以应用于违章建筑拓展监测、乡村扶贫振兴和生态移民居住保障的风貌变化程度。

  • STANet文章的截图

  • 相关的数据集包括(train、每train一轮epch之后紧接着验证val集,还有训练结束之后,将保存的model进行测试的test集 (PS: 文章代码的测试部分,称为val,python val.py 就是测试,而不是验证))。

  • 每一个数据集中包括:

  • ———–| A:前一段时间的遥感图像(1024 * 1024);

  • ———–| B:后一段时间的相同区域的遥感图像(1024*1024) ;

  • ———–| label:标注好两幅遥感图像之间存在的变化,因为数据中考虑一个类别(建筑物)的变化情况,以二值图形式(黑白)进行展示(1024*1024))。
    !命名一定要一致!

  • 二、问题汇总与解答(如列不全,请留言)

  • 1.安装虚拟环境与相关库的的条件:

  • visdom=0.1.8.1 或者 修改可视化版本visdom=0.1.8.8;
  • 不然可能在测试的时候,会出现:AssertionError: X and Y should be the same shape
    
  • scipy=1.1.0:因为1.2.0版本的scipy没有 imread,也会报错。
  • 2. 查看代码中的readme.md文件(里面有跑代码的方式、模型与数据集链接)

  • 如果开始想python demo.py,先下载文章训练好的模型、LEVIR-CD数据集(README.md中有百度网盘、谷歌云盘这两种形式的链接)添加到相应的位置。
  • 在运行代码过程中,多半会出现no file 报错,就按照报错的提示,<创建报错路径>,再将模型或者数据集添加进去即可。
  • 3. python demo.py 的问题

  • TypeError: Cannot handle this data type: (1, 1, 64), |u :听说是因为Python版本问题:我的python=3.6.12没有问题。
  • 4.跑BAM、PAM代码中常见的Out of memory 2 (64、256)GiB 问题

  • 首先:【Out of memory 2 GIB】主要是显存不够,很有效的做法就是减低 batch_size 8 –>4;
  • 其次:降batch size 8 为 4 之后,运行代码,实验跑1 个epoch后, 紧跟的val就会出现 【Out of memory 256GiB】, 因为val验证的代码没有将1024裁剪为256, 服务器的计算资源不够。需要分别裁剪val 文件中的 A B label ,然后更改python train.py 后面的 val_data_path的路径到裁剪的val 文件夹(如 val_256)即可,代码后续放出。
  • 记得将后面测试的 test文件夹的图片也裁剪,同样地,分别裁剪 A B label,不裁剪可能会 【Out of memory 64 GiB】。
  • 三、裁剪代码实现

    
    import os
    import os.path as osp
    import sys
    from multiprocessing import Pool
    import numpy as np
    import cv2
    from PIL import Image
    import time
    from shutil import get_terminal_size
    
    
    sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__))))
    
    
    def main():
        mode = 'pair'  # single (one input folder) | pair (extract corresponding GT and LR pairs)
        opt = {}
        opt['n_thread'] = 20
        opt['compression_level'] = 3  # 3 is the default value in cv2
        # CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer
        # compression time. If read raw images during training, use 0 for faster IO speed.
        if mode == 'single':
            opt['input_folder'] = './data/DIV2K/DIV2K_train_HR'
            opt['save_folder'] = './data/DIV2K/DIV2K800_sub'
            opt['crop_sz'] = 480  # the size of each sub-image
            opt['step'] = 240  # step of the sliding crop window
            opt['thres_sz'] = 48  # size threshold
            extract_signle(opt)
    
        elif mode == 'pair':
            GT_folder = '/home/cug210/data/Lover/code/STANet-master/LEVIR-CD/test/B'
            save_GT_folder = '/home/cug210/data/Lover/code/STANet-master/LEVIR-CD/test_256/B'
            crop_sz = 256  # the size of each sub-image (GT)
            step = 256  # step of the sliding crop window (GT)
            thres_sz = 256  # size threshold
          
            img_GT_list = _get_paths_from_images(GT_folder)
            
            print('process GT...')
            opt['input_folder'] = GT_folder
            opt['save_folder'] = save_GT_folder
            opt['crop_sz'] = crop_sz
            opt['step'] = step
            opt['thres_sz'] = thres_sz
            extract_signle(opt)
           
        else:
            raise ValueError('Wrong mode.')
    
    def extract_signle(opt):
        input_folder = opt['input_folder']
        save_folder = opt['save_folder']
        if not osp.exists(save_folder):
            os.makedirs(save_folder)
            print('mkdir [{:s}] ...'.format(save_folder))
        else:
            print('Folder [{:s}] already exists. Exit...'.format(save_folder))
            sys.exit(1)
        img_list = _get_paths_from_images(input_folder)
    
        def update(arg):
            pbar.update(arg)
    
        pbar = ProgressBar(len(img_list))
    
        pool = Pool(opt['n_thread'])
        for path in img_list:
            pool.apply_async(worker, args=(path, opt), callback=update)
        pool.close()
        pool.join()
        print('All subprocesses done.')
    
    
    def worker(path, opt):
        crop_sz = opt['crop_sz']
        step = opt['step']
        thres_sz = opt['thres_sz']
        img_name = osp.basename(path)
        img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
    
        n_channels = len(img.shape)
        if n_channels == 2:
            h, w = img.shape
        elif n_channels == 3:
            h, w, c = img.shape
        else:
            raise ValueError('Wrong image shape - {}'.format(n_channels))
    
        h_space = np.arange(0, h - crop_sz + 1, step)
        if h - (h_space[-1] + crop_sz) > thres_sz:
            h_space = np.append(h_space, h - crop_sz)
        w_space = np.arange(0, w - crop_sz + 1, step)
        if w - (w_space[-1] + crop_sz) > thres_sz:
            w_space = np.append(w_space, w - crop_sz)
    
        index = 0
        for x in h_space:
            for y in w_space:
                index += 1
                if n_channels == 2:
                    crop_img = img[x:x + crop_sz, y:y + crop_sz]
                else:
                    crop_img = img[x:x + crop_sz, y:y + crop_sz, :]
                crop_img = np.ascontiguousarray(crop_img)
                cv2.imwrite(
                    osp.join(opt['save_folder'],
                             img_name.replace('.png', '_s{:03d}.png'.format(index))), crop_img,
                    [cv2.IMWRITE_PNG_COMPRESSION, opt['compression_level']])
        return 'Processing {:s} ...'.format(img_name)
    
    
    class ProgressBar(object):
        '''A progress bar which can print the progress
        modified from https://github.com/hellock/cvbase/blob/master/cvbase/progress.py
        '''
    
        def __init__(self, task_num=0, bar_width=50, start=True):
            self.task_num = task_num
            max_bar_width = self._get_max_bar_width()
            self.bar_width = (bar_width if bar_width <= max_bar_width else max_bar_width)
            self.completed = 0
            if start:
                self.start()
    
        def _get_max_bar_width(self):
            terminal_width, _ = get_terminal_size()
            max_bar_width = min(int(terminal_width * 0.6), terminal_width - 50)
            if max_bar_width < 10:
                print('terminal width is too small ({}), please consider widen the terminal for better '
                      'progressbar visualization'.format(terminal_width))
                max_bar_width = 10
            return max_bar_width
    
        def start(self):
            if self.task_num > 0:
                sys.stdout.write('[{}] 0/{}, elapsed: 0s, ETA:\n{}\n'.format(
                    ' ' * self.bar_width, self.task_num, 'Start...'))
            else:
                sys.stdout.write('completed: 0, elapsed: 0s')
            sys.stdout.flush()
            self.start_time = time.time()
    
        def update(self, msg='In progress...'):
            self.completed += 1
            elapsed = time.time() - self.start_time + 1e-9
            fps = self.completed / elapsed
            if self.task_num > 0:
                percentage = self.completed / float(self.task_num)
                eta = int(elapsed * (1 - percentage) / percentage + 0.5)
                mark_width = int(self.bar_width * percentage)
                bar_chars = '>' * mark_width + '-' * (self.bar_width - mark_width)
                sys.stdout.write('\033[2F')  # cursor up 2 lines
                sys.stdout.write('\033[J')  # clean the output (remove extra chars since last display)
                sys.stdout.write('[{}] {}/{}, {:.1f} task/s, elapsed: {}s, ETA: {:5}s\n{}\n'.format(
                    bar_chars, self.completed, self.task_num, fps, int(elapsed + 0.5), eta, msg))
            else:
                sys.stdout.write('completed: {}, elapsed: {}s, {:.1f} tasks/s'.format(
                    self.completed, int(elapsed + 0.5), fps))
            sys.stdout.flush()
    
    
    # ###################
    # ### Data Utils ####
    # ###################
    IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP']
    
    def is_image_file(filename):
        return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
    
    
    def _get_paths_from_images(path):
        """get image path list from image folder"""
        assert osp.isdir(path), '{:s} is not a valid directory'.format(path)
        images = []
        for dirpath, _, fnames in sorted(os.walk(path)):
            for fname in sorted(fnames):
                print("..fname is:",fname)
    
                if is_image_file(fname):
                    img_path = os.path.join(dirpath, fname)
                    images.append(img_path)
        assert images, '{:s} has no valid image file'.format(path)
        return images
    
    
    if __name__ == '__main__':
        main()
    
    
    

    说明:

  • 只需要更改32-36行的信息:
  • 【原始文件夹路径】
  • 【保存的裁剪后图片的文件夹路径】
  • 【裁剪尺寸crop_size、位移尺寸step(两者相等,表示下一张图和第一张图没有重叠)】
  • 【阈值(thres_sz)设置为256,表示裁剪到最后,剩下不到256的残缺,就不裁剪了。】
  • 物联沃分享整理
    物联沃-IOTWORD物联网 » STANet_pytorch代码问题汇总、附上裁剪图片代码(有问留言必答)

    发表评论