記事の概要
TensorFlow特有のデータ保存形式「tfrecord」で保存したデータセットについて,何件のデータがtfrecordに圧縮されているかを確認する方法を記載しました.
単純にlen(tf.data.TFRecordDataset(filename))
では取得できないため,自分と同じように困っている人もいるのではと思い執筆しました.
想定ケース
以下のディレクトリ構造のように,「3分割に作成したtfrecord(record1.tfrec)」と「tfrecord内のデータの数をチェックするスクリプト(count.py)」が配置されているケースを考えます.
コード
N
に全tfrecordファイルに格納されているレコードの総数の情報が返されます.
import tensorflow as tf
FILENAMES = ['data/record1.tfrec',
'data/record2.tfrec',
'data/record3.tfrec']
N = 0
for fn in FILENAMES:
N += sum(1 for _ in tf.data.TFRecordDataset(fn))