Computer >> คอมพิวเตอร์ >  >> การเขียนโปรแกรม >> Python

จะเข้าถึงข้อมูลเมตาของเทนเซอร์ใน PyTorch ได้อย่างไร


เราเข้าถึงขนาด (หรือรูปร่าง) ของเทนเซอร์และจำนวนองค์ประกอบในเทนเซอร์เป็นข้อมูลเมตาของเมตริกซ์ ในการเข้าถึงขนาดของเมตริกซ์ เราใช้ .size() เมธอดและรูปร่างของเทนเซอร์เข้าถึงได้โดยใช้ .shape .

ทั้ง .size() และ .shape ให้ผลเช่นเดียวกัน เราใช้ torch.numel() ฟังก์ชันหาจำนวนองค์ประกอบทั้งหมดในเทนเซอร์

ขั้นตอน

  • นำเข้าไลบรารีที่จำเป็น ที่นี่ ห้องสมุดที่จำเป็นคือ ไฟฉาย . ตรวจสอบให้แน่ใจว่าคุณได้ติดตั้ง ไฟฉาย .

  • กำหนดเทนเซอร์ PyTorch

  • ค้นหาข้อมูลเมตาของเทนเซอร์ ใช้ .size() และ .shape เพื่อเข้าถึงขนาดและรูปร่างของเทนเซอร์ ใช้ torch.numel() เพื่อเข้าถึงจำนวนองค์ประกอบในเทนเซอร์

  • พิมพ์เทนเซอร์และข้อมูลเมตาเพื่อความเข้าใจที่ดีขึ้น

ตัวอย่างที่ 1

# Python Program to access meta-data of a Tensor
# import necessary libraries
import torch

# Create a tensor of size 4x3
T = torch.Tensor([[1,2,3],[2,1,3],[2,3,5],[5,6,4]])
print("T:\n", T)

# Find the meta-data of tensor
# Find the size of the above tensor "T"
size_T = T.size()
print("size of tensor T:\n", size_T)

# Other method to get size using .shape
print("Shape of tensor:\n", T.shape)

# Find the number of elements in the tensor "T"
num_T = torch.numel(T)
print("Number of elements in tensor T:\n", num_T)

ผลลัพธ์

เมื่อคุณเรียกใช้โค้ด Python 3 ด้านบน โค้ดจะสร้างเอาต์พุตต่อไปนี้

T:
tensor([[1., 2., 3.],
         [2., 1., 3.],
         [2., 3., 5.],
         [5., 6., 4.]])
size of tensor T:
torch.Size([4, 3])
Shape of tensor:
torch.Size([4, 3])
Number of elements in tensor T:
12

ตัวอย่างที่ 2

# Python Program to access meta-data of a Tensor
# import the libraries
import torch

# Create a tensor of random numbers
T = torch.randn(4,3,2)
print("T:\n", T)

# Find the meta-data of tensor
# Find the size of the above tensor "T"
size_T = T.size()
print("size of tensor T:\n", size_T)

# Other method to get size using .shape
print("Shape of tensor:\n", T.shape)

# Find the number of elements in the tensor "T"
num_T = torch.numel(T)
print("Number of elements in tensor T:\n", num_T)

ผลลัพธ์

เมื่อคุณเรียกใช้โค้ด Python 3 ด้านบน โค้ดจะสร้างเอาต์พุตต่อไปนี้

T:
tensor([[[-1.1806, 0.5569],
         [ 2.2237, 0.9709],
         [ 0.4775, -0.2491]],
         [[-0.9703, 1.9916],
         [ 0.1998, -0.6501],
         [-0.7489, -1.3013]],
         [[ 1.3191, 2.0049],
         [-0.1195, 0.1860],
         [-0.6061, -1.2451]],
         [[-0.6044, 0.6153],
         [-2.2473, -0.1531],
         [ 0.5341, 1.3697]]])
size of tensor T:
torch.Size([4, 3, 2])
Shape of tensor:
torch.Size([4, 3, 2])
Number of elements in tensor T:
24