PyTorch จัดเตรียมวิธีการ torch.kthvalue() เพื่อหาองค์ประกอบที่ k ของเทนเซอร์ ส่งคืนค่าขององค์ประกอบที่ k-th ของเทนเซอร์ที่เรียงลำดับจากน้อยไปมาก และดัชนีขององค์ประกอบในเทนเซอร์เดิม
torch.topk() วิธีที่ใช้เพื่อค้นหาองค์ประกอบ "k" ด้านบน ส่งกลับองค์ประกอบ "k" ด้านบนหรือ "k" ที่ใหญ่ที่สุดในเทนเซอร์
ขั้นตอน
-
นำเข้าไลบรารีที่จำเป็น ในตัวอย่าง Python ทั้งหมดต่อไปนี้ ไลบรารี Python ที่จำเป็นคือ ไฟฉาย . ตรวจสอบให้แน่ใจว่าคุณได้ติดตั้งแล้ว
-
สร้างเทนเซอร์ PyTorch แล้วพิมพ์ออกมา
-
คำนวณ torch.kthvalue(อินพุต, k) . มันส่งกลับเมตริกซ์สองตัว กำหนดเมตริกซ์สองตัวนี้ให้กับตัวแปรใหม่สองตัว "ค่า" และ "ดัชนี" . ในที่นี้อินพุตคือเทนเซอร์และ k เป็นจำนวนเต็ม
-
คำนวณ torch.topk(อินพุต, k) . มันส่งกลับเมตริกซ์สองตัว เทนเซอร์แรกมีค่าขององค์ประกอบ "k" บนสุด และเทนเซอร์ที่สองมีดัชนีขององค์ประกอบเหล่านี้ในเทนเซอร์ดั้งเดิม กำหนดเมตริกซ์ทั้งสองนี้ให้กับตัวแปรใหม่ "ค่า" และ "ดัชนี" .
-
พิมพ์ค่าและดัชนีขององค์ประกอบที่ k ของเทนเซอร์ และค่าและดัชนีขององค์ประกอบ "k" บนสุดของเทนเซอร์
ตัวอย่างที่ 1
โปรแกรม python นี้แสดงวิธีค้นหาองค์ประกอบที่ k ของเทนเซอร์
# Python program to find k-th element of a tensor
# import necessary library
import torch
# Create a 1D tensor
T = torch.Tensor([2.334,4.433,-4.33,-0.433,5, 4.443])
print("Original Tensor:\n", T)
# Find the 3rd element in sorted tensor. First it sorts the
# tensor in ascending order then returns the kth element value
# from sorted tensor and the index of element in original tensor
value, index = torch.kthvalue(T, 3)
# print 3rd element with value and index
print("3rd element value:", value)
print("3rd element index:", index) ผลลัพธ์
Original Tensor: tensor([ 2.3340, 4.4330, -4.3300, -0.4330, 5.0000, 4.4430]) 3rd element value: tensor(2.3340) 3rd element index: tensor(0)
ตัวอย่างที่ 2
โปรแกรม Python ต่อไปนี้จะแสดงวิธีค้นหาองค์ประกอบ "k" หรือ "k" ที่ใหญ่ที่สุดของเทนเซอร์
# Python program to find to top k elements of a tensor
# import necessary library
import torch
# Create a 1D tensor
T = torch.Tensor([2.334,4.433,-4.33,-0.433,5, 4.443])
print("Original Tensor:\n", T)
# Find the top k=2 or 2 largest elements of the tensor
# returns the 2 largest values and their indices in original
# tensor
values, indices = torch.topk(T, 2)
# print top 2 elements with value and index
print("Top 2 element values:", values)
print("Top 2 element indices:", indices) ผลลัพธ์
Original Tensor: tensor([ 2.3340, 4.4330, -4.3300, -0.4330, 5.0000, 4.4430]) Top 2 element values: tensor([5.0000, 4.4430]) Top 2 element indices: tensor([4, 5])