1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
| class ImageTransform(): def __init__(self,resize,mean,std): self.data_transform={ 'train':transforms.Compose( [ transforms.RandomResizedCrop(resize,scale=(0.5,1.0)), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean,std) ] ), 'val':transforms.Compose( [ transforms.Resize(resize), transforms.CenterCrop(resize), transforms.ToTensor(), transforms.Normalize(mean,std) ] ) }
def __call__(self,img,phase='train'): return self.data_transform[phase](img) size=224 mean=(0.485,0.456,0.406) std=(0.229,0.224,0.225)
def make_datapath_list(phase='train'): rootpath='./pytorch_advanced-master/1_image_classification/data/hymenoptera_data/' target_path=osp.join(rootpath+phase+'/**/*.jpg') path_list=[] for path in glob.glob(target_path): path_list.append(path) return path_list
train_list=make_datapath_list(phase='train') val_list=make_datapath_list(phase='val')
class Hdataset(data.Dataset): def __init__(self,file_list,transform=None,phase='train'): self.file_list=file_list self.transform=transform self.phase=phase
def __len__(self): return len(self.file_list)
def __getitem__(self, index): img_path=self.file_list[index] img=Image.open(img_path) img_transformed=self.transform(img,self.phase)
if self.phase=='train': label=img_path[77:81] elif self.phase=='val': label=img_path[75:79]
if label=='ants': label=0 elif label=='bees': label=1
return img_transformed,label
train_dataset=Hdataset(file_list=train_list,transform=ImageTransform(size,mean,std),phase='train') val_dataset=Hdataset(file_list=val_list,transform=ImageTransform(size,mean,std),phase='val')
|