政安晨:【Keras机器学习示例演绎】(八)—— 利用 PointNet 进行点云分割

目录

简介

导入

下载数据集

加载数据集

构建数据集

预处理

创建 TensorFlow 数据集

PointNet 模型

排列不变性

变换不变性

点之间的相互作用

实例化模型

训练

直观了解培训情况

推论

最后说明


政安晨的个人主页政安晨

欢迎 👍点赞✍评论⭐收藏

收录专栏: TensorFlow与Keras机器学习实战

希望政安晨的博客能够对您有所裨益,如有不足之处,欢迎在评论区提出指正!

本文目标:实现基于点网的点云分割模型。

简介


点云 "是存储几何形状数据的一种重要数据结构类型。由于其格式不规则,在用于深度学习应用之前,通常要将其转换为规则的三维体素网格或图像集合,这一步骤会使数据变得不必要的庞大。PointNet 系列模型通过直接消耗点云解决了这一问题,同时尊重点数据的包络不变性属性。PointNet 系列模型提供了一个简单、统一的架构,适用于从物体分类、部件分割到场景语义解析等各种应用。

在本示例中,我们演示了用于形状分割的 PointNet 架构的实施。

导入

import os
import json
import random
import numpy as np
import pandas as pd
from tqdm import tqdm
from glob import globimport tensorflow as tf  # For tf.data
import keras
from keras import layersimport matplotlib.pyplot as plt

下载数据集


ShapeNet 数据集是为建立一个注释丰富的大规模三维图形数据集而持续开展的一项工作。ShapeNetCore 是完整 ShapeNet 数据集的一个子集,其中包含干净的单个三维模型以及人工验证的类别和排列注释。它涵盖 55 个常见物体类别,拥有约 51,300 个独特的三维模型。

在这个例子中,我们使用了 PASCAL 3D+ 的 12 个对象类别之一,它是 ShapenetCore 数据集的一部分。

dataset_url = "https://git.io/JiY4i"dataset_path = keras.utils.get_file(fname="shapenet.zip",origin=dataset_url,cache_subdir="datasets",hash_algorithm="auto",extract=True,archive_format="auto",cache_dir="datasets",
)

加载数据集


我们对数据集元数据进行解析,以便轻松地将模型类别映射到各自的目录中,并将分割类别映射到颜色中,从而实现可视化。

with open("/tmp/.keras/datasets/PartAnnotation/metadata.json") as json_file:metadata = json.load(json_file)print(metadata)
{'Airplane': {'directory': '02691156', 'lables': ['wing', 'body', 'tail', 'engine'], 'colors': ['blue', 'green', 'red', 'pink']}, 'Bag': {'directory': '02773838', 'lables': ['handle', 'body'], 'colors': ['blue', 'green']}, 'Cap': {'directory': '02954340', 'lables': ['panels', 'peak'], 'colors': ['blue', 'green']}, 'Car': {'directory': '02958343', 'lables': ['wheel', 'hood', 'roof'], 'colors': ['blue', 'green', 'red']}, 'Chair': {'directory': '03001627', 'lables': ['leg', 'arm', 'back', 'seat'], 'colors': ['blue', 'green', 'red', 'pink']}, 'Earphone': {'directory': '03261776', 'lables': ['earphone', 'headband'], 'colors': ['blue', 'green']}, 'Guitar': {'directory': '03467517', 'lables': ['head', 'body', 'neck'], 'colors': ['blue', 'green', 'red']}, 'Knife': {'directory': '03624134', 'lables': ['handle', 'blade'], 'colors': ['blue', 'green']}, 'Lamp': {'directory': '03636649', 'lables': ['canopy', 'lampshade', 'base'], 'colors': ['blue', 'green', 'red']}, 'Laptop': {'directory': '03642806', 'lables': ['keyboard'], 'colors': ['blue']}, 'Motorbike': {'directory': '03790512', 'lables': ['wheel', 'handle', 'gas_tank', 'light', 'seat'], 'colors': ['blue', 'green', 'red', 'pink', 'yellow']}, 'Mug': {'directory': '03797390', 'lables': ['handle'], 'colors': ['blue']}, 'Pistol': {'directory': '03948459', 'lables': ['trigger_and_guard', 'handle', 'barrel'], 'colors': ['blue', 'green', 'red']}, 'Rocket': {'directory': '04099429', 'lables': ['nose', 'body', 'fin'], 'colors': ['blue', 'green', 'red']}, 'Skateboard': {'directory': '04225987', 'lables': ['wheel', 'deck'], 'colors': ['blue', 'green']}, 'Table': {'directory': '04379243', 'lables': ['leg', 'top'], 'colors': ['blue', 'green']}}

在本例中,我们训练 PointNet 对飞机模型的部件进行分割。

points_dir = "/tmp/.keras/datasets/PartAnnotation/{}/points".format(metadata["Airplane"]["directory"]
)
labels_dir = "/tmp/.keras/datasets/PartAnnotation/{}/points_label".format(metadata["Airplane"]["directory"]
)
LABELS = metadata["Airplane"]["lables"]
COLORS = metadata["Airplane"]["colors"]VAL_SPLIT = 0.2
NUM_SAMPLE_POINTS = 1024
BATCH_SIZE = 32
EPOCHS = 60
INITIAL_LR = 1e-3

构建数据集


我们根据飞机点云及其标签生成以下内存数据结构:

× point_clouds 是一个 np.array 对象列表,以 x、y 和 z 坐标的形式表示点云数据。轴 0 表示点云中的点数,轴 1 表示坐标。all_labels 是一个列表,以字符串形式表示每个坐标的标签(主要用于可视化目的)。
× test_point_clouds 与 point_clouds 格式相同,但没有相应的点云标签。
× all_labels 是一个 np.array 对象列表,表示每个坐标的点云标签,与 point_clouds 列表相对应。
× point_cloud_labels 是一个 np.array 对象列表,表示每个坐标的点云标签,以单击编码的形式表示,与 point_clouds 列表相对应。

point_clouds, test_point_clouds = [], []
point_cloud_labels, all_labels = [], []points_files = glob(os.path.join(points_dir, "*.pts"))
for point_file in tqdm(points_files):point_cloud = np.loadtxt(point_file)if point_cloud.shape[0] < NUM_SAMPLE_POINTS:continue# Get the file-id of the current point cloud for parsing its# labels.file_id = point_file.split("/")[-1].split(".")[0]label_data, num_labels = {}, 0for label in LABELS:label_file = os.path.join(labels_dir, label, file_id + ".seg")if os.path.exists(label_file):label_data[label] = np.loadtxt(label_file).astype("float32")num_labels = len(label_data[label])# Point clouds having labels will be our training samples.try:label_map = ["none"] * num_labelsfor label in LABELS:for i, data in enumerate(label_data[label]):label_map[i] = label if data == 1 else label_map[i]label_data = [LABELS.index(label) if label != "none" else len(LABELS)for label in label_map]# Apply one-hot encoding to the dense label representation.label_data = keras.utils.to_categorical(label_data, num_classes=len(LABELS) + 1)point_clouds.append(point_cloud)point_cloud_labels.append(label_data)all_labels.append(label_map)except KeyError:test_point_clouds.append(point_cloud)
100%|██████████████████████████████████████████████████████████████████████| 4045/4045 [01:30<00:00, 44.54it/s]

接下来,我们看看刚刚生成的内存阵列中的一些样本:

for _ in range(5):i = random.randint(0, len(point_clouds) - 1)print(f"point_clouds[{i}].shape:", point_clouds[0].shape)print(f"point_cloud_labels[{i}].shape:", point_cloud_labels[0].shape)for j in range(5):print(f"all_labels[{i}][{j}]:",all_labels[i][j],f"\tpoint_cloud_labels[{i}][{j}]:",point_cloud_labels[i][j],"\n",)

演绎展示:
 

point_clouds[333].shape: (2571, 3)
point_cloud_labels[333].shape: (2571, 5)
all_labels[333][0]: tail    point_cloud_labels[333][0]: [0. 0. 1. 0. 0.] 
all_labels[333][1]: wing    point_cloud_labels[333][1]: [1. 0. 0. 0. 0.] 
all_labels[333][2]: tail    point_cloud_labels[333][2]: [0. 0. 1. 0. 0.] 
all_labels[333][3]: engine  point_cloud_labels[333][3]: [0. 0. 0. 1. 0.] 
all_labels[333][4]: wing    point_cloud_labels[333][4]: [1. 0. 0. 0. 0.] 
point_clouds[3273].shape: (2571, 3)
point_cloud_labels[3273].shape: (2571, 5)
all_labels[3273][0]: body   point_cloud_labels[3273][0]: [0. 1. 0. 0. 0.] 
all_labels[3273][1]: body   point_cloud_labels[3273][1]: [0. 1. 0. 0. 0.] 
all_labels[3273][2]: tail   point_cloud_labels[3273][2]: [0. 0. 1. 0. 0.] 
all_labels[3273][3]: wing   point_cloud_labels[3273][3]: [1. 0. 0. 0. 0.] 
all_labels[3273][4]: wing   point_cloud_labels[3273][4]: [1. 0. 0. 0. 0.] 
point_clouds[929].shape: (2571, 3)
point_cloud_labels[929].shape: (2571, 5)
all_labels[929][0]: body    point_cloud_labels[929][0]: [0. 1. 0. 0. 0.] 
all_labels[929][1]: tail    point_cloud_labels[929][1]: [0. 0. 1. 0. 0.] 
all_labels[929][2]: wing    point_cloud_labels[929][2]: [1. 0. 0. 0. 0.] 
all_labels[929][3]: tail    point_cloud_labels[929][3]: [0. 0. 1. 0. 0.] 
all_labels[929][4]: body    point_cloud_labels[929][4]: [0. 1. 0. 0. 0.] 
point_clouds[496].shape: (2571, 3)
point_cloud_labels[496].shape: (2571, 5)
all_labels[496][0]: body    point_cloud_labels[496][0]: [0. 1. 0. 0. 0.] 
all_labels[496][1]: body    point_cloud_labels[496][1]: [0. 1. 0. 0. 0.] 
all_labels[496][2]: body    point_cloud_labels[496][2]: [0. 1. 0. 0. 0.] 
all_labels[496][3]: wing    point_cloud_labels[496][3]: [1. 0. 0. 0. 0.] 
all_labels[496][4]: body    point_cloud_labels[496][4]: [0. 1. 0. 0. 0.] 
point_clouds[3508].shape: (2571, 3)
point_cloud_labels[3508].shape: (2571, 5)
all_labels[3508][0]: body   point_cloud_labels[3508][0]: [0. 1. 0. 0. 0.] 
all_labels[3508][1]: body   point_cloud_labels[3508][1]: [0. 1. 0. 0. 0.] 
all_labels[3508][2]: body   point_cloud_labels[3508][2]: [0. 1. 0. 0. 0.] 
all_labels[3508][3]: body   point_cloud_labels[3508][3]: [0. 1. 0. 0. 0.] 
all_labels[3508][4]: body   point_cloud_labels[3508][4]: [0. 1. 0. 0. 0.] 

现在,让我们将一些点云及其标签可视化。

def visualize_data(point_cloud, labels):df = pd.DataFrame(data={"x": point_cloud[:, 0],"y": point_cloud[:, 1],"z": point_cloud[:, 2],"label": labels,})fig = plt.figure(figsize=(15, 10))ax = plt.axes(projection="3d")for index, label in enumerate(LABELS):c_df = df[df["label"] == label]try:ax.scatter(c_df["x"], c_df["y"], c_df["z"], label=label, alpha=0.5, c=COLORS[index])except IndexError:passax.legend()plt.show()visualize_data(point_clouds[0], all_labels[0])
visualize_data(point_clouds[300], all_labels[300])

预处理


需要注意的是,我们加载的所有点云都由数量不等的点组成,这使得我们很难将它们集中在一起。为了解决这个问题,我们从每个点云中随机抽取固定数量的点。我们还对点云进行了归一化处理,以使数据与比例尺无关。

for index in tqdm(range(len(point_clouds))):current_point_cloud = point_clouds[index]current_label_cloud = point_cloud_labels[index]current_labels = all_labels[index]num_points = len(current_point_cloud)# Randomly sampling respective indices.sampled_indices = random.sample(list(range(num_points)), NUM_SAMPLE_POINTS)# Sampling points corresponding to sampled indices.sampled_point_cloud = np.array([current_point_cloud[i] for i in sampled_indices])# Sampling corresponding one-hot encoded labels.sampled_label_cloud = np.array([current_label_cloud[i] for i in sampled_indices])# Sampling corresponding labels for visualization.sampled_labels = np.array([current_labels[i] for i in sampled_indices])# Normalizing sampled point cloud.norm_point_cloud = sampled_point_cloud - np.mean(sampled_point_cloud, axis=0)norm_point_cloud /= np.max(np.linalg.norm(norm_point_cloud, axis=1))point_clouds[index] = norm_point_cloudpoint_cloud_labels[index] = sampled_label_cloudall_labels[index] = sampled_labels
100%|█████████████████████████████████████████████████████████████████████| 3694/3694 [00:08<00:00, 446.45it/s]

让我们将采样和归一化的点云及其相应的标签可视化。

visualize_data(point_clouds[0], all_labels[0])
visualize_data(point_clouds[300], all_labels[300])

创建 TensorFlow 数据集


我们为训练数据和验证数据创建 tf.data.Dataset 对象。我们还通过随机抖动来增强训练点云。

def load_data(point_cloud_batch, label_cloud_batch):point_cloud_batch.set_shape([NUM_SAMPLE_POINTS, 3])label_cloud_batch.set_shape([NUM_SAMPLE_POINTS, len(LABELS) + 1])return point_cloud_batch, label_cloud_batchdef augment(point_cloud_batch, label_cloud_batch):noise = tf.random.uniform(tf.shape(label_cloud_batch), -0.001, 0.001, dtype=tf.float64)point_cloud_batch += noise[:, :, :3]return point_cloud_batch, label_cloud_batchdef generate_dataset(point_clouds, label_clouds, is_training=True):dataset = tf.data.Dataset.from_tensor_slices((point_clouds, label_clouds))dataset = dataset.shuffle(BATCH_SIZE * 100) if is_training else datasetdataset = dataset.map(load_data, num_parallel_calls=tf.data.AUTOTUNE)dataset = dataset.batch(batch_size=BATCH_SIZE)dataset = (dataset.map(augment, num_parallel_calls=tf.data.AUTOTUNE)if is_trainingelse dataset)return datasetsplit_index = int(len(point_clouds) * (1 - VAL_SPLIT))
train_point_clouds = point_clouds[:split_index]
train_label_cloud = point_cloud_labels[:split_index]
total_training_examples = len(train_point_clouds)val_point_clouds = point_clouds[split_index:]
val_label_cloud = point_cloud_labels[split_index:]print("Num train point clouds:", len(train_point_clouds))
print("Num train point cloud labels:", len(train_label_cloud))
print("Num val point clouds:", len(val_point_clouds))
print("Num val point cloud labels:", len(val_label_cloud))train_dataset = generate_dataset(train_point_clouds, train_label_cloud)
val_dataset = generate_dataset(val_point_clouds, val_label_cloud, is_training=False)print("Train Dataset:", train_dataset)
print("Validation Dataset:", val_dataset)
Num train point clouds: 2955
Num train point cloud labels: 2955
Num val point clouds: 739
Num val point cloud labels: 739
Train Dataset: <_ParallelMapDataset element_spec=(TensorSpec(shape=(None, 1024, 3), dtype=tf.float64, name=None), TensorSpec(shape=(None, 1024, 5), dtype=tf.float64, name=None))>
Validation Dataset: <_BatchDataset element_spec=(TensorSpec(shape=(None, 1024, 3), dtype=tf.float64, name=None), TensorSpec(shape=(None, 1024, 5), dtype=tf.float64, name=None))>

PointNet 模型


下图描述了 PointNet 型号系列的内部结构:

鉴于 PointNet 将无序的坐标集作为输入数据,因此其架构需要与点云数据的以下特性相匹配:

排列不变性


鉴于点云数据的非结构化性质,由 n 个点组成的扫描有 n 种排列组合。随后的数据处理必须对不同的表示方法保持不变。为了使 PointNet 不受输入排列的影响,我们在将 n 个输入点映射到高维空间后使用了对称函数(如最大池化)。这样就得到了一个全局特征向量,旨在捕捉 n 个输入点的总体特征。
全局特征向量与局部点特征一起用于分割。

变换不变性


如果物体发生某些变换,如平移或缩放,分割输出结果应保持不变。对于给定的输入点云,我们会应用适当的刚性或仿射变换来实现姿态归一化。由于 n 个输入点中的每个点都表示为一个向量,并独立映射到嵌入空间,因此应用几何变换只需将每个点与一个变换矩阵相乘即可。这就是空间变换器网络概念的由来。

构成 T-Net 的操作是受 PointNet 高级架构的启发。MLP(或全连接层)用于将输入点独立且相同地映射到高维空间;最大池化用于编码全局特征向量,然后通过全连接层降低其维度。最后全连接层的输入相关特征与全局可训练权重和偏置相结合,形成一个 3 乘 3 的变换矩阵。

点之间的相互作用


相邻点之间的相互作用往往蕴含着有用的信息(即不应孤立地处理单个点)。分类只需利用全局特征,而分割则必须能够利用局部点特征和全局点特征。

既然知道了 PointNet 模型的组成要素,我们就可以实现该模型了。我们首先要实现基本模块,即卷积模块和多层感知器模块。

def conv_block(x, filters, name):x = layers.Conv1D(filters, kernel_size=1, padding="valid", name=f"{name}_conv")(x)x = layers.BatchNormalization(name=f"{name}_batch_norm")(x)return layers.Activation("relu", name=f"{name}_relu")(x)def mlp_block(x, filters, name):x = layers.Dense(filters, name=f"{name}_dense")(x)x = layers.BatchNormalization(name=f"{name}_batch_norm")(x)return layers.Activation("relu", name=f"{name}_relu")(x)

我们实施了一个正则化器(取自本例),以加强特征空间的正交性。这是确保变换后的特征幅度不会有太大变化所必需的。

class OrthogonalRegularizer(keras.regularizers.Regularizer):"""Reference: https://keras.io/examples/vision/pointnet/#build-a-model"""def __init__(self, num_features, l2reg=0.001):self.num_features = num_featuresself.l2reg = l2regself.identity = keras.ops.eye(num_features)def __call__(self, x):x = keras.ops.reshape(x, (-1, self.num_features, self.num_features))xxt = keras.ops.tensordot(x, x, axes=(2, 2))xxt = keras.ops.reshape(xxt, (-1, self.num_features, self.num_features))return keras.ops.sum(self.l2reg * keras.ops.square(xxt - self.identity))def get_config(self):config = super().get_config()config.update({"num_features": self.num_features, "l2reg_strength": self.l2reg})return config

下一个部分是我们之前介绍过的转换网络。

def transformation_net(inputs, num_features, name):"""Reference: https://keras.io/examples/vision/pointnet/#build-a-model.The `filters` values come from the original paper:https://arxiv.org/abs/1612.00593."""x = conv_block(inputs, filters=64, name=f"{name}_1")x = conv_block(x, filters=128, name=f"{name}_2")x = conv_block(x, filters=1024, name=f"{name}_3")x = layers.GlobalMaxPooling1D()(x)x = mlp_block(x, filters=512, name=f"{name}_1_1")x = mlp_block(x, filters=256, name=f"{name}_2_1")return layers.Dense(num_features * num_features,kernel_initializer="zeros",bias_initializer=keras.initializers.Constant(np.eye(num_features).flatten()),activity_regularizer=OrthogonalRegularizer(num_features),name=f"{name}_final",)(x)def transformation_block(inputs, num_features, name):transformed_features = transformation_net(inputs, num_features, name=name)transformed_features = layers.Reshape((num_features, num_features))(transformed_features)return layers.Dot(axes=(2, 1), name=f"{name}_mm")([inputs, transformed_features])

最后,我们将上述模块拼接在一起,实现分割模型。

def get_shape_segmentation_model(num_points, num_classes):input_points = keras.Input(shape=(None, 3))# PointNet Classification Network.transformed_inputs = transformation_block(input_points, num_features=3, name="input_transformation_block")features_64 = conv_block(transformed_inputs, filters=64, name="features_64")features_128_1 = conv_block(features_64, filters=128, name="features_128_1")features_128_2 = conv_block(features_128_1, filters=128, name="features_128_2")transformed_features = transformation_block(features_128_2, num_features=128, name="transformed_features")features_512 = conv_block(transformed_features, filters=512, name="features_512")features_2048 = conv_block(features_512, filters=2048, name="pre_maxpool_block")global_features = layers.MaxPool1D(pool_size=num_points, name="global_features")(features_2048)global_features = keras.ops.tile(global_features, [1, num_points, 1])# Segmentation head.segmentation_input = layers.Concatenate(name="segmentation_input")([features_64,features_128_1,features_128_2,transformed_features,features_512,global_features,])segmentation_features = conv_block(segmentation_input, filters=128, name="segmentation_features")outputs = layers.Conv1D(num_classes, kernel_size=1, activation="softmax", name="segmentation_head")(segmentation_features)return keras.Model(input_points, outputs)

实例化模型

x, y = next(iter(train_dataset))num_points = x.shape[1]
num_classes = y.shape[-1]segmentation_model = get_shape_segmentation_model(num_points, num_classes)
segmentation_model.summary()

演绎展示: 

Model: "functional_1"
┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┓
┃ Layer (type)        ┃ Output Shape      ┃ Param # ┃ Connected to         ┃
┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━┩
│ input_layer         │ (None, None, 3)   │       0 │ -                    │
│ (InputLayer)        │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ input_transformati… │ (None, None, 64)  │     256 │ input_layer[0][0]    │
│ (Conv1D)            │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ input_transformati… │ (None, None, 64)  │     256 │ input_transformatio… │
│ (BatchNormalizatio… │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ input_transformati… │ (None, None, 64)  │       0 │ input_transformatio… │
│ (Activation)        │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ input_transformati… │ (None, None, 128) │   8,320 │ input_transformatio… │
│ (Conv1D)            │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ input_transformati… │ (None, None, 128) │     512 │ input_transformatio… │
│ (BatchNormalizatio… │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ input_transformati… │ (None, None, 128) │       0 │ input_transformatio… │
│ (Activation)        │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ input_transformati… │ (None, None,      │ 132,096 │ input_transformatio… │
│ (Conv1D)            │ 1024)             │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ input_transformati… │ (None, None,      │   4,096 │ input_transformatio… │
│ (BatchNormalizatio… │ 1024)             │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ input_transformati… │ (None, None,      │       0 │ input_transformatio… │
│ (Activation)        │ 1024)             │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ global_max_pooling… │ (None, 1024)      │       0 │ input_transformatio… │
│ (GlobalMaxPooling1… │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ input_transformati… │ (None, 512)       │ 524,800 │ global_max_pooling1… │
│ (Dense)             │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ input_transformati… │ (None, 512)       │   2,048 │ input_transformatio… │
│ (BatchNormalizatio… │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ input_transformati… │ (None, 512)       │       0 │ input_transformatio… │
│ (Activation)        │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ input_transformati… │ (None, 256)       │ 131,328 │ input_transformatio… │
│ (Dense)             │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ input_transformati… │ (None, 256)       │   1,024 │ input_transformatio… │
│ (BatchNormalizatio… │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ input_transformati… │ (None, 256)       │       0 │ input_transformatio… │
│ (Activation)        │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ input_transformati… │ (None, 9)         │   2,313 │ input_transformatio… │
│ (Dense)             │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ reshape (Reshape)   │ (None, 3, 3)      │       0 │ input_transformatio… │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ input_transformati… │ (None, None, 3)   │       0 │ input_layer[0][0],   │
│ (Dot)               │                   │         │ reshape[0][0]        │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ features_64_conv    │ (None, None, 64)  │     256 │ input_transformatio… │
│ (Conv1D)            │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ features_64_batch_… │ (None, None, 64)  │     256 │ features_64_conv[0]… │
│ (BatchNormalizatio… │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ features_64_relu    │ (None, None, 64)  │       0 │ features_64_batch_n… │
│ (Activation)        │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ features_128_1_conv │ (None, None, 128) │   8,320 │ features_64_relu[0]… │
│ (Conv1D)            │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ features_128_1_bat… │ (None, None, 128) │     512 │ features_128_1_conv… │
│ (BatchNormalizatio… │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ features_128_1_relu │ (None, None, 128) │       0 │ features_128_1_batc… │
│ (Activation)        │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ features_128_2_conv │ (None, None, 128) │  16,512 │ features_128_1_relu… │
│ (Conv1D)            │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ features_128_2_bat… │ (None, None, 128) │     512 │ features_128_2_conv… │
│ (BatchNormalizatio… │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ features_128_2_relu │ (None, None, 128) │       0 │ features_128_2_batc… │
│ (Activation)        │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ transformed_featur… │ (None, None, 64)  │   8,256 │ features_128_2_relu… │
│ (Conv1D)            │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ transformed_featur… │ (None, None, 64)  │     256 │ transformed_feature… │
│ (BatchNormalizatio… │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ transformed_featur… │ (None, None, 64)  │       0 │ transformed_feature… │
│ (Activation)        │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ transformed_featur… │ (None, None, 128) │   8,320 │ transformed_feature… │
│ (Conv1D)            │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ transformed_featur… │ (None, None, 128) │     512 │ transformed_feature… │
│ (BatchNormalizatio… │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ transformed_featur… │ (None, None, 128) │       0 │ transformed_feature… │
│ (Activation)        │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ transformed_featur… │ (None, None,      │ 132,096 │ transformed_feature… │
│ (Conv1D)            │ 1024)             │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ transformed_featur… │ (None, None,      │   4,096 │ transformed_feature… │
│ (BatchNormalizatio… │ 1024)             │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ transformed_featur… │ (None, None,      │       0 │ transformed_feature… │
│ (Activation)        │ 1024)             │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ global_max_pooling… │ (None, 1024)      │       0 │ transformed_feature… │
│ (GlobalMaxPooling1… │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ transformed_featur… │ (None, 512)       │ 524,800 │ global_max_pooling1… │
│ (Dense)             │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ transformed_featur… │ (None, 512)       │   2,048 │ transformed_feature… │
│ (BatchNormalizatio… │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ transformed_featur… │ (None, 512)       │       0 │ transformed_feature… │
│ (Activation)        │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ transformed_featur… │ (None, 256)       │ 131,328 │ transformed_feature… │
│ (Dense)             │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ transformed_featur… │ (None, 256)       │   1,024 │ transformed_feature… │
│ (BatchNormalizatio… │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ transformed_featur… │ (None, 256)       │       0 │ transformed_feature… │
│ (Activation)        │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ transformed_featur… │ (None, 16384)     │ 4,210,… │ transformed_feature… │
│ (Dense)             │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ reshape_1 (Reshape) │ (None, 128, 128)  │       0 │ transformed_feature… │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ transformed_featur… │ (None, None, 128) │       0 │ features_128_2_relu… │
│ (Dot)               │                   │         │ reshape_1[0][0]      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ features_512_conv   │ (None, None, 512) │  66,048 │ transformed_feature… │
│ (Conv1D)            │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ features_512_batch… │ (None, None, 512) │   2,048 │ features_512_conv[0… │
│ (BatchNormalizatio… │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ features_512_relu   │ (None, None, 512) │       0 │ features_512_batch_… │
│ (Activation)        │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ pre_maxpool_block_… │ (None, None,      │ 1,050,… │ features_512_relu[0… │
│ (Conv1D)            │ 2048)             │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ pre_maxpool_block_… │ (None, None,      │   8,192 │ pre_maxpool_block_c… │
│ (BatchNormalizatio… │ 2048)             │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ pre_maxpool_block_… │ (None, None,      │       0 │ pre_maxpool_block_b… │
│ (Activation)        │ 2048)             │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ global_features     │ (None, None,      │       0 │ pre_maxpool_block_r… │
│ (MaxPooling1D)      │ 2048)             │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ tile (Tile)         │ (None, None,      │       0 │ global_features[0][… │
│                     │ 2048)             │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ segmentation_input  │ (None, None,      │       0 │ features_64_relu[0]… │
│ (Concatenate)       │ 3008)             │         │ features_128_1_relu… │
│                     │                   │         │ features_128_2_relu… │
│                     │                   │         │ transformed_feature… │
│                     │                   │         │ features_512_relu[0… │
│                     │                   │         │ tile[0][0]           │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ segmentation_featu… │ (None, None, 128) │ 385,152 │ segmentation_input[… │
│ (Conv1D)            │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ segmentation_featu… │ (None, None, 128) │     512 │ segmentation_featur… │
│ (BatchNormalizatio… │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ segmentation_featu… │ (None, None, 128) │       0 │ segmentation_featur… │
│ (Activation)        │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ segmentation_head   │ (None, None, 5)   │     645 │ segmentation_featur… │
│ (Conv1D)            │                   │         │                      │
└─────────────────────┴───────────────────┴─────────┴──────────────────────┘
 Total params: 7,370,062 (28.11 MB)
 Trainable params: 7,356,110 (28.06 MB)
 Non-trainable params: 13,952 (54.50 KB)

训练


对于训练,作者建议使用学习率计划,即每 20 个历元将初始学习率减半。在本例中,我们使用的是 5 个历元。

steps_per_epoch = total_training_examples // BATCH_SIZE
total_training_steps = steps_per_epoch * EPOCHS
print(f"Steps per epoch: {steps_per_epoch}.")
print(f"Total training steps: {total_training_steps}.")lr_schedule = keras.optimizers.schedules.ExponentialDecay(initial_learning_rate=0.003,decay_steps=steps_per_epoch * 5,decay_rate=0.5,staircase=True,
)steps = range(total_training_steps)
lrs = [lr_schedule(step) for step in steps]plt.plot(lrs)
plt.xlabel("Steps")
plt.ylabel("Learning Rate")
plt.show()

演绎展示:
 

Steps per epoch: 92.
Total training steps: 5520.

最后,我们实现了一个用于运行实验和启动模型训练的实用程序。

def run_experiment(epochs):segmentation_model = get_shape_segmentation_model(num_points, num_classes)segmentation_model.compile(optimizer=keras.optimizers.Adam(learning_rate=lr_schedule),loss=keras.losses.CategoricalCrossentropy(),metrics=["accuracy"],)checkpoint_filepath = "checkpoint.weights.h5"checkpoint_callback = keras.callbacks.ModelCheckpoint(checkpoint_filepath,monitor="val_loss",save_best_only=True,save_weights_only=True,)history = segmentation_model.fit(train_dataset,validation_data=val_dataset,epochs=epochs,callbacks=[checkpoint_callback],)segmentation_model.load_weights(checkpoint_filepath)return segmentation_model, historysegmentation_model, history = run_experiment(epochs=EPOCHS)
Epoch 1/602/93 [37m━━━━━━━━━━━━━━━━━━━━  7s 86ms/step - accuracy: 0.1427 - loss: 48748.8203WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1699916678.434176   90326 device_compiler.h:187] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.93/93 ━━━━━━━━━━━━━━━━━━━━ 53s 259ms/step - accuracy: 0.3739 - loss: 27980.7305 - val_accuracy: 0.4340 - val_loss: 10361231.0000
Epoch 2/6093/93 ━━━━━━━━━━━━━━━━━━━━ 48s 82ms/step - accuracy: 0.6355 - loss: 339.9151 - val_accuracy: 0.3820 - val_loss: 19069320.0000
Epoch 3/6093/93 ━━━━━━━━━━━━━━━━━━━━ 8s 82ms/step - accuracy: 0.6695 - loss: 281.5728 - val_accuracy: 0.2859 - val_loss: 15993839.0000
Epoch 4/6093/93 ━━━━━━━━━━━━━━━━━━━━ 8s 89ms/step - accuracy: 0.6812 - loss: 253.0939 - val_accuracy: 0.2287 - val_loss: 9633191.0000
Epoch 5/6093/93 ━━━━━━━━━━━━━━━━━━━━ 8s 89ms/step - accuracy: 0.6873 - loss: 231.1317 - val_accuracy: 0.3030 - val_loss: 6001454.0000
Epoch 6/6093/93 ━━━━━━━━━━━━━━━━━━━━ 8s 88ms/step - accuracy: 0.6860 - loss: 216.6793 - val_accuracy: 0.0620 - val_loss: 1945100.8750
Epoch 7/6093/93 ━━━━━━━━━━━━━━━━━━━━ 8s 82ms/step - accuracy: 0.6947 - loss: 210.2683 - val_accuracy: 0.4539 - val_loss: 7908162.5000
Epoch 8/6093/93 ━━━━━━━━━━━━━━━━━━━━ 8s 82ms/step - accuracy: 0.7014 - loss: 203.2560 - val_accuracy: 0.4035 - val_loss: 17741164.0000
Epoch 9/6093/93 ━━━━━━━━━━━━━━━━━━━━ 8s 82ms/step - accuracy: 0.7006 - loss: 197.3710 - val_accuracy: 0.1900 - val_loss: 34120616.0000
Epoch 10/6093/93 ━━━━━━━━━━━━━━━━━━━━ 8s 82ms/step - accuracy: 0.7047 - loss: 192.0777 - val_accuracy: 0.3391 - val_loss: 33157422.0000
Epoch 11/6093/93 ━━━━━━━━━━━━━━━━━━━━ 8s 82ms/step - accuracy: 0.7102 - loss: 188.4875 - val_accuracy: 0.3394 - val_loss: 4630613.5000
Epoch 12/6093/93 ━━━━━━━━━━━━━━━━━━━━ 8s 89ms/step - accuracy: 0.7186 - loss: 184.9940 - val_accuracy: 0.1662 - val_loss: 487790.1250
Epoch 13/6093/93 ━━━━━━━━━━━━━━━━━━━━ 8s 89ms/step - accuracy: 0.7175 - loss: 182.7206 - val_accuracy: 0.1602 - val_loss: 70590.3203
Epoch 14/6093/93 ━━━━━━━━━━━━━━━━━━━━ 8s 88ms/step - accuracy: 0.7159 - loss: 180.5028 - val_accuracy: 0.1631 - val_loss: 16990.2324
Epoch 15/6093/93 ━━━━━━━━━━━━━━━━━━━━ 8s 88ms/step - accuracy: 0.7201 - loss: 180.1674 - val_accuracy: 0.2318 - val_loss: 4992.7783
Epoch 16/6093/93 ━━━━━━━━━━━━━━━━━━━━ 8s 88ms/step - accuracy: 0.7222 - loss: 176.5523 - val_accuracy: 0.6246 - val_loss: 647.5634
Epoch 17/6093/93 ━━━━━━━━━━━━━━━━━━━━ 8s 88ms/step - accuracy: 0.7291 - loss: 175.6139 - val_accuracy: 0.6551 - val_loss: 324.0956
Epoch 18/6093/93 ━━━━━━━━━━━━━━━━━━━━ 8s 88ms/step - accuracy: 0.7285 - loss: 175.0228 - val_accuracy: 0.6430 - val_loss: 257.9340
Epoch 19/6093/93 ━━━━━━━━━━━━━━━━━━━━ 8s 88ms/step - accuracy: 0.7300 - loss: 172.7668 - val_accuracy: 0.6399 - val_loss: 253.2745
Epoch 20/6093/93 ━━━━━━━━━━━━━━━━━━━━ 8s 89ms/step - accuracy: 0.7316 - loss: 172.9001 - val_accuracy: 0.6084 - val_loss: 232.9293
Epoch 21/6093/93 ━━━━━━━━━━━━━━━━━━━━ 8s 89ms/step - accuracy: 0.7364 - loss: 170.8767 - val_accuracy: 0.6451 - val_loss: 191.7183
Epoch 22/6093/93 ━━━━━━━━━━━━━━━━━━━━ 8s 88ms/step - accuracy: 0.7395 - loss: 171.4525 - val_accuracy: 0.6825 - val_loss: 180.2473
Epoch 23/6093/93 ━━━━━━━━━━━━━━━━━━━━ 8s 82ms/step - accuracy: 0.7392 - loss: 170.1975 - val_accuracy: 0.6095 - val_loss: 180.3243
Epoch 24/6093/93 ━━━━━━━━━━━━━━━━━━━━ 8s 88ms/step - accuracy: 0.7362 - loss: 169.2144 - val_accuracy: 0.6017 - val_loss: 178.3013
Epoch 25/6093/93 ━━━━━━━━━━━━━━━━━━━━ 8s 82ms/step - accuracy: 0.7409 - loss: 169.2571 - val_accuracy: 0.6582 - val_loss: 178.3481
Epoch 26/6093/93 ━━━━━━━━━━━━━━━━━━━━ 8s 89ms/step - accuracy: 0.7415 - loss: 167.7480 - val_accuracy: 0.6808 - val_loss: 177.8774
Epoch 27/6093/93 ━━━━━━━━━━━━━━━━━━━━ 8s 89ms/step - accuracy: 0.7440 - loss: 167.7844 - val_accuracy: 0.7131 - val_loss: 176.5841
Epoch 28/6093/93 ━━━━━━━━━━━━━━━━━━━━ 8s 88ms/step - accuracy: 0.7423 - loss: 167.5307 - val_accuracy: 0.6891 - val_loss: 176.1687
Epoch 29/6093/93 ━━━━━━━━━━━━━━━━━━━━ 8s 88ms/step - accuracy: 0.7409 - loss: 166.4581 - val_accuracy: 0.7136 - val_loss: 174.9417
Epoch 30/6093/93 ━━━━━━━━━━━━━━━━━━━━ 8s 88ms/step - accuracy: 0.7419 - loss: 165.9243 - val_accuracy: 0.7407 - val_loss: 173.0663
Epoch 31/6093/93 ━━━━━━━━━━━━━━━━━━━━ 8s 88ms/step - accuracy: 0.7471 - loss: 166.9746 - val_accuracy: 0.7454 - val_loss: 172.9663
Epoch 32/6093/93 ━━━━━━━━━━━━━━━━━━━━ 8s 82ms/step - accuracy: 0.7472 - loss: 165.9707 - val_accuracy: 0.7480 - val_loss: 173.9868
Epoch 33/6093/93 ━━━━━━━━━━━━━━━━━━━━ 8s 82ms/step - accuracy: 0.7443 - loss: 165.9368 - val_accuracy: 0.7076 - val_loss: 174.4526
Epoch 34/6093/93 ━━━━━━━━━━━━━━━━━━━━ 8s 82ms/step - accuracy: 0.7496 - loss: 165.5322 - val_accuracy: 0.7441 - val_loss: 174.6099
Epoch 35/6093/93 ━━━━━━━━━━━━━━━━━━━━ 8s 82ms/step - accuracy: 0.7453 - loss: 164.2007 - val_accuracy: 0.7469 - val_loss: 174.2793
Epoch 36/6093/93 ━━━━━━━━━━━━━━━━━━━━ 8s 82ms/step - accuracy: 0.7503 - loss: 165.3418 - val_accuracy: 0.7469 - val_loss: 174.0812
Epoch 37/6093/93 ━━━━━━━━━━━━━━━━━━━━ 8s 82ms/step - accuracy: 0.7491 - loss: 164.4796 - val_accuracy: 0.7524 - val_loss: 173.9656
Epoch 38/6093/93 ━━━━━━━━━━━━━━━━━━━━ 10s 82ms/step - accuracy: 0.7489 - loss: 164.4573 - val_accuracy: 0.7516 - val_loss: 175.3401
Epoch 39/6093/93 ━━━━━━━━━━━━━━━━━━━━ 8s 82ms/step - accuracy: 0.7437 - loss: 163.4484 - val_accuracy: 0.7532 - val_loss: 173.8172
Epoch 40/6093/93 ━━━━━━━━━━━━━━━━━━━━ 8s 82ms/step - accuracy: 0.7507 - loss: 163.6720 - val_accuracy: 0.7537 - val_loss: 173.9127
Epoch 41/6093/93 ━━━━━━━━━━━━━━━━━━━━ 8s 82ms/step - accuracy: 0.7506 - loss: 164.0555 - val_accuracy: 0.7556 - val_loss: 173.0979
Epoch 42/6093/93 ━━━━━━━━━━━━━━━━━━━━ 8s 89ms/step - accuracy: 0.7517 - loss: 164.1554 - val_accuracy: 0.7562 - val_loss: 172.8895
Epoch 43/6093/93 ━━━━━━━━━━━━━━━━━━━━ 10s 82ms/step - accuracy: 0.7527 - loss: 164.6351 - val_accuracy: 0.7567 - val_loss: 173.0476
Epoch 44/6093/93 ━━━━━━━━━━━━━━━━━━━━ 8s 88ms/step - accuracy: 0.7505 - loss: 164.1568 - val_accuracy: 0.7571 - val_loss: 172.2751
Epoch 45/6093/93 ━━━━━━━━━━━━━━━━━━━━ 8s 88ms/step - accuracy: 0.7500 - loss: 163.8129 - val_accuracy: 0.7579 - val_loss: 171.8897
Epoch 46/6093/93 ━━━━━━━━━━━━━━━━━━━━ 8s 82ms/step - accuracy: 0.7534 - loss: 163.6473 - val_accuracy: 0.7577 - val_loss: 172.5457
Epoch 47/6093/93 ━━━━━━━━━━━━━━━━━━━━ 8s 82ms/step - accuracy: 0.7510 - loss: 163.7318 - val_accuracy: 0.7580 - val_loss: 172.2256
Epoch 48/6093/93 ━━━━━━━━━━━━━━━━━━━━ 8s 82ms/step - accuracy: 0.7517 - loss: 163.3274 - val_accuracy: 0.7575 - val_loss: 172.3276
Epoch 49/6093/93 ━━━━━━━━━━━━━━━━━━━━ 8s 89ms/step - accuracy: 0.7511 - loss: 163.5069 - val_accuracy: 0.7581 - val_loss: 171.2155
Epoch 50/6093/93 ━━━━━━━━━━━━━━━━━━━━ 8s 89ms/step - accuracy: 0.7507 - loss: 163.7366 - val_accuracy: 0.7578 - val_loss: 171.1100
Epoch 51/6093/93 ━━━━━━━━━━━━━━━━━━━━ 8s 82ms/step - accuracy: 0.7519 - loss: 163.1190 - val_accuracy: 0.7580 - val_loss: 171.7971
Epoch 52/6093/93 ━━━━━━━━━━━━━━━━━━━━ 8s 81ms/step - accuracy: 0.7510 - loss: 162.7351 - val_accuracy: 0.7579 - val_loss: 171.9780
Epoch 53/6093/93 ━━━━━━━━━━━━━━━━━━━━ 8s 82ms/step - accuracy: 0.7510 - loss: 162.9639 - val_accuracy: 0.7577 - val_loss: 171.6770
Epoch 54/6093/93 ━━━━━━━━━━━━━━━━━━━━ 8s 88ms/step - accuracy: 0.7530 - loss: 162.7419 - val_accuracy: 0.7578 - val_loss: 170.5556
Epoch 55/6093/93 ━━━━━━━━━━━━━━━━━━━━ 8s 82ms/step - accuracy: 0.7515 - loss: 163.2893 - val_accuracy: 0.7582 - val_loss: 171.9172
Epoch 56/6093/93 ━━━━━━━━━━━━━━━━━━━━ 8s 82ms/step - accuracy: 0.7505 - loss: 164.2843 - val_accuracy: 0.7584 - val_loss: 171.9182
Epoch 57/6093/93 ━━━━━━━━━━━━━━━━━━━━ 8s 82ms/step - accuracy: 0.7498 - loss: 162.6679 - val_accuracy: 0.7587 - val_loss: 173.7610
Epoch 58/6093/93 ━━━━━━━━━━━━━━━━━━━━ 8s 82ms/step - accuracy: 0.7523 - loss: 163.3332 - val_accuracy: 0.7585 - val_loss: 172.5207
Epoch 59/6093/93 ━━━━━━━━━━━━━━━━━━━━ 8s 82ms/step - accuracy: 0.7529 - loss: 162.4575 - val_accuracy: 0.7586 - val_loss: 171.6861
Epoch 60/6093/93 ━━━━━━━━━━━━━━━━━━━━ 8s 82ms/step - accuracy: 0.7498 - loss: 162.9523 - val_accuracy: 0.7586 - val_loss: 172.3012

直观了解培训情况

def plot_result(item):plt.plot(history.history[item], label=item)plt.plot(history.history["val_" + item], label="val_" + item)plt.xlabel("Epochs")plt.ylabel(item)plt.title("Train and Validation {} Over Epochs".format(item), fontsize=14)plt.legend()plt.grid()plt.show()plot_result("loss")
plot_result("accuracy")

推论

validation_batch = next(iter(val_dataset))
val_predictions = segmentation_model.predict(validation_batch[0])
print(f"Validation prediction shape: {val_predictions.shape}")def visualize_single_point_cloud(point_clouds, label_clouds, idx):label_map = LABELS + ["none"]point_cloud = point_clouds[idx]label_cloud = label_clouds[idx]visualize_data(point_cloud, [label_map[np.argmax(label)] for label in label_cloud])idx = np.random.choice(len(validation_batch[0]))
print(f"Index selected: {idx}")# Plotting with ground-truth.
visualize_single_point_cloud(validation_batch[0], validation_batch[1], idx)# Plotting with predicted labels.
visualize_single_point_cloud(validation_batch[0], val_predictions, idx)

演绎展示:

 1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 1s/step
Validation prediction shape: (32, 1024, 5)
Index selected: 26

最后说明


如果您有兴趣了解有关此主题的更多信息,您可能会发现本资料库非常有用。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/637901.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【AI开发:音频】二、GPT-SoVITS使用方法和过程中出现的问题(GPU版)

1.FileNotFoundError: [Errno 2] No such file or directory: logs/guanshenxxx/2-name2text-0.txt 这个问题中包含了两个&#xff1a; 第一个&#xff1a;No module named pyopenjtalk 我的电脑出现的就是这个 解决&#xff1a;pip install pyopenjtalk 第二个&#xff1a…

Wpf 使用 Prism 实战开发Day21

配置默认首页 当应用程序启动时&#xff0c;默认显示首页 一.实现思路&#xff0c;通过自定义接口来配置应用程序加载完成时&#xff0c;设置默认显示页 步骤1.创建自定义 IConfigureService 接口 namespace MyToDo.Common {/// <summary>/// 配置默认显示页接口/// <…

深入理解CAS机制-基础使用与三大问题

&#x1f3f7;️个人主页&#xff1a;牵着猫散步的鼠鼠 &#x1f3f7;️系列专栏&#xff1a;Java全栈-专栏 &#x1f3f7;️个人学习笔记&#xff0c;若有缺误&#xff0c;欢迎评论区指正 目录 1. 前言 2. 原子性问题 3. 乐观锁与悲观锁 4. CAS操作 5. CAS算法带来的三大…

第24天:安全开发-PHP应用文件管理模块显示上传黑白名单类型过滤访问控制

第二十四天 一、PHP文件管理-显示&上传功能实现 如果被抓包抓到数据包&#xff0c;并修改Content-Type内容 则也可以绕过筛查 正常进行上传和下载 二、文件上传-$_FILES&过滤机制实现 无过滤机制 黑名单过滤机制 使用 explode 函数通过点号分割文件名&#xff0c;…

kali /mac 成功的反弹shell语句

mac &#xff1a;192.168.19.107 kali:192.168.19.111 kali 监听mac : nc -lvvp 6666 mac执行&#xff1a; 1: mknod backpipe p && nc 192.168.19.111 6666 0<backpipe | /bin/bash 1>backpipe 2: rm /tmp/f;mkfifo /tmp/f;cat /tmp/f|/bin/sh -i 2>&…

力扣HOT100 - 25. K 个一组翻转链表

解题思路&#xff1a; class Solution {public ListNode reverseKGroup(ListNode head, int k) {ListNode dum new ListNode(0, head);ListNode pre dum;ListNode end dum;while (end.next ! null) {for (int i 0; i < k && end ! null; i) {end end.next;}if …

使用docker搭建GitLab个人开发项目私服

一、安装docker 1.更新系统 dnf update # 最后出现这个标识就说明更新系统成功 Complete!2.添加docker源 dnf config-manager --add-repohttps://download.docker.com/linux/centos/docker-ce.repo # 最后出现这个标识就说明添加成功 Adding repo from: https://download.…

React【Day4下+5】

环境搭建 使用CRA创建项目&#xff0c;并安装必要依赖&#xff0c;包括下列基础包 Redux状态管理 - reduxjs/toolkit 、 react-redux路由 - react-router-dom时间处理 - dayjsclass类名处理 - classnames移动端组件库 - antd-mobile请求插件 - axios pnpm i reduxjs/toolkit r…

华为开源自研AI框架昇思MindSpore应用案例:数据处理性能优化

如果你对MindSpore感兴趣&#xff0c;可以关注昇思MindSpore社区 数据是整个深度学习中最重要的一环&#xff0c;因为数据的好坏决定了最终结果的上限&#xff0c;模型的好坏只是去无限逼近这个上限&#xff0c;所以高质量的数据输入&#xff0c;会在整个深度神经网络中起到积极…

内存管理下及模板初阶

嗨喽&#xff0c;今天阿鑫给大家带来内存管理下以及模板初阶的博客&#xff0c;下面让我们开始今天的学习吧&#xff01; 内存管理下及模板初阶 new和delete的实现原理定位new表达式(placement-new)常见面试题泛型编程函数模板类模板 1. new和delete的实现原理 1.1 内置类型…

知道做到 一篇总结学习方法的笔记

元数据 [!abstract] 知道做到&#xff1a;跃迁式学习 书名&#xff1a; 知道做到&#xff1a;跃迁式学习作者&#xff1a; 彼得•霍林斯简介&#xff1a; 学习是改善你的生活环境、成为你想成为的人的关键。科学的方法能加速学习进程&#xff0c;让你事半功倍。技能、信息和能力…

SpringBoot 操作 Redis

导入对应版本的依赖 <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-data-redis</artifactId> </dependency>修改配置文件中的信息 spring:redis:host: 127.0.0.1port: 8888注意: 我这里 xsh…