本章节主要讲述了对Pascal VOC数据集的信息分析,将标注好的xml文件内容存储到annotation_data的数组中,以便于后面进行读取,进行检测与分类,代码解析也写到了代码里面,方便查看。
import os import cv2 import xml.etree.ElementTree as ET##解析xml文件的编译器 import numpy as np def get_data(input_path):#解析文件的路径 all_imgs = [] classes_count = {} class_mapping = {} visualise = False data_paths = [os.path.join(input_path,s) for s in ['VOC2007', 'VOC2012']]##os.path.join()进行路径拼接 print('Parsing annotation files') ##对每个路径的图片进行分析,并将图片的内容保存到all_imgs中 for data_path in data_paths: annot_path = os.path.join(data_path, 'Annotations')#存放.xml文件 imgs_path = os.path.join(data_path, 'JPEGImages')#存放.jpg文件 imgsets_path_trainval = os.path.join(data_path, 'ImageSets','Main','trainval.txt')#存放训练集图片的名称 imgsets_path_test = os.path.join(data_path, 'ImageSets','Main','test.txt')#存放测试集图片的名称 trainval_files = [] test_files = [] try: with open(imgsets_path_trainval) as f: for line in f: trainval_files.append(line.strip() + '.jpg') except Exception as e: print(e) try: with open(imgsets_path_test) as f: for line in f: test_files.append(line.strip() + '.jpg') except Exception as e: if data_path[-7:] == 'VOC2012': # this is expected, most pascal voc distibutions dont have the test.txt file pass else: print(e) annots = [os.path.join(annot_path, s) for s in os.listdir(annot_path)]#存放所有图片xml文件的路径 idx = 0 for annot in annots: try: idx += 1 et = ET.parse(annot) element = et.getroot() element_objs = element.findall('object') element_filename = element.find('filename').text element_width = int(element.find('size').find('width').text) element_height = int(element.find('size').find('height').text) if len(element_objs) > 0: annotation_data = {'filepath': os.path.join(imgs_path, element_filename), 'width': element_width, 'height': element_height, 'bboxes': []}#存放图片的基本信息 if element_filename in trainval_files: annotation_data['imageset'] = 'trainval' elif element_filename in test_files: annotation_data['imageset'] = 'test' else: annotation_data['imageset'] = 'trainval' for element_obj in element_objs: class_name = element_obj.find('name').text if class_name not in classes_count: classes_count[class_name] = 1 else: classes_count[class_name] += 1 if class_name not in class_mapping: class_mapping[class_name] = len(class_mapping)##标记 obj_bbox = element_obj.find('bndbox') x1 = int(round(float(obj_bbox.find('xmin').text))) y1 = int(round(float(obj_bbox.find('ymin').text))) x2 = int(round(float(obj_bbox.find('xmax').text))) y2 = int(round(float(obj_bbox.find('ymax').text))) difficulty = int(element_obj.find('difficult').text) == 1 annotation_data['bboxes'].append( {'class': class_name, 'x1': x1, 'x2': x2, 'y1': y1, 'y2': y2, 'difficult': difficulty}) all_imgs.append(annotation_data) if visualise: img = cv2.imread(annotation_data['filepath']) for bbox in annotation_data['bboxes']: cv2.rectangle(img, (bbox['x1'], bbox['y1']), (bbox[ 'x2'], bbox['y2']), (0, 0, 255)) cv2.imshow('img', img) cv2.waitKey(0) except Exception as e: print(e) continue return all_imgs, classes_count, class_mapping该函数返回参数:
all_imgs:存储的是每一张图片的内容,filename,width,height,imageSet(训练集/测试集),Bbox(类别、坐标以及difficult);
classes_count:存储的是训练集中,每一类的总数量;
class_mapping:存储的是一个字典key:value代表着:class:类别