torch.where() 함수

이 개념은 샵투스쿨의 “트랜스포머 모델로 GPT만들기” 학습 중 수강생분들이 더 자세히 알고 싶어하시는 용어들을 설명한 것입니다.

`torch.where()` 함수는 조건에 따라 두 개의 값 중 하나를 선택하는 역할을 합니다. 일반적으로 다음과 같은 형식을 가지고 있습니다:

“`python
torch.where(condition, x, y)
“`

– `condition`: 조건 텐서입니다. 이는 같은 모양(shape)을 가진 불리언(True/False) 텐서입니다.
– `x`: `condition`이 True인 위치에서 선택되는 텐서입니다.
– `y`: `condition`이 False인 위치에서 선택되는 텐서입니다.

`torch.where()` 함수는 `condition`의 각 요소를 확인하고, 해당 위치에서 `x`의 값이 선택되거나 `y`의 값이 선택됩니다. 따라서 `x`와 `y`는 같은 모양을 가지고 있어야 합니다.

예를 들어, 다음과 같은 코드를 살펴봅시다:

“`python
import torch

a = torch.tensor([1, 2, 3, 4])
b = torch.tensor([10, 20, 30, 40])
condition = torch.tensor([True, False, True, False])

result = torch.where(condition, a, b)
print(result)
“`

위의 코드에서 `condition`은 `[True, False, True, False]`이므로 첫 번째와 세 번째 위치에서는 `a`의 값이 선택되고, 두 번째와 네 번째 위치에서는 `b`의 값이 선택됩니다. 결과는 `[1, 20, 3, 40]`이 됩니다.

따라서 `torch.where()` 함수는 조건에 따라 두 개의 값을 선택하는 유용한 함수로, 다양한 상황에서 유연한 값을 선택하기 위해 사용됩니다.

답글 남기기

이메일 주소는 공개되지 않습니다. 필수 필드는 *로 표시됩니다