You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
54 lines
1.5 KiB
54 lines
1.5 KiB
import tensorflow as tf
|
|
from tensorflow.keras import layers, models
|
|
from tensorflow.keras.preprocessing.image import ImageDataGenerator
|
|
|
|
# 设置训练和验证数据目录
|
|
train_dir = 'train_data_directory'
|
|
validation_dir = 'validation_data_directory'
|
|
|
|
# 定义模型
|
|
model = models.Sequential([
|
|
layers.Conv2D(32, (3, 3), activation='relu', input_shape=(150, 150, 3)),
|
|
layers.MaxPooling2D((2, 2)),
|
|
layers.Conv2D(64, (3, 3), activation='relu'),
|
|
layers.MaxPooling2D((2, 2)),
|
|
layers.Conv2D(128, (3, 3), activation='relu'),
|
|
layers.MaxPooling2D((2, 2)),
|
|
layers.Conv2D(128, (3, 3), activation='relu'),
|
|
layers.MaxPooling2D((2, 2)),
|
|
layers.Flatten(),
|
|
layers.Dense(512, activation='relu'),
|
|
layers.Dense(1, activation='sigmoid')
|
|
])
|
|
|
|
# 编译模型
|
|
model.compile(optimizer='adam',
|
|
loss='binary_crossentropy',
|
|
metrics=['accuracy'])
|
|
|
|
# 数据预处理和增强
|
|
train_datagen = ImageDataGenerator(rescale=1./255)
|
|
test_datagen = ImageDataGenerator(rescale=1./255)
|
|
|
|
train_generator = train_datagen.flow_from_directory(
|
|
train_dir,
|
|
target_size=(150, 150),
|
|
batch_size=20,
|
|
class_mode='binary')
|
|
|
|
validation_generator = test_datagen.flow_from_directory(
|
|
validation_dir,
|
|
target_size=(150, 150),
|
|
batch_size=20,
|
|
class_mode='binary')
|
|
|
|
# 训练模型
|
|
history = model.fit_generator(
|
|
train_generator,
|
|
steps_per_epoch=100,
|
|
epochs=30,
|
|
validation_data=validation_generator,
|
|
validation_steps=50)
|
|
|
|
# 保存模型
|
|
model.save('human_detection_model.h5')
|
|
|