bugfix> tensorflow > 投稿

私は cache を使用しようとしています dataset の変換 。ここに私の現在のコード(簡略化)があります:

dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=1)
dataset = dataset.apply(tf.contrib.data.shuffle_and_repeat(buffer_size=5000, count=1))
dataset = dataset.map(_parser_a, num_parallel_calls=12)
dataset = dataset.padded_batch(
    20, 
    padded_shapes=padded_shapes,
    padding_values=padding_values
)
dataset = dataset.prefetch(buffer_size=1)
dataset = dataset.cache()

最初のエポックの後、次のエラーメッセージが表示されました。

The calling iterator did not fully read the dataset we were attempting to cache. In order to avoid unexpected truncation of the sequence, the current [partially cached] sequence will be dropped. This can occur if you have a sequence similar to dataset.cache().take(k).repeat() 。 代わりに、順序を入れ替えます(つまり、 dataset.take(k).cache().repeat()

その後、コードは続行し、キャッシュではなくハードドライブからデータを読み取りました。だから、どこに dataset.cache() を置くべきですかエラーを回避するには? ありがとう。

回答 1 件
  • Dataset.cache() の実装  変換は非常に簡単です:繰り返し処理を行うときに、それを通過する要素のリストを作成します完全に それは初めてであり、それを反復する後続の試行でそのリストから要素を返します。最初のパスでのみ実行される場合部分的 データを渡すとリストは不完全になり、TensorFlowはキャッシュされたデータを使用しようとしません。残りの要素が必要かどうかがわからず、一般に計算するためにすべての先行要素を再処理する必要があるからです。残りの要素。

    データセット全体を使用するようにプログラムを変更し、 tf.errors.OutOfRangeError まで反復処理する  が発生すると、キャッシュにはデータセット内の要素の完全なリストが含まれ、以降のすべての反復で使用されます。

あなたの答え