目录
简介
下载 LOL 数据集
创建 TensorFlow 数据集
MIRNet 模型
选择性核特征融合
双注意单元
多尺度残差块
MIRNet 模型
训练
推论
测试图像推理
政安晨的个人主页:政安晨
欢迎 👍点赞✍评论⭐收藏
收录专栏: TensorFlow与Keras机器学习实战
希望政安晨的博客能够对您有所裨益,如有不足之处,欢迎在评论区提出指正!
本文目标:实现用于弱光图像增强的 MIRNet 架构。
简介
图像复原的目标是从劣质版本中恢复出高质量的图像内容,因此在摄影、安防、医疗成像和遥感等领域应用广泛。
在本示例中,我们实现了用于弱光图像增强的 MIRNet 模型,这是一种全卷积架构,可学习一组丰富的特征,结合来自多个尺度的上下文信息,同时保留高分辨率的空间细节。
下载 LOL 数据集
LoL 数据集是为弱光图像增强而创建的。该数据集提供 485 幅图像用于训练,15 幅图像用于测试。数据集中的每对图像都由低照度输入图像和相应的曝光良好的参考图像组成。
import osos.environ["KERAS_BACKEND"] = "tensorflow"import random
import numpy as np
from glob import glob
from PIL import Image, ImageOps
import matplotlib.pyplot as pltimport keras
from keras import layersimport tensorflow as tf
演绎展示:
!wget https://huggingface.co/datasets/geekyrakshit/LoL-Dataset/resolve/main/lol_dataset.zip
!unzip -q lol_dataset.zip && rm lol_dataset.zip
--2023-11-10 23:10:00-- https://huggingface.co/datasets/geekyrakshit/LoL-Dataset/resolve/main/lol_dataset.zip
Resolving huggingface.co (huggingface.co)... 3.163.189.74, 3.163.189.37, 3.163.189.114, ...
Connecting to huggingface.co (huggingface.co)|3.163.189.74|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs.huggingface.co/repos/d9/09/d909ef7668bb417b7065a311bd55a3084cc83a1f918e13cb41c5503328432db2/419fddc48958cd0f5599939ee0248852a37ceb8bb738c9b9525e95b25a89de9a?response-content-disposition=attachment%3B+filename*%3DUTF-8%27%27lol_dataset.zip%3B+filename%3D%22lol_dataset.zip%22%3B&response-content-type=application%2Fzip&Expires=1699917000&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTY5OTkxNzAwMH19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5odWdnaW5nZmFjZS5jby9yZXBvcy9kOS8wOS9kOTA5ZWY3NjY4YmI0MTdiNzA2NWEzMTFiZDU1YTMwODRjYzgzYTFmOTE4ZTEzY2I0MWM1NTAzMzI4NDMyZGIyLzQxOWZkZGM0ODk1OGNkMGY1NTk5OTM5ZWUwMjQ4ODUyYTM3Y2ViOGJiNzM4YzliOTUyNWU5NWIyNWE4OWRlOWE%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qJnJlc3BvbnNlLWNvbnRlbnQtdHlwZT0qIn1dfQ__&Signature=xyZ1oUBOnWdy6-vCAFzqZsDMetsPu6OSluyOoTS%7EKRZ6lvAy8yUwQgp5WjcZGJ7Jnex0IdnsPiUzsxaxjM-eZjUcQGPdGj4WhSV5DUBxr8xkwTEospYSg1fX%7EE2I1KkP9gBsXvinsKIOAZzchbg9f28xxdlvTbZ0h4ndcUfbDPknwlU1CIZNa5qjU6NqLMH2bPQmI1AIVau2DgQC%7E1n2dgTZsMfHTVmoM2ivsAl%7E9XgQ3m247ke2aj5BmgssZF52VWKTE-vwYDtbuiem73pS6gS-dZlmXYPE1OSRr2tsDo1cgPEBBtuK3hEnYcOq8jjEZk3AEAbFAJoHKLVIERZ30g__&Key-Pair-Id=KVTP0A1DKRTAX [following]
--2023-11-10 23:10:00-- https://cdn-lfs.huggingface.co/repos/d9/09/d909ef7668bb417b7065a311bd55a3084cc83a1f918e13cb41c5503328432db2/419fddc48958cd0f5599939ee0248852a37ceb8bb738c9b9525e95b25a89de9a?response-content-disposition=attachment%3B+filename*%3DUTF-8%27%27lol_dataset.zip%3B+filename%3D%22lol_dataset.zip%22%3B&response-content-type=application%2Fzip&Expires=1699917000&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTY5OTkxNzAwMH19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5odWdnaW5nZmFjZS5jby9yZXBvcy9kOS8wOS9kOTA5ZWY3NjY4YmI0MTdiNzA2NWEzMTFiZDU1YTMwODRjYzgzYTFmOTE4ZTEzY2I0MWM1NTAzMzI4NDMyZGIyLzQxOWZkZGM0ODk1OGNkMGY1NTk5OTM5ZWUwMjQ4ODUyYTM3Y2ViOGJiNzM4YzliOTUyNWU5NWIyNWE4OWRlOWE%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qJnJlc3BvbnNlLWNvbnRlbnQtdHlwZT0qIn1dfQ__&Signature=xyZ1oUBOnWdy6-vCAFzqZsDMetsPu6OSluyOoTS%7EKRZ6lvAy8yUwQgp5WjcZGJ7Jnex0IdnsPiUzsxaxjM-eZjUcQGPdGj4WhSV5DUBxr8xkwTEospYSg1fX%7EE2I1KkP9gBsXvinsKIOAZzchbg9f28xxdlvTbZ0h4ndcUfbDPknwlU1CIZNa5qjU6NqLMH2bPQmI1AIVau2DgQC%7E1n2dgTZsMfHTVmoM2ivsAl%7E9XgQ3m247ke2aj5BmgssZF52VWKTE-vwYDtbuiem73pS6gS-dZlmXYPE1OSRr2tsDo1cgPEBBtuK3hEnYcOq8jjEZk3AEAbFAJoHKLVIERZ30g__&Key-Pair-Id=KVTP0A1DKRTAX
Resolving cdn-lfs.huggingface.co (cdn-lfs.huggingface.co)... 108.138.94.122, 108.138.94.14, 108.138.94.25, ...
Connecting to cdn-lfs.huggingface.co (cdn-lfs.huggingface.co)|108.138.94.122|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 347171015 (331M) [application/zip]
Saving to: ‘lol_dataset.zip’
lol_dataset.zip 100%[===================>] 331.09M 316MB/s in 1.0s
2023-11-10 23:10:01 (316 MB/s) - ‘lol_dataset.zip’ saved [347171015/347171015]
创建 TensorFlow 数据集
我们使用 LoL 数据集训练集中的 300 对图像进行训练,并使用剩余的 185 对图像进行验证。我们从用于训练和验证的图像对中随机生成大小为 128 x 128 的裁剪。
random.seed(10)IMAGE_SIZE = 128
BATCH_SIZE = 4
MAX_TRAIN_IMAGES = 300def read_image(image_path):image = tf.io.read_file(image_path)image = tf.image.decode_png(image, channels=3)image.set_shape([None, None, 3])image = tf.cast(image, dtype=tf.float32) / 255.0return imagedef random_crop(low_image, enhanced_image):low_image_shape = tf.shape(low_image)[:2]low_w = tf.random.uniform(shape=(), maxval=low_image_shape[1] - IMAGE_SIZE + 1, dtype=tf.int32)low_h = tf.random.uniform(shape=(), maxval=low_image_shape[0] - IMAGE_SIZE + 1, dtype=tf.int32)low_image_cropped = low_image[low_h : low_h + IMAGE_SIZE, low_w : low_w + IMAGE_SIZE]enhanced_image_cropped = enhanced_image[low_h : low_h + IMAGE_SIZE, low_w : low_w + IMAGE_SIZE]# in order to avoid `NONE` during shape inferencelow_image_cropped.set_shape([IMAGE_SIZE, IMAGE_SIZE, 3])enhanced_image_cropped.set_shape([IMAGE_SIZE, IMAGE_SIZE, 3])return low_image_cropped, enhanced_image_croppeddef load_data(low_light_image_path, enhanced_image_path):low_light_image = read_image(low_light_image_path)enhanced_image = read_image(enhanced_image_path)low_light_image, enhanced_image = random_crop(low_light_image, enhanced_image)return low_light_image, enhanced_imagedef get_dataset(low_light_images, enhanced_images):dataset = tf.data.Dataset.from_tensor_slices((low_light_images, enhanced_images))dataset = dataset.map(load_data, num_parallel_calls=tf.data.AUTOTUNE)dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)return datasettrain_low_light_images = sorted(glob("./lol_dataset/our485/low/*"))[:MAX_TRAIN_IMAGES]
train_enhanced_images = sorted(glob("./lol_dataset/our485/high/*"))[:MAX_TRAIN_IMAGES]val_low_light_images = sorted(glob("./lol_dataset/our485/low/*"))[MAX_TRAIN_IMAGES:]
val_enhanced_images = sorted(glob("./lol_dataset/our485/high/*"))[MAX_TRAIN_IMAGES:]test_low_light_images = sorted(glob("./lol_dataset/eval15/low/*"))
test_enhanced_images = sorted(glob("./lol_dataset/eval15/high/*"))train_dataset = get_dataset(train_low_light_images, train_enhanced_images)
val_dataset = get_dataset(val_low_light_images, val_enhanced_images)print("Train Dataset:", train_dataset.element_spec)
print("Val Dataset:", val_dataset.element_spec)
Train Dataset: (TensorSpec(shape=(4, 128, 128, 3), dtype=tf.float32, name=None), TensorSpec(shape=(4, 128, 128, 3), dtype=tf.float32, name=None))
Val Dataset: (TensorSpec(shape=(4, 128, 128, 3), dtype=tf.float32, name=None), TensorSpec(shape=(4, 128, 128, 3), dtype=tf.float32, name=None))
MIRNet 模型
以下是 MIRNet 模型的主要特点:
× 一种特征提取模型,可计算多空间尺度的互补特征集,同时保留原始的高分辨率特征,以保存精确的空间细节。
× 一种定期重复的信息交换机制,将多分辨率分支的特征逐步融合在一起,以改进表征学习。
× 一种利用选择性内核网络融合多尺度特征的新方法,它能动态地结合可变感受野,并在每个空间分辨率上忠实地保留原始特征信息。
× 递归残差设计可逐步分解输入信号,从而简化整个学习过程,并可构建非常深入的网络。
选择性核特征融合
选择性内核特征融合或 SKFF 模块通过两种操作对感受野进行动态调整:融合和选择。Fuse 运算符通过组合多分辨率流的信息生成全局特征描述符。选择运算器使用这些描述符来重新校准(不同数据流的)特征图,然后进行汇总。
Fuse:SKFF 接收来自三个并行卷积流的输入,这些卷积流携带不同尺度的信息。
首先,我们使用元素和将这些多尺度特征组合起来,然后在空间维度上应用全局平均池化(GAP)。
然后,我们应用一个通道降尺度卷积层,生成一个紧凑的特征表示,再通过三个并行的通道升尺度卷积层(每个分辨率流一个),为我们提供三个特征描述符。
Select: 该运算器将软最大函数应用于特征描述符,以获得相应的激活度,用于自适应地重新校准多尺度特征图。聚合特征被定义为相应的多尺度特征与特征描述符的乘积之和。
def selective_kernel_feature_fusion(multi_scale_feature_1, multi_scale_feature_2, multi_scale_feature_3
):channels = list(multi_scale_feature_1.shape)[-1]combined_feature = layers.Add()([multi_scale_feature_1, multi_scale_feature_2, multi_scale_feature_3])gap = layers.GlobalAveragePooling2D()(combined_feature)channel_wise_statistics = layers.Reshape((1, 1, channels))(gap)compact_feature_representation = layers.Conv2D(filters=channels // 8, kernel_size=(1, 1), activation="relu")(channel_wise_statistics)feature_descriptor_1 = layers.Conv2D(channels, kernel_size=(1, 1), activation="softmax")(compact_feature_representation)feature_descriptor_2 = layers.Conv2D(channels, kernel_size=(1, 1), activation="softmax")(compact_feature_representation)feature_descriptor_3 = layers.Conv2D(channels, kernel_size=(1, 1), activation="softmax")(compact_feature_representation)feature_1 = multi_scale_feature_1 * feature_descriptor_1feature_2 = multi_scale_feature_2 * feature_descriptor_2feature_3 = multi_scale_feature_3 * feature_descriptor_3aggregated_feature = layers.Add()([feature_1, feature_2, feature_3])return aggregated_feature
双注意单元
双注意单元(DAU)用于提取卷积流中的特征。在 SKFF 模块融合多分辨率分支信息的同时,我们还需要一种机制来共享特征张量中的信息,包括空间维度和信道维度的信息,而这正是 DAU 模块所要做的。DAU 会抑制不那么有用的特征,只允许信息量更大的特征进一步通过。这种特征重新校准是通过使用通道注意和空间注意机制来实现的。
通道关注分支通过应用挤压和激励操作,利用卷积特征图的通道间关系。
给定一个特征图后,挤压运算会在空间维度上应用全局平均池化(Global Average Pooling)来编码全局上下文,从而得到一个特征描述符。
激励操作将该特征描述符通过两个卷积层,然后进行 sigmoid 门控并产生激活。
最后,用输出激活对输入特征图进行重新缩放,得到通道注意分支的输出。
class ChannelPooling(layers.Layer):def __init__(self, axis=-1, *args, **kwargs):super().__init__(*args, **kwargs)self.axis = axisself.concat = layers.Concatenate(axis=self.axis)def call(self, inputs):average_pooling = tf.expand_dims(tf.reduce_mean(inputs, axis=-1), axis=-1)max_pooling = tf.expand_dims(tf.reduce_max(inputs, axis=-1), axis=-1)return self.concat([average_pooling, max_pooling])def get_config(self):config = super().get_config()config.update({"axis": self.axis})def spatial_attention_block(input_tensor):compressed_feature_map = ChannelPooling(axis=-1)(input_tensor)feature_map = layers.Conv2D(1, kernel_size=(1, 1))(compressed_feature_map)feature_map = keras.activations.sigmoid(feature_map)return input_tensor * feature_mapdef channel_attention_block(input_tensor):channels = list(input_tensor.shape)[-1]average_pooling = layers.GlobalAveragePooling2D()(input_tensor)feature_descriptor = layers.Reshape((1, 1, channels))(average_pooling)feature_activations = layers.Conv2D(filters=channels // 8, kernel_size=(1, 1), activation="relu")(feature_descriptor)feature_activations = layers.Conv2D(filters=channels, kernel_size=(1, 1), activation="sigmoid")(feature_activations)return input_tensor * feature_activationsdef dual_attention_unit_block(input_tensor):channels = list(input_tensor.shape)[-1]feature_map = layers.Conv2D(channels, kernel_size=(3, 3), padding="same", activation="relu")(input_tensor)feature_map = layers.Conv2D(channels, kernel_size=(3, 3), padding="same")(feature_map)channel_attention = channel_attention_block(feature_map)spatial_attention = spatial_attention_block(feature_map)concatenation = layers.Concatenate(axis=-1)([channel_attention, spatial_attention])concatenation = layers.Conv2D(channels, kernel_size=(1, 1))(concatenation)return layers.Add()([input_tensor, concatenation])
多尺度残差块
多尺度残差块能够通过保持高分辨率表征生成空间精确输出,同时从低分辨率接收丰富的上下文信息。MRB 由多个(本文中为三个)并行连接的全分辨率流组成。它允许在并行流之间交换信息,以便在低分辨率特征的帮助下巩固高分辨率特征,反之亦然。MIRNet 采用递归残差设计(跳过连接),以简化学习过程中的信息流。为了保持我们架构的残差性质,残差大小调整模块用于执行多尺度残差块中使用的降采样和升采样操作。
# Recursive Residual Modulesdef down_sampling_module(input_tensor):channels = list(input_tensor.shape)[-1]main_branch = layers.Conv2D(channels, kernel_size=(1, 1), activation="relu")(input_tensor)main_branch = layers.Conv2D(channels, kernel_size=(3, 3), padding="same", activation="relu")(main_branch)main_branch = layers.MaxPooling2D()(main_branch)main_branch = layers.Conv2D(channels * 2, kernel_size=(1, 1))(main_branch)skip_branch = layers.MaxPooling2D()(input_tensor)skip_branch = layers.Conv2D(channels * 2, kernel_size=(1, 1))(skip_branch)return layers.Add()([skip_branch, main_branch])def up_sampling_module(input_tensor):channels = list(input_tensor.shape)[-1]main_branch = layers.Conv2D(channels, kernel_size=(1, 1), activation="relu")(input_tensor)main_branch = layers.Conv2D(channels, kernel_size=(3, 3), padding="same", activation="relu")(main_branch)main_branch = layers.UpSampling2D()(main_branch)main_branch = layers.Conv2D(channels // 2, kernel_size=(1, 1))(main_branch)skip_branch = layers.UpSampling2D()(input_tensor)skip_branch = layers.Conv2D(channels // 2, kernel_size=(1, 1))(skip_branch)return layers.Add()([skip_branch, main_branch])# MRB Block
def multi_scale_residual_block(input_tensor, channels):# featureslevel1 = input_tensorlevel2 = down_sampling_module(input_tensor)level3 = down_sampling_module(level2)# DAUlevel1_dau = dual_attention_unit_block(level1)level2_dau = dual_attention_unit_block(level2)level3_dau = dual_attention_unit_block(level3)# SKFFlevel1_skff = selective_kernel_feature_fusion(level1_dau,up_sampling_module(level2_dau),up_sampling_module(up_sampling_module(level3_dau)),)level2_skff = selective_kernel_feature_fusion(down_sampling_module(level1_dau),level2_dau,up_sampling_module(level3_dau),)level3_skff = selective_kernel_feature_fusion(down_sampling_module(down_sampling_module(level1_dau)),down_sampling_module(level2_dau),level3_dau,)# DAU 2level1_dau_2 = dual_attention_unit_block(level1_skff)level2_dau_2 = up_sampling_module((dual_attention_unit_block(level2_skff)))level3_dau_2 = up_sampling_module(up_sampling_module(dual_attention_unit_block(level3_skff)))# SKFF 2skff_ = selective_kernel_feature_fusion(level1_dau_2, level2_dau_2, level3_dau_2)conv = layers.Conv2D(channels, kernel_size=(3, 3), padding="same")(skff_)return layers.Add()([input_tensor, conv])
MIRNet 模型
def recursive_residual_group(input_tensor, num_mrb, channels):conv1 = layers.Conv2D(channels, kernel_size=(3, 3), padding="same")(input_tensor)for _ in range(num_mrb):conv1 = multi_scale_residual_block(conv1, channels)conv2 = layers.Conv2D(channels, kernel_size=(3, 3), padding="same")(conv1)return layers.Add()([conv2, input_tensor])def mirnet_model(num_rrg, num_mrb, channels):input_tensor = keras.Input(shape=[None, None, 3])x1 = layers.Conv2D(channels, kernel_size=(3, 3), padding="same")(input_tensor)for _ in range(num_rrg):x1 = recursive_residual_group(x1, num_mrb, channels)conv = layers.Conv2D(3, kernel_size=(3, 3), padding="same")(x1)output_tensor = layers.Add()([input_tensor, conv])return keras.Model(input_tensor, output_tensor)model = mirnet_model(num_rrg=3, num_mrb=2, channels=64)
训练
我们使用 Charbonnier Loss 作为损失函数,并使用学习率为 1e-4 的 Adam Optimizer 来训练 MIRNet。
我们使用峰值信号噪声比(PSNR)作为衡量指标,它是信号的最大可能值(功率)与影响信号表示质量的扭曲噪声功率之间的比值表达式。
def charbonnier_loss(y_true, y_pred):return tf.reduce_mean(tf.sqrt(tf.square(y_true - y_pred) + tf.square(1e-3)))def peak_signal_noise_ratio(y_true, y_pred):return tf.image.psnr(y_pred, y_true, max_val=255.0)optimizer = keras.optimizers.Adam(learning_rate=1e-4)
model.compile(optimizer=optimizer,loss=charbonnier_loss,metrics=[peak_signal_noise_ratio],
)history = model.fit(train_dataset,validation_data=val_dataset,epochs=50,callbacks=[keras.callbacks.ReduceLROnPlateau(monitor="val_peak_signal_noise_ratio",factor=0.5,patience=5,verbose=1,min_delta=1e-7,mode="max",)],
)def plot_history(value, name):plt.plot(history.history[value], label=f"train_{name.lower()}")plt.plot(history.history[f"val_{value}"], label=f"val_{name.lower()}")plt.xlabel("Epochs")plt.ylabel(name)plt.title(f"Train and Validation {name} Over Epochs", fontsize=14)plt.legend()plt.grid()plt.show()plot_history("loss", "Loss")
plot_history("peak_signal_noise_ratio", "PSNR")
Epoch 1/50WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1699658204.480352 77759 device_compiler.h:187] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process.75/75 ━━━━━━━━━━━━━━━━━━━━ 445s 686ms/step - loss: 0.2162 - peak_signal_noise_ratio: 61.5549 - val_loss: 0.1358 - val_peak_signal_noise_ratio: 65.2699 - learning_rate: 1.0000e-04
Epoch 2/5075/75 ━━━━━━━━━━━━━━━━━━━━ 29s 385ms/step - loss: 0.1745 - peak_signal_noise_ratio: 63.1785 - val_loss: 0.1237 - val_peak_signal_noise_ratio: 65.8360 - learning_rate: 1.0000e-04
Epoch 3/5075/75 ━━━━━━━━━━━━━━━━━━━━ 29s 386ms/step - loss: 0.1681 - peak_signal_noise_ratio: 63.4903 - val_loss: 0.1205 - val_peak_signal_noise_ratio: 65.9048 - learning_rate: 1.0000e-04
Epoch 4/5075/75 ━━━━━━━━━━━━━━━━━━━━ 29s 385ms/step - loss: 0.1668 - peak_signal_noise_ratio: 63.4793 - val_loss: 0.1185 - val_peak_signal_noise_ratio: 66.0290 - learning_rate: 1.0000e-04
Epoch 5/5075/75 ━━━━━━━━━━━━━━━━━━━━ 29s 383ms/step - loss: 0.1564 - peak_signal_noise_ratio: 63.9205 - val_loss: 0.1217 - val_peak_signal_noise_ratio: 66.1207 - learning_rate: 1.0000e-04
Epoch 6/5075/75 ━━━━━━━━━━━━━━━━━━━━ 29s 384ms/step - loss: 0.1601 - peak_signal_noise_ratio: 63.9336 - val_loss: 0.1166 - val_peak_signal_noise_ratio: 66.6102 - learning_rate: 1.0000e-04
Epoch 7/5075/75 ━━━━━━━━━━━━━━━━━━━━ 29s 385ms/step - loss: 0.1600 - peak_signal_noise_ratio: 63.9043 - val_loss: 0.1335 - val_peak_signal_noise_ratio: 65.5639 - learning_rate: 1.0000e-04
Epoch 8/5075/75 ━━━━━━━━━━━━━━━━━━━━ 29s 382ms/step - loss: 0.1609 - peak_signal_noise_ratio: 64.0606 - val_loss: 0.1135 - val_peak_signal_noise_ratio: 66.9369 - learning_rate: 1.0000e-04
Epoch 9/5075/75 ━━━━━━━━━━━━━━━━━━━━ 29s 384ms/step - loss: 0.1539 - peak_signal_noise_ratio: 64.3915 - val_loss: 0.1165 - val_peak_signal_noise_ratio: 66.9783 - learning_rate: 1.0000e-04
Epoch 10/5075/75 ━━━━━━━━━━━━━━━━━━━━ 43s 409ms/step - loss: 0.1536 - peak_signal_noise_ratio: 64.4491 - val_loss: 0.1118 - val_peak_signal_noise_ratio: 66.8747 - learning_rate: 1.0000e-04
Epoch 11/5075/75 ━━━━━━━━━━━━━━━━━━━━ 29s 383ms/step - loss: 0.1449 - peak_signal_noise_ratio: 64.6579 - val_loss: 0.1167 - val_peak_signal_noise_ratio: 66.9626 - learning_rate: 1.0000e-04
Epoch 12/5075/75 ━━━━━━━━━━━━━━━━━━━━ 29s 383ms/step - loss: 0.1501 - peak_signal_noise_ratio: 64.7929 - val_loss: 0.1143 - val_peak_signal_noise_ratio: 66.9400 - learning_rate: 1.0000e-04
Epoch 13/5075/75 ━━━━━━━━━━━━━━━━━━━━ 29s 380ms/step - loss: 0.1510 - peak_signal_noise_ratio: 64.6816 - val_loss: 0.1302 - val_peak_signal_noise_ratio: 66.0576 - learning_rate: 1.0000e-04
Epoch 14/5075/75 ━━━━━━━━━━━━━━━━━━━━ 29s 383ms/step - loss: 0.1632 - peak_signal_noise_ratio: 63.9234 - val_loss: 0.1146 - val_peak_signal_noise_ratio: 67.0321 - learning_rate: 1.0000e-04
Epoch 15/5075/75 ━━━━━━━━━━━━━━━━━━━━ 28s 379ms/step - loss: 0.1486 - peak_signal_noise_ratio: 64.7125 - val_loss: 0.1284 - val_peak_signal_noise_ratio: 66.2105 - learning_rate: 1.0000e-04
Epoch 16/5075/75 ━━━━━━━━━━━━━━━━━━━━ 28s 379ms/step - loss: 0.1482 - peak_signal_noise_ratio: 64.8123 - val_loss: 0.1176 - val_peak_signal_noise_ratio: 66.8114 - learning_rate: 1.0000e-04
Epoch 17/5075/75 ━━━━━━━━━━━━━━━━━━━━ 29s 381ms/step - loss: 0.1459 - peak_signal_noise_ratio: 64.7795 - val_loss: 0.1092 - val_peak_signal_noise_ratio: 67.4173 - learning_rate: 1.0000e-04
Epoch 18/5075/75 ━━━━━━━━━━━━━━━━━━━━ 28s 378ms/step - loss: 0.1482 - peak_signal_noise_ratio: 64.8821 - val_loss: 0.1175 - val_peak_signal_noise_ratio: 67.0296 - learning_rate: 1.0000e-04
Epoch 19/5075/75 ━━━━━━━━━━━━━━━━━━━━ 29s 381ms/step - loss: 0.1524 - peak_signal_noise_ratio: 64.7275 - val_loss: 0.1028 - val_peak_signal_noise_ratio: 67.8485 - learning_rate: 1.0000e-04
Epoch 20/5075/75 ━━━━━━━━━━━━━━━━━━━━ 28s 379ms/step - loss: 0.1350 - peak_signal_noise_ratio: 65.6166 - val_loss: 0.1040 - val_peak_signal_noise_ratio: 67.8551 - learning_rate: 1.0000e-04
Epoch 21/5075/75 ━━━━━━━━━━━━━━━━━━━━ 29s 380ms/step - loss: 0.1383 - peak_signal_noise_ratio: 65.5167 - val_loss: 0.1071 - val_peak_signal_noise_ratio: 67.5902 - learning_rate: 1.0000e-04
Epoch 22/5075/75 ━━━━━━━━━━━━━━━━━━━━ 28s 379ms/step - loss: 0.1393 - peak_signal_noise_ratio: 65.6293 - val_loss: 0.1096 - val_peak_signal_noise_ratio: 67.2940 - learning_rate: 1.0000e-04
Epoch 23/5075/75 ━━━━━━━━━━━━━━━━━━━━ 29s 383ms/step - loss: 0.1399 - peak_signal_noise_ratio: 65.5146 - val_loss: 0.1044 - val_peak_signal_noise_ratio: 67.6932 - learning_rate: 1.0000e-04
Epoch 24/5075/75 ━━━━━━━━━━━━━━━━━━━━ 28s 378ms/step - loss: 0.1390 - peak_signal_noise_ratio: 65.7525 - val_loss: 0.1135 - val_peak_signal_noise_ratio: 66.9891 - learning_rate: 1.0000e-04
Epoch 25/5075/75 ━━━━━━━━━━━━━━━━━━━━ 0s 326ms/step - loss: 0.1333 - peak_signal_noise_ratio: 65.8340
Epoch 25: ReduceLROnPlateau reducing learning rate to 4.999999873689376e-05.75/75 ━━━━━━━━━━━━━━━━━━━━ 28s 380ms/step - loss: 0.1332 - peak_signal_noise_ratio: 65.8348 - val_loss: 0.1252 - val_peak_signal_noise_ratio: 66.5684 - learning_rate: 1.0000e-04
Epoch 26/5075/75 ━━━━━━━━━━━━━━━━━━━━ 29s 381ms/step - loss: 0.1547 - peak_signal_noise_ratio: 64.8968 - val_loss: 0.1105 - val_peak_signal_noise_ratio: 67.0688 - learning_rate: 5.0000e-05
Epoch 27/5075/75 ━━━━━━━━━━━━━━━━━━━━ 28s 380ms/step - loss: 0.1269 - peak_signal_noise_ratio: 66.3882 - val_loss: 0.1035 - val_peak_signal_noise_ratio: 67.7006 - learning_rate: 5.0000e-05
Epoch 28/5075/75 ━━━━━━━━━━━━━━━━━━━━ 30s 405ms/step - loss: 0.1243 - peak_signal_noise_ratio: 66.5826 - val_loss: 0.1063 - val_peak_signal_noise_ratio: 67.2497 - learning_rate: 5.0000e-05
Epoch 29/5075/75 ━━━━━━━━━━━━━━━━━━━━ 29s 383ms/step - loss: 0.1292 - peak_signal_noise_ratio: 66.1734 - val_loss: 0.1064 - val_peak_signal_noise_ratio: 67.3989 - learning_rate: 5.0000e-05
Epoch 30/5075/75 ━━━━━━━━━━━━━━━━━━━━ 0s 328ms/step - loss: 0.1304 - peak_signal_noise_ratio: 66.1267
Epoch 30: ReduceLROnPlateau reducing learning rate to 2.499999936844688e-05.75/75 ━━━━━━━━━━━━━━━━━━━━ 29s 382ms/step - loss: 0.1304 - peak_signal_noise_ratio: 66.1294 - val_loss: 0.1109 - val_peak_signal_noise_ratio: 66.8935 - learning_rate: 5.0000e-05
Epoch 31/5075/75 ━━━━━━━━━━━━━━━━━━━━ 29s 381ms/step - loss: 0.1141 - peak_signal_noise_ratio: 67.1338 - val_loss: 0.1145 - val_peak_signal_noise_ratio: 66.8367 - learning_rate: 2.5000e-05
Epoch 32/5075/75 ━━━━━━━━━━━━━━━━━━━━ 28s 380ms/step - loss: 0.1141 - peak_signal_noise_ratio: 66.9369 - val_loss: 0.1132 - val_peak_signal_noise_ratio: 66.9264 - learning_rate: 2.5000e-05
Epoch 33/5075/75 ━━━━━━━━━━━━━━━━━━━━ 28s 380ms/step - loss: 0.1184 - peak_signal_noise_ratio: 66.7723 - val_loss: 0.1090 - val_peak_signal_noise_ratio: 67.1115 - learning_rate: 2.5000e-05
Epoch 34/5075/75 ━━━━━━━━━━━━━━━━━━━━ 29s 380ms/step - loss: 0.1243 - peak_signal_noise_ratio: 66.4147 - val_loss: 0.1080 - val_peak_signal_noise_ratio: 67.2300 - learning_rate: 2.5000e-05
Epoch 35/5075/75 ━━━━━━━━━━━━━━━━━━━━ 0s 325ms/step - loss: 0.1230 - peak_signal_noise_ratio: 66.7113
Epoch 35: ReduceLROnPlateau reducing learning rate to 1.249999968422344e-05.75/75 ━━━━━━━━━━━━━━━━━━━━ 29s 381ms/step - loss: 0.1229 - peak_signal_noise_ratio: 66.7121 - val_loss: 0.1038 - val_peak_signal_noise_ratio: 67.5288 - learning_rate: 2.5000e-05
Epoch 36/5075/75 ━━━━━━━━━━━━━━━━━━━━ 28s 380ms/step - loss: 0.1181 - peak_signal_noise_ratio: 66.9202 - val_loss: 0.1030 - val_peak_signal_noise_ratio: 67.6249 - learning_rate: 1.2500e-05
Epoch 37/5075/75 ━━━━━━━━━━━━━━━━━━━━ 28s 380ms/step - loss: 0.1086 - peak_signal_noise_ratio: 67.5034 - val_loss: 0.1016 - val_peak_signal_noise_ratio: 67.6940 - learning_rate: 1.2500e-05
Epoch 38/5075/75 ━━━━━━━━━━━━━━━━━━━━ 28s 380ms/step - loss: 0.1127 - peak_signal_noise_ratio: 67.3735 - val_loss: 0.1004 - val_peak_signal_noise_ratio: 68.0042 - learning_rate: 1.2500e-05
Epoch 39/5075/75 ━━━━━━━━━━━━━━━━━━━━ 28s 379ms/step - loss: 0.1135 - peak_signal_noise_ratio: 67.3436 - val_loss: 0.1150 - val_peak_signal_noise_ratio: 66.9541 - learning_rate: 1.2500e-05
Epoch 40/5075/75 ━━━━━━━━━━━━━━━━━━━━ 29s 381ms/step - loss: 0.1152 - peak_signal_noise_ratio: 67.1675 - val_loss: 0.1093 - val_peak_signal_noise_ratio: 67.2030 - learning_rate: 1.2500e-05
Epoch 41/5075/75 ━━━━━━━━━━━━━━━━━━━━ 28s 378ms/step - loss: 0.1191 - peak_signal_noise_ratio: 66.7586 - val_loss: 0.1095 - val_peak_signal_noise_ratio: 67.2424 - learning_rate: 1.2500e-05
Epoch 42/5075/75 ━━━━━━━━━━━━━━━━━━━━ 30s 405ms/step - loss: 0.1062 - peak_signal_noise_ratio: 67.6856 - val_loss: 0.1092 - val_peak_signal_noise_ratio: 67.2187 - learning_rate: 1.2500e-05
Epoch 43/5075/75 ━━━━━━━━━━━━━━━━━━━━ 0s 323ms/step - loss: 0.1099 - peak_signal_noise_ratio: 67.6400
Epoch 43: ReduceLROnPlateau reducing learning rate to 6.24999984211172e-06.75/75 ━━━━━━━━━━━━━━━━━━━━ 28s 377ms/step - loss: 0.1099 - peak_signal_noise_ratio: 67.6378 - val_loss: 0.1079 - val_peak_signal_noise_ratio: 67.4591 - learning_rate: 1.2500e-05
Epoch 44/5075/75 ━━━━━━━━━━━━━━━━━━━━ 28s 378ms/step - loss: 0.1155 - peak_signal_noise_ratio: 67.0911 - val_loss: 0.1019 - val_peak_signal_noise_ratio: 67.8073 - learning_rate: 6.2500e-06
Epoch 45/5075/75 ━━━━━━━━━━━━━━━━━━━━ 28s 377ms/step - loss: 0.1145 - peak_signal_noise_ratio: 67.1876 - val_loss: 0.1067 - val_peak_signal_noise_ratio: 67.4283 - learning_rate: 6.2500e-06
Epoch 46/5075/75 ━━━━━━━━━━━━━━━━━━━━ 29s 384ms/step - loss: 0.1077 - peak_signal_noise_ratio: 67.7168 - val_loss: 0.1114 - val_peak_signal_noise_ratio: 67.1392 - learning_rate: 6.2500e-06
Epoch 47/5075/75 ━━━━━━━━━━━━━━━━━━━━ 28s 377ms/step - loss: 0.1117 - peak_signal_noise_ratio: 67.3210 - val_loss: 0.1081 - val_peak_signal_noise_ratio: 67.3622 - learning_rate: 6.2500e-06
Epoch 48/5075/75 ━━━━━━━━━━━━━━━━━━━━ 0s 326ms/step - loss: 0.1074 - peak_signal_noise_ratio: 67.7986
Epoch 48: ReduceLROnPlateau reducing learning rate to 3.12499992105586e-06.75/75 ━━━━━━━━━━━━━━━━━━━━ 29s 380ms/step - loss: 0.1074 - peak_signal_noise_ratio: 67.7992 - val_loss: 0.1101 - val_peak_signal_noise_ratio: 67.3376 - learning_rate: 6.2500e-06
Epoch 49/5075/75 ━━━━━━━━━━━━━━━━━━━━ 28s 380ms/step - loss: 0.1081 - peak_signal_noise_ratio: 67.5032 - val_loss: 0.1121 - val_peak_signal_noise_ratio: 67.0685 - learning_rate: 3.1250e-06
Epoch 50/5075/75 ━━━━━━━━━━━━━━━━━━━━ 28s 378ms/step - loss: 0.1077 - peak_signal_noise_ratio: 67.6709 - val_loss: 0.1084 - val_peak_signal_noise_ratio: 67.6183 - learning_rate: 3.1250e-06
推论
def plot_results(images, titles, figure_size=(12, 12)):fig = plt.figure(figsize=figure_size)for i in range(len(images)):fig.add_subplot(1, len(images), i + 1).set_title(titles[i])_ = plt.imshow(images[i])plt.axis("off")plt.show()def infer(original_image):image = keras.utils.img_to_array(original_image)image = image.astype("float32") / 255.0image = np.expand_dims(image, axis=0)output = model.predict(image, verbose=0)output_image = output[0] * 255.0output_image = output_image.clip(0, 255)output_image = output_image.reshape((np.shape(output_image)[0], np.shape(output_image)[1], 3))output_image = Image.fromarray(np.uint8(output_image))original_image = Image.fromarray(np.uint8(original_image))return output_image
测试图像推理
我们将通过 MIRNet 增强的 LOLDataset 测试图像与通过 PIL.ImageOps.autocontrast() 函数增强的图像进行了比较。
您可以使用 Hugging Face Hub 上托管的训练有素的模型,并在 Hugging Face Spaces 上尝试演示。
for low_light_image in random.sample(test_low_light_images, 6):original_image = Image.open(low_light_image)enhanced_image = infer(original_image)plot_results([original_image, ImageOps.autocontrast(original_image), enhanced_image],["Original", "PIL Autocontrast", "MIRNet Enhanced"],(20, 12),)