深度学习入门之猫vs狗(超简单),
深度学习入门之猫vs狗(超简单),
学习深度学习需要从简单模型入手,可以选择手写字识别或者猫vs狗数据入手。
这篇文章从猫和狗的识别入手对深度学习有一个简单的认知, 最后可以输入自己的图片做测试。
文章结构如下:
深度学习框架
文章用的是tensorflow2.0版本的深度学习框架。所以开始之前需要下载python,安装tensorflow2.0的库。
1.猫狗数据集
数据集导入
tensorflow_datasets中有许多数据集,我们训练用的猫狗数据集就从tensorflow_datasets中引入即可。
import tensorflow_datasets as tfds
tfds.disable_progress_bar()
SPLIT_WEIGHTS = (8, 1, 1) # 将数据集按8:1:1分为训练集,验证集,测试集。
splits = tfds.Split.TRAIN.subsplit(weighted=SPLIT_WEIGHTS)
# 加载数据集
(raw_train, raw_validation, raw_test), metadata = tfds.load('cats_vs_dogs',
split=list(splits),with_info=True,
as_supervised=True)
print(raw_train)
print(raw_validation)
print(raw_test)
打印的结果是训练集,验证集,测试集的shape,它们都是三通道的数组与标签组成,形如((None, None, 3), (1)),如下所示:
<_OptionsDataset shapes: ((None, None, 3), ()), types: (tf.uint8, tf.int64)>
<_OptionsDataset shapes: ((None, None, 3), ()), types: (tf.uint8, tf.int64)>
<_OptionsDataset shapes: ((None, None, 3), ()), types: (tf.uint8, tf.int64)>
数据查看
先看看数据长啥样吧。
代码中metadata是数据自带的元数据,这里是cat或者dog。可以不用关心这个数据。
raw_train.take(2)表示取两个数据。
get_label_name = metadata.features['label'].int2str
for image, label in raw_train.take(2):
print(“label = ”,label)
plt.figure()
plt.imshow(image)
plt.title(get_label_name(label))
输出的tf.Tensor(1, shape=(), dtype=int64)表示这是一个tensor变量(不清楚可以百度),并且值为1,这里1是狗,0是猫。
如下所示:
至此,数据集的载入完成,接下来是对数据做一些预处理,做归一化与resize等。
2. 数据预处理
由于图片大小不一样,所以我们在训练时需要将其resize到一个大小,以便使用批处理,至于为什么用批处理,可以看数据预处理部分。
注意这里对数据做的几个处理,一是将数据类型变为float32(tensorflow模型输入的统一数据类型), 二是做归一化,将输入图片的像素值变为(0,1), 最后是做resize。
tensorflow2.0对数据的处理方面也很人性化,很容易理解,如下的raw_train为一个dataset类型,可以调用.map对它做映射,.shuffle打乱顺序,.batch设置批量。对自己的数据集做预处理载入可以看说明文档。
# 图片预处理
IMG_SIZE = 160 # All images will be resized to 160x160
def format_example(image, label):
image = tf.cast(image, tf.float32)
image = (image/127.5) - 1
image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))
return image, label
train = raw_train.map(format_example)
validation = raw_validation.map(format_example)
test = raw_test.map(format_example)
BATCH_SIZE = 32
SHUFFLE_BUFFER_SIZE = 1000
train_batches = train.shuffle(SHUFFLE_BUFFER_SIZE).batch(BATCH_SIZE)
validation_batches = validation.batch(BATCH_SIZE)
test_batches = test.batch(BATCH_SIZE)
3. 模型加载与训练
上面做好了模型的输入后,接下来就要进行模型的加载与训练了。模型加载完了,激动人心的测试阶段还会远吗。
加载内置模型作为基础模型
先加载MobileNetV2模型作为卷积层。
加载tensorflow内置的模型方法,以及参数可以参看模型加载部分。
# 加载模型
import tensorflow as tf
IMG_SHAPE = (IMG_SIZE, IMG_SIZE, 3)
# 以 MobileNet V2为基础模型。
base_model = tf.keras.applications.MobileNetV2(input_shape=IMG_SHAPE,
include_top=False,
weights='imagenet')
feature_batch = base_model(image_batch)
添加全连接层
在上述模型的基础上添加全连接层。
在制作模型时涉及的函数参数可以参看tensorflow2.0 api。其中Sequential是一个序列,模型按照序列中的顺序执行。
# 修改模型
global_average_layer = tf.keras.layers.GlobalAveragePooling2D()
# 分类层, Dense中的参数为输出的类别数量,这里分1类,即只识别狗。
prediction_layer = keras.layers.Dense(1)
model = tf.keras.Sequential([
base_model,
global_average_layer,
prediction_layer
])
模型训练
tensorflow2.0的模型训练过程就特别人性化,特别简单了,很多参数已经默认,一般需要注意输入与输出即可,这里输入用dataset,输出为预测值。后面再细看输出是什么。
base_learning_rate = 0.0001 # 学习率,代表每次优化的大小,一般1e-3与1e-4比较合适
initial_epochs = 5 # 训练的轮数
model.compile(optimizer=tf.keras.optimizers.RMSprop(lr=base_learning_rate),
loss='binary_crossentropy',
metrics=['accuracy'])
model.fit(train_batches,
epochs=initial_epochs,
validation_data=validation_batches)
weight_path = os.path.join('cat_vs_dogs')
model.save_weights(weight_path) # 保存模型参数
4. 输入自己的数据做测试
激动人心的时刻终于来了,接下来用自己的照片验证模型识别效果,我百度了几张猫和狗的照片作为例子验证。
from PIL import Image
weight_path = os.path.join('cat_vs_dogs')
model.load_weights(weight_path)
def format_test(image):
#对输入做预处理
image = np.array(image)
image = tf.cast(image, tf.float32)
image = (image/127.5) - 1
image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))
return image
image = Image.open("dog2.jpg")
plt.imshow(image)
test_image = format_test(image)
test_image = test_image[np.newaxis,:,:,:] # 给输入增加一个维度变为[1, h, w, c], 这里batch为1
pred = model.predict(test_image)
print(pred)
# 结果大于0为狗,小于0为猫
pred[pred>0]=1
pred[pred<=0]=0
print(pred)
结果中的predict为[batch, 1]维数组,在文章的例子中,batch取1,所以一次训练一张图。同时,由于二分类可以只识别狗,概率大于0的是狗,小于0的是猫, 这是模型中的分类层决定的。
结果:
猫狗识别是比较简单的分类网络,输入的图片与label分别为n通道的图与数字, 输出的为各个类别的概率列表。
评论暂时关闭