Tensorflow 2.0 新手向- 壹之型

Source: Deep Learning on Medium

tf.data.Dataset

在 tensorflow中要建立一個 input pipeline 非常適合用 Dataset 來完成,也提供了一些好用的method

  • map: 可以定義自己的前處理 function
  • batch: 批次執行你的輸入資料
  • shuffle: 打散你的輸入資料排序

在 tf2.0 中要注意一下使用方式

Datasets are iterables (not iterators), and work just like other Python iterables in Eager mode

# define your dataset
(x, y), _ = tf.keras.datasets.mnist.load_data()
ds = tf.data.Dataset.from_tensor_slices((x, y))
# work like iterable
for x, y in ds:
......

但之前在使用 alpha 版本的時候也遇到一些問題,之前在訓練一個 model,結果一直很差,後來把資料印出來才發現雖然有經過 shuffle 但每次的 shuffle 結果都一樣,後來才發現在有人已經提交這個問題,那時候這個問題找超久,果然測試版本真的要一直踩雷。

然後可能在寫說明文件的人有時也不是非常仔細,所以如果照著官方的 tutorial 執行一般使用上都沒問題

train_data = mnist_train.map(scale).shuffle(BUFFER_SIZE).batch(BATCH_SIZE).take(5)
test_data = mnist_test.map(scale).batch(BATCH_SIZE).take(5)

但那時在測速度的時候發現到整個跑起來很慢,後來去翻 Dataset 的文件才發現原來 batch 和 map 的順序有很大的影響。

if map does little work, this overhead can dominate the total cost. In such cases, we recommend vectorizing the user-defined function (that is, have it operate over a batch of inputs at once) and apply thebatch transformation before the map transformation.

所以像這個例子 map 只做一些簡單的型態轉換還有正規化,就應該先 batch 再 map,我用 mnist 資料集測試,只有改了這個順序之後,執行時間就快了50%。

如果對 tf 2.0 有問題的,都歡迎留言一起討論,有興趣的朋友可以 follow,之後預計也會推出 tf 2.0 實際應用的系列文章。