Super Kawaii Cute Cat Kaoani [Pytorch] Dataloader와 Sampler

연구/PyTorch

[Pytorch] Dataloader와 Sampler

치킨고양이짱아 2023. 10. 7. 17:13
728x90
728x90

모든 pytorch의 dataloader는 sampler라는걸 가지게 된다.

* RandomSampler

DataLoader(dataset=train_dataset, shuffle = True, batch_size = 1)

위와 같이 shuffle=True로 세팅하게 되면 dataloader의 sampler는 자동으로 RandomSampler로 선택된다. 만약 RandomSampler가 아닌 내가 원하는 방식대로 동작하는 sampler를 따로 지정해주고 싶다면 shuffle= False로 세팅하여야한다.

 

* SubsetRandomSampler

shuffle=True일때는 전체 dataset에서 data를 andom하게 뽑게 된다. 만약 전체 dataset에서가 아닌 일부 subset에서 data random하게 뽑고 싶을 때는 SubsetRandomSampler를 사용하면 된다. (예를 들어 train dataset과 valid dataset을 나눈다고 하면, training 시킬 때 사용해야하는 data는 전체 dataset이 아닌 train dataset이다. 이런 경우에 SubsetRandomSampler를 사용하면 된다.)

parameter로는 해당되는 subset의 인덱스를 넣어주면 된다.

ex) 예제코드

    # train_set : torchvision.datasets에서 가져온 Training 데이터셋
    num_train = len(train_set)          # 50,000
    indices = list( range(num_train) )  # 0부터 50,000까지의 숫자가 순서대로 나열된 리스트
    np.random.shuffle(indices)          # 0부터 50,000까지의 숫자가 순서없이 뒤섞인 리스트가 됨
    split = int( np.floor( 0.2 * num_train )  # 10,000
    train_idx, valid_idx = indices[split:], indices[:split]

    train_sampler = SubsetRandomSampler(train_idx)
    valid_sampler = SubsetRandomSampler(valid_idx)

    train_loader = torch.utils.data.DataLoader(train_data, ... , sampler=train_sampler)
    valid_loader = torch.utils.data.DataLoader(train_data, ... , sampler=valid_sampler)

 

* SequentialSampler

이 sampler를 사용하면 그냥 주어진 dataset에서 순서대로 뽑게된다.

 

* BatchSampler

BatchSampelr를 사용하면 batch 단위로 sampling하는 거이 가능하다.

Dataloader에 batch_size를 넣어주는거랑 BatchSampler에 batch_size를 넣어주는게 무슨 차이일까 했는데 BatchSampler는 batch_size를 가지는 인덱스를 return하는거고, Dataloader는 batch_size 만큼의 data를 return 한다는 차이가 있었다.

BatchSampler의 parameter인 drop_last는 batch_size보다 사이즈가 작은 last batch를 버릴 것인지 결정하는 옵션이다.

BatchSampler를 Dataloader안에서 사용할 때 batch_sampler를 sampler로 지정해주어야한다.

loader = DataLoader(
    dataset=dataset,
    # This line below!
    sampler=BatchSampler(
        SequentialSampler(dataset), batch_size=self.hparams.batch_size, drop_last=False
    ),
    num_workers=self.hparams.num_data_workers,
)

 


이 외에도 다양한 sampler가 존재한다. WeightedRandomSampler 와 같은 sampler를 사용하면 비율 등을 조정하는 것도 가능하다.

728x90
728x90