Home [NLP] KoBART 요약 실행해보기(3/3)
Post
Cancel

[NLP] KoBART 요약 실행해보기(3/3)

KoBART Summarization (3/3)

지난 포스팅에 이어 모델 학습 및 테스트에 대한 내용을 적을 것이다.


작업 환경

구글 코랩 프로를 사용하고 있다.

1
2
# gpu 사용 체크
!nvidia-smi

image-20211227112227850



모델 학습

자세한 학습 코드는 KoBART-summarization 깃허브 페이지를 통해 확인할 수 있다.

GPU

1
python train.py --gradient_clip_val 1.0 --max_epochs 1 --default_root_dir logs --gpus 1 --batch_size 10 --num_workers 4

or

1
sh run_train.sh

CPU

1
python train.py  --gradient_clip_val 1.0 --max_epochs 50 --default_root_dir logs  --batch_size 4 --num_workers 4

or

1
sh run_train_cpu.sh


위 코드를 통해 학습을 진행할 수 있다. 구글 코랩 p100 기준으로 1 epoch 진행에 약 1시간 30분이 소요되는 것을 확인할 수 있었다. (가공한 데이터 포함 기준 - data length:61275)


  • torchtext 버전 오류가 나는 경우, 학습 진행 전 아래 코드를 실행해 패키지를 설치해준다. (맥에서 학습을 진행할 때에는 별다른 에러 없이 잘 진행되었다. 하지만 코랩을 사용하니 여러 오류가 발생하는 것 같다.)

    1
    
    !pip install torchtext==0.8.0
    



Get Model

학습을 끝냈다. 이제 학습시킨 모델을 가져와 저장해보자.

1
2
3
4
# kobart_summary 디렉토리에 모델 저장
# hparams: logs 하위 디렉토리에서 사용할 모델의 버전 골라 hparams.yaml set
# model_binary: logs 하위 디렉토리에서 사용할 체크포인트 골라 *.ckpt set
!python get_model_binary.py --hparams ./logs/tb_logs/default/version_0/hparams.yaml --model_binary ./logs/model_chp/epoch=00-val_loss=1.464.ckpt

위 실행 코드는 version_0의 파라미터를 사용하고, (val_loss가 1.464인) 0번째 epoch의 체크포인트를 사용해 모델을 가져오겠다는 의미다.

코드를 실행하면 폴더명이 kobart_summary인 하위 폴더에 모델이 저장될 것이다.


  • Loader 관련 오류 발생 시 pyyaml 설치를 진행 후 get_model_binary.py를 실행해준다.
    1
    
    !pip install pyyaml==5.4.1
    



Test

테스트를 진행하기 전 패키지 설치를 진행해준다.

1
pip install git+https://github.com/SKT-AI/KoBART#egg=kobart

get_model_binary.py의 실행을 통해 저장한 모델을 사용해 테스트를 진행한다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import torch
from transformers import PreTrainedTokenizerFast
from transformers.models.bart import BartForConditionalGeneration

model = BartForConditionalGeneration.from_pretrained('./kobart_summary')
tokenizer = PreTrainedTokenizerFast.from_pretrained('gogamza/kobart-base-v1')

text = input()

if text:
    input_ids = tokenizer.encode(text)
    input_ids = torch.tensor(input_ids)
    input_ids = input_ids.unsqueeze(0)
    output = model.generate(input_ids, eos_token_id=1, max_length=512, num_beams=5)
    output = tokenizer.decode(output[0], skip_special_tokens=True)
    print(output)

텍스트를 입력받아 가공을 진행하고, 모델에 넣어 결과를 출력하는 과정이다.

max_length는 512로 제한되어있다. 사용자가 사용할 텍스트에 따라 조정하면 될 것 같다.


네이버에서 기사를 하나 골라 테스트를 진행해보았다.

epoch을 1로 주어 학습량이 상당히 적은 모델을 사용했다. 그래서 그런지 ‘사인위조 등 혐의로 기소된 A씨에 대해’와 같은 부분이 중복되는 것을 확인했고, 요약 모델로서 원하는 성능에 미치지는 못한다.

하지만 기사에서 말하고자 하는 중요한 부분을 한 문장으로 잘 표현하고 있는 것을 보니 학습을 계속해서 진행한다면 좋은 성능을 기대해볼만 하다고 생각한다.



이로써 KoBART 요약 모델 학습을 진행해보는 과정이 끝났다.

학습을 조금 더 진행한 후, Flask에 올려 요약 웹, 앱 개발 프로젝트에 사용할 수 있을 것이라 기대한다.

This post is licensed under younghwani by the author.

[NLP] KoBART 요약 실행해보기(2/3)

[Env] vim 초기 세팅하기

Comments powered by Disqus.