首页 \ 问答 \ 用于回归的TensorFlow dynamic_rnn输入(TensorFlow dynamic_rnn input for regression)

用于回归的TensorFlow dynamic_rnn输入(TensorFlow dynamic_rnn input for regression)

我试图将现有的张量流序列转换为序列分类器到回归器。

目前我一直在处理tf.nn.dynamic_rnn()的输入。 根据文档和其他答案,输入应该是(batch_size, sequence_length, input_size)的形状。 但是我的输入数据只有两个维度: (sequence_length, batch_size)

在将输入提供给dynamic_rnn()之前,原始解决方案使用tf.nn.embedding_lookup()作为中间步骤。 如果我理解正确,我相信我不需要这一步,因为我正在处理回归问题,而不是分类问题。

我需要embedding_lookup步骤吗? 如果是这样,为什么? 如果没有,我如何将我的encoder_inputs直接装入dynamic_rnn()

以下是一般工作的最小化示例:

import numpy as np
import tensorflow as tf

tf.reset_default_graph()
sess = tf.InteractiveSession()

PAD = 0
EOS = 1
VOCAB_SIZE = 10 # Don't think I should need this for regression?
input_embedding_size = 20

encoder_hidden_units = 20
decoder_hidden_units = encoder_hidden_units

LENGTH_MIN = 3
LENGTH_MAX = 8
VOCAB_LOWER = 2
VOCAB_UPPER = VOCAB_SIZE
BATCH_SIZE = 10

def get_random_sequences():
    sequences = []
    for j in range(BATCH_SIZE):
        random_numbers = np.random.randint(3, 10, size=8)
        sequences.append(random_numbers)
    sequences = np.asarray(sequences).T
    return(sequences)

def next_feed():
    batch = get_random_sequences()

    encoder_inputs_ = batch
    eos = np.ones(BATCH_SIZE)
    decoder_targets_ = np.hstack((batch.T, np.atleast_2d(eos).T)).T
    decoder_inputs_ = np.hstack((np.atleast_2d(eos).T, batch.T)).T

    #print(encoder_inputs_)
    #print(decoder_inputs_)

    return {
        encoder_inputs: encoder_inputs_,
        decoder_inputs: decoder_inputs_,
        decoder_targets: decoder_targets_,
    }

### "MAIN"

# Placeholders
encoder_inputs = tf.placeholder(shape=(LENGTH_MAX, BATCH_SIZE), dtype=tf.int32, name='encoder_inputs')
decoder_targets = tf.placeholder(shape=(LENGTH_MAX + 1, BATCH_SIZE), dtype=tf.int32, name='decoder_targets')
decoder_inputs = tf.placeholder(shape=(LENGTH_MAX + 1, BATCH_SIZE), dtype=tf.int32, name='decoder_inputs')

# Don't think I should need this for regression problems
embeddings = tf.Variable(tf.random_uniform([VOCAB_SIZE, input_embedding_size], -1.0, 1.0), dtype=tf.float32)
encoder_inputs_embedded = tf.nn.embedding_lookup(embeddings, encoder_inputs)
decoder_inputs_embedded = tf.nn.embedding_lookup(embeddings, decoder_inputs)

# Encoder RNN
encoder_cell = tf.contrib.rnn.LSTMCell(encoder_hidden_units)
encoder_outputs, encoder_final_state = tf.nn.dynamic_rnn(
    encoder_cell, encoder_inputs_embedded, # Throws 'ValueError: Shape (8, 10) must have rank at least 3' if encoder_inputs is used
    dtype=tf.float32, time_major=True,
)

# Decoder RNN
decoder_cell = tf.contrib.rnn.LSTMCell(decoder_hidden_units)
decoder_outputs, decoder_final_state = tf.nn.dynamic_rnn(
    decoder_cell, decoder_inputs_embedded, 
    initial_state=encoder_final_state,
    dtype=tf.float32, time_major=True, scope="plain_decoder",
)
decoder_logits = tf.contrib.layers.linear(decoder_outputs, VOCAB_SIZE)
decoder_prediction = tf.argmax(decoder_logits, 2)

# Loss function
loss = tf.reduce_mean(tf.squared_difference(decoder_logits, tf.one_hot(decoder_targets, depth=VOCAB_SIZE, dtype=tf.float32)))
train_op = tf.train.AdamOptimizer().minimize(loss)


sess.run(tf.global_variables_initializer())

max_batches = 5000
batches_in_epoch = 500

print('Starting train')
try:
    for batch in range(max_batches):
        feed = next_feed()
        _, l = sess.run([train_op, loss], feed)

        if batch == 0 or batch % batches_in_epoch == 0:
            print('batch {}'.format(batch))
            print('  minibatch loss: {}'.format(sess.run(loss, feed)))
            predict_ = sess.run(decoder_prediction, feed)
            for i, (inp, pred) in enumerate(zip(feed[encoder_inputs].T, predict_.T)):
                print('  sample {}:'.format(i + 1))
                print('    input     > {}'.format(inp))
                print('    predicted > {}'.format(pred))
                if i >= 2:
                    break
            print()
except KeyboardInterrupt:
    print('training interrupted')

我已经在stackoverflow上阅读了类似的问题,但我发现自己仍然对如何解决这个问题感到困惑。

编辑:我想我应该澄清上面的代码运行良好,但真正的期望输出应该模仿一个有噪声的信号(例如文本到语音),这就是为什么我认为我需要连续的输出值而不是单词或字母。


I'm stuck trying to convert an existing tensorflow sequence to sequence classifier to a regressor.

Currently I'm stuck in handling the input for tf.nn.dynamic_rnn(). According to the documentation and other answers, input should be in the shape of (batch_size, sequence_length, input_size). However my input data has only two dimensions: (sequence_length, batch_size).

The original solution uses tf.nn.embedding_lookup() as an intermediate step before feeding input to dynamic_rnn(). If I understand correctly, I believe I don't need this step since I'm working on a regression problem, not a classification problem.

Do I need the embedding_lookup step? If so, why? If not, how can I fit my encoder_inputs directly into dynamic_rnn()?

Below is a working minimalized example of the general idea:

import numpy as np
import tensorflow as tf

tf.reset_default_graph()
sess = tf.InteractiveSession()

PAD = 0
EOS = 1
VOCAB_SIZE = 10 # Don't think I should need this for regression?
input_embedding_size = 20

encoder_hidden_units = 20
decoder_hidden_units = encoder_hidden_units

LENGTH_MIN = 3
LENGTH_MAX = 8
VOCAB_LOWER = 2
VOCAB_UPPER = VOCAB_SIZE
BATCH_SIZE = 10

def get_random_sequences():
    sequences = []
    for j in range(BATCH_SIZE):
        random_numbers = np.random.randint(3, 10, size=8)
        sequences.append(random_numbers)
    sequences = np.asarray(sequences).T
    return(sequences)

def next_feed():
    batch = get_random_sequences()

    encoder_inputs_ = batch
    eos = np.ones(BATCH_SIZE)
    decoder_targets_ = np.hstack((batch.T, np.atleast_2d(eos).T)).T
    decoder_inputs_ = np.hstack((np.atleast_2d(eos).T, batch.T)).T

    #print(encoder_inputs_)
    #print(decoder_inputs_)

    return {
        encoder_inputs: encoder_inputs_,
        decoder_inputs: decoder_inputs_,
        decoder_targets: decoder_targets_,
    }

### "MAIN"

# Placeholders
encoder_inputs = tf.placeholder(shape=(LENGTH_MAX, BATCH_SIZE), dtype=tf.int32, name='encoder_inputs')
decoder_targets = tf.placeholder(shape=(LENGTH_MAX + 1, BATCH_SIZE), dtype=tf.int32, name='decoder_targets')
decoder_inputs = tf.placeholder(shape=(LENGTH_MAX + 1, BATCH_SIZE), dtype=tf.int32, name='decoder_inputs')

# Don't think I should need this for regression problems
embeddings = tf.Variable(tf.random_uniform([VOCAB_SIZE, input_embedding_size], -1.0, 1.0), dtype=tf.float32)
encoder_inputs_embedded = tf.nn.embedding_lookup(embeddings, encoder_inputs)
decoder_inputs_embedded = tf.nn.embedding_lookup(embeddings, decoder_inputs)

# Encoder RNN
encoder_cell = tf.contrib.rnn.LSTMCell(encoder_hidden_units)
encoder_outputs, encoder_final_state = tf.nn.dynamic_rnn(
    encoder_cell, encoder_inputs_embedded, # Throws 'ValueError: Shape (8, 10) must have rank at least 3' if encoder_inputs is used
    dtype=tf.float32, time_major=True,
)

# Decoder RNN
decoder_cell = tf.contrib.rnn.LSTMCell(decoder_hidden_units)
decoder_outputs, decoder_final_state = tf.nn.dynamic_rnn(
    decoder_cell, decoder_inputs_embedded, 
    initial_state=encoder_final_state,
    dtype=tf.float32, time_major=True, scope="plain_decoder",
)
decoder_logits = tf.contrib.layers.linear(decoder_outputs, VOCAB_SIZE)
decoder_prediction = tf.argmax(decoder_logits, 2)

# Loss function
loss = tf.reduce_mean(tf.squared_difference(decoder_logits, tf.one_hot(decoder_targets, depth=VOCAB_SIZE, dtype=tf.float32)))
train_op = tf.train.AdamOptimizer().minimize(loss)


sess.run(tf.global_variables_initializer())

max_batches = 5000
batches_in_epoch = 500

print('Starting train')
try:
    for batch in range(max_batches):
        feed = next_feed()
        _, l = sess.run([train_op, loss], feed)

        if batch == 0 or batch % batches_in_epoch == 0:
            print('batch {}'.format(batch))
            print('  minibatch loss: {}'.format(sess.run(loss, feed)))
            predict_ = sess.run(decoder_prediction, feed)
            for i, (inp, pred) in enumerate(zip(feed[encoder_inputs].T, predict_.T)):
                print('  sample {}:'.format(i + 1))
                print('    input     > {}'.format(inp))
                print('    predicted > {}'.format(pred))
                if i >= 2:
                    break
            print()
except KeyboardInterrupt:
    print('training interrupted')

I have read similar questions here on stackoverflow but find my self still puzzled as to how to solve this.

EDIT: I think I should clarify that the code above works well, however the real desired output should mimic a noisy signal (text to speech for example) which is why I think I need continuous output values instead of words or letters.


原文:https://stackoverflow.com/questions/44871420
更新时间:2023-03-31 10:03

最满意答案

gtk.main()运行直到关闭窗口(它调用“主循环”或“事件循环”并且它在GUI程序中执行所有操作 - 获取键/鼠标事件,将其发送到窗口小部件,重绘窗口小部件,在ypu按下时运行函数按钮等)。

您必须使用Threading同时运行(长时间运行)代码或使用GUI中的某些Timer类来定期执行某些代码。


gtk.main() runs till you close window (it is call "main loop" or "event loop" and it does everything in GUI program - get key/mouse event, send it to widgets, redraw widgets, run functions when ypu press button, etc.).

You have to use Threading to run (long-running) code at the same time or use some Timer class in GUI to execute some code periodically.

相关问答

更多
  • 只需按照说明操作并使用关键字参数即可。 我还将按钮更改为使用.add_buttons()因为它还抛出了DeprecationWarning: import gi gi.require_version('Gtk', '3.0') from gi.repository import Gtk dialog = Gtk.FileChooserDialog( title="Please choose a folder", action=Gtk.FileChooserAction.SELECT_FOL ...
  • gtk功能 : gtk.main()函数运行主循环,直到gtk.main_quit()函数。 因此,您需要在单独的线程中运行while循环,并在完成后调用main_quit() 。 gtk Functions: The gtk.main() function runs the main loop until the gtk.main_quit() function is called. So you need to run your while loop in a separate thread and c ...
  • Gtk.Container不是Widget,它是您必须实现的接口。 这不太可能是你想要实际做的事情,因为实现一个新容器并非易事。 如果你想让它包含多个孩子,你想要使用的可能是Gtk.Box Gtk.Bin如果你只想要一个孩子,你可能想要使用Gtk.Box 。 Gtk.Container is not a Widget it is an Interface that you have to implement. It is unlikely that is what you want to actually ...
  • 检查以下代码是否适合您: #include int main( int argc, char *argv[]) { GtkWidget *window; GtkWidget *layout; GtkWidget *image; GtkWidget *button; gtk_init(&argc, &argv); window = gtk_window_new(GTK_WINDOW_TOPLEVEL); gtk_window ...
  • 不是100%确定你想要做什么,但如果你想根据输入的文本过滤树视图显示的内容,你应该看一下GtkTreeModelFilter 。 您可以将它用作TreeModel并设置您自己的VisibleFunc,它根据输入的文本决定行是否可见。 当文本改变时,只需调用refilter():它将为每一行调用VisibleFunc()。 很抱歉没有为您提供python示例或文档,我希望这仍然有帮助...... Not 100% sure what you want to do but if you want to filt ...
  • 由于menu是show_menu()函数中的一个局部变量,并且它没有被别的引用,所以它的引用计数会降到0,并在函数结束时被销毁。 不幸的是,当你期望看到它时,这是正确的。 相反,在全局范围内创建menu使得menu不再局限于一个函数,因此它不会在函数结束时被销毁。 from gi.repository import Gtk def show_menu(self, *args): i1 = Gtk.MenuItem("Item 1") menu.append(i1) i2 = Gtk ...
  • gtk.main()运行直到关闭窗口(它调用“主循环”或“事件循环”并且它在GUI程序中执行所有操作 - 获取键/鼠标事件,将其发送到窗口小部件,重绘窗口小部件,在ypu按下时运行函数按钮等)。 您必须使用Threading同时运行(长时间运行)代码或使用GUI中的某些Timer类来定期执行某些代码。 gtk.main() runs till you close window (it is call "main loop" or "event loop" and it does everything in ...
  • 您可以使用与GTK主循环集成的Gio.Subprocess ,而不是使用Python的subprocess Gio.Subprocess模块: #!/usr/bin/python3 from gi.repository import Gtk, Gio # ... class updateWindow(Gtk.Window): def __init__(self): Gtk.Window.__init__(self, title="Updating...") s ...
  • 这显示了如何增加行高。 可悲的是,似乎没有row_expand()可以这么说。 我想你可以在调整大小时总是得到窗口的高度,并做一些数学来计算行的高度...... class MainWindow(Gtk.Window): def __init__(self): Gtk.Window.__init__(self) self.calendar = Gtk.Calendar() self.calendar.set_detail_height_rows( ...
  • 您实际上没有启动该线程,您只实例化了一个可用于启动它的对象。 完整的解决方案需要仔细分离GUI线程和工作线程之间的职责。 你想要做的是以下内容: 在单独的线程中进行繁重的计算,由GUI代码生成并加入。 计算不应该生成自己的线程,也不需要知道线程(当然,除了线程安全之外)。 线程完成后,使用gobject.idle_add()告诉GUI可以撤消进度指示器。 ( gobject.idle_add是唯一可以安全地从另一个线程调用的GTK函数。) 通过这样的设置,无论计算如何,GUI都保持完全响应并且进度条更新,并 ...

相关文章

更多

最新问答

更多
  • 散列包括方法和/或嵌套属性(Hash include methods and/or nested attributes)
  • TensorFlow:基于索引列表创建新张量(TensorFlow: Create a new tensor based on list of indices)
  • 企业安全培训的各项内容
  • 错误:RPC失败;(error: RPC failed; curl transfer closed with outstanding read data remaining)
  • NumPy:将int64值存储在np.array中并使用dtype float64并将其转换回整数是否安全?(NumPy: Is it safe to store an int64 value in an np.array with dtype float64 and later convert it back to integer?)
  • 注销后如何隐藏导航portlet?(How to hide navigation portlet after logout?)
  • 将多个行和可变行移动到列(moving multiple and variable rows to columns)
  • 对setOnInfoWindowClickListener的意图(Intent on setOnInfoWindowClickListener)
  • Angular $资源不会改变方法(Angular $resource doesn't change method)
  • 如何配置Composite C1以将.m和桌面作为同一站点提供服务(How to configure Composite C1 to serve .m and desktop as the same site)
  • 不适用:悬停在悬停时:在元素之前[复制](Don't apply :hover when hovering on :before element [duplicate])
  • Mysql DB单个字段匹配多个其他字段(Mysql DB single field matching to multiple other fields)
  • 产品页面上的Magento Up出售对齐问题(Magento Up sell alignment issue on the products page)
  • 是否可以嵌套hazelcast IMaps?(Is it possible to nest hazelcast IMaps? And whick side effects can I expect? Is it a good Idea anyway?)
  • UIViewAnimationOptionRepeat在两个动画之间暂停(UIViewAnimationOptionRepeat pausing in between two animations)
  • 在x-kendo-template中使用Razor查询(Using Razor query within x-kendo-template)
  • 在BeautifulSoup中替换文本而不转义(Replace text without escaping in BeautifulSoup)
  • 如何在存根或模拟不存在的方法时配置Rspec以引发错误?(How can I configure Rspec to raise error when stubbing or mocking non-existing methods?)
  • asp用javascript(asp with javascript)
  • “%()s”在sql查询中的含义是什么?(What does “%()s” means in sql query?)
  • 如何为其编辑的内容提供自定义UITableViewCell上下文?(How to give a custom UITableViewCell context of what it is editing?)
  • c ++十进制到二进制,然后使用操作,然后回到十进制(c++ Decimal to binary, then use operation, then back to decimal)
  • 以编程方式创建视频?(Create videos programmatically?)
  • 无法在BeautifulSoup中正确解析数据(Unable to parse data correctly in BeautifulSoup)
  • webform和mvc的区别 知乎
  • 如何使用wadl2java生成REST服务模板,其中POST / PUT方法具有参数?(How do you generate REST service template with wadl2java where POST/PUT methods have parameters?)
  • 我无法理解我的travis构建有什么问题(I am having trouble understanding what is wrong with my travis build)
  • iOS9 Scope Bar出现在Search Bar后面或旁边(iOS9 Scope Bar appears either behind or beside Search Bar)
  • 为什么开机慢上面还显示;Inetrnet,Explorer
  • 有关调用远程WCF服务的超时问题(Timeout Question about Invoking a Remote WCF Service)