主页 > 游戏开发  > 

数据加载器


数据集类必要的三个方法:

init:从哪里读取数据,写路径

getitem:获取数据的内容

len:获取数据的总个数

数据加载器类相当于一个容器,你把数据放进去由它进行下一步管理。

注意文件的路径

import torch from torch.utils.data import Dataset,DataLoader data_path = r"C:\Users\Li\Desktop\learnPytorch\2\smsspamcollection\SMSSpamCollection" #完成数据集类 class MyDateset(Dataset): def __init__(self): self.lines = open(data_path,'r',encoding='utf-8').readlines() def __getitem__(self, index): #获取索引对应位置的一条数据 这里有无strip都无关紧要 cur_line = self.lines[index].strip() #将是否是骚扰和短信的内容分开。label表示标记,content是里面的内容 label = cur_line[:4].strip() content = cur_line[4:].strip() return label,content def __len__(self): #返回数据的总数量 return len(self.lines) my_dataset = MyDateset() data_loader = DataLoader(dataset=my_dataset,batch_size=2,shuffle=True) if __name__ == '__main__': # my_dataset = MyDateset() # print(my_dataset[0]) # print(my_dataset.__len__()) for i in data_loader: print(i)

标签:

数据加载器由讯客互联游戏开发栏目发布,感谢您对讯客互联的认可,以及对我们原创作品以及文章的青睐,非常欢迎各位朋友分享到个人网站或者朋友圈,但转载请说明文章出处“数据加载器