发布于2023-02-03 21:17 阅读(349) 评论(0) 点赞(6) 收藏(4)
我有一个数据集,每个观察都有多个组件(比标准的多X
,y
假设我有额外的组件Z
)。每个观察值都可以有可变长度,因此我想使用bucket_by_sequence_length
API tf.data.Dataset
。X
我分别为、Z
、创建了数据集y
,然后将它们压缩在一起。这是最小的例子:
import numpy as np
import tensorflow as tf
np.random.seed(42)
X = []
Z = []
y = []
for i in range(100):
obs_len = np.random.randint(5, 25)
X.append(np.random.random(size=[obs_len, 4]))
Z.append(np.random.random(size=[obs_len, 1]))
y.append(np.random.randint(0, 2, size=[obs_len,]))
def create_generator(list_of_arrays):
for i in list_of_arrays:
yield i
X_dataset = tf.data.Dataset.from_generator(lambda: create_generator(X), output_types= tf.float32, output_shapes=(None, 4))
Z_dataset = tf.data.Dataset.from_generator(lambda: create_generator(Z), output_types= tf.float32, output_shapes=(None, 1))
y_dataset = tf.data.Dataset.from_generator(lambda: create_generator(y), output_types= tf.float32, output_shapes=(None, ))
dataset = tf.data.Dataset.zip((X_dataset, Z_dataset, y_dataset))
现在我想继续bucket_by_sequence_length
,dataset
但出现以下错误(请参阅摘录):
(...)
2686 def element_to_bucket_id(*args):
2687 """Return int64 id of the length bucket for this element."""
-> 2688 seq_length = element_length_func(*args)
2689
2690 boundaries = list(bucket_boundaries)
TypeError: <lambda>() takes 1 positional argument but 3 were given
由于 tf 抱怨element_length_func
它被定义为采用一个参数(一个包含 3 个元素的元组),但收到了三个参数(一个元组被展开),我试图更改长度函数的实现:
def get_len(X, Z, y):
return X.shape[0]
dataset.bucket_by_sequence_length(element_length_func=get_len,
bucket_boundaries=[15],
bucket_batch_sizes=[8, 8])
但是也没有成功,导致报错:
ValueError: Tried to convert 'y' to a tensor and failed. Error: None values not supported.
不幸的是,这两个错误的回溯真的很长,所以我决定修剪它们,但这个例子应该很容易重现。
bucket_by_sequence_length
因此,我的问题是 -当我的数据集包含不止一个或两个组件时,我该如何使用?
只使用动态形状tf.shape
而get_len
不是静态形状.shape
。
import numpy as np
import tensorflow as tf
np.random.seed(42)
X = []
Z = []
y = []
for i in range(100):
obs_len = np.random.randint(5, 25)
X.append(np.random.random(size=[obs_len, 4]))
Z.append(np.random.random(size=[obs_len, 1]))
y.append(np.random.randint(0, 2, size=[obs_len,]))
def create_generator(list_of_arrays):
for i in list_of_arrays:
yield i
X_dataset = tf.data.Dataset.from_generator(lambda: create_generator(X), output_types= tf.float32, output_shapes=(None, 4))
Z_dataset = tf.data.Dataset.from_generator(lambda: create_generator(Z), output_types= tf.float32, output_shapes=(None, 1))
y_dataset = tf.data.Dataset.from_generator(lambda: create_generator(y), output_types= tf.float32, output_shapes=(None, ))
dataset = tf.data.Dataset.zip((X_dataset, Z_dataset, y_dataset))
def get_len(X, Z, y):
return tf.shape(X)[0]
bucket_dataset = dataset.bucket_by_sequence_length(
element_length_func=get_len,
bucket_boundaries=[15],
bucket_batch_sizes=[8, 8]
)
next(iter(bucket_dataset))
作者:黑洞官方问答小能手
链接:https://www.pythonheidong.com/blog/article/1895302/0e3d6a4e78c0b5bbe52f/
来源:python黑洞网
任何形式的转载都请注明出处,如有侵权 一经发现 必将追究其法律责任
昵称:
评论内容:(最多支持255个字符)
---无人问津也好,技不如人也罢,你都要试着安静下来,去做自己该做的事,而不是让内心的烦躁、焦虑,坏掉你本来就不多的热情和定力
Copyright © 2018-2021 python黑洞网 All Rights Reserved 版权所有,并保留所有权利。 京ICP备18063182号-1
投诉与举报,广告合作请联系vgs_info@163.com或QQ3083709327
免责声明:网站文章均由用户上传,仅供读者学习交流使用,禁止用做商业用途。若文章涉及色情,反动,侵权等违法信息,请向我们举报,一经核实我们会立即删除!