tek4

Giới Thiệu Reshape Trong Pytorch

by - September. 21, 2021
Kiến thức
Machine Learning
Python
<p style="text-align: justify;"><em>Reshaping operations&nbsp;</em>l&agrave; hoạt động tensor quan trọng nhất. Bởi v&igrave;, shape của tensor c&oacute; thể cho ch&uacute;ng ta c&aacute;c th&ocirc;ng tin cụ thể m&agrave; ch&uacute;ng ta c&oacute; thể sử dụng để x&aacute;c định c&aacute;c thuộc t&iacute;nh kh&aacute;c trong tensor.</p> <p style="text-align: justify;"><img style="width: 657px; display: block; margin-left: auto; margin-right: auto;" src="http://tek4vn.2soft.top/public_files/1-6-png-1" alt="1-6" height="343" /></p> <h4 class="sub-section-heading" style="text-align: justify;">Tensor Shape Review</h4> <p style="text-align: justify;">Giả sử ch&uacute;ng ta c&oacute; tensor sau:</p> <pre class="language-python"><code>&gt; t = torch.tensor([ [1,1,1,1], [2,2,2,2], [3,3,3,3] ], dtype=torch.float32)</code></pre> <p style="text-align: justify;">Để x&aacute;c định <em>shape</em> của tensor n&agrave;y, ta đếm thấy c&oacute; 3 h&agrave;ng, 4 cột&nbsp; v&agrave; do đ&oacute; tensor n&agrave;y c&oacute; <em>shape</em> l&agrave; 3 x 4 v&agrave; <em>rank</em> 2.</p> <p style="text-align: justify;">Trong Pytorch ch&uacute;ng ta c&oacute; 2 c&aacute;ch để lấy shape:</p> <pre class="language-python"><code>&gt; t.size() torch.Size([3, 4]) &gt; t.shape torch.Size([3, 4])</code></pre> <p style="text-align: justify;">Trong Pytorch&nbsp;<em>size</em>&nbsp;and&nbsp;<em>shape&nbsp;</em>c&oacute; &yacute; nghĩa như nhau.&nbsp;Th&ocirc;ng thường, sau khi ch&uacute;ng ta biết shape của tensor, ch&uacute;ng ta c&oacute; thể suy ra một v&agrave;i điều.&nbsp;Đầu ti&ecirc;n, ch&uacute;ng ta c&oacute; thể suy ra rank của tensor. Rank&nbsp;của tensor bằng&nbsp;length của shape.</p> <pre class="language-python"><code>&gt; len(t.shape) 2</code></pre> <p style="text-align: justify;">Ch&uacute;ng ta cũng c&oacute; thể suy ra số phần tử chứa trong tensor.&nbsp;Số phần tử b&ecirc;n trong một tensor (trong v&iacute; dụ của ch&uacute;ng ta l&agrave; 12) bằng t&iacute;ch c&aacute;c gi&aacute; trị th&agrave;nh phần của shape.</p> <pre class="language-python"><code>&gt; torch.tensor(t.shape).prod() tensor(12)</code></pre> <p style="text-align: justify;">Trong PyTorch, c&oacute; một h&agrave;m d&agrave;nh ri&ecirc;ng cho việc n&agrave;y:</p> <pre class="language-python"><code>&gt; t.numel() 12</code></pre> <p style="text-align: justify;">Số lượng phần tử chứa trong một tensor rất quan trọng để reshape v&igrave; việc reshape phải t&iacute;nh đến tổng số phần tử hiện c&oacute;. Reshape&nbsp;thay đổi shape của tensor nhưng kh&ocirc;ng thay đổi dữ liệu cơ bản. Tensor của ch&uacute;ng ta c&oacute; 12 phần tử, v&igrave; vậy bất kỳ sự reshape n&agrave;o cũng phải c&ograve;n đ&uacute;ng 12 phần tử.</p> <h4 style="text-align: justify;">Reshape một tensor trong pytorch</h4> <p style="text-align: justify;">B&acirc;y giờ ch&uacute;ng ta h&atilde;y xem x&eacute;t tất cả c&aacute;c c&aacute;ch m&agrave; tensor <em>t</em>&nbsp;c&oacute; thể được reshape m&agrave; kh&ocirc;ng thay đổi rank:</p> <pre class="language-python"><code>&gt; t.reshape([1,12]) tensor([[1., 1., 1., 1., 2., 2., 2., 2., 3., 3., 3., 3.]]) &gt; t.reshape([2,6]) tensor([[1., 1., 1., 1., 2., 2.], [2., 2., 3., 3., 3., 3.]]) &gt; t.reshape([3,4]) tensor([[1., 1., 1., 1.], [2., 2., 2., 2.], [3., 3., 3., 3.]]) &gt; t.reshape([4,3]) tensor([[1., 1., 1.], [1., 2., 2.], [2., 2., 3.], [3., 3., 3.]]) &gt; t.reshape(6,2) tensor([[1., 1.], [1., 1.], [2., 2.], [2., 2.], [3., 3.], [3., 3.]]) &gt; t.reshape(12,1) tensor([[1.], [1.], [1.], [1.], [2.], [2.], [2.], [2.], [3.], [3.], [3.], [3.]])</code></pre> <p style="text-align: justify;">Sử dụng h&agrave;m <em>reshape(),</em> ch&uacute;ng ta c&oacute; thể chỉ định shape (<em>row x column)</em> m&agrave; ch&uacute;ng ta cần.&nbsp;Lưu &yacute; rằng tất cả c&aacute;c shape phải t&iacute;nh đến số phần tử trong tensor. Đối với v&iacute; dụ của ch&uacute;ng ta như sau:</p> <pre><code>rows * columns = 12 phần tử </code></pre> <p style="text-align: justify;">Ch&uacute;ng ta chỉ c&oacute; thể sử dụng c&aacute;c từ <em>rows</em> v&agrave; <em>columns</em> khi ch&uacute;ng ta xử l&yacute; một tensor rank 2.</p> <pre class="language-python"><code>&gt; t.reshape(2,2,3) tensor( [ [ [1., 1., 1.], [1., 2., 2.] ], [ [2., 2., 3.], [3., 3., 3.] ] ])</code></pre> <p style="text-align: justify;">Trong v&iacute; dụ n&agrave;y, ch&uacute;ng ta tăng rank l&ecirc;n 3 v&agrave; do đ&oacute; sẽ mất đi kh&aacute;i niệm <em>row</em> v&agrave; <em>column.&nbsp;</em>Tuy nhi&ecirc;n, t&iacute;ch của c&aacute;c th&agrave;nh phần của shape(2,2,3) vẫn phải bằng số phần tử trong tensor ban đầu (12).</p> <h3 style="text-align: justify;">Thay đổi shape bằng c&aacute;ch&nbsp;Squeezing v&agrave; Unsqueezing</h3> <p style="text-align: justify;">C&aacute;ch tiếp theo ch&uacute;ng ta c&oacute; thể thay đổi shape của tensor l&agrave; bằng&nbsp;<em>squeezing</em>&nbsp;v&agrave;&nbsp;<em>unsqueezing.&nbsp;</em></p> <ul style="text-align: justify;"> <li><em>Squeezing&nbsp;</em>một tensor sẽ loại c&aacute;c k&iacute;ch thước hoặc axes c&oacute; độ d&agrave;i bằng một.</li> <li><em>Unsqueezing&nbsp;</em>một tensor&nbsp;sẽ th&ecirc;m một k&iacute;ch thước c&oacute; chiều d&agrave;i l&agrave; một.</li> </ul> <p style="text-align: justify;">C&aacute;c&nbsp;functions n&agrave;y&nbsp;cho ph&eacute;p ch&uacute;ng ta mở rộng hoặc thu nhỏ rank của tensor. V&iacute; dụ:</p> <pre class="language-python"><code>&gt; print(t.reshape([1,12])) &gt; print(t.reshape([1,12]).shape) tensor([[1., 1., 1., 1., 2., 2., 2., 2., 3., 3., 3., 3.]]) torch.Size([1, 12]) &gt; print(t.reshape([1,12]).squeeze()) &gt; print(t.reshape([1,12]).squeeze().shape) tensor([1., 1., 1., 1., 2., 2., 2., 2., 3., 3., 3., 3.]) torch.Size([12]) &gt; print(t.reshape([1,12]).squeeze().unsqueeze(dim=0)) &gt; print(t.reshape([1,12]).squeeze().unsqueeze(dim=0).shape) tensor([[1., 1., 1., 1., 2., 2., 2., 2., 3., 3., 3., 3.]]) torch.Size([1, 12])</code></pre> <p style="text-align: justify;">H&atilde;y ch&uacute; &yacute; xem shape đ&atilde; thay đổi như thế n&agrave;o khi ch&uacute;ng ta thực hiện&nbsp;squeeze v&agrave; unsqueeze một tensor trong v&iacute; dụ tr&ecirc;n.</p> <p style="text-align: justify;">B&acirc;y giờ h&atilde;y xem một trường hợp phổ biến sử dụng&nbsp;squeezing bằng c&aacute;ch x&acirc;y dựng một h&agrave;m&nbsp;<em>flatten.</em></p> <h4 class="sub-section-heading" style="text-align: justify;">Flatten một Tensor</h4> <p style="text-align: justify;">Flatten một tensor l&agrave; reshape lại tensor để c&oacute; shape bằng số phần tử c&oacute; trong tensor.</p> <pre><code>Flattening a tensor means to remove all of the dimensions except for one.</code></pre> <p style="text-align: justify;">H&atilde;y tạo một h&agrave;m Python c&oacute; t&ecirc;n l&agrave; <em>flatten():</em></p> <pre class="language-python"><code>def flatten(t): t = t.reshape(1, -1) t = t.squeeze() return t</code></pre> <p style="text-align: justify;">H&agrave;m <em>flatten()</em> nhận tensor <em>t</em> l&agrave;m đối số.&nbsp;V&igrave; đối số t c&oacute; thể l&agrave; tensor bất kỳ, ch&uacute;ng ta truyền -1 l&agrave;m đối số thứ hai cho h&agrave;m <em>reshape().&nbsp;</em>Trong PyTorch, -1 l&agrave;m cho h&agrave;m <em>reshape()</em>&nbsp;tự t&igrave;m ra gi&aacute; trị dựa tr&ecirc;n số lượng phần tử c&oacute; trong tensor. H&atilde;y nhớ rằng, shape mới của tensor&nbsp;phải bằng shape của tensor ban đầu.&nbsp;Đ&acirc;y l&agrave; c&aacute;ch PyTorch c&oacute; thể t&igrave;m ra gi&aacute; trị n&ecirc;n l&agrave; g&igrave;.&nbsp;V&igrave; tensor <em>t</em> của ch&uacute;ng ta c&oacute; 12 phần tử, n&ecirc;n h&agrave;m <em>reshape()</em> c&oacute; thể t&igrave;m ra rằng 12 l&agrave; cần thiết.</p> <p style="text-align: justify;">Sau khi squeezing, axis đầu ti&ecirc;n (axis-0) bị loại bỏ v&agrave; ch&uacute;ng ta thu được kết quả mong muốn, một mảng 1d c&oacute; độ d&agrave;i 12. Dưới đ&acirc;y l&agrave; một v&iacute; dụ:</p> <pre class="language-python"><code>&gt; t = torch.ones(4, 3) &gt; t tensor([[1., 1., 1.], [1., 1., 1.], [1., 1., 1.], [1., 1., 1.]]) &gt; flatten(t) tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])</code></pre> <p style="text-align: justify;">Trong c&aacute;c b&agrave;i viết tiếp theo, khi ch&uacute;ng ta bắt đầu x&acirc;y dựng một mạng CNN, ch&uacute;ng ta sẽ hiểu r&otilde; hơn khi sử dụng h&agrave;m&nbsp;<em>flatten()</em> n&agrave;y.</p> <h4 class="sub-section-heading" style="text-align: justify;">Concatenating Tensors</h4> <p style="text-align: justify;">Ch&uacute;ng ta kết hợp c&aacute;c tensor bằng c&aacute;ch sử dụng h&agrave;m <em>cat(),</em> v&agrave; tensor kết quả sẽ c&oacute; shape phụ thuộc v&agrave;o shape của hai tensor đầu v&agrave;o.</p> <p style="text-align: justify;">Giả sử ch&uacute;ng ta c&oacute; hai tensor:</p> <pre class="language-python"><code>&gt; t1 = torch.tensor([ [1,2], [3,4] ]) &gt; t2 = torch.tensor([ [5,6], [7,8] ])</code></pre> <p style="text-align: justify;">Ch&uacute;ng ta c&oacute; thể kết hợp t1 v&agrave; t2 theo h&agrave;ng (axis-0) theo c&aacute;ch sau:</p> <pre class="language-python"><code>&gt; torch.cat((t1, t2), dim=0) tensor([[1, 2], [3, 4], [5, 6], [7, 8]])</code></pre> <p style="text-align: justify;">Ch&uacute;ng ta c&oacute; thể kết hợp ch&uacute;ng theo cột (axis-1) như sau:</p> <pre class="language-python"><code>&gt; torch.cat((t1, t2), dim=1) tensor([[1, 2, 5, 6], [3, 4, 7, 8]])</code></pre> <p style="text-align: justify;">Khi ch&uacute;ng ta nối c&aacute;c tensor, ch&uacute;ng ta l&agrave;m tăng số phần tử c&oacute; trong tensor kết quả.&nbsp;Điều n&agrave;y l&agrave;m cho c&aacute;c gi&aacute; trị th&agrave;nh phần trong shape (chiều d&agrave;i của c&aacute;c axes) tự điều chỉnh để t&iacute;nh cho c&aacute;c phần tử bổ sung.</p> <pre class="language-python"><code>&gt; torch.cat((t1, t2), dim=0).shape torch.Size([4, 2]) &gt; torch.cat((t1, t2), dim=1).shape torch.Size([2, 4])</code></pre> <h3 style="text-align: justify;">Kết Luận</h3> <p style="text-align: justify;">B&acirc;y giờ ch&uacute;ng ta đ&atilde; hiểu r&otilde; về reshape một tensor. Hi vọng bạn th&iacute;ch b&agrave;i viết n&agrave;y.</p> <p style="text-align: justify;">Hẹn gặp lại bạn trong c&aacute;c b&agrave;i viết tiếp theo!</p> <hr /> <p style="text-align: center;"><em><strong>Fanpage Facebook:</strong>&nbsp;<a href="https://www.facebook.com/tek4.vn/">TEK4.VN</a></em>&nbsp;</p> <p style="text-align: center;"><em><strong>Tham gia cộng đồng để chia sẻ, trao đổi v&agrave; thảo luận:</strong>&nbsp;<a href="https://www.facebook.com/groups/tek4.vn/">TEK4.VN - Học Lập Tr&igrave;nh Miễn Ph&iacute;</a></em></p>