🥳 200만 유저의 친구 ‘이루다’ 기술로 AI 캐릭터를 자유롭게 만들어보세요 ‘핑퐁 스튜디오’ 보러가기

Tech

하나의 조직에서 TensorFlow와 PyTorch 동시 활용하기

불타는 텐서 흐름!

정욱재 | 2020년 12월 04일 | #Machine_Learning #Engineering

보통 제품이 있는 조직은 TensorFlow를, 리서치 조직은 PyTorch를 활용하는 것이 좋다고 알려져 있습니다. 그리고 하나의 조직에서는 하나의 프레임워크만 사용하는 것이 일반적이죠. 하지만 핑퐁팀의 형태는 다소 특이합니다. TensorFlow와 PyTorch를 동시에 사용하고 있습니다. 어떻게 둘 다 동시에 사용이 가능한지와 장단점에 대해 소개하겠습니다.

근데 왜 TensorFlow와 PyTorch를 동시에 써요?

핑퐁팀은 작년 초까지 TensorFlow만을 사용했습니다. 그러던 중 리서치 팀에서 조금 더 유연한 리서치를 위해 PyTorch를 사용하고 싶다고 말해주었고, 이 때부터는 PyTorch 하나로 통일하여 사용해왔습니다. 하지만 대형 모델이 점점 더 많아지면서 기존의 방식으로 배포하기에 어려웠습니다. 이 때부터 연산과 가중치가 같다면 같은 결과를 낸다는 가정 하에 부분적으로 TensorFlow를 도입했습니다. 왜 TensorFlow가 배포에 유리한지는 아래에서 설명하겠습니다.

왜 PyTorch로 배포를 안했나요?

결정적인 이유는 아래와 같이 요약할 수 있습니다.

우리는 성능 상의 큰 오버헤드인 Python 오버헤드를 제외하고 싶었습니다. 많은 사람들이 Python flask와 같은 서버로 서빙을 한다고 하지만 이 때 Python과 HTTP에서 오는 오버헤드는 정말 큽니다. 모델이 작은 경우에는 모델 추론 시간보다 오버헤드가 더 심한 경우도 있습니다. 이 때 사용할 수 있는 옵션은 TensorFlow Serving과 ONNX, 그리고 PyTorch의 C++ 런타임인 libtorch 정도였습니다. 지금에 와서는 TorchServe의 JIT 적용된 모델이 archive 된 형태도 하나의 옵션일 수 있습니다. 하지만 이 옵션은 그 당시 존재하지 않았습니다.

또한 배포 편의성이 좋아야 했고, 어느 시스템에서도 잘 구동되는 형태여야 했습니다. 배포를 편하게 할 수 있어야 서로의 리소스를 아낄 수 있고, 모델 개발에 집중할 수 있습니다. 현재도 Cortex를 제외한다면 ONNX는 이렇다 할 배포 옵션이 존재하지 않습니다. 이 때문에 Docker 이미지와 모델 파일 하나로 서빙이 가능한 TensorFlow Serving만이 유일한 선택지로 남게 되었죠.

지금 시점에서 다시 생각해본다면 TorchServe가 대체재가 될 수 있습니다. 하지만 아직은 Model Archiver를 따로 돌려야 하는 점, 유연한 리서치를 위한 코드에서 TorchScript와 호환되는 코드로 다시 수정해서 유지해야하는 점을 생각해볼 때 TensorFlow Serving이 제일 합리적인 선택지라고 판단했습니다.

TensorFlow Serving을 사용하면 TensorFlow 기본 저장 옵션인 Saved Model로 저장할 경우 그 파일만을 이용하여 바로 gRPC/HTTP 서버를 구동할 수 있습니다. 해당 서버들은 고성능 C++ 서버로 작성되어 있어서 Python에서 오는 오버헤드가 존재하지 않습니다.

그럼 TensorFlow 코드와 PyTorch 코드를 이중으로 유지해야하는 불편함이 있을 것 같아요. 이 부분에 들이는 리소스는 크지 않았나요?

이 부분이 제일 걱정이 되었습니다. 아무래도 이중으로 코드를 유지하다보면 실수할 여지가 크고, 좋은 Inner-source 라이브러리로 남기 힘들다는 생각이 들었습니다. 하지만 아래 이유들 때문에 리소스를 많이 들이더라도 TensorFlow 코드를 유지하기로 마음먹었습니다.

TPU Compatible한 코드를 항상 작성할 수 있습니다. Custom Ops만 넣지 않는다면 Strategy 변경만으로 TPU를 바로 사용할 수 있는 코드가 나옵니다. 이 부분은 대형 모델을 학습할 때 매우 큰 장점이 되었고, 덕분에 핑퐁팀에서 대형 모델 학습을 시작할 때 모델과 학습 코드 등 많은 부분이 TensorFlow 코드로 옮겨졌습니다.

대형 모델 학습 시의 데이터 파이프라인 코드가 간결해집니다. PyTorch가 동적으로 데이터를 생성해낼 수 있다는 장점이 있지만, 자유로운 만큼 버그는 증가하고 가독성은 떨어집니다. 하지만 TensorFlow는 tf.data.Dataset만으로 코드를 작성할 경우 모든 device (CPU, GPU, TPU)에서 작동하며 순수함수에 가깝게 작성할 수 있습니다.

프로덕션과 엮여있는 모델들의 학습 코드가 정말 쉬워집니다. PyTorch를 사용할 경우 TensorFlow Keras에 버금가는 읽기 좋은 Pythonic한 학습 코드를 작성하려 할 때 pytorch-lightning을 사용해야 합니다. 하지만 pytorch-lightning을 사용하더라도 결국 PyTorch 기반의 모델이 나오기 때문에 위에서 말한 PyTorch로 서빙하기 어려운 이유가 동일하게 나타나게 됩니다.

두 라이브러리의 코드를 함께 유지하기 위해 어떤 것들이 필요했나요?

우선 내부에서 사용하는 모델들을 PyTorch, TensorFlow 버전으로 다시 작성하였습니다. PyTorch까지 재작성을 한 이유는 변경을 많이 해야하는 리서치 프로젝트에서 최종 모델이 나온 시점에 코드를 확정하기 위함이었고, 이를 기반으로 TensorFlow로 다시 작성하였습니다. 유지보수가 편리한 코드를 작성하기 위해서 별도의 코드 베이스에서 작업하였습니다.

그 결과 읽기 쉬우면서도 기존 리서치 프로젝트와 호환되는 PyTorch 모델 코드가 탄생하였고, PyTorch 모델의 가중치를 옮길 대상인 TensorFlow 모델 코드도 잘 작성되었습니다. 예시로 현재 스캐터랩의 내부 라이브러리를 이용하여 아래의 코드로 같은 동작을 보장하는 TensorFlow, PyTorch용 Language Model을 생성할 수 있습니다.

#
# TensorFlow 모델 로딩
from models.tf.language_model import LanguageModel

model = LanguageModel(vocab_size=32000, word_embedding_size=64, hidden_size=64, num_layers=4)

#
# PyTorch 모델 로딩
from models.torch.language_model import LanguageModel

model = LanguageModel(vocab_size=32000, word_embedding_size=64, hidden_size=64, num_layers=4)

그 다음에는 모델의 모든 가중치를 변환해주는 코드를 추가로 작성하였습니다. 변환해야하는 상황은 1) 리서치 중 나온 PyTorch 가중치를 TensorFlow로 옮겨 배포를 해야하거나, 2) TensorFlow로 학습된 대형 모델 가중치를 리서치를 할 수 있도록 PyTorch로 옮겨주어야 할 때입니다. 자세한 사항은 아래에서 설명하겠습니다.

모델 가중치 변환

이 경우 여러 가지 고려사항이 있을 수 있지만 크게는 다음과 같습니다.

TensorFlow Checkpoint → PyTorch Model

TensorFlow Checkpoint에서 PyTorch Model Weight를 만들어내는 경우에는 다음 두 함수를 사용할 수 있습니다.

위 함수 이름에서 보여지듯, tf.train.list_variables로 Checkpoint들의 값들을 확인한 다음 값들을 tf.train.load_variable로 로딩하여 PyTorch 모델에 적용하였습니다. tf.train.load_variable의 반환 타입은 numpy.ndarray이기 때문에 torch model로 적용하기 위해서는 torch.from_numpy를 호출하면 됩니다. 결론적으로 아래와 같은 방식으로 로딩이 가능합니다.

# 가중치 목록 확인
# print(tf.train.list_variables("checkpoint-path"))
weight = tf.train.load_variables("checkpoint-path", "variable-name")
torch_model.weight.data = torch.from_numpy(weight)

위에서 특정 Weight를 가져오기 위해 data 필드에 접근하는 이유는, PyTorch 내부의 모듈들이 대부분 torch.Tensor보다 torch.nn.parameter.Parameter를 모델 가중치로 사용하기 때문입니다.

PyTorch State Dict → TensorFlow Model

이 경우도 위의 경우와 크게 다르지 않습니다. torch의 state dict 로딩은 아래와 같은 과정을 거칩니다.

import torch

# PyTorch의 가중치는 GPU용으로 저장이 되어있는 경우가 많기 때문에 꼭 map_location 인자를 넣어주어야 합니다.
torch_state_dict = torch.load(args.model_path, map_location=torch.device("cpu"))
# 기본적으로 torch.Tensor로 로딩됩니다. 따라서 detach()와 numpy() 메소드를 불러주는 것이 꼭 필요합니다.
torch_state_dict = {key: val.detach().numpy() for key, val in torch_state_dict.items()}

현재 pytorch의 master branch에 torch/serialization.py#L484-L488과 같은 코드가 존재하기 때문에 앞으로 나올 버전을 사용하는 경우 map_location 인자를 꼭 명시하지 않아도 괜찮습니다.

위 과정을 통해 가져온 state_dict는 아래와 같이 set_weights 함수를 통해 가져올 수 있습니다.

tf_module.set_weights(
    [
        torch_state_dict["some-keys-for-weight"],
        torch_state_dict["some-keys-for-bias"],
    ]
)

Weight를 위와 같이 적용하게 되는데, 이 경우 Weight의 순서는 아래처럼 알 수 있습니다.

tf.keras.layers.Dense의 경우 tensorflow/python/keras/layers/core.py#L1067-L1233 코드를 참고하면 build 메소드에서 kernel과 bias를 순서대로 할당하는 것을 알 수 있습니다. 이 경우에는 set_weights의 인자에도 kernel과 bias가 순서대로 들어가야 합니다.

TensorFlow Model → PyTorch Model

모델과 모델 사이의 가중치 변환은 조금 더 쉽습니다. TensorFlow Model은 set_weights와 함께 get_weights가 존재하는데 이 함수의 반환값은 list of numpy array이기 때문에 set_weights의 인자와 같다고 생각하시면 됩니다. 따라서 아래와 같은 방법으로 가져올 수 있습니다.

# tf_module의 첫번째 weight를 torch_module의 weight에 적용하는 예시
tf_weights = tf_module.get_weights()
torch_module.weight.data = torch.from_numpy(tf_weights[0])

PyTorch Model → TensorFlow Model

여기서는 nn.Module에 존재하는 state_dict를 활용할 수 있습니다. 아래는 weight, bias를 tf_module에 적용하는 에시입니다.

weight = torch_module.state_dict()["weight"].detach().numpy()
bias = torch_module.state_dict()["bias"].detach().numpy()

tf_module.set_weights([weight, bias])

테스팅

아무리 같은 구조의 모델을 작성하고, 같은 weight를 가진다고 해도 추론 결과값이 동일한지 직접 확인할 수 없는 이상 제대로 구현했는지 확신할 수 없습니다. 따라서 핑퐁팀에서는 아래와 같은 방법으로 테스팅을 진행했습니다.

  1. torch, tf 모델 초기화 및 빌드
  2. torch, tf 사이의 가중치 변환
  3. numpy로 예시 입력값을 만들어 두 모델의 결과값 비교

아래는 PyTorch LayerNormalization을 받아서 TensorFlow LayerNormalization으로 변환해주는 convert_torch_layer_norm를 테스트하는 예시입니다.

@pytest.mark.parametrize("input_dim", [pytest.param(10), pytest.param(100)])
def test_convert_torch_layer_normalization_with_dims(input_dim: int):
    batch_size = 10
    epsilon = 1e-6

    # Build Layer
    tf_layer_norm = tf.keras.layers.LayerNormalization(epsilon=epsilon)
    tf_layer_norm(tf.keras.Input([input_dim]))

    torch_layer_norm = nn.LayerNorm(input_dim, eps=epsilon)
    torch_layer_norm.eval()

    # Convert Weight
    convert_torch_layer_norm(torch_layer_norm, tf_layer_norm)

    for _ in range(100):
        # Build Input
        test_input = np.random.randn(batch_size, input_dim).astype(np.float32)
        tf_input = tf.constant(test_input, dtype=tf.float32)
        torch_input = torch.tensor(test_input, dtype=torch.float32)

        # Check Output
        tf_output = tf_layer_norm(tf_input).numpy()
        torch_output = torch_layer_norm(torch_input).detach().numpy()

        # Layer Normalization은 구하는 방식에 따라 조금씩 값이 차이가 나기 때문에 tolerance를 조금은 높게 줍니다.
        assert np.allclose(tf_output, torch_output, rtol=1e-5, atol=1e-6)

추가 고려사항

아래는 제가 weight 변환을 진행하면서 발견했던 대표적인 이슈입니다. 아래의 이슈 외에도 다른 이슈가 존재하였으나 대부분 큰 시간 소요 없이 해결 가능하였습니다.

tf.keras.layers.Dense, torch.nn.Linear의 weight shape

>>> import torch
>>> import tensorflow as tf
>>> torch_linear = torch.nn.Linear(10, 20)
>>> torch_linear.weight.data.shape
torch.Size([20, 10])
>>> tf_dense = tf.keras.layers.Dense(20)
>>> tf_dense(tf.keras.Input([10]))
<tf.Tensor 'dense/BiasAdd:0' shape=(None, 20) dtype=float32>
>>> tf_dense.get_weights()[0].shape
(10, 20)

위에서 볼 수 있듯이 TensorFlow와 PyTorch의 Feed Forward Layer의 weight shape가 다릅니다. Transpose를 해주어야만 정확하게 부여됩니다.

특정 경우의 Matrix Multiplication 결과의 차이

>>> import tensorflow as tf
>>> import torch
>>> import numpy as np
>>> a = np.random.randn(20, 30).astype(np.float32)
>>> b = np.random.randn(30, 20).astype(np.float32)
>>> tf_result = tf.matmul(tf.constant(a), tf.constant(b))
>>> torch_result = torch.matmul(torch.tensor(a), torch.tensor(b))
>>> # absolute diff
>>> np.max(np.abs(tf_result.numpy() - torch_result.numpy()))
2.861023e-06
>>> # relative diff
>>> np.max(np.abs((tf_result.numpy() - torch_result.numpy()) / tf_result.numpy()))
1.0196954e-06

이것은 대부분의 경우에 문제가 되지 않습니다. 위의 예시에서도 numpy의 기본 allclose 연산이 가진 범위 안에서 안전해 보이지만, 이런 사소한 오차가 깊은 신경망에서 여러 번 누적된다면 세밀한 결과가 필요한 모델은 결과가 많이 달라질 수 있으니 유의해야 합니다.

GRU와 같은 레이어의 Weight 순서

LSTM, GRU와 같은 레이어를 TensorFlow와 PyTorch 사이에 서로 변환할 때 Gate의 순서가 다를 수 있습니다. 여기서는 GRU를 예시로 설명해보겠습니다.

TensorFlow의 GRU

TensorFlow의 GRU는 아래와 같이 Gate 연산을 진행합니다. (tensorflow/python/keras/layers/recurrent.py#L1672-L1949)

      x_z = K.dot(inputs_z, self.kernel[:, :self.units])
      x_r = K.dot(inputs_r, self.kernel[:, self.units:self.units * 2])
      x_h = K.dot(inputs_h, self.kernel[:, self.units * 2:])

self.kernel이 Update Gate, Reset Gate, Output Candidate를 계산하기 위한 커널이 순서대로 연결되어 있다고 이해할 수 있습니다. 또한 bias는 아래처럼 계산합니다. (전체 코드가 아닌 일부 코드입니다.)

      recurrent_z = K.dot(h_tm1_z, self.recurrent_kernel[:, :self.units])
      recurrent_r = K.dot(h_tm1_r,
                          self.recurrent_kernel[:, self.units:self.units * 2])
      if self.reset_after and self.use_bias:
        recurrent_z = K.bias_add(recurrent_z, recurrent_bias[:self.units])
        recurrent_r = K.bias_add(recurrent_r,
                                 recurrent_bias[self.units:self.units * 2])

self.recurrent_kernelself.bias(recurrent_biasself.bias에서 나온 값입니다.)도 구역을 나누어 쓰는 것을 알 수 있습니다. tf.keras.layers.GRUCell 전체 코드를 들여다보면 결국 아래처럼 weight가 구성되어 있는 것을 알 수 있습니다.

PyTorch의 GRU

그렇다면 PyTorch의 GRU 구현은 어떻게 되어 있을까요? 아래와 같이 연산을 진행합니다. (aten/src/ATen/native/RNN.cpp#L723-L753)

template <typename cell_params>
struct GRUCell : Cell<Tensor, cell_params> {
  using hidden_type = Tensor;

  hidden_type operator()(
      const Tensor& input,
      const hidden_type& hidden,
      const cell_params& params,
      bool pre_compute_input = false) const override {
    if (input.is_cuda()) {
      TORCH_CHECK(!pre_compute_input);
      auto igates = params.matmul_ih(input);
      auto hgates = params.matmul_hh(hidden);
      auto result = at::_thnn_fused_gru_cell(
          igates, hgates, hidden, params.b_ih(), params.b_hh());
      // Slice off the workspace argument (it's needed only for AD).
      return std::move(std::get<0>(result));
    }
    const auto chunked_igates = pre_compute_input
        ? input.unsafe_chunk(3, 1)
        : params.linear_ih(input).unsafe_chunk(3, 1);
    auto chunked_hgates = params.linear_hh(hidden).unsafe_chunk(3, 1);
    const auto reset_gate =
        chunked_hgates[0].add_(chunked_igates[0]).sigmoid_();
    const auto input_gate =
        chunked_hgates[1].add_(chunked_igates[1]).sigmoid_();
    const auto new_gate =
        chunked_igates[2].add(chunked_hgates[2].mul_(reset_gate)).tanh_();
    return (hidden - new_gate).mul_(input_gate).add_(new_gate);
  }
};

chunked_hgateschunked_igates 변수를 활용하는 것을 잘 보면, reset, input, new로 순서대로 구성이 되어있는 것을 알 수 있습니다. 각각 Reset gate, Update gate, Output candidate에 해당합니다. 그리고 PyTorch GRU의 State Dict를 추출해보면 weight_ih_l{LAYER 숫자}, weight_hh_l{LAYER 숫자}, bias_ih_l{LAYER 숫자}, bias_hh_l{LAYER 숫자}(예를 들면 weight_ih_l0)처럼 나옵니다. 즉 아래처럼 구성되어 있다는 것을 알 수 있죠.

Torch 혹은 TF의 kernel을 가져와서 Reset gate, Update gate의 순서를 바꾸어주고 Transpose를 한 뒤, bias들을 합치거나 분해하면 각각의 weight로 무리 없이 변환할 수 있습니다.

결론

TensorFlow와 PyTorch를 한 조직 내에서 사용하는 것은 어렵고 고된 일입니다. 리소스가 많이 들지만, 리서치의 편의성을 보장해주면서도 배포의 편리함을 가져가는 것은 분명한 장점입니다. 핑퐁팀은 리서치 코드 베이스와 엔지니어링 코드 베이스를 분리하기로 하였고 그 결과 리서치를 진행할 때 팀 내의 모든 모델을 PyTorch 버전으로 편하게 가져다 쓸 수 있으면서, 리서치가 끝난 경우 그 가중치를 TensorFlow 모델로 빠르게 옮겨 배포도 용이하게 하는 라이브러리가 탄생하게 되었습니다.

스캐터랩이 직접 전해주는
AI에 관한 소식을 받아보세요

능력있는 현업 개발자, 기획자, 디자이너가
지금 스캐터랩에서 하고 있는 일, 세상에 벌어지고 있는 흥미로운 일들을 알려드립니다.