Как использовать метод torch.argmax() в PyTorch?

Kak Ispol Zovat Metod Torch Argmax V Pytorch



В PyTorch « факел.argmax() ” — это встроенная функция, которая возвращает индексы максимальных значений определенного тензора в заданном измерении. Пользователи используют эту функцию, когда работают с тензорами и хотят найти индекс максимального значения по заданному измерению тензора. Более того, этот метод также может быть полезен для классификации, когда пользователи хотят знать, какой класс имеет наибольшую вероятность.

В этом блоге будет продемонстрирован метод использования метода «torch.argmax()» в PyTorch.

Как использовать метод torch.argmax() в PyTorch?

Метод «torch.argmax()» принимает на вход любой 1D или 2D тензор и возвращает тензор, содержащий индексы/индексы максимальных значений по заданному измерению.







Синтаксис метода torch.argmax() приведен ниже:



факел. аргументмакс ( < input_tensor > )

Чтобы использовать этот метод в PyTorch, для лучшего понимания просмотрите следующие примеры:



Пример 1. Использование метода «torch.argmax()» с 1D-тензором

В первом примере мы создадим 1D-тензор и будем использовать с ним метод torch.argmax(). Давайте выполним следующую пошаговую процедуру:





Шаг 1. Импортируйте библиотеку PyTorch.

Сначала импортируйте « факел ” для использования метода “torch.argmax()”:

Импортировать факел

Шаг 2. Создайте 1D-тензор

Затем создайте 1D-тензор и распечатайте его элементы. Здесь мы создаем следующее: Десятки1 тензор из списка с помощью тензора « факел.тензор() » функция:



Десятки1 '=' факел. тензор ( [ 5 , 0 , - 8 , 1 , 9 , 7 ] )

Распечатать ( Десятки1 )

Это создало 1D-тензор, как показано ниже:

Шаг 3: Найдите индексы максимального значения

Теперь используйте « факел.argmax() ” для поиска индекса/индексов максимального значения в “ Десятки1 тензор:

T1_ind '=' факел. аргументмакс ( Десятки1 )

Шаг 4: Распечатайте индекс максимального значения

Наконец, отобразите индекс максимального значения во входном тензоре:

Распечатать ( «Индексы:» , T1_ind )

Вывод ниже показывает индекс максимального значения в « Десятки1 тензор, т. е. 4. Это означает, что наибольшее значение тензора находится в 4-м индексе, который равен « 9 »:

Пример 2. Использование метода «torch.argmax()» с 2D-тензором

Во втором примере мы создадим 2D-тензор и будем использовать с ним метод torch.argmax(). Давайте выполним предложенные шаги:

Шаг 1. Импортируйте библиотеку PyTorch.

Сначала импортируйте « факел ” для использования метода “torch.argmax()”:

Импортировать факел

Шаг 2. Создайте 2D-тензор

Затем используйте « факел.тензор() » для создания 2D-тензора и печати его элементов. Здесь мы создаем следующее: Десятки2 «2D-тензор:

Десятки2 '=' факел. тензор ( [ [ 4 , 1 , - 7 ] , [ пятнадцать , 6 , 0 ] , [ - 7 , 9 , 2 ] ] )

Распечатать ( Десятки2 )

Это создало 2D-тензор, как показано ниже:

Шаг 3: Найдите индексы максимального значения

Теперь найдите индекс максимального значения в « Десятки2 тензор с использованием тензора « факел.argmax() » функция:

Т2_инд '=' факел. аргументмакс ( Десятки2 )

Шаг 4: Распечатайте индекс максимального значения

Наконец, отобразите индекс максимального значения во входном тензоре:

Распечатать ( «Индексы:» , Т2_инд )

Согласно приведенному ниже выводу, индекс максимального значения в « Десятки2 тензор равен «3». Это означает, что наибольшее значение тензора находится в третьем индексе, который равен « пятнадцать »:

Шаг 5. Найдите индексы максимального значения по столбцам

Более того, пользователи также могут найти индексы/индексы максимальных значений по каждому столбцу тензора. Например, мы можем использовать « тусклый=0 аргумент с помощью функции torch.argmax(). Он находит индексы максимальных значений по столбцам в « Десятки2 тензор, а затем печатает эти индексы:

col_index '=' факел. аргументмакс ( Десятки2 , тусклый '=' 0 )

Распечатать ( 'Индексы в столбцах:' , col_index )

В приведенном ниже выводе показаны индексы максимальных значений по каждому столбцу тензора:

Шаг 6. Найдите индексы максимального значения по строкам

Аналогичным образом пользователи также могут найти индексы/индексы максимальных значений вдоль каждой строки тензора. Например, используйте « тусклый = 1 ” с функцией “torch.argmax()”, чтобы найти индексы максимальных значений по строкам в тензоре “Tens2”, а затем распечатать эти индексы:

индекс_строки '=' факел. аргументмакс ( Десятки2 , тусклый '=' 1 )

Распечатать ( 'Индексы в строках:' , индекс_строки )

Индексы максимального значения по каждой строке тензора «Tens2» можно увидеть ниже:

Мы подробно объяснили метод использования метода «torch.argmax()» в PyTorch.

Примечание : Вы можете получить доступ к нашему блокноту Google Colab по этому адресу. связь .

Заключение

Чтобы использовать метод «torch.argmax()» в PyTorch, сначала импортируйте « факел » библиотека. Затем создайте желаемый 1D или 2D-тензор и просмотрите его элементы. Далее используйте « факел.argmax() ” метод для поиска/вычисления индексов/индексов максимальных значений в тензоре. Более того, пользователи также могут найти индексы максимального значения вдоль каждой строки или столбца тензора, используя « тусклый аргумент. Наконец, отобразите индекс максимального значения во входном тензоре. В этом блоге приведен пример использования метода torch.argmax() в PyTorch.