博客
关于我
SSD框架训练自己的数据集
阅读量:108 次
发布时间:2019-02-26

本文共 43946 字,大约阅读时间需要 146 分钟。

demo中详细介绍了如何在VOC数据集上使用SSD进行物体检测的训练和验证。 本文介绍如何使用SSD实现对自己数据集的训练和验证过程,内容包括:
1 数据集的标注 2 数据集的转换 3 使用SSD如何训练 4 使用SSD如何测试 1 数据集的标注    数据的标注使用工具,该工具使用python实现,使用简单方便。修改后的工具支持多label的标签标注。 该工具生成的标签格式是: object_number className x1min y1min x1max y1max classname x2min y2min x2max y2max ... 1.1 labelTool工具的使用说明   BBox-Label-Tool工具实现较简单,原始的git版本使用起来有一些小问题,进行了简单的修改,修改后的版本
#-------------------------------------------------------------------------------# Name:        Object bounding box label tool# Purpose:     Label object bboxes for ImageNet Detection data# Author:      Qiushi# Created:     06/06/2014##-------------------------------------------------------------------------------from __future__ import divisionfrom Tkinter import *import tkMessageBoxfrom PIL import Image, ImageTkimport osimport globimport random# colors for the bboxesCOLORS = ['red', 'blue', 'yellow', 'pink', 'cyan', 'green', 'black']# image sizes for the examplesSIZE = 256, 256classLabels=['mat', 'door', 'sofa', 'chair', 'table', 'bed', 'ashcan', 'shoe']class LabelTool():    def __init__(self, master):        # set up the main frame        self.parent = master        self.parent.title("LabelTool")        self.frame = Frame(self.parent)        self.frame.pack(fill=BOTH, expand=1)        self.parent.resizable(width = False, height = False)        # initialize global state        self.imageDir = ''        self.imageList= []        self.egDir = ''        self.egList = []        self.outDir = ''        self.cur = 0        self.total = 0        self.category = 0        self.imagename = ''        self.labelfilename = ''        self.tkimg = None        # initialize mouse state        self.STATE = {}        self.STATE['click'] = 0        self.STATE['x'], self.STATE['y'] = 0, 0        # reference to bbox        self.bboxIdList = []        self.bboxId = None        self.bboxList = []        self.hl = None        self.vl = None        self.currentClass = ''        # ----------------- GUI stuff ---------------------        # dir entry & load        self.label = Label(self.frame, text = "Image Dir:")        self.label.grid(row = 0, column = 0, sticky = E)        self.entry = Entry(self.frame)        self.entry.grid(row = 0, column = 1, sticky = W+E)        self.ldBtn = Button(self.frame, text = "Load", command = self.loadDir)        self.ldBtn.grid(row = 0, column = 2, sticky = W+E)        # main panel for labeling        self.mainPanel = Canvas(self.frame, cursor='tcross')        self.mainPanel.bind("
", self.mouseClick) self.mainPanel.bind("
", self.mouseMove) self.parent.bind("
", self.cancelBBox) # press
to cancel current bbox self.parent.bind("s", self.cancelBBox) self.parent.bind("a", self.prevImage) # press 'a' to go backforward self.parent.bind("d", self.nextImage) # press 'd' to go forward self.mainPanel.grid(row = 1, column = 1, rowspan = 4, sticky = W+N) # showing bbox info & delete bbox self.lb1 = Label(self.frame, text = 'Bounding boxes:') self.lb1.grid(row = 1, column = 2, sticky = W+N) self.listbox = Listbox(self.frame, width = 22, height = 12) self.listbox.grid(row = 2, column = 2, sticky = N) self.btnDel = Button(self.frame, text = 'Delete', command = self.delBBox) self.btnDel.grid(row = 3, column = 2, sticky = W+E+N) self.btnClear = Button(self.frame, text = 'ClearAll', command = self.clearBBox) self.btnClear.grid(row = 4, column = 2, sticky = W+E+N) #select class type self.classPanel = Frame(self.frame) self.classPanel.grid(row = 5, column = 1, columnspan = 10, sticky = W+E) label = Label(self.classPanel, text = 'class:') label.grid(row = 5, column = 1, sticky = W+N) self.classbox = Listbox(self.classPanel, width = 4, height = 2) self.classbox.grid(row = 5,column = 2) for each in range(len(classLabels)): function = 'select' + classLabels[each] print classLabels[each] btnMat = Button(self.classPanel, text = classLabels[each], command = getattr(self, function)) btnMat.grid(row = 5, column = each + 3) # control panel for image navigation self.ctrPanel = Frame(self.frame) self.ctrPanel.grid(row = 6, column = 1, columnspan = 2, sticky = W+E) self.prevBtn = Button(self.ctrPanel, text='<< Prev', width = 10, command = self.prevImage) self.prevBtn.pack(side = LEFT, padx = 5, pady = 3) self.nextBtn = Button(self.ctrPanel, text='Next >>', width = 10, command = self.nextImage) self.nextBtn.pack(side = LEFT, padx = 5, pady = 3) self.progLabel = Label(self.ctrPanel, text = "Progress: / ") self.progLabel.pack(side = LEFT, padx = 5) self.tmpLabel = Label(self.ctrPanel, text = "Go to Image No.") self.tmpLabel.pack(side = LEFT, padx = 5) self.idxEntry = Entry(self.ctrPanel, width = 5) self.idxEntry.pack(side = LEFT) self.goBtn = Button(self.ctrPanel, text = 'Go', command = self.gotoImage) self.goBtn.pack(side = LEFT) # example pannel for illustration self.egPanel = Frame(self.frame, border = 10) self.egPanel.grid(row = 1, column = 0, rowspan = 5, sticky = N) self.tmpLabel2 = Label(self.egPanel, text = "Examples:") self.tmpLabel2.pack(side = TOP, pady = 5) self.egLabels = [] for i in range(3): self.egLabels.append(Label(self.egPanel)) self.egLabels[-1].pack(side = TOP) # display mouse position self.disp = Label(self.ctrPanel, text='') self.disp.pack(side = RIGHT) self.frame.columnconfigure(1, weight = 1) self.frame.rowconfigure(10, weight = 1) # for debugging## self.setImage()## self.loadDir() def loadDir(self, dbg = False): if not dbg: s = self.entry.get() self.parent.focus() self.category = int(s) else: s = r'D:\workspace\python\labelGUI'## if not os.path.isdir(s):## tkMessageBox.showerror("Error!", message = "The specified dir doesn't exist!")## return # get image list self.imageDir = os.path.join(r'./Images', '%d' %(self.category)) self.imageList = glob.glob(os.path.join(self.imageDir, '*.jpg')) if len(self.imageList) == 0: print 'No .JPEG images found in the specified dir!' return # set up output dir self.outDir = os.path.join(r'./Labels', '%d' %(self.category)) if not os.path.exists(self.outDir): os.mkdir(self.outDir) labeledPicList = glob.glob(os.path.join(self.outDir, '*.txt')) for label in labeledPicList: data = open(label, 'r') if '0\n' == data.read(): data.close() continue data.close() picture = label.replace('Labels', 'Images').replace('.txt', '.jpg') if picture in self.imageList: self.imageList.remove(picture) # default to the 1st image in the collection self.cur = 1 self.total = len(self.imageList) self.loadImage() print '%d images loaded from %s' %(self.total, s) def loadImage(self): # load image imagepath = self.imageList[self.cur - 1] self.img = Image.open(imagepath) self.imgSize = self.img.size self.tkimg = ImageTk.PhotoImage(self.img) self.mainPanel.config(width = max(self.tkimg.width(), 400), height = max(self.tkimg.height(), 400)) self.mainPanel.create_image(0, 0, image = self.tkimg, anchor=NW) self.progLabel.config(text = "%04d/%04d" %(self.cur, self.total)) # load labels self.clearBBox() self.imagename = os.path.split(imagepath)[-1].split('.')[0] labelname = self.imagename + '.txt' self.labelfilename = os.path.join(self.outDir, labelname) bbox_cnt = 0 if os.path.exists(self.labelfilename): with open(self.labelfilename) as f: for (i, line) in enumerate(f): if i == 0: bbox_cnt = int(line.strip()) continue tmp = [int(t.strip()) for t in line.split()]## print tmp self.bboxList.append(tuple(tmp)) tmpId = self.mainPanel.create_rectangle(tmp[0], tmp[1], \ tmp[2], tmp[3], \ width = 2, \ outline = COLORS[(len(self.bboxList)-1) % len(COLORS)]) self.bboxIdList.append(tmpId) self.listbox.insert(END, '(%d, %d) -> (%d, %d)' %(tmp[0], tmp[1], tmp[2], tmp[3])) self.listbox.itemconfig(len(self.bboxIdList) - 1, fg = COLORS[(len(self.bboxIdList) - 1) % len(COLORS)]) def saveImage(self): with open(self.labelfilename, 'w') as f: f.write('%d\n' %len(self.bboxList)) for bbox in self.bboxList: f.write(' '.join(map(str, bbox)) + '\n') print 'Image No. %d saved' %(self.cur) def mouseClick(self, event): if self.STATE['click'] == 0: self.STATE['x'], self.STATE['y'] = event.x, event.y #self.STATE['x'], self.STATE['y'] = self.imgSize[0], self.imgSize[1] else: x1, x2 = min(self.STATE['x'], event.x), max(self.STATE['x'], event.x) y1, y2 = min(self.STATE['y'], event.y), max(self.STATE['y'], event.y) if x2 > self.imgSize[0]: x2 = self.imgSize[0] if y2 > self.imgSize[1]: y2 = self.imgSize[1] self.bboxList.append((self.currentClass, x1, y1, x2, y2)) self.bboxIdList.append(self.bboxId) self.bboxId = None self.listbox.insert(END, '(%d, %d) -> (%d, %d)' %(x1, y1, x2, y2)) self.listbox.itemconfig(len(self.bboxIdList) - 1, fg = COLORS[(len(self.bboxIdList) - 1) % len(COLORS)]) self.STATE['click'] = 1 - self.STATE['click'] def mouseMove(self, event): self.disp.config(text = 'x: %d, y: %d' %(event.x, event.y)) if self.tkimg: if self.hl: self.mainPanel.delete(self.hl) self.hl = self.mainPanel.create_line(0, event.y, self.tkimg.width(), event.y, width = 2) if self.vl: self.mainPanel.delete(self.vl) self.vl = self.mainPanel.create_line(event.x, 0, event.x, self.tkimg.height(), width = 2) if 1 == self.STATE['click']: if self.bboxId: self.mainPanel.delete(self.bboxId) self.bboxId = self.mainPanel.create_rectangle(self.STATE['x'], self.STATE['y'], \ event.x, event.y, \ width = 2, \ outline = COLORS[len(self.bboxList) % len(COLORS)]) def cancelBBox(self, event): if 1 == self.STATE['click']: if self.bboxId: self.mainPanel.delete(self.bboxId) self.bboxId = None self.STATE['click'] = 0 def delBBox(self): sel = self.listbox.curselection() if len(sel) != 1 : return idx = int(sel[0]) self.mainPanel.delete(self.bboxIdList[idx]) self.bboxIdList.pop(idx) self.bboxList.pop(idx) self.listbox.delete(idx) def clearBBox(self): for idx in range(len(self.bboxIdList)): self.mainPanel.delete(self.bboxIdList[idx]) self.listbox.delete(0, len(self.bboxList)) self.bboxIdList = [] self.bboxList = [] def selectmat(self): self.currentClass = 'mat' self.classbox.delete(0,END) self.classbox.insert(0, 'mat') self.classbox.itemconfig(0,fg = COLORS[0]) def selectdoor(self): self.currentClass = 'door' self.classbox.delete(0,END) self.classbox.insert(0, 'door') self.classbox.itemconfig(0,fg = COLORS[0]) def selectsofa(self): self.currentClass = 'sofa' self.classbox.delete(0,END) self.classbox.insert(0, 'sofa') self.classbox.itemconfig(0,fg = COLORS[0]) def selectchair(self): self.currentClass = 'chair' self.classbox.delete(0,END) self.classbox.insert(0, 'chair') self.classbox.itemconfig(0,fg = COLORS[0]) def selecttable(self): self.currentClass = 'table' self.classbox.delete(0,END) self.classbox.insert(0, 'table') self.classbox.itemconfig(0,fg = COLORS[0]) def selectbed(self): self.currentClass = 'bed' self.classbox.delete(0,END) self.classbox.insert(0, 'bed') self.classbox.itemconfig(0,fg = COLORS[0]) def selectashcan(self): self.currentClass = 'ashcan' self.classbox.delete(0,END) self.classbox.insert(0, 'ashcan') self.classbox.itemconfig(0,fg = COLORS[0]) def selectshoe(self): self.currentClass = 'shoe' self.classbox.delete(0,END) self.classbox.insert(0, 'shoe') self.classbox.itemconfig(0,fg = COLORS[0]) def prevImage(self, event = None): self.saveImage() if self.cur > 1: self.cur -= 1 self.loadImage() def nextImage(self, event = None): self.saveImage() if self.cur < self.total: self.cur += 1 self.loadImage() def gotoImage(self): idx = int(self.idxEntry.get()) if 1 <= idx and idx <= self.total: self.saveImage() self.cur = idx self.loadImage()## def setImage(self, imagepath = r'test2.png'):## self.img = Image.open(imagepath)## self.tkimg = ImageTk.PhotoImage(self.img)## self.mainPanel.config(width = self.tkimg.width())## self.mainPanel.config(height = self.tkimg.height())## self.mainPanel.create_image(0, 0, image = self.tkimg, anchor=NW)if __name__ == '__main__': root = Tk() tool = LabelTool(root) root.mainloop()

  使用方法:   

     (1) 在BBox-Label-Tool/Images目录下创建保存图片的目录, 目录以数字命名(BBox-Label-Tool/Images/1), 然后将待标注的图片copy到1这个目录下;

     (2) 在BBox-Label-Tool目录下执行命令   python main.py

     (3) 在工具界面上, Image Dir 框中输入需要标记的目录名(比如 1), 然后点击load按钮, 工具自动将Images/1目录下的图片加载进来;

      需要说明一下, 如果目录中的图片已经标注过,点击load时不会被重新加载进来.

     (4) 该工具支持多类别标注, 画bounding boxs框标定之前,需要先选定类别,然后再画框.

     (5) 一张图片标注完后, 点击Next>>按钮, 标注下一张图片,  图片label成功后,会在BBox-Label-Tool/Labels对应的目录下生成与图片文件名对应的label文件.

数据集的转换

  caffe训练使用LMDB格式的数据,ssd框架中提供了voc数据格式转换成LMDB格式的脚本。 所以实践中先将BBox-Label-Tool标注的数据转换成voc数据格式,然后再转换成LMDB格式。 2.1 voc数据格式

(1)Annotations中保存的是xml格式的label信息
VOC2007
1.jpg
My Database
VOC2007
flickr
NULL
NULL
idaneel
320
240
3
0
door
Unspecified
0
0
109
3
199
204

(2)ImageSet目录下的Main目录里存放的是用于表示训练的图片集和测试的图片集

(3)JPEGImages目录下存放所有图片集

(4)label目录下保存的是BBox-Label-Tool工具标注好的bounding box坐标文件, 该目录下的文件就是待转换的label标签文件。
2.2 Label转换成VOC数据格式
BBox-Label-Tool工具标注好的bounding box坐标文件转换成VOC数据格式的形式. 具体的转换过程包括了两个步骤: (1)将BBox-Label-Tool下的txt格式保存的bounding box信息转换成VOC数据格式下以xml方式表示; (2)生成用于训练的数据集和用于测试的数据集。 用python实现了上述两个步骤的换转。 createXml.py  完成txt到xml的转换;  执行脚本./createXml.py
#!/usr/bin/env pythonimport osimport sysimport cv2from itertools import islicefrom xml.dom.minidom import Documentlabels='label'imgpath='JPEGImages/'xmlpath_new='Annotations/'foldername='VOC2007'def insertObject(doc, datas):    obj = doc.createElement('object')    name = doc.createElement('name')    name.appendChild(doc.createTextNode(datas[0]))    obj.appendChild(name)    pose = doc.createElement('pose')    pose.appendChild(doc.createTextNode('Unspecified'))    obj.appendChild(pose)    truncated = doc.createElement('truncated')    truncated.appendChild(doc.createTextNode(str(0)))    obj.appendChild(truncated)    difficult = doc.createElement('difficult')    difficult.appendChild(doc.createTextNode(str(0)))    obj.appendChild(difficult)    bndbox = doc.createElement('bndbox')        xmin = doc.createElement('xmin')    xmin.appendChild(doc.createTextNode(str(datas[1])))    bndbox.appendChild(xmin)        ymin = doc.createElement('ymin')                    ymin.appendChild(doc.createTextNode(str(datas[2])))    bndbox.appendChild(ymin)                    xmax = doc.createElement('xmax')                    xmax.appendChild(doc.createTextNode(str(datas[3])))    bndbox.appendChild(xmax)                    ymax = doc.createElement('ymax')        if  '\r' == str(datas[4])[-1] or '\n' == str(datas[4])[-1]:        data = str(datas[4])[0:-1]    else:        data = str(datas[4])    ymax.appendChild(doc.createTextNode(data))    bndbox.appendChild(ymax)    obj.appendChild(bndbox)                    return objdef create():    for walk in os.walk(labels):        for each in walk[2]:            fidin=open(walk[0] + '/'+ each,'r')            objIndex = 0            for data in islice(fidin, 1, None):                        objIndex += 1                data=data.strip('\n')                datas = data.split(' ')                if 5 != len(datas):                    print 'bounding box information error'                    continue                pictureName = each.replace('.txt', '.jpg')                imageFile = imgpath + pictureName                img = cv2.imread(imageFile)                imgSize = img.shape                if 1 == objIndex:                    xmlName = each.replace('.txt', '.xml')                    f = open(xmlpath_new + xmlName, "w")                    doc = Document()                    annotation = doc.createElement('annotation')                    doc.appendChild(annotation)                                        folder = doc.createElement('folder')                    folder.appendChild(doc.createTextNode(foldername))                    annotation.appendChild(folder)                                        filename = doc.createElement('filename')                    filename.appendChild(doc.createTextNode(pictureName))                    annotation.appendChild(filename)                                        source = doc.createElement('source')                                    database = doc.createElement('database')                    database.appendChild(doc.createTextNode('My Database'))                    source.appendChild(database)                    source_annotation = doc.createElement('annotation')                    source_annotation.appendChild(doc.createTextNode(foldername))                    source.appendChild(source_annotation)                    image = doc.createElement('image')                    image.appendChild(doc.createTextNode('flickr'))                    source.appendChild(image)                    flickrid = doc.createElement('flickrid')                    flickrid.appendChild(doc.createTextNode('NULL'))                    source.appendChild(flickrid)                    annotation.appendChild(source)                                        owner = doc.createElement('owner')                    flickrid = doc.createElement('flickrid')                    flickrid.appendChild(doc.createTextNode('NULL'))                    owner.appendChild(flickrid)                    name = doc.createElement('name')                    name.appendChild(doc.createTextNode('idaneel'))                    owner.appendChild(name)                    annotation.appendChild(owner)                                        size = doc.createElement('size')                    width = doc.createElement('width')                    width.appendChild(doc.createTextNode(str(imgSize[1])))                    size.appendChild(width)                    height = doc.createElement('height')                    height.appendChild(doc.createTextNode(str(imgSize[0])))                    size.appendChild(height)                    depth = doc.createElement('depth')                    depth.appendChild(doc.createTextNode(str(imgSize[2])))                    size.appendChild(depth)                    annotation.appendChild(size)                                        segmented = doc.createElement('segmented')                    segmented.appendChild(doc.createTextNode(str(0)))                    annotation.appendChild(segmented)                                annotation.appendChild(insertObject(doc, datas))                else:                    annotation.appendChild(insertObject(doc, datas))            try:                f.write(doc.toprettyxml(indent = '    '))                f.close()                fidin.close()            except:                pass             if __name__ == '__main__':    create()

  createTest.py 生成训练集和测试集标识文件; 执行脚本

  ./createTest.py %startID% %endID% %testNumber%

#!/usr/bin/env pythonimport osimport sysimport randomtry:    start = int(sys.argv[1])    end = int(sys.argv[2])    test = int(sys.argv[3])    allNum = end-start+1except:    print 'Please input picture range'    print './createTest.py 1 1500 500'    os._exit(0)b_list = range(start,end)blist_webId = random.sample(b_list, test)blist_webId = sorted(blist_webId) allFile = []testFile = open('ImageSets/Main/test.txt', 'w')trainFile = open('ImageSets/Main/trainval.txt', 'w')for i in range(allNum):    allFile.append(i+1)for test in blist_webId:    allFile.remove(test)    testFile.write(str(test) + '\n')   for train in allFile:    trainFile.write(str(train) + '\n')testFile.close()trainFile.close()

说明: 由于BBox-Label-Tool实现相对简单,该工具每次只能对一个类别进行打标签,所以转换脚本

每一次也是对一个类别进行数据的转换,这个问题后续需要优化改进。

优化后的BBox-Label-Tool工具,支持多类别标定,生成的label文件中增加了类别名称信息。

使用时修改classLabels,改写成自己的类别, 修改后的工具代码参见1.1中的main.py 

2.3  VOC数据转换成LMDB数据

  SSD提供了VOC数据到LMDB数据的转换脚本 data/VOC0712/create_list.sh 和 ./data/VOC0712/create_data.sh,这两个脚本是完全针对VOC0712目录下的数据进行的转换。   实现中为了不破坏VOC0712目录下的数据内容,针对我们自己的数据集,修改了上面这两个脚本, 将脚本中涉及到VOC0712的信息替换成我们自己的目录信息。 在处理我们的数据集时,将VOC0712替换成indoor。 具体的步骤如下:   (1) 在 $HOME/data/VOCdevkit目录下创建indoor目录,该目录中存放自己转换完成的VOC数据集;   (2) $CAFFE_ROOT/examples目录下创建indoor目录; (3) $CAFFE_ROOT/data目录下创建indoor目录,同时将data/VOC0712下的create_list.sh,create_data.sh,labelmap_voc.prototxt 这三个文件copy到indoor目录下,分别重命名为create_list_indoor.sh,create_data_indoor.sh, labelmap_indoor.prototxt   (4)对上面新生成的两个create文件进行修改,主要修改是将VOC0712相关的信息替换成indoor   修改后的这两个文件分别为:  
#!/bin/bashroot_dir=$HOME/data/VOCdevkit/sub_dir=ImageSets/Mainbash_dir="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"for dataset in trainval test    do  dst_file=$bash_dir/$dataset.txt  if [ -f $dst_file ]  then    rm -f $dst_file  fi  for name in indoor  do    if [[ $dataset == "test" && $name == "VOC2012" ]]    then      continue    fi    echo "Create list for $name $dataset..."    dataset_file=$root_dir/$name/$sub_dir/$dataset.txt    img_file=$bash_dir/$dataset"_img.txt"    cp $dataset_file $img_file    sed -i "s/^/$name\/JPEGImages\//g" $img_file    sed -i "s/$/.jpg/g" $img_file    label_file=$bash_dir/$dataset"_label.txt"    cp $dataset_file $label_file    sed -i "s/^/$name\/Annotations\//g" $label_file    sed -i "s/$/.xml/g" $label_file    paste -d' ' $img_file $label_file >> $dst_file    rm -f $label_file    rm -f $img_file  done  # Generate image name and size infomation.  if [ $dataset == "test" ]  then    $bash_dir/../../build/tools/get_image_size $root_dir $dst_file $bash_dir/$dataset"_name_size.txt"  fi  # Shuffle trainval file.  if [ $dataset == "trainval" ]  then    rand_file=$dst_file.random    cat $dst_file | perl -MList::Util=shuffle -e 'print shuffle(
);' > $rand_file mv $rand_file $dst_file fidone
cur_dir=$(cd $( dirname ${BASH_SOURCE[0]} ) && pwd )root_dir=$cur_dir/../..cd $root_dirredo=1data_root_dir="$HOME/data/VOCdevkit"dataset_name="indoor"mapfile="$root_dir/data/$dataset_name/labelmap_indoor.prototxt"anno_type="detection"db="lmdb"min_dim=0max_dim=0width=0height=0extra_cmd="--encode-type=jpg --encoded"if [ $redo ]then  extra_cmd="$extra_cmd --redo"fifor subset in test trainvaldo  python $root_dir/scripts/create_annoset.py --anno-type=$anno_type --label-map-file=$mapfile --min-dim=$min_dim --max-dim=$max_dim --resize-width=$width --resize-height=$height --check-label $extra_cmd $data_root_dir $root_dir/data/$dataset_name/$subset.txt $data_root_dir/$dataset_name/$db/$dataset_name"_"$subset"_"$db examples/$dataset_namedone
(5)修改labelmap_indoor.prototxt,将该文件中的类别修改成和自己的数据集相匹配,注意需要保留一个label 0 , background类别
item {  name: "none_of_the_above"  label: 0  display_name: "background"}item {  name: "door"  label: 1  display_name: "door"}

  完成上面步骤的修改后,可以开始LMDB数据数据的制作,在$CAFFE_ROOT目录下分别运行:

  ./data/indoor/create_list_indoor.sh

  ./data/indoor/create_data_indoor.sh

  命令执行完毕后,可以在$CAFFE_ROOT/indoor目录下查看转换完成的LMDB数据数据。

3 使用SSD进行自己数据集的训练 训练时使用ssd demo中提供的预训练好的VGGnet model : 将该模型保存到$CAFFE_ROOT/models/VGGNet下。 将ssd_pascal.py copy一份 ssd_pascal_indoor.py文件, 根据自己的数据集修改ssd_pascal_indoor.py 主要修改点:  (1)train_data和test_data修改成指向自己的数据集LMDB    train_data = "examples/indoor/indoor_trainval_lmdb"             test_data = "examples/indoor/indoor_test_lmdb" (2) num_test_image该变量修改成自己数据集中测试数据的数量 (3)num_classes 该变量修改成自己数据集中 标签类别数量数 + 1 针对我的数据集,ssd_pascal_indoor.py的内容为:
from __future__ import print_functionimport caffefrom caffe.model_libs import *from google.protobuf import text_formatimport mathimport osimport shutilimport statimport subprocessimport sys# Add extra layers on top of a "base" network (e.g. VGGNet or Inception).def AddExtraLayers(net, use_batchnorm=True):    use_relu = True    # Add additional convolutional layers.    from_layer = net.keys()[-1]    # TODO(weiliu89): Construct the name using the last layer to avoid duplication.    out_layer = "conv6_1"    ConvBNLayer(net, from_layer, out_layer, use_batchnorm, use_relu, 256, 1, 0, 1)    from_layer = out_layer    out_layer = "conv6_2"    ConvBNLayer(net, from_layer, out_layer, use_batchnorm, use_relu, 512, 3, 1, 2)    for i in xrange(7, 9):      from_layer = out_layer      out_layer = "conv{}_1".format(i)      ConvBNLayer(net, from_layer, out_layer, use_batchnorm, use_relu, 128, 1, 0, 1)      from_layer = out_layer      out_layer = "conv{}_2".format(i)      ConvBNLayer(net, from_layer, out_layer, use_batchnorm, use_relu, 256, 3, 1, 2)    # Add global pooling layer.    name = net.keys()[-1]    net.pool6 = L.Pooling(net[name], pool=P.Pooling.AVE, global_pooling=True)    return net### Modify the following parameters accordingly #### The directory which contains the caffe code.# We assume you are running the script at the CAFFE_ROOT.caffe_root = os.getcwd()# Set true if you want to start training right after generating all files.run_soon = True# Set true if you want to load from most recently saved snapshot.# Otherwise, we will load from the pretrain_model defined below.resume_training = True# If true, Remove old model files.remove_old_models = False# The database file for training data. Created by data/VOC0712/create_data.shtrain_data = "examples/indoor/indoor_trainval_lmdb"# The database file for testing data. Created by data/VOC0712/create_data.shtest_data = "examples/indoor/indoor_test_lmdb"# Specify the batch sampler.resize_width = 300resize_height = 300resize = "{}x{}".format(resize_width, resize_height)batch_sampler = [        {                'sampler': {                        },                'max_trials': 1,                'max_sample': 1,        },        {                'sampler': {                        'min_scale': 0.3,                        'max_scale': 1.0,                        'min_aspect_ratio': 0.5,                        'max_aspect_ratio': 2.0,                        },                'sample_constraint': {                        'min_jaccard_overlap': 0.1,                        },                'max_trials': 50,                'max_sample': 1,        },        {                'sampler': {                        'min_scale': 0.3,                        'max_scale': 1.0,                        'min_aspect_ratio': 0.5,                        'max_aspect_ratio': 2.0,                        },                'sample_constraint': {                        'min_jaccard_overlap': 0.3,                        },                'max_trials': 50,                'max_sample': 1,        },        {                'sampler': {                        'min_scale': 0.3,                        'max_scale': 1.0,                        'min_aspect_ratio': 0.5,                        'max_aspect_ratio': 2.0,                        },                'sample_constraint': {                        'min_jaccard_overlap': 0.5,                        },                'max_trials': 50,                'max_sample': 1,        },        {                'sampler': {                        'min_scale': 0.3,                        'max_scale': 1.0,                        'min_aspect_ratio': 0.5,                        'max_aspect_ratio': 2.0,                        },                'sample_constraint': {                        'min_jaccard_overlap': 0.7,                        },                'max_trials': 50,                'max_sample': 1,        },        {                'sampler': {                        'min_scale': 0.3,                        'max_scale': 1.0,                        'min_aspect_ratio': 0.5,                        'max_aspect_ratio': 2.0,                        },                'sample_constraint': {                        'min_jaccard_overlap': 0.9,                        },                'max_trials': 50,                'max_sample': 1,        },        {                'sampler': {                        'min_scale': 0.3,                        'max_scale': 1.0,                        'min_aspect_ratio': 0.5,                        'max_aspect_ratio': 2.0,                        },                'sample_constraint': {                        'max_jaccard_overlap': 1.0,                        },                'max_trials': 50,                'max_sample': 1,        },        ]train_transform_param = {        'mirror': True,        'mean_value': [104, 117, 123],        'resize_param': {                'prob': 1,                'resize_mode': P.Resize.WARP,                'height': resize_height,                'width': resize_width,                'interp_mode': [                        P.Resize.LINEAR,                        P.Resize.AREA,                        P.Resize.NEAREST,                        P.Resize.CUBIC,                        P.Resize.LANCZOS4,                        ],                },        'emit_constraint': {            'emit_type': caffe_pb2.EmitConstraint.CENTER,            }        }test_transform_param = {        'mean_value': [104, 117, 123],        'resize_param': {                'prob': 1,                'resize_mode': P.Resize.WARP,                'height': resize_height,                'width': resize_width,                'interp_mode': [P.Resize.LINEAR],                },        }# If true, use batch norm for all newly added layers.# Currently only the non batch norm version has been tested.use_batchnorm = False# Use different initial learning rate.if use_batchnorm:    base_lr = 0.0004else:    # A learning rate for batch_size = 1, num_gpus = 1.    base_lr = 0.00004# Modify the job name if you want.job_name = "SSD_{}".format(resize)# The name of the model. Modify it if you want.model_name = "VGG_VOC0712_{}".format(job_name)# Directory which stores the model .prototxt file.save_dir = "models/VGGNet/VOC0712/{}".format(job_name)# Directory which stores the snapshot of models.snapshot_dir = "models/VGGNet/VOC0712/{}".format(job_name)# Directory which stores the job script and log file.job_dir = "jobs/VGGNet/VOC0712/{}".format(job_name)# Directory which stores the detection results.output_result_dir = "{}/data/VOCdevkit/results/VOC2007/{}/Main".format(os.environ['HOME'], job_name)# model definition files.train_net_file = "{}/train.prototxt".format(save_dir)test_net_file = "{}/test.prototxt".format(save_dir)deploy_net_file = "{}/deploy.prototxt".format(save_dir)solver_file = "{}/solver.prototxt".format(save_dir)# snapshot prefix.snapshot_prefix = "{}/{}".format(snapshot_dir, model_name)# job script path.job_file = "{}/{}.sh".format(job_dir, model_name)# Stores the test image names and sizes. Created by data/VOC0712/create_list.shname_size_file = "data/indoor/test_name_size.txt"# The pretrained model. We use the Fully convolutional reduced (atrous) VGGNet.pretrain_model = "models/VGGNet/VGG_ILSVRC_16_layers_fc_reduced.caffemodel"# Stores LabelMapItem.label_map_file = "data/indoor/labelmap_indoor.prototxt"# MultiBoxLoss parameters.num_classes = 2share_location = Truebackground_label_id=0train_on_diff_gt = Truenormalization_mode = P.Loss.VALIDcode_type = P.PriorBox.CENTER_SIZEneg_pos_ratio = 3.loc_weight = (neg_pos_ratio + 1.) / 4.multibox_loss_param = {    'loc_loss_type': P.MultiBoxLoss.SMOOTH_L1,    'conf_loss_type': P.MultiBoxLoss.SOFTMAX,    'loc_weight': loc_weight,    'num_classes': num_classes,    'share_location': share_location,    'match_type': P.MultiBoxLoss.PER_PREDICTION,    'overlap_threshold': 0.5,    'use_prior_for_matching': True,    'background_label_id': background_label_id,    'use_difficult_gt': train_on_diff_gt,    'do_neg_mining': True,    'neg_pos_ratio': neg_pos_ratio,    'neg_overlap': 0.5,    'code_type': code_type,    }loss_param = {    'normalization': normalization_mode,    }# parameters for generating priors.# minimum dimension of input imagemin_dim = 300# conv4_3 ==> 38 x 38# fc7 ==> 19 x 19# conv6_2 ==> 10 x 10# conv7_2 ==> 5 x 5# conv8_2 ==> 3 x 3# pool6 ==> 1 x 1mbox_source_layers = ['conv4_3', 'fc7', 'conv6_2', 'conv7_2', 'conv8_2', 'pool6']# in percent %min_ratio = 20max_ratio = 95step = int(math.floor((max_ratio - min_ratio) / (len(mbox_source_layers) - 2)))min_sizes = []max_sizes = []for ratio in xrange(min_ratio, max_ratio + 1, step):  min_sizes.append(min_dim * ratio / 100.)  max_sizes.append(min_dim * (ratio + step) / 100.)min_sizes = [min_dim * 10 / 100.] + min_sizesmax_sizes = [[]] + max_sizesaspect_ratios = [[2], [2, 3], [2, 3], [2, 3], [2, 3], [2, 3]]# L2 normalize conv4_3.normalizations = [20, -1, -1, -1, -1, -1]# variance used to encode/decode prior bboxes.if code_type == P.PriorBox.CENTER_SIZE:  prior_variance = [0.1, 0.1, 0.2, 0.2]else:  prior_variance = [0.1]flip = Trueclip = True# Solver parameters.# Defining which GPUs to use.gpus = "0"gpulist = gpus.split(",")num_gpus = len(gpulist)# Divide the mini-batch to different GPUs.batch_size = 4accum_batch_size = 32iter_size = accum_batch_size / batch_sizesolver_mode = P.Solver.CPUdevice_id = 0batch_size_per_device = batch_sizeif num_gpus > 0:  batch_size_per_device = int(math.ceil(float(batch_size) / num_gpus))  iter_size = int(math.ceil(float(accum_batch_size) / (batch_size_per_device * num_gpus)))  solver_mode = P.Solver.GPU  device_id = int(gpulist[0])if normalization_mode == P.Loss.NONE:  base_lr /= batch_size_per_deviceelif normalization_mode == P.Loss.VALID:  base_lr *= 25. / loc_weightelif normalization_mode == P.Loss.FULL:  # Roughly there are 2000 prior bboxes per image.  # TODO(weiliu89): Estimate the exact # of priors.  base_lr *= 2000.# Which layers to freeze (no backward) during training.freeze_layers = ['conv1_1', 'conv1_2', 'conv2_1', 'conv2_2']# Evaluate on whole test set.num_test_image = 800test_batch_size = 1test_iter = num_test_image / test_batch_sizesolver_param = {    # Train parameters    'base_lr': base_lr,    'weight_decay': 0.0005,    'lr_policy': "step",    'stepsize': 40000,    'gamma': 0.1,    'momentum': 0.9,    'iter_size': iter_size,    'max_iter': 60000,    'snapshot': 40000,    'display': 10,    'average_loss': 10,    'type': "SGD",    'solver_mode': solver_mode,    'device_id': device_id,    'debug_info': False,    'snapshot_after_train': True,    # Test parameters    'test_iter': [test_iter],    'test_interval': 10000,    'eval_type': "detection",    'ap_version': "11point",    'test_initialization': False,    }# parameters for generating detection output.det_out_param = {    'num_classes': num_classes,    'share_location': share_location,    'background_label_id': background_label_id,    'nms_param': {   'nms_threshold': 0.45, 'top_k': 400},    'save_output_param': {        'output_directory': output_result_dir,        'output_name_prefix': "comp4_det_test_",        'output_format': "VOC",        'label_map_file': label_map_file,        'name_size_file': name_size_file,        'num_test_image': num_test_image,        },    'keep_top_k': 200,    'confidence_threshold': 0.01,    'code_type': code_type,    }# parameters for evaluating detection results.det_eval_param = {    'num_classes': num_classes,    'background_label_id': background_label_id,    'overlap_threshold': 0.5,    'evaluate_difficult_gt': False,    'name_size_file': name_size_file,    }### Hopefully you don't need to change the following #### Check file.check_if_exist(train_data)check_if_exist(test_data)check_if_exist(label_map_file)check_if_exist(pretrain_model)make_if_not_exist(save_dir)make_if_not_exist(job_dir)make_if_not_exist(snapshot_dir)# Create train net.net = caffe.NetSpec()net.data, net.label = CreateAnnotatedDataLayer(train_data, batch_size=batch_size_per_device,        train=True, output_label=True, label_map_file=label_map_file,        transform_param=train_transform_param, batch_sampler=batch_sampler)VGGNetBody(net, from_layer='data', fully_conv=True, reduced=True, dilated=True,    dropout=False, freeze_layers=freeze_layers)AddExtraLayers(net, use_batchnorm)mbox_layers = CreateMultiBoxHead(net, data_layer='data', from_layers=mbox_source_layers,        use_batchnorm=use_batchnorm, min_sizes=min_sizes, max_sizes=max_sizes,        aspect_ratios=aspect_ratios, normalizations=normalizations,        num_classes=num_classes, share_location=share_location, flip=flip, clip=clip,        prior_variance=prior_variance, kernel_size=3, pad=1)# Create the MultiBoxLossLayer.name = "mbox_loss"mbox_layers.append(net.label)net[name] = L.MultiBoxLoss(*mbox_layers, multibox_loss_param=multibox_loss_param,        loss_param=loss_param, include=dict(phase=caffe_pb2.Phase.Value('TRAIN')),        propagate_down=[True, True, False, False])with open(train_net_file, 'w') as f:    print('name: "{}_train"'.format(model_name), file=f)    print(net.to_proto(), file=f)shutil.copy(train_net_file, job_dir)# Create test net.net = caffe.NetSpec()net.data, net.label = CreateAnnotatedDataLayer(test_data, batch_size=test_batch_size,        train=False, output_label=True, label_map_file=label_map_file,        transform_param=test_transform_param)VGGNetBody(net, from_layer='data', fully_conv=True, reduced=True, dilated=True,    dropout=False, freeze_layers=freeze_layers)AddExtraLayers(net, use_batchnorm)mbox_layers = CreateMultiBoxHead(net, data_layer='data', from_layers=mbox_source_layers,        use_batchnorm=use_batchnorm, min_sizes=min_sizes, max_sizes=max_sizes,        aspect_ratios=aspect_ratios, normalizations=normalizations,        num_classes=num_classes, share_location=share_location, flip=flip, clip=clip,        prior_variance=prior_variance, kernel_size=3, pad=1)conf_name = "mbox_conf"if multibox_loss_param["conf_loss_type"] == P.MultiBoxLoss.SOFTMAX:  reshape_name = "{}_reshape".format(conf_name)  net[reshape_name] = L.Reshape(net[conf_name], shape=dict(dim=[0, -1, num_classes]))  softmax_name = "{}_softmax".format(conf_name)  net[softmax_name] = L.Softmax(net[reshape_name], axis=2)  flatten_name = "{}_flatten".format(conf_name)  net[flatten_name] = L.Flatten(net[softmax_name], axis=1)  mbox_layers[1] = net[flatten_name]elif multibox_loss_param["conf_loss_type"] == P.MultiBoxLoss.LOGISTIC:  sigmoid_name = "{}_sigmoid".format(conf_name)  net[sigmoid_name] = L.Sigmoid(net[conf_name])  mbox_layers[1] = net[sigmoid_name]net.detection_out = L.DetectionOutput(*mbox_layers,    detection_output_param=det_out_param,    include=dict(phase=caffe_pb2.Phase.Value('TEST')))net.detection_eval = L.DetectionEvaluate(net.detection_out, net.label,    detection_evaluate_param=det_eval_param,    include=dict(phase=caffe_pb2.Phase.Value('TEST')))with open(test_net_file, 'w') as f:    print('name: "{}_test"'.format(model_name), file=f)    print(net.to_proto(), file=f)shutil.copy(test_net_file, job_dir)# Create deploy net.# Remove the first and last layer from test net.deploy_net = netwith open(deploy_net_file, 'w') as f:    net_param = deploy_net.to_proto()    # Remove the first (AnnotatedData) and last (DetectionEvaluate) layer from test net.    del net_param.layer[0]    del net_param.layer[-1]    net_param.name = '{}_deploy'.format(model_name)    net_param.input.extend(['data'])    net_param.input_shape.extend([        caffe_pb2.BlobShape(dim=[1, 3, resize_height, resize_width])])    print(net_param, file=f)shutil.copy(deploy_net_file, job_dir)# Create solver.solver = caffe_pb2.SolverParameter(        train_net=train_net_file,        test_net=[test_net_file],        snapshot_prefix=snapshot_prefix,        **solver_param)with open(solver_file, 'w') as f:    print(solver, file=f)shutil.copy(solver_file, job_dir)max_iter = 0# Find most recent snapshot.for file in os.listdir(snapshot_dir):  if file.endswith(".solverstate"):    basename = os.path.splitext(file)[0]    iter = int(basename.split("{}_iter_".format(model_name))[1])    if iter > max_iter:      max_iter = itertrain_src_param = '--weights="{}" \\\n'.format(pretrain_model)if resume_training:  if max_iter > 0:    train_src_param = '--snapshot="{}_iter_{}.solverstate" \\\n'.format(snapshot_prefix, max_iter)if remove_old_models:  # Remove any snapshots smaller than max_iter.  for file in os.listdir(snapshot_dir):    if file.endswith(".solverstate"):      basename = os.path.splitext(file)[0]      iter = int(basename.split("{}_iter_".format(model_name))[1])      if max_iter > iter:        os.remove("{}/{}".format(snapshot_dir, file))    if file.endswith(".caffemodel"):      basename = os.path.splitext(file)[0]      iter = int(basename.split("{}_iter_".format(model_name))[1])      if max_iter > iter:        os.remove("{}/{}".format(snapshot_dir, file))# Create job file.with open(job_file, 'w') as f:  f.write('cd {}\n'.format(caffe_root))  f.write('./build/tools/caffe train \\\n')  f.write('--solver="{}" \\\n'.format(solver_file))  f.write(train_src_param)  if solver_param['solver_mode'] == P.Solver.GPU:    f.write('--gpu {} 2>&1 | tee {}/{}.log\n'.format(gpus, job_dir, model_name))  else:    f.write('2>&1 | tee {}/{}.log\n'.format(job_dir, model_name))# Copy the python script to job_dir.py_file = os.path.abspath(__file__)shutil.copy(py_file, job_dir)# Run the job.os.chmod(job_file, stat.S_IRWXU)if run_soon:  subprocess.call(job_file, shell=True)
训练命令: python examples/ssd/ssd_pascal_indoor.py 4 测试 SSD框架中提供了测试代码,有C++版本和python版本

 4.1 c++版本

编译完SSD后,C++版本的的可执行文件存放目录: .build_release/examples/ssd/ssd_detect.bin 测试命令 ./.build_release/examples/ssd/ssd_detect.bin models/VGGNet/indoor/deploy.prototxt   models/VGGNet/indoor/VGG_VOC0712_SSD_300x300_iter_60000.caffemodel pictures.txt 其中pictures.txt中保存的是待测试图片的list

 4.2 python版本

    python 版本的测试过程参见examples/detection.ipynb

参考:  1 2
你可能感兴趣的文章
MySQL内存表使用技巧
查看>>
mysql函数汇总之条件判断函数
查看>>
mysql函数汇总之系统信息函数
查看>>
MySQL函数简介
查看>>
mysql函数遍历json数组
查看>>
MySQL函数(转发)
查看>>
mysql分区表
查看>>
MySQL分层架构与运行机制详解
查看>>
mysql分库分表中间件简书_MySQL分库分表
查看>>
MySQL分库分表会带来哪些问题?分库分表问题
查看>>
MySQL分组函数
查看>>
MySQL分组查询
查看>>
Mysql分表后同结构不同名称表之间复制数据以及Update语句只更新日期加减不更改时间
查看>>
mySql分页Iimit优化
查看>>
mysql列转行函数是什么
查看>>
mysql创建函数报错_mysql在创建存储函数时报错
查看>>
mysql加强(4)~多表查询:笛卡尔积、消除笛卡尔积操作(等值、非等值连接),内连接(隐式连接、显示连接)、外连接、自连接
查看>>
mysql加强(5)~DML 增删改操作和 DQL 查询操作
查看>>
mysql加强(6)~子查询简单介绍、子查询分类
查看>>
MySqL双机热备份(二)--MysqL主-主复制实现
查看>>