Computer >> คอมพิวเตอร์ >  >> การเขียนโปรแกรม >> Python

จะหา k-th และองค์ประกอบ k ด้านบนของเทนเซอร์ใน PyTorch ได้อย่างไร?


PyTorch จัดเตรียมวิธีการ torch.kthvalue() เพื่อหาองค์ประกอบที่ k ของเทนเซอร์ ส่งคืนค่าขององค์ประกอบที่ k-th ของเทนเซอร์ที่เรียงลำดับจากน้อยไปมาก และดัชนีขององค์ประกอบในเทนเซอร์เดิม

torch.topk() วิธีที่ใช้เพื่อค้นหาองค์ประกอบ "k" ด้านบน ส่งกลับองค์ประกอบ "k" ด้านบนหรือ "k" ที่ใหญ่ที่สุดในเทนเซอร์

ขั้นตอน

  • นำเข้าไลบรารีที่จำเป็น ในตัวอย่าง Python ทั้งหมดต่อไปนี้ ไลบรารี Python ที่จำเป็นคือ ไฟฉาย . ตรวจสอบให้แน่ใจว่าคุณได้ติดตั้งแล้ว

  • สร้างเทนเซอร์ PyTorch แล้วพิมพ์ออกมา

  • คำนวณ torch.kthvalue(อินพุต, k) . มันส่งกลับเมตริกซ์สองตัว กำหนดเมตริกซ์สองตัวนี้ให้กับตัวแปรใหม่สองตัว "ค่า" และ "ดัชนี" . ในที่นี้อินพุตคือเทนเซอร์และ k เป็นจำนวนเต็ม

  • คำนวณ torch.topk(อินพุต, k) . มันส่งกลับเมตริกซ์สองตัว เทนเซอร์แรกมีค่าขององค์ประกอบ "k" บนสุด และเทนเซอร์ที่สองมีดัชนีขององค์ประกอบเหล่านี้ในเทนเซอร์ดั้งเดิม กำหนดเมตริกซ์ทั้งสองนี้ให้กับตัวแปรใหม่ "ค่า" และ "ดัชนี" .

  • พิมพ์ค่าและดัชนีขององค์ประกอบที่ k ของเทนเซอร์ และค่าและดัชนีขององค์ประกอบ "k" บนสุดของเทนเซอร์

ตัวอย่างที่ 1

โปรแกรม python นี้แสดงวิธีค้นหาองค์ประกอบที่ k ของเทนเซอร์

# Python program to find k-th element of a tensor
# import necessary library
import torch

# Create a 1D tensor
T = torch.Tensor([2.334,4.433,-4.33,-0.433,5, 4.443])
print("Original Tensor:\n", T)

# Find the 3rd element in sorted tensor. First it sorts the
# tensor in ascending order then returns the kth element value
# from sorted tensor and the index of element in original tensor
value, index = torch.kthvalue(T, 3)

# print 3rd element with value and index
print("3rd element value:", value)
print("3rd element index:", index)

ผลลัพธ์

Original Tensor:
   tensor([ 2.3340, 4.4330, -4.3300, -0.4330, 5.0000, 4.4430])
3rd element value: tensor(2.3340)
3rd element index: tensor(0)

ตัวอย่างที่ 2

โปรแกรม Python ต่อไปนี้จะแสดงวิธีค้นหาองค์ประกอบ "k" หรือ "k" ที่ใหญ่ที่สุดของเทนเซอร์

# Python program to find to top k elements of a tensor
# import necessary library
import torch

# Create a 1D tensor
T = torch.Tensor([2.334,4.433,-4.33,-0.433,5, 4.443])
print("Original Tensor:\n", T)

# Find the top k=2 or 2 largest elements of the tensor
# returns the 2 largest values and their indices in original
# tensor
values, indices = torch.topk(T, 2)

# print top 2 elements with value and index
print("Top 2 element values:", values)
print("Top 2 element indices:", indices)

ผลลัพธ์

Original Tensor:
   tensor([ 2.3340, 4.4330, -4.3300, -0.4330, 5.0000, 4.4430])
Top 2 element values: tensor([5.0000, 4.4430])
Top 2 element indices: tensor([4, 5])