seq2seq 训练时 feed 自己的数据

xiaoxiao2021-02-28  3

在这个文件加入以下代码https://github.com/tensorflow/models/blob/master/tutorials/rnn/translate/translate.py

def vectorize_data(data, word_idx): #word_idx >=1 ,frist is unknow-token Q = [] for line in data: ss = [] for word in line: if word not in word_idx: ss.append(0) else: ss.append(word_idx[word]) Q.append(ss) return Q def load_data(file): with open(file) as f: lines = f.readlines() chinese_data = [] english_data = [] index = 0 for line in lines: if line == "\n": continue words_list=[] words = line.split(' ') [words_list.append(word.strip("\n")) for word in words] if index % 2 == 0: chinese_data.append(words_list) elif index%2 == 1: english_data.append(words_list) index+=1 return chinese_data,english_data chinese_data, english_data = load_data('./data.txt') _PAD = b"_PAD" _GO = b"_GO" _EOS = b"_EOS" _UNK = b"_UNK" PAD_ID = 0 GO_ID = 1 EOS_ID = 2 UNK_ID = 3 temp = reduce(lambda x, y: x + y, [story for story in chinese_data]) chinese_vocab = set(temp) chinese_word_idx = dict((c, i + 4) for i, c in enumerate(chinese_vocab)) chinese_word_idx[_PAD]= PAD_ID chinese_word_idx[_GO] = GO_ID chinese_word_idx[_EOS] = EOS_ID chinese_word_idx[_UNK] = UNK_ID sentence_max_word_number_chinese = max(map(len, chinese_data)) temp = reduce(lambda x, y: x + y, [story for story in english_data]) english_vocab = set(temp) english_word_idx = dict((c, i + 4) for i, c in enumerate(english_vocab)) english_word_idx[_PAD] = PAD_ID english_word_idx[_GO] = GO_ID english_word_idx[_EOS] = EOS_ID english_word_idx[_UNK] = UNK_ID sentence_max_word_number_english = max(map(len, english_data)) chinese_ids = vectorize_data(chinese_data,chinese_word_idx) english_ids = vectorize_data(english_data, english_word_idx) for line in english_ids: line.append(EOS_ID) data_set = [[] for _ in _buckets] for chinese_line,english_line in zip(chinese_ids,english_ids): for bucket_id, (source_size, target_size) in enumerate(_buckets): if len(chinese_line) < source_size and len(english_line) < target_size: data_set[bucket_id].append([chinese_line, english_line]) break train_set = data_set # 替换原来的train_set

数据文件的样子

纽约 比 加州 早 三个 小时 New York is 3 hours ahead of California 但 这 没有 让 加州 变慢 but it does not make California slow 有人 22岁 毕业了 Someone graduated at the age of 22 但 等了 五年 才 找到 好的 工作 but waited 5 years before securing a good job 有人 25岁 当上 CEO Someone became a CEO at 25 却 在 50岁 去世 and died at 50 然而 另一个人 50岁 当上 CEO While another became a CEO at 50 然后 活到 90岁 and lived to 90 years 有人 依然 单身 Someone is still single 然而 也 有人 已经 结婚 while someone else got married 奥巴马 55岁 退休 Obama retires at 55 但 川普 70岁 开始 but Trump starts at 70 本来 世界上 每个人 在 自己的 时区 工作 Absolutely everyone in this world works based on their Time Zone 身边 有人 可能 看似 走 在 你 前面 People around you might seem to go ahead of you 有人 可能 看似 在 你 后面 some might seem to be behind you. 但 每个人 正在 以 他们的 速度 奔跑 在 他们 自己的 时区 But everyone is running their own RACE in their own TIME. 不要 嫉妒 或 嘲笑 他们 Don’t envy them or mock them 他们 在 他们的 时区 你 在 你的 They are in their TIME ZONE and you are in yours 生命 是 关于 等待 正确的 时机 行动 Life is about waiting for the right moment to act 所以 放轻松 So RELAX 你 没有 落后 You’re not LATE 你 没有 领先 You’re not EARLY 你 非常 准时 在 命运 为 你 安排 的 时区 You are very much ON TIME, and in your TIME ZONE Destiny set up for you.

转载请注明原文地址: https://www.6miu.com/read-250373.html

最新回复(0)