본문 바로가기

AI 프레임워크 (파이토치,텐서플로우)

JIT / Torch Script : 1. TRACE MODE , 2. SCRIPT MODE

PyTorch JIT 은 Pytorch 의 IR(중간표현)이다.

이 중간 표현을 Torch Script 라고 부르는데, Pytorch 의 '그래프' 표현이다.

파이토치는 텐서플로우와 다르게 동적 그래프를 이용하기 때문에, 따로 그래프 표현을 해주어야 한다.

 

사용하는 이유 : Python 언어에 의존하지 않고, 불러낼 수 있음.

-> 최적화가 가능하고, 다른 언어에서 사용이 가능한 형태.

 

파이토치 모델을 IR(Torch Script) 으로 바꾸는 방법.

1. 추적 모드 : 실행되지 않은 제어흐름 (EX) if문) 을 캡처할 수 없음.

코드를 실행하고 발생하는 작업을 기록하며 정확하게 수행하는 스크립트 모듈을 구성함. - 제어흐름은 지워짐.

 

 

=> input과 같은 shape 를 갖춘 random input 을 만들고, 밑의 구문을 실행.

traced_model = torch.jit.trace(model, (random_input)) 

 

(1) traced_model.graph 에는 중간표현의 그래프가 기록되어있다.

(2) traced_model.code 에는 파이썬 구문으로 제공됨.  (우리가 원하는 형태) (제어흐름은 기록되지 못함) 

 

 

 

 

2. 스크립트 모드 : 함수/클래스를 취하고, 파이썬 코드를 재해석하며, TorchScript IR 을 직접 출력한다.

Script Compiler 를 이용하여 Python 소스 코드를 직접 분석하여 TorchScript를 변환함.

 

=> 

---------------------------------------------------------------------------------------------------

1-1

class decision(torch.nn.Module) :

            def forward(self,x):

                  pass

 

class model(torch.nn.Module) :

           def __init__(self,dg) :

                  self.dg = dg

           def forward(self,x):

                  pass (dg 이용한 branch..)

 

---------------------------------------------------------------------------------------------------

1-2.0

(trace 이용 ver.)

 

decision_branch = decision()

my_model = model( decision_branch )

traced_model = torch.jit.trace( my_model , (random_input) )

print(traced_model.code)

-> if-else 분기 (decision) 가 사라짐. 

스크립트 컴파일러를 이용해야함.

---------------------------------------------------------------------------------------------------

1-2.1

(script compiler 이용 ver.)

 

decision_branch = torch.jit.script(decision())

my_model = model(decision_branch)

traced_model = torch.jit.script(my_model)

print(traced_model.code)

---------------------------------------------------------------------------------------------------

1-2.2

(둘 다 이용)

why? : 제어흐름이 우리가 IR에 기록하고 싶지 않은 상수값을 기반으로 만들어진 경우가 있다면, 이 부분은 script 가 아닌 trace를 이용해야 합니다.

 

 

class MyRNNLoop(torch.nn.Module):

    def __init__(self):

           super(MyRNNLoop, self).__init__()

self.cell = torch.jit.trace(MyCell(scripted_gate), (x, h))

def forward(self, xs):

    h, y = torch.zeros(3, 4), torch.zeros(3, 4)

    for i in range(xs.size(0)):

         y, h = self.cell(xs[i], h)

    return y, h

rnn_loop = torch.jit.script(MyRNNLoop())

print(rnn_loop.code)

 

class WrapRNN(torch.nn.Module):

       def __init__(self):

               super(WrapRNN, self).__init__()

               self.loop = torch.jit.script(MyRNNLoop())

       def forward(self, xs):

               y, h = self.loop(xs)

               return torch.relu(y)

traced = torch.jit.trace(WrapRNN(), (torch.rand(10, 3, 4)))

print(traced.code)

 

---------------------------------------------------------------------------------------------------

저장 , 불러오기

traced.save('wrapped_rnn.zip')

loaded = torch.jit.load('wrapped_rnn.zip')

print(loaded)

print(loaded.code)

 

 

 

참고한 글 

data-newbie.tistory.com/425

tutorials.pytorch.kr/beginner/Intro_to_TorchScript_tutorial.html