Make sure GPU is enabled (edit->notebook settings->Hardware Accelerator GPU)
Clone and cd into repo
!git clone https://github.com/nshepperd/gpt-2.git
cd gpt-2
Install requirements
!pip3 install -r requirements.txt
Download the model
!python3 download_model.py 117M
!export PYTHONIOENCODING=UTF-8
Clone joke dataset
cd /content/gpt-2
! git clone https://github.com/taivop/joke-dataset
Exract useful jokes and save to plain text
cd joke-dataset/
import json
with open("stupidstuff.json") as f:
jokes = json.load(f)
#print(jokes)
text = ''.join(x.get("body") for x in jokes)
#print(text)
with open('stupidstuff.txt', 'w') as out_file:
out_file.write(text)
with open("wocka.json") as f:
jokes = json.load(f)
text = ''.join(x.get("body") for x in jokes)
with open('wocka.txt', 'w') as out_file:
out_file.write(text)
def catjoke(x):
title = x.get("title")
body = x.get("body")
return title + body
with open("reddit_jokes.json") as f:
jokes = json.load(f)
text = ''.join(catjoke(x) for x in jokes)
with open('reddit_jokes.txt', 'w') as out_file:
out_file.write(text)
Let's get our train on! In this case the file is reddit jokes scraped from /r/jokes. We are going to retrain GPT-2 117M model on this custom text dataset. Note that we can use small datasets but we have to be sure not to run the fine-tuning for too long or we will overfit badly.
The default training setting will save checkpoints every 1000 steps.
cd /content/gpt-2
!PYTHONPATH=src ./train.py --dataset /content/gpt-2/joke-dataset/reddit_jokes.txt
Load trained model for use
!cp -r /content/gpt-2/checkpoint/run1/* /content/gpt-2/models/117M/
Generate conditional samples from the model given a prompt - change top-k hyperparameter if desired (default is 40)
!python3 src/interactive_conditional_samples.py --top_k 40
To check flag descriptions, use:
!python3 src/interactive_conditional_samples.py -- --help
Generate unconditional samples from the model
!python3 src/generate_unconditional_samples.py | tee /tmp/samples
To check flag descriptions, use:
!python3 src/generate_unconditional_samples.py -- --help