def preprocess_file(Config):
# 语料文本内容
files_content = ''
with open(Config.poetry_file, 'r', encoding='utf-8') as f:
for line in f:
# 每行的末尾加上"]"符号代表一首诗结束
files_content += line.strip() + "]".split(":")[-1]
words = sorted(list(files_content))
counted_words = {}
for word in words:
if word in counted_words:
counted_words[word] += 1
else:
counted_words[word] = 1
# 去掉低频的字
erase = []
for key in counted_words:
if counted_words[key] <= 2:
erase.append(key)
for key in erase:
del counted_words[key]
wordPairs = sorted(counted_words.items(), key=lambda x: -x[1])
words, _ = zip(*wordPairs)
words += (" ",)
# word到id的映射
word2num = dict((c, i) for i, c in enumerate(words))
num2word = dict((i, c) for i, c in enumerate(words))
word2numF = lambda x: word2num.get(x, len(words) - 1)
return word2numF, num2word, words, files_content
def data_generator(self):
'''生成数据'''
i = 0
while 1:
x = self.files_content[i: i + self.config.max_len]
y = self.files_content[i + self.config.max_len]
if ']' in x or ']' in y:
i += 1
continue
y_vec = np.zeros(
shape=(1, len(self.words)),
dtype=np.bool
)
y_vec[0, self.word2numF(y)] = 1.0
x_vec = np.zeros(
shape=(1, self.config.max_len, len(self.words)),
dtype=np.bool
)
for t, char in enumerate(x):
x_vec[0, t, self.word2numF(char)] = 1.0
yield x_vec, y_vec
i += 1
a tuple (inputs, targets, sample_weights). This tuple (a single output of the generator) makes a single batch. Therefore, all arrays in this tuple must have the same length (equal to the size of this batch). Different batches may have different sizes. For example, the last batch of the epoch is commonly smaller than the others, if the size of the dataset is not divisible by the batch size. The generator is expected to loop over its data indefinitely. An epoch finishes when steps_per_epoch batches have been seen by the model.