Kā PyTorch izmantot metodi “torch.argmax()”?

Ka Pytorch Izmantot Metodi Torch Argmax



Programmā PyTorch “ torch.argmax() ” metode ir iebūvēta funkcija, kas atgriež noteikta tensora maksimālo vērtību indeksus noteiktā dimensijā. Lietotāji izmanto šo funkciju, strādājot ar tensoriem un vēlas atrast maksimālās vērtības indeksu pa tensora norādīto dimensiju. Turklāt šī metode var būt noderīga arī klasifikācijai, kur lietotāji vēlas zināt, kurai klasei ir vislielākā varbūtība.

Šajā emuārā ir parādīts piemērs, kā PyTorch izmantot metodi “torch.argmax()”.

Kā PyTorch izmantot metodi “torch.argmax()”?

Metode “torch.argmax()” izmanto jebkuru 1D vai 2D tensoru kā ievadi un atgriež tensoru, kas satur maksimālo vērtību indeksus/indeksi norādītajā dimensijā.







Metodes “torch.argmax()” sintakse ir norādīta zemāk:



lāpa. argmax ( < ievades_tensors > )

Lai izmantotu šo metodi programmā PyTorch, labākai izpratnei skatiet tālāk norādītos piemērus.



1. piemērs: izmantojiet metodi “torch.argmax()” ar 1D tensoru

Pirmajā piemērā mēs izveidosim 1D tensoru un ar to izmantosim metodi “torch.argmax()”. Izpildiet tālāk norādīto soli pa solim procedūru:





1. darbība: importējiet PyTorch bibliotēku

Vispirms importējiet ' lāpa ” bibliotēka, lai izmantotu metodi “torch.argmax()”:

imports lāpa

2. darbība. Izveidojiet 1D tensoru

Pēc tam izveidojiet 1D tensoru un izdrukājiet tā elementus. Šeit mēs izveidojam šādu ' Desmitnieki1 ' tensoru no saraksta, izmantojot ' torch.tensor() ” funkcija:



Desmitnieki1 = lāpa. tenzors ( [ 5 , 0 , - 8 , 1 , 9 , 7 ] )

drukāt ( Desmitnieki1 )

Tas ir izveidojis 1D tensoru, kā redzams tālāk:

3. darbība: atrodiet maksimālās vērtības indeksus

Tagad izmantojiet ' torch.argmax() ”, lai atrastu maksimālās vērtības indeksu/indeksus Desmitnieki1 ” tensors:

T1_ind = lāpa. argmax ( Desmitnieki1 )

4. darbība: drukājiet maksimālās vērtības indeksu

Visbeidzot, ievades tensorā parādiet maksimālās vērtības indeksu:

drukāt ( 'Indeksi:' , T1_ind )

Zemāk redzamā izvade parāda maksimālās vērtības indeksu ' Desmitnieki1 ” tenzors, t.i., 4. Tas nozīmē, ka tenzora augstākā vērtība ir pie 4. indeksa, kas ir “ 9 ”:

2. piemērs: izmantojiet metodi “torch.argmax()” ar 2D tensoru

Otrajā piemērā mēs izveidosim 2D tensoru un ar to izmantosim metodi “torch.argmax()”. Izpildiet norādītās darbības:

1. darbība: importējiet PyTorch bibliotēku

Vispirms importējiet ' lāpa ” bibliotēka, lai izmantotu metodi “torch.argmax()”:

imports lāpa

2. darbība. Izveidojiet 2D tensoru

Pēc tam izmantojiet ' torch.tensor() ” funkciju, lai izveidotu 2D tensoru un izdrukātu tā elementus. Šeit mēs izveidojam šādu ' Desmitie2 “2D tensors:

Desmitie2 = lāpa. tenzors ( [ [ 4 , 1 , - 7 ] , [ piecpadsmit , 6 , 0 ] , [ - 7 , 9 , 2 ] ] )

drukāt ( Desmitie2 )

Tas ir izveidojis 2D tensoru, kā redzams tālāk:

3. darbība: atrodiet maksimālās vērtības indeksus

Tagad atrodiet maksimālās vērtības indeksu sadaļā “ Desmitie2 'tensoru, izmantojot ' torch.argmax() ” funkcija:

T2_ind = lāpa. argmax ( Desmitie2 )

4. darbība: drukājiet maksimālās vērtības indeksu

Visbeidzot, ievades tensorā parādiet maksimālās vērtības indeksu:

drukāt ( 'Indeksi:' , T2_ind )

Saskaņā ar zemāk redzamo izvadi maksimālās vērtības indekss ' Desmitie2 ” tensors ir “3”. Tas nozīmē, ka tenzora augstākā vērtība ir pie 3. indeksa, kas ir ' piecpadsmit ”:

5. darbība. Kolonnās atrodiet maksimālās vērtības indeksus

Turklāt lietotāji var atrast arī maksimālo vērtību indeksus/indeksus katrā tenzora kolonnā. Piemēram, mēs varam izmantot “ dim=0 ” argumentu ar funkciju “torch.argmax()”. Tas atrod maksimālo vērtību indeksus kolonnās ' Desmitie2 ” tensoru un pēc tam izdrukā šos indeksus:

col_index = lāpa. argmax ( Desmitie2 , blāvs = 0 )

drukāt ( 'Indeksi kolonnās:' , col_index )

Zemāk esošajā izvadē ir parādīti maksimālo vērtību indeksi katrā tenzora kolonnā:

6. darbība: atrodiet maksimālās vērtības indeksus gar rindām

Tāpat lietotāji var atrast arī maksimālo vērtību indeksus/indeksus katrā tenzora rindā. Piemēram, izmantojiet ' dim=1 ” argumentu ar funkciju “torch.argmax()”, lai atrastu maksimālo vērtību indeksus pa rindām “Tens2” tensorā un pēc tam izdrukātu šos indeksus:

rindas_indekss = lāpa. argmax ( Tens2 , blāvs = 1 )

drukāt ( 'Indeksi rindās:' , rindas_indekss )

Maksimālo vērtību indeksus katrā Tens2 tensora rindā var redzēt zemāk:

Mēs esam efektīvi izskaidrojuši metodi “torch.argmax()” izmantošanai programmā PyTorch.

Piezīme : varat piekļūt mūsu Google Colab piezīmju grāmatiņai šeit saite .

Secinājums

Lai PyTorch izmantotu metodi “torch.argmax()”, vispirms importējiet “ lāpa ” bibliotēka. Pēc tam izveidojiet vajadzīgo 1D vai 2D tensoru un apskatiet tā elementus. Pēc tam izmantojiet ' torch.argmax() ” metodi, lai atrastu/aprēķinātu tensora maksimālo vērtību indeksus/indeksus. Turklāt lietotāji var arī atrast maksimālo vērtību indeksus katrā tensora rindā vai kolonnā, izmantojot ' blāvs ' arguments. Visbeidzot, ievades tensorā parādiet maksimālās vērtības indeksu. Šajā emuārā ir parādīts piemērs, kā PyTorch izmantot metodi “torch.argmax()”.