今天在使用PyTorch中Dataset遇到了一個問題。先看代碼
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
|
class psDataset(Dataset): def __init__( self , x, y, transforms = None ): super (Dataset, self ).__init__() self .x = x self .y = y if transforms = = None : self .transforms = Compose([Resize(( 224 , 224 )), ToTensor()]) else : self .transforms = transforms def __len__( self ): return len ( self .x) def __getitem__( self , idx): img = Image. open ( self .x[idx]) img = self .transforms(img) return img, torch.tensor([[ self .y[idx]]]) |
結(jié)果運(yùn)行時報錯:RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0. Got 3 and 1 in dimension 1 at /opt/conda/conda-bld/pytorch_1522182087074/work/torch/lib/TH/generic/THTensorMath.c:2897
Google了一下發(fā)現(xiàn)是這樣的:讀入的圖片有些是灰度圖(1個通道),絕大多數(shù)是RGB圖片(3通道),也有些是帶透明度的(4通道)
。這導(dǎo)致在讀入后最后一個維度(通道數(shù))不一致(可能是1、3或者4)。
Dataloader在制作batch data時,tensor的shape必須一樣,就報了這個錯誤。解決的方法是:img = img.convert(“RGB”)。完
整代碼如下:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
|
class psDataset(Dataset): def __init__( self , x, y, transforms = None ): super (Dataset, self ).__init__() self .x = x self .y = y if transforms = = None : self .transforms = Compose([Resize(( 224 , 224 )), ToTensor()]) else : self .transforms = transforms def __len__( self ): return len ( self .x) def __getitem__( self , idx): img = Image. open ( self .x[idx]) img = img.convert( "RGB" ) img = self .transforms(img) return img, torch.tensor([[ self .y[idx]]]) |
以上這篇PyTorch 解決Dataset和Dataloader遇到的問題就是小編分享給大家的全部內(nèi)容了,希望能給大家一個參考,也希望大家多多支持服務(wù)器之家。
原文鏈接:https://blog.csdn.net/xgbm_k/article/details/84067245