JAX (Just After eXecution;실행직후)
JAX 라이브러리는 TensorFlow와 유사한 기능을 가지고 있으며, 특히 자동 미분과 XLA(accelerated linear algebra) 및 JIT(just-in-time) 컴파일 기능을 강조합니다. 이러한 기능은 머신러닝 알고리즘을 더욱 효율적으로 구현하고 최적화하는 데 도움이 됩니다.
JAX 라이브러리가 “Just After eXecution”의 약자로 실행직후 라는 이름을 가지게 된것은, NumPy, SciPy, PyTorch 등과 같은 수치 연산 라이브러리와 같이 계산 그래프를 구성하고 실행하기 직전까지 모든 것이 동적으로 결정되는 동적 계산 그래프 (dynamic computation graph) 라이브러리이기 때문입니다. 이러한 동적 계산 그래프 방식은 머신러닝 프레임워크에서 일반적으로 사용되는 정적 계산 그래프와 대조적입니다. JAX는 JIT(Just-In-Time) 컴파일러와 함께 사용되어 동적 계산 그래프를 효율적으로 컴파일하여 실행 시간을 최적화합니다. 따라서 “Just After eXecution”이라는 이름이 지어진 것입니다.
JAX는 pip로 설치할 수 있으며, 아래 명령어를 사용하여 설치할 수 있습니다.
pip install jax jaxlib
다만, JAX는 GPU를 사용하기 때문에 GPU가 없는 경우 CPU 버전으로 설치해야 합니다. GPU가 없는 경우 아래 명령어를 사용하여 설치할 수 있습니다.
pip install jax jaxlib==0.1.70+cpu -f https://storage.googleapis.com/jax-releases/jax_releases.html
아래는 JAX를 사용하여 간단한 선형 회귀 모델을 구현하는 예시입니다.GPU 가 없는 컴퓨터를 사용하시는 분을 위해 구글 colab에서 실행 가능하도록 작성되었습니다.
import jax.numpy as jnp from jax import grad # 선형회귀 모델 함수 def linear_regression(X, w): return jnp.dot(X, w.T) # 평균제곱오차 손실 함수 def mse_loss(y_true, y_pred): return jnp.mean((y_true - y_pred) ** 2) # 경사하강법 함수 def gradient_descent(w, X, y, learning_rate): grad_func = grad(mse_loss) grad_w = grad_func(y, linear_regression(X, w)) return w - learning_rate * grad_w # 입력과 출력 정의 X = jnp.array([[1., 1.], [2., 2.], [3., 3.], [4., 4.]]) y = jnp.array([[2.], [4.], [6.], [8.]]) # 가중치 초기값 설정 w = jnp.array([[1., 1.]]) # 학습 for epoch in range(10): # 예측 계산 y_pred = linear_regression(X, w) # 손실 계산 loss = mse_loss(y, y_pred) print(f"에포크 {epoch}\n가중치: {w}\n손실: {loss}\n예측결과: {y_pred}") # 경사하강법을 사용하여 가중치 갱신 w = gradient_descent(w, X, y, 0.1)
JAX가 tensorflow 보다 쌘놈인가?
JAX (Just After eXecution) 는 기계 학습 라이브러리로, Google에서 개발했습니다. TensorFlow와 마찬가지로 JAX도 딥 러닝 모델을 만들고 훈련시키는 데 사용됩니다. 그러나 TensorFlow와 JAX 사이에는 몇 가지 차이점이 있습니다.
첫째, JAX는 미분 가능한 함수만 허용합니다. 이는 딥 러닝 모델을 구축하는 데 도움이 됩니다. 미분 가능한 함수는 모든 미분 가능한 지점에서 기울기를 가지기 때문입니다. JAX는 이러한 함수만 지원하므로 모델에서 발생할 수 있는 다양한 오류를 방지할 수 있습니다.
둘째, JAX는 함수를 컴파일하고 JIT(Just-in-time) 컴파일러로 컴파일합니다. 이는 계산 속도를 빠르게 만들어줍니다. TensorFlow와 같은 라이브러리는 모델을 그래프로 변환한 다음 실행합니다. 반면 JAX는 모델을 즉시 컴파일하고 실행합니다.
셋째, JAX는 TensorFlow와 비교해 더 깔끔한 구문을 제공합니다. TensorFlow는 상태 변수와 그래프를 별도로 추적해야 하지만 JAX는 함수형 프로그래밍을 사용하여 코드를 작성하므로 보다 간결합니다.
JAX와 TensorFlow는 모두 딥러닝 라이브러리이지만, 각각의 장단점이 있습니다. 따라서 뛰어난 것은 사용하는 목적과 상황에 따라 다릅니다.
JAX는 JAX를 이용한 코드는 다른 프레임워크와 달리 일부 또는 전체가 컴파일됩니다. 이는 빠른 속도와 함께 간소화된 코드를 제공합니다. 또한 JAX는 XLA 컴파일러를 사용하므로, TPU에서 더 빠르게 실행할 수 있습니다.
반면, TensorFlow는 이미 매우 널리 사용되고 있으며, TensorFlow 기반의 라이브러리 및 프레임워크가 많이 개발되어 있습니다. TensorFlow는 또한 다양한 하드웨어를 지원하며, 자체적으로 GPU를 활용하여 연산을 가속화할 수 있습니다.
따라서 뛰어난 것은 사용하는 목적과 상황에 따라 다르므로, 적절한 상황에서 적합한 도구를 선택하는 것이 중요합니다.
다음은 위에 JAX로 작성된 선형 회귀 모델 코드를 TensorFlow로 변환한 예시입니다.
JAX는 TensorFlow와 비슷한 API를 가지고 있기 때문에 코드 변환이 그리 어렵지 않습니다. 그러나 두 라이브러리의 내부 구현 방식이 다르기 때문에 성능 차이가 있을 수 있습니다. 또한 TensorFlow는 GPU를 이용한 병렬 처리에 대한 지원이 더욱 뛰어나기 때문에 대규모 데이터에 대한 학습 속도에서 우위를 보입니다.
import tensorflow as tf import numpy as np # 입력과 출력 정의 X = np.array([[1., 1.], [2., 2.], [3., 3.], [4., 4.]]) y = np.array([[2.], [4.], [6.], [8.]]) # 가중치 초기값 설정 w = tf.Variable(np.array([[1., 1.]])) # 선형 회귀 모델 함수 def linear_regression(X, w): return tf.matmul(X, tf.transpose(w)) # 평균제곱오차 손실 함수 def mse_loss(y_true, y_pred): return tf.reduce_mean(tf.square(y_true - y_pred)) # 경사하강법 함수 def gradient_descent(w, X, y, learning_rate): with tf.GradientTape() as tape: y_pred = linear_regression(X, w) loss = mse_loss(y, y_pred) grads = tape.gradient(loss, w) w.assign_sub(learning_rate * grads) return w # 학습 for epoch in range(10): # 예측 계산 y_pred = linear_regression(X, w) # 손실 계산 loss = mse_loss(y, y_pred) print(f"에포크 {epoch}\n가중치: {w}\n손실: {loss}\n예측결과: {y_pred.numpy()}") # 경사하강법을 사용하여 가중치 갱신 w = gradient_descent(w, X, y, 0.1)
아래는 jax로 작성된 코드를 tensorflow도 사용하지 말고, python으로 코딩한 예시입니다. 속도 측면에서는 jax나 tensorflow보다 느릴 수 있습니다.
import numpy as np # 선형회귀 모델 함수 def linear_regression(X, w): return np.dot(X, w.T) # 평균제곱오차 손실 함수 def mse_loss(y_true, y_pred): return np.mean((y_true - y_pred) ** 2) # 경사하강법 함수 def gradient_descent(w, X, y, learning_rate): def mse_grad(y, y_pred): return -2 * np.mean((y - y_pred) * X, axis=0) grad_w = mse_grad(y, linear_regression(X, w)) return w - learning_rate * grad_w # 입력과 출력 정의 X = np.array([[1., 1.], [2., 2.], [3., 3.], [4., 4.]]) y = np.array([[2.], [4.], [6.], [8.]]) # 가중치 초기값 설정 w = np.array([[1., 1.]]) # 학습 for epoch in range(10): # 예측 계산 y_pred = linear_regression(X, w) # 손실 계산 loss = mse_loss(y, y_pred) print(f"에포크 {epoch}\n가중치: {w}\n손실: {loss}\n예측결과: {y_pred}") # 경사하강법을 사용하여 가중치 갱신 w = gradient_descent(w, X, y, 0.1)
JAX의 자동 미분 기능
JAX는 자동 미분 기능을 내장하고 있기 때문에 딥러닝 모델의 학습에 필요한 그래디언트 계산을 자동으로 수행할 수 있습니다.딥러닝에서 모델 학습은 주로 그래디언트 기반 최적화 알고리즘을 사용하여 이루어집니다. 그래디언트 기반 최적화 알고리즘은 손실 함수의 그래디언트를 계산하여 모델 파라미터를 업데이트하는 방식으로 작동합니다. 따라서 자동 미분 기능이 없으면 그래디언트를 직접 계산해야 하는 번거로움이 있습니다.
JAX는 그래디언트 자동 계산 기능을 제공함으로써 이러한 번거로움을 줄여줍니다. 이를 통해 개발자는 모델 학습에 집중할 수 있습니다.
예를 들어, 다음과 같은 간단한 신경망 모델이 있다고 가정해봅시다.
import jax.numpy as jnp from jax import grad def loss_fn(params, x, y_true): y_pred = jnp.dot(x, params) loss = jnp.mean((y_pred - y_true) ** 2) return loss params = jnp.array([0.5, -0.2, 1.0]) x = jnp.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) y_true = jnp.array([10, 20, 30]) grad_loss_fn = grad(loss_fn) grad_params = grad_loss_fn(params, x, y_true) print(grad_params)
위 코드에서는 간단한 3개의 입력값과 1개의 출력값을 가지는 선형 회귀 모델을 정의하였습니다. 이 모델의 손실 함수(loss function)는 평균 제곱 오차(Mean Squared Error)를 사용하였습니다. 그리고 이 손실 함수의 그래디언트(gradient)를 자동으로 계산하기 위해 JAX의 grad
함수를 이용하였습니다.
grad
함수는 함수의 인수로 전달된 함수의 그래디언트를 계산하여 반환하는 함수입니다. grad
함수가 반환한 그래디언트 값은 params
에 대한 손실 함수의 편미분값이므로, 이 값은 params
를 어떻게 업데이트해야 하는지 알려주는 방향성을 가집니다. 따라서 이 값을 이용하여 모델의 가중치(weight)를 업데이트하면 됩니다.
위 코드를 실행하면, grad_params
의 값으로[-128.40001, -154.40001, -180.40001] 이 출력됩니다. 이 값은 입력값 x가 [1.0, 2.0, 3.0]일 때, 손실 함수를 x의 각 요소에 대해서 미분한 그래디언트 벡터를 의미합니다.