본문 바로가기
jetson nano 실습

[ActionAI]jetson nano에서 trt_pose 설치

by 코딩새내기_ 2021. 7. 5.

Nvidia에서 제공하는 trt_pose를 설치해보려고 합니다.

https://github.com/NVIDIA-AI-IOT/trt_pose

 

NVIDIA-AI-IOT/trt_pose

Real-time pose estimation accelerated with NVIDIA TensorRT - NVIDIA-AI-IOT/trt_pose

github.com


먼저 환경설정은 아래 링크에서 하고 오시면 됩니다. 저는 도커 환경에서 진행하여 커맨드에 sudo가 없습니다.

https://ddo-code.tistory.com/21?category=943823 

tensorrt는 메모리를 많이 먹어서 일단 swap memory를 늘려줘야 합니다.

아래 링크대로 하고 오시면 됩니다.

https://ddo-code.tistory.com/24

torch2trt를 설치합니다.

저는 git clone할 때 도커에서 하지 않고 그냥 로컬 커맨드에서 하였습니다. 그리고 도커에서 폴더에 들어가서 진행하였습니다.

git clone https://github.com/NVIDIA-AI-IOT/torch2trt
cd torch2trt
python3 setup.py install --plugins

다른 패키지를 설치합니다.

pip3 install tqdm cython pycocotools
apt-get install python3-matplotlib

trt_pose를 설치합니다.

git clone https://github.com/NVIDIA-AI-IOT/trt_pose
cd trt_pose
python3 setup.py install

plugins를 복사해줍니다.

나중에 trt_pose를 실행할 때 plugins.cpython-36m-aarch64-linux-gnu.so가 없으면 오류가 뜹니다.

커맨드로 하면 이렇게 되겠습니다.

cp build/lib.linux-aarch64-3.6/trt_pose/plugins.cpython-36m-aarch64-linux-gnu.so trt_pose/.

 

torch2trt에서도 똑같이 plugins를 복사해줍니다.

그리고 torch2trt안에 있는 torch2trt폴더를 trt_pose안에 복사해줍니다.

tkinter를 설치합니다.

apt-get install python3-tk

trt_test.py를 생성합니다.

import json
import trt_pose.coco
import torch2trt

with open('human_pose.json', 'r') as f:
    human_pose = json.load(f)
print(human_pose)
topology = trt_pose.coco.coco_category_to_topology(human_pose)

import trt_pose.models

num_parts = len(human_pose['keypoints'])
num_links = len(human_pose['skeleton'])

model = trt_pose.models.resnet18_baseline_att(num_parts, 2 * num_links).cuda().eval()
import torch

MODEL_WEIGHTS = 'resnet18_baseline_att_224x224_A_epoch_249.pth'

model.load_state_dict(torch.load(MODEL_WEIGHTS))
print('load model complete')
WIDTH = 224
HEIGHT = 224

data = torch.zeros((1, 3, HEIGHT, WIDTH)).cuda()



model_trt = torch2trt.torch2trt(model, [data], fp16_mode=True, max_workspace_size=1<<25)
print('trt model complete')
OPTIMIZED_MODEL = 'resnet18_baseline_att_224x224_A_epoch_249_trt.pth'

torch.save(model_trt.state_dict(), OPTIMIZED_MODEL)

trt_pose폴더에서 python3 trt_test.py를 실행시켜서 resnet18_baseline_att_224x224_A_epoch_249_trt.pth가 생성되면 성공입니다.

댓글