KoBART Summarization (3/3)
지난 포스팅에 이어 모델 학습 및 테스트에 대한 내용을 적을 것이다.
작업 환경
구글 코랩 프로를 사용하고 있다.
1
2
# gpu 사용 체크
!nvidia-smi
모델 학습
자세한 학습 코드는 KoBART-summarization 깃허브 페이지를 통해 확인할 수 있다.
GPU
1
python train.py --gradient_clip_val 1.0 --max_epochs 1 --default_root_dir logs --gpus 1 --batch_size 10 --num_workers 4
or
1
sh run_train.sh
CPU
1
python train.py --gradient_clip_val 1.0 --max_epochs 50 --default_root_dir logs --batch_size 4 --num_workers 4
or
1
sh run_train_cpu.sh
위 코드를 통해 학습을 진행할 수 있다. 구글 코랩 p100 기준으로 1 epoch 진행에 약 1시간 30분이 소요되는 것을 확인할 수 있었다. (가공한 데이터 포함 기준 - data length:61275)
torchtext 버전 오류가 나는 경우, 학습 진행 전 아래 코드를 실행해 패키지를 설치해준다. (맥에서 학습을 진행할 때에는 별다른 에러 없이 잘 진행되었다. 하지만 코랩을 사용하니 여러 오류가 발생하는 것 같다.)
1
!pip install torchtext==0.8.0
Get Model
학습을 끝냈다. 이제 학습시킨 모델을 가져와 저장해보자.
1
2
3
4
# kobart_summary 디렉토리에 모델 저장
# hparams: logs 하위 디렉토리에서 사용할 모델의 버전 골라 hparams.yaml set
# model_binary: logs 하위 디렉토리에서 사용할 체크포인트 골라 *.ckpt set
!python get_model_binary.py --hparams ./logs/tb_logs/default/version_0/hparams.yaml --model_binary ./logs/model_chp/epoch=00-val_loss=1.464.ckpt
위 실행 코드는 version_0의 파라미터를 사용하고, (val_loss가 1.464인) 0번째 epoch의 체크포인트를 사용해 모델을 가져오겠다는 의미다.
코드를 실행하면 폴더명이 kobart_summary인 하위 폴더에 모델이 저장될 것이다.
- Loader 관련 오류 발생 시 pyyaml 설치를 진행 후 get_model_binary.py를 실행해준다.
1
!pip install pyyaml==5.4.1
Test
테스트를 진행하기 전 패키지 설치를 진행해준다.
1
pip install git+https://github.com/SKT-AI/KoBART#egg=kobart
get_model_binary.py의 실행을 통해 저장한 모델을 사용해 테스트를 진행한다.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import torch
from transformers import PreTrainedTokenizerFast
from transformers.models.bart import BartForConditionalGeneration
model = BartForConditionalGeneration.from_pretrained('./kobart_summary')
tokenizer = PreTrainedTokenizerFast.from_pretrained('gogamza/kobart-base-v1')
text = input()
if text:
input_ids = tokenizer.encode(text)
input_ids = torch.tensor(input_ids)
input_ids = input_ids.unsqueeze(0)
output = model.generate(input_ids, eos_token_id=1, max_length=512, num_beams=5)
output = tokenizer.decode(output[0], skip_special_tokens=True)
print(output)
텍스트를 입력받아 가공을 진행하고, 모델에 넣어 결과를 출력하는 과정이다.
max_length는 512로 제한되어있다. 사용자가 사용할 텍스트에 따라 조정하면 될 것 같다.
네이버에서 기사를 하나 골라 테스트를 진행해보았다.
원본 기사 링크 : 한국일보-아픈 아이 데려오려고 남편 도장 위조… 대법 “정당 행위”
Output
1
27일 대법원 3부(주심 노정희 대법관)는 사인위조 등 혐의로 기소된 A씨에 대해 사인위조 등 혐의로 기소된 A씨에 대해 무죄를 선고한 원심을 확정했다고 밝혔다.
epoch을 1로 주어 학습량이 상당히 적은 모델을 사용했다. 그래서 그런지 ‘사인위조 등 혐의로 기소된 A씨에 대해’와 같은 부분이 중복되는 것을 확인했고, 요약 모델로서 원하는 성능에 미치지는 못한다.
하지만 기사에서 말하고자 하는 중요한 부분을 한 문장으로 잘 표현하고 있는 것을 보니 학습을 계속해서 진행한다면 좋은 성능을 기대해볼만 하다고 생각한다.
이로써 KoBART 요약 모델 학습을 진행해보는 과정이 끝났다.
학습을 조금 더 진행한 후, Flask에 올려 요약 웹, 앱 개발 프로젝트에 사용할 수 있을 것이라 기대한다.
Comments powered by Disqus.