PyTorch JIT 은 Pytorch 의 IR(중간표현)이다.
이 중간 표현을 Torch Script 라고 부르는데, Pytorch 의 '그래프' 표현이다.
파이토치는 텐서플로우와 다르게 동적 그래프를 이용하기 때문에, 따로 그래프 표현을 해주어야 한다.
사용하는 이유 : Python 언어에 의존하지 않고, 불러낼 수 있음.
-> 최적화가 가능하고, 다른 언어에서 사용이 가능한 형태.
파이토치 모델을 IR(Torch Script) 으로 바꾸는 방법.
1. 추적 모드 : 실행되지 않은 제어흐름 (EX) if문) 을 캡처할 수 없음.
코드를 실행하고 발생하는 작업을 기록하며 정확하게 수행하는 스크립트 모듈을 구성함. - 제어흐름은 지워짐.
=> input과 같은 shape 를 갖춘 random input 을 만들고, 밑의 구문을 실행.
traced_model = torch.jit.trace(model, (random_input))
(1) traced_model.graph 에는 중간표현의 그래프가 기록되어있다.
(2) traced_model.code 에는 파이썬 구문으로 제공됨. (우리가 원하는 형태) (제어흐름은 기록되지 못함)
2. 스크립트 모드 : 함수/클래스를 취하고, 파이썬 코드를 재해석하며, TorchScript IR 을 직접 출력한다.
Script Compiler 를 이용하여 Python 소스 코드를 직접 분석하여 TorchScript를 변환함.
=>
---------------------------------------------------------------------------------------------------
1-1
class decision(torch.nn.Module) :
def forward(self,x):
pass
class model(torch.nn.Module) :
def __init__(self,dg) :
self.dg = dg
def forward(self,x):
pass (dg 이용한 branch..)
---------------------------------------------------------------------------------------------------
1-2.0
(trace 이용 ver.)
decision_branch = decision()
my_model = model( decision_branch )
traced_model = torch.jit.trace( my_model , (random_input) )
print(traced_model.code)
-> if-else 분기 (decision) 가 사라짐.
스크립트 컴파일러를 이용해야함.
---------------------------------------------------------------------------------------------------
1-2.1
(script compiler 이용 ver.)
decision_branch = torch.jit.script(decision())
my_model = model(decision_branch)
traced_model = torch.jit.script(my_model)
print(traced_model.code)
---------------------------------------------------------------------------------------------------
1-2.2
(둘 다 이용)
why? : 제어흐름이 우리가 IR에 기록하고 싶지 않은 상수값을 기반으로 만들어진 경우가 있다면, 이 부분은 script 가 아닌 trace를 이용해야 합니다.
class MyRNNLoop(torch.nn.Module):
def __init__(self):
super(MyRNNLoop, self).__init__()
self.cell = torch.jit.trace(MyCell(scripted_gate), (x, h))
def forward(self, xs):
h, y = torch.zeros(3, 4), torch.zeros(3, 4)
for i in range(xs.size(0)):
y, h = self.cell(xs[i], h)
return y, h
rnn_loop = torch.jit.script(MyRNNLoop())
print(rnn_loop.code)
class WrapRNN(torch.nn.Module):
def __init__(self):
super(WrapRNN, self).__init__()
self.loop = torch.jit.script(MyRNNLoop())
def forward(self, xs):
y, h = self.loop(xs)
return torch.relu(y)
traced = torch.jit.trace(WrapRNN(), (torch.rand(10, 3, 4)))
print(traced.code)
---------------------------------------------------------------------------------------------------
저장 , 불러오기
traced.save('wrapped_rnn.zip')
loaded = torch.jit.load('wrapped_rnn.zip')
print(loaded)
print(loaded.code)
참고한 글
tutorials.pytorch.kr/beginner/Intro_to_TorchScript_tutorial.html
'AI 프레임워크 (파이토치,텐서플로우)' 카테고리의 다른 글
TF - eager execution , JIT (@tf.function , jax) (0) | 2021.10.30 |
---|