-
Notifications
You must be signed in to change notification settings - Fork 41
/
tweet_generator.py
70 lines (58 loc) · 2.24 KB
/
tweet_generator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
#!/usr/bin/python3
import yaml
import tweepy
import re
from textgenrnn import textgenrnn
def process_tweet_text(text):
text = re.sub(r'http\S+', '', text) # Remove URLs
text = re.sub(r'@[a-zA-Z0-9_]+', '', text) # Remove @ mentions
text = text.strip(" ") # Remove whitespace resulting from above
text = re.sub(r' +', ' ', text) # Remove redundant spaces
# Handle common HTML entities
text = re.sub(r'<', '<', text)
text = re.sub(r'>', '>', text)
text = re.sub(r'&', '&', text)
return text
with open("config.yml", "r") as f:
cfg = yaml.load(f)
auth = tweepy.OAuthHandler(cfg['consumer_key'], cfg['consumer_secret'])
auth.set_access_token(cfg['access_key'], cfg['access_secret'])
api = tweepy.API(auth)
texts = []
context_labels = []
for user in cfg['twitter_users']:
print("Downloading {}'s Tweets...".format(user))
all_tweets = tweepy.Cursor(api.user_timeline,
screen_name=user,
count=200,
tweet_mode='extended',
include_rts=False).pages(16)
for page in all_tweets:
for tweet in page:
tweet_text = process_tweet_text(tweet.full_text)
if tweet_text is not '':
texts.append(tweet_text)
context_labels.append(user)
textgen = textgenrnn(name='{}_twitter'.format("_".join(cfg['twitter_users'])))
if cfg['new_model']:
textgen.train_new_model(
texts,
context_labels=context_labels,
num_epochs=cfg['num_epochs'],
gen_epochs=cfg['gen_epochs'],
batch_size=cfg['batch_size'],
train_size=cfg['train_size'],
rnn_layers=cfg['model_config']['rnn_layers'],
rnn_size=cfg['model_config']['rnn_size'],
rnn_bidirectional=cfg['model_config']['rnn_bidirectional'],
max_length=cfg['model_config']['max_length'],
dim_embeddings=cfg['model_config']['dim_embeddings'],
word_level=cfg['model_config']['word_level'])
else:
textgen.train_on_texts(
texts,
context_labels=context_labels,
num_epochs=cfg['num_epochs'],
gen_epochs=cfg['gen_epochs'],
train_size=cfg['train_size'],
batch_size=cfg['batch_size'])