tek4

Dataset Và DataLoader Trong Pytorch

by - September. 21, 2021
Kiến thức
Machine Learning
Python
<p style="text-align: justify;">B&agrave;i viết n&agrave;y sẽ tiếp tục tr&igrave;nh b&agrave;y về quy tr&igrave;nh chuẩn bị dữ liệu cho dự &aacute;n, cụ thể l&agrave; ta sẽ thấy c&aacute;ch l&agrave;m việc với c&aacute;c lớp <strong>Dataset</strong> v&agrave; <strong>DataLoader</strong> của PyTorch. Bắt đầu th&ocirc;i!</p> <p><img style="width: 591px; display: block; margin-left: auto; margin-right: auto;" src="http://tek4vn.2soft.top/public_files/dataset-va-dataloader-png-1" alt="Dataset V&agrave; DataLoader" height="310" /></p> <h3 class="section-heading" style="text-align: justify;">PyTorch Dataset:&nbsp; L&agrave;m việc với bộ training set</h3> <p style="text-align: justify;">H&atilde;y bắt đầu bằng c&aacute;ch xem x&eacute;t một số thao t&aacute;c m&agrave; ch&uacute;ng ta c&oacute; thể thực hiện để hiểu r&otilde; hơn về dữ liệu của m&igrave;nh.</p> <h4 class="sub-section-heading" style="text-align: justify;">Kh&aacute;m Ph&aacute; Dữ Liệu</h4> <p style="text-align: justify;">Để xem c&oacute; bao nhi&ecirc;u h&igrave;nh ảnh trong tập huấn luyện, ch&uacute;ng ta c&oacute; thể kiểm tra độ d&agrave;i của tập dữ liệu bằng c&aacute;ch sử dụng h&agrave;m Python <em>len():</em></p> <pre class="language-python"><code>&gt; len(train_set) 60000</code></pre> <p style="text-align: justify;">Giả sử ch&uacute;ng ta muốn xem c&aacute;c nh&atilde;n cho mỗi h&igrave;nh ảnh, ch&uacute;ng ta c&oacute; thể l&agrave;m như dưới đ&acirc;y.</p> <p style="text-align: justify;">Lưu &yacute; rằng API torchvision đ&atilde; được thay đổi bắt đầu từ phi&ecirc;n bản 0.2.1. Bạn c&oacute; thể xem th&ecirc;m <a href="https://github.com/pytorch/vision/releases/tag/v0.2.2" target="_blank" rel="noopener">tại đ&acirc;y</a>.</p> <pre class="language-python"><code># Trước phi&ecirc;n bản torchvision 0.2.2 &gt; train_set.train_labels tensor([9, 0, 0, ..., 3, 0, 5]) # Bắt đầu từ torchvision 0.2.2 &gt; train_set.targets tensor([9, 0, 0, ..., 3, 0, 5])</code></pre> <p style="text-align: justify;">H&igrave;nh đầu ti&ecirc;n l&agrave; số 9 v&agrave; hai h&igrave;nh tiếp theo l&agrave; số 0. H&atilde;y xem lại b&agrave;i&nbsp;<a href="https://tek4.vn/du-lieu-trong-deep-learning-lap-trinh-neural-network-bai-13/" target="_blank" rel="noopener">Dữ Liệu Trong Deep Learning</a>&nbsp;để thấy được c&aacute;c số n&agrave;y đại diện cho mặt h&agrave;ng thời trang n&agrave;o. V&iacute; dụ: số 9 l&agrave;&nbsp;Ankle boot v&agrave; s&ocirc; 0 l&agrave;&nbsp;T-shirt/top.</p> <p><img style="width: 293px; display: block; margin-left: auto; margin-right: auto;" src="http://tek4vn.2soft.top/public_files/fashion-mnist-t-shirt-png" alt="fashion-mnist-t-shirt" height="296" /></p> <p style="text-align: justify;">Nếu ch&uacute;ng ta muốn xem số lượng mẫu của mỗi nh&atilde;n tồn tại trong tập dữ liệu, ch&uacute;ng ta c&oacute; thể sử dụng h&agrave;m PyTorch <em>bincount()</em> như sau</p> <p style="text-align: justify;">Lưu &yacute; rằng API torchvision đ&atilde; được thay đổi bắt đầu từ phi&ecirc;n bản 0.2.1. Bạn c&oacute; thể xem th&ecirc;m <a href="https://github.com/pytorch/vision/releases/tag/v0.2.2" target="_blank" rel="noopener">tại đ&acirc;y</a>.</p> <pre class="language-python"><code># Trước phi&ecirc;n bản torchvision 0.2.2 &gt; train_set.train_labels.bincount() tensor([6000, 6000, 6000, 6000, 6000, 6000, 6000, 6000, 6000, 6000]) # Từ torchvision 0.2.2 &gt; train_set.targets.bincount() tensor([6000, 6000, 6000, 6000, 6000, 6000, 6000, 6000, 6000, 6000])</code></pre> <h4 class="sub-section-heading" style="text-align: justify;">Balanced v&agrave; Unbalanced Datasets</h4> <p style="text-align: justify;">Từ v&iacute; dụ tr&ecirc;n cho ta thấy rằng bộ dữ liệu&nbsp;Fashion-MNIST c&oacute; số lượng mẫu của mỗi nh&atilde;n c&acirc;n bằng nhau, ch&uacute;ng ta c&oacute; 6000 mẫu cho mỗi nh&atilde;n.&nbsp;Kết quả l&agrave;, tập dữ liệu n&agrave;y được cho l&agrave; c&acirc;n bằng (balanced).&nbsp;Nếu c&aacute;c nh&atilde;n c&oacute; số lượng mẫu kh&aacute;c nhau, ch&uacute;ng ta sẽ gọi l&agrave; tập dữ liệu kh&ocirc;ng c&acirc;n bằng (unbalanced).</p> <p style="text-align: justify;">Mất c&acirc;n bằng (<em>Class imbalance</em>) l&agrave; một vấn đề phổ biến, nhưng trong trường hợp n&agrave;y, ch&uacute;ng ta thấy rằng tập dữ liệu Fashion-MNIST thực sự c&acirc;n bằng, v&igrave; vậy ch&uacute;ng ta kh&ocirc;ng cần lo lắng về điều đ&oacute; cho dự &aacute;n của m&igrave;nh.</p> <p style="text-align: justify;">Để đọc th&ecirc;m về c&aacute;c c&aacute;ch giảm thiểu mất c&acirc;n bằng dữ liệu trong Deep learning h&atilde;y xem b&agrave;i b&aacute;o n&agrave;y: <a href="https://arxiv.org/abs/1710.05381">A systematic study of the class imbalance problem in convolutional neural networks.</a></p> <h4 class="sub-section-heading" style="text-align: justify;">Truy Cập Dữ Liệu Trong Tập Huấn Luyện</h4> <p style="text-align: justify;">Để truy cập một phần tử ri&ecirc;ng lẻ từ tập huấn luyện, trước ti&ecirc;n ch&uacute;ng ta cần chuyển đối tượng <em>train_set</em> tới h&agrave;m t&iacute;ch hợp sẵn của Python l&agrave; <em>iter(),</em> h&agrave;m n&agrave;y trả về một đối tượng đại diện cho một luồng dữ liệu.</p> <p style="text-align: justify;">Với luồng dữ liệu, ch&uacute;ng ta c&oacute; thể sử dụng h&agrave;m <em>next()</em> t&iacute;ch hợp sẵn trong Python để lấy phần tử dữ liệu tiếp theo trong luồng dữ liệu.</p> <pre class="language-python"><code>&gt; sample = next(iter(train_set)) &gt; len(sample) 2</code></pre> <p style="text-align: justify;">Sau khi truyền sample cho h&agrave;m <em>len(),</em> ch&uacute;ng ta c&oacute; thể thấy rằng sample chứa hai mục v&agrave; điều n&agrave;y l&agrave; do tập dữ liệu chứa c&aacute;c cặp nh&atilde;n v&agrave; h&igrave;nh ảnh.&nbsp;Mỗi mẫu m&agrave; ch&uacute;ng ta lấy từ tập huấn luyện chứa dữ liệu h&igrave;nh ảnh dưới dạng tensor v&agrave; nh&atilde;n tương ứng dưới dạng tensor.</p> <p style="text-align: justify;">V&igrave; sample l&agrave; loại&nbsp;<a href="https://docs.python.org/3/library/stdtypes.html#typesseq">sequence</a>&nbsp;n&ecirc;n ch&uacute;ng ta c&oacute; thể sử dụng <em>sequence unpacking</em> để g&aacute;n h&igrave;nh ảnh v&agrave; nh&atilde;n.&nbsp;B&acirc;y giờ ch&uacute;ng ta sẽ kiểm tra type của h&igrave;nh ảnh v&agrave; nh&atilde;n, c&oacute; thể thấy ch&uacute;ng đều l&agrave; <em>torch.Tensor.</em></p> <pre class="language-python"><code>&gt; type(image) torch.Tensor # Trước phi&ecirc;n bản torchvision 0.2.2 &gt; type(label) torch.Tensor # Kể từ torchvision 0.2.2 &gt; type(label) int</code></pre> <p style="text-align: justify;">Ch&uacute;ng ta sẽ kiểm tra shape để thấy rằng h&igrave;nh ảnh l&agrave; tensor 1 x 28 x 28 trong khi nh&atilde;n l&agrave; tensor c&oacute; gi&aacute; trị v&ocirc; hướng:</p> <pre class="language-python"><code>&gt; image.shape torch.Size([1, 28, 28]) &gt; torch.tensor(label).shape torch.Size([])</code></pre> <p style="text-align: justify;">Ch&uacute;ng ta cũng sẽ gọi h&agrave;m <em>squeeze()</em> tr&ecirc;n h&igrave;nh ảnh để xem c&aacute;ch ch&uacute;ng ta c&oacute; thể x&oacute;a dimension của k&iacute;ch thước 1:</p> <pre class="language-python"><code>&gt; image.squeeze().shape torch.Size([28, 28])</code></pre> <p style="text-align: justify;">Để hiển thị h&igrave;nh ảnh ta l&agrave;m như sau:</p> <pre class="language-python"><code>&gt; plt.imshow(image.squeeze(), cmap="gray") &gt; torch.tensor(label) tensor(9)</code></pre> <p><img style="width: 280px; display: block; margin-left: auto; margin-right: auto;" src="http://tek4vn.2soft.top/public_files/fashion-mnist-ankle-boot-png" alt="fashion-mnist-ankle-boot" height="278" /></p> <p style="text-align: justify;">Ch&uacute;ng ta nhận được một h&igrave;nh&nbsp;ankle-boot v&agrave; nh&atilde;n số 9.&nbsp;Ch&uacute;ng ta đ&atilde; biết rằng nh&atilde;n số 9 đại diện cho một đ&ocirc;i gi&agrave;y ankle-boot v&igrave; n&oacute; đ&atilde; được chỉ định trong b&agrave;i viết trước đ&oacute;.</p> <p style="text-align: justify;">Tiếp theo ch&uacute;ng ta h&atilde;y xem c&aacute;ch l&agrave;m việc với data loader.</p> <h3 class="section-heading" style="text-align: justify;">PyTorch DataLoader: L&agrave;m việc với h&agrave;ng loạt dữ liệu</h3> <p style="text-align: justify;">Ch&uacute;ng ta sẽ bắt đầu bằng c&aacute;ch tạo một tr&igrave;nh tải dữ liệu mới với batch size nhỏ l&agrave; 10 để dễ d&agrave;ng chứng minh điều g&igrave; đang xảy ra:</p> <pre class="language-python"><code>&gt; display_loader = torch.utils.data.DataLoader( train_set, batch_size=10 )</code></pre> <p style="text-align: justify;">Ch&uacute;ng ta sẽ nhận được 1 batch từ loader theo c&aacute;ch giống như ch&uacute;ng ta đ&atilde; l&agrave;m với tập huấn luyện ph&iacute;a tr&ecirc;n. Sử dụng h&agrave;m <em>iter()</em> v&agrave; <em>next().</em></p> <pre class="language-python"><code>&gt; batch = next(iter(display_loader)) &gt; print('len:', len(batch)) len: 2</code></pre> <p style="text-align: justify;">Kiểm tra độ d&agrave;i của batch được trả về, ch&uacute;ng ta cũng nhận được len = 2 giống như tập huấn luyện. H&atilde;y cũng unpack batch v&agrave; xem x&eacute;t 2 tensor cũng như shape của ch&uacute;ng:</p> <pre class="language-python"><code>&gt; images, labels = batch &gt; print('types:', type(images), type(labels)) &gt; print('shapes:', images.shape, labels.shape) types: &lt;class 'torch.Tensor'&gt; &lt;class 'torch.Tensor'&gt; shapes: torch.Size([10, 1, 28, 28]) torch.Size([10])</code></pre> <p style="text-align: justify;">V&igrave; <em>batch_size=10&nbsp;</em>n&ecirc;n ch&uacute;ng ta biết rằng ch&uacute;ng ta đang xử l&yacute; một loạt 10 h&igrave;nh ảnh v&agrave; 10 nh&atilde;n tương ứng.</p> <p style="text-align: justify;">Quan s&aacute;t shape ta thấy n&oacute; kh&aacute;c khi ch&uacute;ng ta kiểm tra với 1 mẫu, thay v&igrave; c&oacute; một gi&aacute; trị v&ocirc; hướng duy nhất l&agrave;m nh&atilde;n, ch&uacute;ng ta lại c&oacute; một tensor rank-1 với 10 gi&aacute; trị. Trường hợp n&agrave;y shape được x&aacute;c định như sau:</p> <pre><code>(batch size, số k&ecirc;nh m&agrave;u, chiều cao h&igrave;nh ảnh, chiều rộng h&igrave;nh ảnh)</code></pre> <p style="text-align: justify;">Batch size l&agrave; 10, đ&oacute; l&agrave; l&yacute; do tại sao b&acirc;y giờ ch&uacute;ng ta c&oacute; s&ocirc; 10 trong shape.</p> <p style="text-align: justify;">Để vẽ một loạt h&igrave;nh ảnh, ch&uacute;ng ta c&oacute; thể sử dụng h&agrave;m <em>torchvision.utils.make_grid()</em>&nbsp;như sau:</p> <pre class="language-python"><code>&gt; grid = torchvision.utils.make_grid(images, nrow=10) &gt; plt.figure(figsize=(15,15)) &gt; plt.imshow(np.transpose(grid, (1,2,0))) &gt; print('labels:', labels) labels: tensor([9, 0, 0, 3, 0, 2, 7, 2, 5, 5])</code></pre> <p><img style="width: 715px; display: block; margin-left: auto; margin-right: auto;" src="http://tek4vn.2soft.top/public_files/fashion-mnist-grid-sample-png" alt="fashion-mnist-grid-sample" height="100" /></p> <p style="text-align: justify;">Ch&uacute; &yacute;: Ta c&oacute; thể sử dụng&nbsp;<em>permute()</em> thay thế cho&nbsp;<em>np.transpose()</em> ở tr&ecirc;n như sau:</p> <pre class="language-python"><code>&gt; grid = torchvision.utils.make_grid(images, nrow=10) &gt; plt.figure(figsize=(15,15)) &gt; plt.imshow(grid.permute(1,2,0)) &gt; print('labels:', labels) labels: tensor([9, 0, 0, 3, 0, 2, 7, 2, 5, 5])</code></pre> <p style="text-align: justify;"><img style="width: 715px; display: block; margin-left: auto; margin-right: auto;" src="http://tek4vn.2soft.top/public_files/fashion-mnist-grid-sample-png" alt="fashion-mnist-grid-sample" height="100" /></p> <h3 class="section-heading" style="text-align: justify;">C&aacute;ch Vẽ C&aacute;c H&igrave;nh Ảnh Bằng PyTorch DataLoader</h3> <p style="text-align: justify;">Đ&acirc;y l&agrave; một c&aacute;ch kh&aacute;c để vẽ c&aacute;c h&igrave;nh ảnh bằng PyTorch DataLoader:</p> <pre class="language-python"><code>how_many_to_plot = 20 train_loader = torch.utils.data.DataLoader( train_set, batch_size=1, shuffle=True ) plt.figure(figsize=(50,50)) for i, batch in enumerate(train_loader, start=1): image, label = batch plt.subplot(10,10,i) plt.imshow(image.reshape(28,28), cmap='gray') plt.axis('off') plt.title(train_set.classes[label.item()], fontsize=28) if (i &gt;= how_many_to_plot): break plt.show()</code></pre> <p style="text-align: justify;"><img style="width: 100%;" src="http://tek4vn.2soft.top/public_files/fashion-mnist-grid-sample-2-png" alt="fashion-mnist-grid-sample-2" /></p> <h3 style="text-align: justify;">Kết Luận</h3> <p style="text-align: justify;">B&acirc;y giờ ch&uacute;ng ta đ&atilde; hiểu r&otilde; về c&aacute;ch kh&aacute;m ph&aacute; v&agrave; tương t&aacute;c với <em>Datasets</em> v&agrave; <em>DataLoaders.</em>&nbsp;Cả hai điều n&agrave;y sẽ được chứng minh l&agrave; quan trọng khi ch&uacute;ng ta bắt đầu x&acirc;y dựng mạng neural t&iacute;ch chập v&agrave; v&ograve;ng lặp đ&agrave;o tạo của m&igrave;nh.&nbsp;Tr&ecirc;n thực tế, bộ&nbsp;<em>DataLoaders</em> sẽ được sử dụng trực tiếp b&ecirc;n trong v&ograve;ng lặp đ&agrave;o tạo. Hi vọng bạn th&iacute;ch b&agrave;i viết n&agrave;y. Hẹn gặp lại bạn trong b&agrave;i viết tiếp theo tr&ecirc;n <a href="https://tek4.vn/">tek4.vn</a>.</p> <p style="text-align: justify;">&nbsp;</p> <hr /> <p style="text-align: center;"><em><strong>Fanpage Facebook:</strong>&nbsp;<a href="https://www.facebook.com/tek4.vn/">TEK4.VN</a></em>&nbsp;</p> <p style="text-align: center;"><em><strong>Tham gia cộng đồng để chia sẻ, trao đổi v&agrave; thảo luận:</strong>&nbsp;<a href="https://www.facebook.com/groups/tek4.vn/">TEK4.VN - Học Lập Tr&igrave;nh Miễn Ph&iacute;</a></em></p>