发布于2023-11-17 20:58 阅读(1047) 评论(0) 点赞(26) 收藏(0)
我有这个代码,y的形状是(4,5),结果是[[0.81045085]]和[[1.1789, 0.489]],x1的形状是(5,1),x2的形状是(4,1) :
import os
import random
import numpy as np
import tensorflow as tf
# Set all random seeds for reproducibility on the same machine at least
RANDOM_SEED = 1
os.environ['PYTHONHASHSEED'] = str(RANDOM_SEED)
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
tf.random.set_seed(RANDOM_SEED)
# Data
x1 = np.array([1.05, 1.10, 1.15, 1.20, 1.25])
x2 = np.array([0.2, 0.3, 0.4, 0.5])
y = np.array([[0.000, 0.000, 0.000, 0.000, 0.000],
[0.350, 0.350, 0.350, 0.350, 0.350],
[0.615, 0.619, 0.623, 0.626, 0.628],
[0.805, 0.816, 0.826, 0.834, 0.839]])
X1_mesh, X2_mesh = np.meshgrid(x1, x2)
X = np.column_stack((X1_mesh.ravel(), X2_mesh.ravel()))
Y = y.ravel()
# Model
model = tf.keras.Sequential()
model.add(tf.keras.layers.Dense(64, input_dim=2, activation='relu'))
model.add(tf.keras.layers.Dense(64, activation='relu'))
model.add(tf.keras.layers.Dense(1, activation='linear'))
model.compile(optimizer='adam', loss='mean_squared_error')
model.fit(X, Y, epochs=1000, verbose=0)
# Predict
new_point = np.array([[1.1789, 0.489]])
predicted_value = model.predict(new_point)
print(predicted_value)
此时,如果我尝试使用相同的输入,结果会得到 [[0.80769724]],为什么它会改变?
import os
import random
import numpy as np
import tensorflow as tf
import pandas as pd
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
# Set all random seeds for reproducibility on same machine at least.
RANDOM_SEED = 1
os.environ['PYTHONHASHSEED'] = str(RANDOM_SEED)
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
tf.random.set_seed(RANDOM_SEED)
# Data
x1 = np.array([1.05, 1.10, 1.15, 1.20, 1.25, 1.30, 1.35, 1.40, 1.45, 1.50, 1.60, 1.70, 1.80, 1.90, 2.00, 2.20,
2.40, 2.60, 2.80, 3.00])
x2 = np.array([0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0, 2.1, 2.2,
2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3.0, 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9, 4.0, 4.1, 4.2, 4.3,
4.4, 4.5, 4.6, 4.7, 4.8, 4.9, 5.0, 5.1, 5.2, 5.3, 5.4, 5.5, 5.6, 5.7, 5.8, 5.9, 6.0, 6.1, 6.2, 6.3, 6.4,
6.5, 6.6, 6.7, 6.8, 6.9, 7.0, 7.1, 7.2, 7.3, 7.4, 7.5, 7.6, 7.7, 7.8, 7.9, 8.0, 8.1, 8.2, 8.3, 8.4, 8.5,
8.6, 8.7, 8.8, 8.9, 9.0, 9.1, 9.2, 9.3, 9.4, 9.5, 9.6, 9.7, 9.8, 9.9, 10.0, 10.1, 10.2, 10.3, 10.4, 10.5,
10.6, 10.7, 10.8, 10.9, 11.0, 11.1, 11.2, 11.3, 11.4, 11.5, 11.6, 11.7, 11.8, 11.9, 12.0])
df = pd.DataFrame({
"x_0": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
"x_A": [0.350, 0.350, 0.350, 0.350, 0.350, 0.350, 0.350, 0.350, 0.350, 0.350, 0.350, 0.350, 0.350, 0.350, 0.350,
0.350, 0.350, 0.350, 0.350, 0.350],
"x_B": [0.615, 0.619, 0.623, 0.626, 0.628, 0.630, 0.632, 0.633, 0.634, 0.635, 0.636, 0.637, 0.638, 0.639, 0.639,
0.640, 0.640, 0.640, 0.640, 0.640],
"x_C": [0.805, 0.816, 0.826, 0.834, 0.839, 0.844, 0.848, 0.851, 0.854, 0.856, 0.860, 0.862, 0.864, 0.866, 0.867,
0.868, 0.869, 0.869, 0.869, 0.869],
"x_D": [0.955, 0.971, 0.985, 0.998, 1.011, 1.022, 1.032, 1.040, 1.045, 1.048, 1.049, 1.049, 1.050, 1.050, 1.050,
1.051, 1.051, 1.052, 1.052, 1.052],
"x_E": [1.078, 1.100, 1.124, 1.145, 1.162, 1.178, 1.190, 1.199, 1.203, 1.207, 1.210, 1.211, 1.213, 1.214, 1.216,
1.218, 1.219, 1.220, 1.220, 1.220],
"x_F": [1.175, 1.207, 1.239, 1.264, 1.285, 1.300, 1.313, 1.322, 1.332, 1.340, 1.347, 1.352, 1.357, 1.359, 1.360,
1.363, 1.364, 1.364, 1.364, 1.364],
"x_G": [1.256, 1.300, 1.335, 1.365, 1.386, 1.403, 1.471, 1.429, 1.440, 1.450, 1.462, 1.472, 1.480, 1.485, 1.489,
1.492, 1.494, 1.495, 1.495, 1.495],
"x_H": [1.327, 1.375, 1.420, 1.455, 1.479, 1.500, 1.415, 1.530, 1.541, 1.551, 1.568, 1.580, 1.590, 1.598, 1.602,
1.607, 1.608, 1.609, 1.610, 1.610],
"x_I": [1.380, 1.438, 1.435, 1.528, 1.552, 1.573, 1.591, 1.606, 1.616, 1.631, 1.653, 1.667, 1.676, 1.684, 1.691,
1.699, 1.702, 1.706, 1.709, 1.711],
"x_J": [1.433, 1.500, 1.550, 1.600, 1.625, 1.645, 1.666, 1.682, 1.690, 1.710, 1.737, 1.753, 1.761, 1.770, 1.780,
1.790, 1.795, 1.802, 1.808, 1.812],
"x_K": [1.463, 1.545, 1.602, 1.657, 1.684, 1.709, 1.731, 1.746, 1.758, 1.779, 1.810, 1.828, 1.836, 1.845, 1.858,
1.868, 1.875, 1.883, 1.890, 1.896],
"x_L": [1.492, 1.590, 1.654, 1.713, 1.742, 1.772, 1.795, 1.810, 1.825, 1.847, 1.882, 1.903, 1.911, 1.920, 1.935,
1.945, 1.954, 1.964, 1.972, 1.980],
"x_M": [1.510, 1.620, 1.690, 1.757, 1.791, 1.824, 1.848, 1.867, 1.884, 1.906, 1.938, 1.962, 1.973, 1.984, 1.997,
2.010, 2.019, 2.027, 2.036, 2.045],
"x_N": [1.527, 1.649, 1.726, 1.800, 1.839, 1.875, 1.900, 1.923, 1.943, 1.964, 1.993, 2.021, 2.035, 2.047, 2.059,
2.074, 2.083, 2.090, 2.100, 2.110],
"x_O": [1.544, 1.670, 1.754, 1.834, 1.876, 1.917, 1.943, 1.969, 1.991, 2.012, 2.043, 2.072, 2.089, 2.102, 2.116,
2.131, 2.141, 2.148, 2.159, 2.169],
"x_P": [1.560, 1.690, 1.782, 1.867, 1.913, 1.958, 1.985, 2.014, 2.038, 2.060, 2.093, 2.123, 2.142, 2.157, 2.172,
2.188, 2.198, 2.205, 2.217, 2.227],
"x_Q": [1.575, 1.708, 1.808, 1.896, 1.944, 1.993, 2.022, 2.054, 2.079, 2.100, 2.136, 2.165, 2.187, 2.204, 2.219,
2.237, 2.247, 2.256, 2.267, 2.279],
"x_R": [1.590, 1.725, 1.833, 1.924, 1.975, 2.027, 2.059, 2.093, 2.119, 2.140, 2.178, 2.207, 2.231, 2.250, 2.265,
2.285, 2.295, 2.307, 2.317, 2.330],
"x_S": [1.604, 1.743, 1.854, 1.947, 2.003, 2.057, 2.092, 2.126, 2.153, 2.176, 2.215, 2.248, 2.272, 2.292, 2.307,
2.326, 2.337, 2.350, 2.361, 2.375],
"x_T": [1.617, 1.761, 1.876, 1.971, 2.031, 2.086, 2.125, 2.160, 2.187, 2.212, 2.252, 2.288, 2.313, 2.334, 2.349,
2.366, 2.380, 2.394, 2.404, 2.420],
"x_1": [1.631, 1.779, 1.897, 1.994, 2.059, 2.116, 2.157, 2.193, 2.222, 2.249, 2.288, 2.329, 2.354, 2.375, 2.391,
2.407, 2.422, 2.437, 2.448, 2.465],
"x_2": [1.644, 1.797, 1.919, 2.018, 2.087, 2.145, 2.190, 2.227, 2.256, 2.285, 2.325, 2.369, 2.395, 2.417, 2.433,
2.447, 2.465, 2.481, 2.491, 2.510],
"x_3": [1.658, 1.815, 1.940, 2.041, 2.115, 2.175, 2.223, 2.260, 2.290, 2.321, 2.362, 2.410, 2.436, 2.459, 2.475,
2.488, 2.507, 2.524, 2.535, 2.555],
"x_4": [1.672, 1.830, 1.958, 2.061, 2.137, 2.198, 2.249, 2.288, 2.318, 2.350, 2.392, 2.442, 2.469, 2.492, 2.508,
2.523, 2.544, 2.562, 2.574, 2.593],
"x_5": [1.685, 1.845, 1.976, 2.081, 2.159, 2.221, 2.275, 2.316, 2.347, 2.379, 2.423, 2.474, 2.502, 2.525, 2.541,
2.559, 2.581, 2.599, 2.612, 2.630],
"x_6": [1.699, 1.860, 1.994, 2.101, 2.180, 2.245, 2.302, 2.344, 2.375, 2.407, 2.453, 2.506, 2.534, 2.557, 2.575,
2.594, 2.617, 2.637, 2.651, 2.668],
"x_7": [1.712, 1.875, 2.012, 2.121, 2.202, 2.268, 2.328, 2.372, 2.404, 2.436, 2.484, 2.538, 2.567, 2.590, 2.608,
2.630, 2.654, 2.674, 2.689, 2.705],
"x_8": [1.726, 1.890, 2.030, 2.140, 2.224, 2.291, 2.354, 2.400, 2.432, 2.465, 2.514, 2.570, 2.600, 2.623, 2.641,
2.665, 2.691, 2.712, 2.728, 2.743],
"x_9": [1.740, 1.904, 2.046, 2.157, 2.243, 2.311, 2.376, 2.423, 2.455, 2.489, 2.540, 2.597, 2.628, 2.652, 2.670,
2.694, 2.722, 2.744, 2.759, 2.775],
"x_10": [1.754, 1.918, 2.062, 2.175, 2.261, 2.331, 2.397, 2.446, 2.478, 2.512, 2.565, 2.623, 2.657, 2.681, 2.700,
2.723, 2.753, 2.775, 2.790, 2.806],
"x_11": [1.767, 1.932, 2.078, 2.192, 2.280, 2.350, 2.419, 2.469, 2.502, 2.536, 2.591, 2.650, 2.685, 2.709, 2.729,
2.752, 2.783, 2.807, 2.821, 2.838],
"x_12": [1.781, 1.946, 2.094, 2.210, 2.298, 2.370, 2.440, 2.492, 2.525, 2.559, 2.616, 2.676, 2.714, 2.738, 2.759,
2.781, 2.814, 2.838, 2.852, 2.869],
"x_13": [1.795, 1.960, 2.110, 2.227, 2.317, 2.390, 2.462, 2.515, 2.548, 2.583, 2.642, 2.703, 2.742, 2.767, 2.788,
2.810, 2.845, 2.870, 2.883, 2.901],
"x_14": [1.808, 1.974, 2.125, 2.243, 2.333, 2.407, 2.480, 2.535, 2.568, 2.603, 2.664, 2.726, 2.766, 2.792, 2.813,
2.836, 2.872, 2.910, 2.911, 2.929],
"x_15": [1.822, 1.988, 2.140, 2.259, 2.349, 2.424, 2.498, 2.556, 2.588, 2.624, 2.686, 2.748, 2.791, 2.817, 2.839,
2.862, 2.899, 2.950, 2.938, 2.957],
"x_16": [1.835, 2.002, 2.155, 2.275, 2.365, 2.440, 2.517, 2.576, 2.609, 2.644, 2.708, 2.771, 2.815, 2.843, 2.864,
2.888, 2.925, 2.990, 2.966, 2.984],
"x_17": [1.849, 2.016, 2.170, 2.291, 2.381, 2.457, 2.535, 2.597, 2.629, 2.665, 2.730, 2.793, 2.840, 2.868, 2.890,
2.914, 2.952, 3.030, 2.993, 3.012],
"x_18": [1.862, 2.030, 2.186, 2.306, 2.397, 2.474, 2.553, 2.617, 2.649, 2.685, 2.752, 2.816, 2.864, 2.893, 2.915,
2.940, 2.979, 3.070, 3.021, 3.040],
"x_19": [1.875, 2.044, 2.201, 2.321, 2.413, 2.490, 2.569, 2.634, 2.667, 2.703, 2.771, 2.836, 2.885, 2.915, 2.938,
2.963, 3.002, 3.081, 3.045, 3.064],
"x_20": [1.889, 2.058, 2.216, 2.336, 2.429, 2.506, 2.586, 2.651, 2.685, 2.721, 2.789, 2.856, 2.907, 2.937, 2.960,
2.985, 3.025, 3.092, 3.069, 3.088],
"x_21": [1.902, 2.073, 2.230, 2.351, 2.444, 2.523, 2.602, 2.669, 2.702, 2.740, 2.808, 2.875, 2.928, 2.958, 2.983,
3.008, 3.049, 3.103, 3.094, 3.112],
"x_22": [1.916, 2.087, 2.245, 2.366, 2.460, 2.539, 2.619, 2.686, 2.720, 2.758, 2.826, 2.895, 2.950, 2.980, 3.005,
3.030, 3.072, 3.114, 3.118, 3.136],
"x_23": [1.929, 2.101, 2.260, 2.381, 2.476, 2.555, 2.635, 2.703, 2.738, 2.776, 2.845, 2.915, 2.971, 3.002, 3.028,
3.053, 3.095, 3.125, 3.142, 3.160],
"x_24": [1.942, 2.115, 2.274, 2.395, 2.491, 2.570, 2.651, 2.719, 2.754, 2.793, 2.863, 2.933, 2.990, 3.022, 3.048,
3.074, 3.117, 3.147, 3.164, 3.182],
"x_25": [1.955, 2.128, 2.288, 2.409, 2.507, 2.586, 2.666, 2.735, 2.770, 2.810, 2.881, 2.952, 3.009, 3.041, 3.068,
3.095, 3.139, 3.168, 3.186, 3.203],
"x_26": [1.969, 2.142, 2.301, 2.423, 2.522, 2.601, 2.682, 2.752, 2.786, 2.826, 2.899, 2.970, 3.027, 3.061, 3.088,
3.115, 3.161, 3.190, 3.209, 3.225],
"x_27": [1.982, 2.155, 2.315, 2.437, 2.538, 2.617, 2.697, 2.768, 2.802, 2.843, 2.917, 2.989, 3.046, 3.080, 3.108,
3.136, 3.183, 3.211, 3.231, 3.246],
"x_28": [1.995, 2.169, 2.329, 2.451, 2.553, 2.632, 2.713, 2.784, 2.818, 2.860, 2.935, 3.007, 3.065, 3.100, 3.128,
3.157, 3.205, 3.233, 3.253, 3.268],
"x_29": [2.009, 2.183, 2.342, 2.465, 2.567, 2.646, 2.728, 2.799, 2.834, 2.876, 2.952, 3.024, 3.082, 3.118, 3.146,
3.177, 3.225, 3.253, 3.274, 3.288],
"x_30": [2.024, 2.197, 2.355, 2.479, 2.581, 2.661, 2.743, 2.814, 2.850, 2.892, 2.968, 3.042, 3.099, 3.136, 3.164,
3.196, 3.244, 3.273, 3.295, 3.308],
"x_31": [2.038, 2.210, 2.369, 2.492, 2.595, 2.675, 2.758, 2.830, 2.865, 2.908, 2.985, 3.059, 3.117, 3.153, 3.182,
3.216, 3.264, 3.294, 3.315, 3.328],
"x_32": [2.053, 2.224, 2.382, 2.506, 2.609, 2.690, 2.773, 2.845, 2.881, 2.924, 3.001, 3.077, 3.134, 3.171, 3.200,
3.235, 3.283, 3.314, 3.336, 3.348],
"x_33": [2.067, 2.238, 2.395, 2.520, 2.623, 2.704, 2.788, 2.860, 2.897, 2.940, 3.018, 3.094, 3.151, 3.189, 3.218,
3.255, 3.303, 3.334, 3.357, 3.368],
"x_34": [2.079, 2.251, 2.408, 2.533, 2.636, 2.718, 2.801, 2.874, 2.912, 2.955, 3.037, 3.110, 3.168, 3.206, 3.235,
3.273, 3.321, 3.352, 3.375, 3.386],
"x_35": [2.091, 2.264, 2.421, 2.547, 2.650, 2.731, 2.815, 2.888, 2.926, 2.970, 3.049, 3.125, 3.185, 3.224, 3.252,
3.291, 3.339, 3.370, 3.393, 3.405],
"x_36": [2.102, 2.277, 2.435, 2.560, 2.663, 2.745, 2.828, 2.902, 2.941, 2.985, 3.065, 3.141, 3.201, 3.241, 3.270,
3.309, 3.356, 3.389, 3.412, 3.423],
"x_37": [2.114, 2.290, 2.448, 2.574, 2.677, 2.758, 2.842, 2.916, 2.955, 3.000, 3.080, 3.156, 3.218, 3.259, 3.287,
3.327, 3.374, 3.407, 3.430, 3.442],
"x_38": [2.126, 2.303, 2.461, 2.587, 2.690, 2.772, 2.855, 2.930, 2.970, 3.015, 3.096, 3.172, 3.235, 3.276, 3.304,
3.345, 3.392, 3.425, 3.448, 3.460],
"x_39": [2.139, 2.316, 2.474, 2.600, 2.703, 2.785, 2.869, 2.943, 2.984, 3.029, 3.111, 3.187, 3.250, 3.292, 3.321,
3.362, 3.409, 3.442, 3.466, 3.477],
"x_40": [2.152, 2.328, 2.486, 2.612, 2.716, 2.799, 2.882, 2.956, 2.997, 3.043, 3.125, 3.202, 3.266, 3.308, 3.337,
3.379, 3.426, 3.459, 3.483, 3.494],
"x_41": [2.165, 2.341, 2.499, 2.625, 2.729, 2.812, 2.896, 2.970, 3.011, 3.056, 3.140, 3.218, 3.281, 3.323, 3.354,
3.395, 3.443, 3.476, 3.501, 3.511],
"x_42": [2.178, 2.353, 2.511, 2.637, 2.742, 2.826, 2.909, 2.983, 3.024, 3.070, 3.154, 3.233, 3.297, 3.339, 3.370,
3.412, 3.460, 3.493, 3.518, 3.528],
"x_43": [2.191, 2.366, 2.524, 2.650, 2.755, 2.839, 2.923, 2.996, 3.038, 3.084, 3.169, 3.248, 3.312, 3.355, 3.387,
3.429, 3.477, 3.510, 3.536, 3.545],
"x_44": [2.204, 2.379, 2.536, 2.662, 2.768, 2.852, 2.936, 3.009, 3.051, 3.098, 3.183, 3.262, 3.327, 3.370, 3.402,
3.444, 3.493, 3.526, 3.551, 3.561],
"x_45": [2.217, 2.391, 2.548, 2.675, 2.781, 2.864, 2.949, 3.022, 3.064, 3.112, 3.197, 3.276, 3.341, 3.385, 3.417,
3.459, 3.508, 3.542, 3.567, 3.577],
"x_46": [2.229, 2.404, 2.560, 2.687, 2.794, 2.877, 2.963, 3.034, 3.077, 3.126, 3.210, 3.291, 3.356, 3.399, 3.432,
3.475, 3.524, 3.557, 3.582, 3.592],
"x_47": [2.242, 2.416, 2.572, 2.700, 2.807, 2.889, 2.976, 3.047, 3.090, 3.140, 3.224, 3.305, 3.370, 3.414, 3.447,
3.490, 3.539, 3.573, 3.598, 3.608],
"x_48": [2.255, 2.429, 2.584, 2.712, 2.820, 2.902, 2.989, 3.060, 3.103, 3.154, 3.238, 3.319, 3.385, 3.429, 3.462,
3.505, 3.555, 3.589, 3.613, 3.624],
"x_49": [2.268, 2.442, 2.597, 2.724, 2.832, 2.915, 3.002, 3.073, 3.116, 3.167, 3.251, 3.332, 3.399, 3.443, 3.477,
3.520, 3.570, 3.604, 3.628, 3.639],
"x_50": [2.281, 2.454, 2.609, 2.737, 2.844, 2.928, 3.014, 3.085, 3.129, 3.180, 3.264, 3.345, 3.413, 3.457, 3.491,
3.534, 3.584, 3.618, 3.643, 3.654],
"x_51": [2.294, 2.467, 2.622, 2.749, 2.856, 2.941, 3.027, 3.098, 3.141, 3.194, 3.278, 3.359, 3.427, 3.472, 3.506,
3.549, 3.599, 3.633, 3.659, 3.670],
"x_52": [2.307, 2.479, 2.634, 2.762, 2.868, 2.954, 3.039, 3.110, 3.154, 3.207, 3.291, 3.372, 3.441, 3.486, 3.520,
3.563, 3.613, 3.647, 3.674, 3.685],
"x_53": [2.320, 2.492, 2.647, 2.774, 2.880, 2.967, 3.052, 3.123, 3.167, 3.220, 3.304, 3.385, 3.455, 3.500, 3.535,
3.578, 3.628, 3.662, 3.689, 3.700],
"x_54": [2.333, 2.505, 2.660, 2.786, 2.892, 2.979, 3.065, 3.135, 3.180, 3.233, 3.317, 3.398, 3.468, 3.514, 3.548,
3.591, 3.642, 3.676, 3.703, 3.714],
"x_55": [2.346, 2.517, 2.672, 2.799, 2.904, 2.991, 3.077, 3.147, 3.192, 3.246, 3.330, 3.411, 3.482, 3.528, 3.562,
3.605, 3.656, 3.690, 3.718, 3.728],
"x_56": [2.359, 2.530, 2.685, 2.811, 2.916, 3.003, 3.090, 3.160, 3.205, 3.260, 3.344, 3.424, 3.495, 3.541, 3.575,
3.618, 3.670, 3.704, 3.732, 3.742],
"x_57": [2.372, 2.542, 2.697, 2.824, 2.928, 3.015, 3.102, 3.172, 3.217, 3.274, 3.357, 3.437, 3.509, 3.555, 3.589,
3.632, 3.684, 3.718, 3.747, 3.756],
"x_58": [2.385, 2.555, 2.710, 2.836, 2.940, 3.027, 3.115, 3.184, 3.230, 3.287, 3.370, 3.450, 3.522, 3.569, 3.602,
3.645, 3.698, 3.732, 3.761, 3.770],
"x_59": [2.398, 2.568, 2.723, 2.848, 2.952, 3.039, 3.127, 3.197, 3.242, 3.299, 3.382, 3.462, 3.534, 3.581, 3.615,
3.658, 3.711, 3.745, 3.774, 3.783],
"x_60": [2.411, 2.580, 2.736, 2.861, 2.964, 3.051, 3.139, 3.209, 3.254, 3.311, 3.394, 3.474, 3.546, 3.594, 3.627,
3.671, 3.723, 3.758, 3.788, 3.796],
"x_61": [2.424, 2.593, 2.748, 2.873, 2.977, 3.064, 3.151, 3.222, 3.266, 3.323, 3.407, 3.486, 3.559, 3.606, 3.640,
3.684, 3.736, 3.771, 3.801, 3.810],
"x_62": [2.437, 2.605, 2.761, 2.886, 2.989, 3.076, 3.163, 3.234, 3.278, 3.335, 3.419, 3.498, 3.571, 3.619, 3.652,
3.697, 3.748, 3.784, 3.815, 3.823],
"x_63": [2.450, 2.618, 2.774, 2.898, 3.001, 3.088, 3.175, 3.247, 3.290, 3.347, 3.431, 3.510, 3.583, 3.631, 3.665,
3.710, 3.761, 3.797, 3.828, 3.836],
"x_64": [2.462, 2.631, 2.787, 2.910, 3.013, 3.100, 3.187, 3.259, 3.302, 3.359, 3.443, 3.523, 3.595, 3.643, 3.677,
3.722, 3.773, 3.810, 3.840, 3.849],
"x_65": [2.475, 2.643, 2.799, 2.923, 3.025, 3.112, 3.199, 3.270, 3.315, 3.370, 3.456, 3.535, 3.607, 3.655, 3.690,
3.734, 3.786, 3.823, 3.853, 3.862],
"x_66": [2.487, 2.656, 2.812, 2.935, 3.038, 3.124, 3.211, 3.282, 3.327, 3.382, 3.468, 3.548, 3.619, 3.666, 3.702,
3.746, 3.798, 3.835, 3.865, 3.875],
"x_67": [2.500, 2.668, 2.824, 2.948, 3.050, 3.136, 3.223, 3.293, 3.340, 3.393, 3.481, 3.560, 3.631, 3.678, 3.715,
3.758, 3.811, 3.848, 3.878, 3.888],
"x_68": [2.512, 2.681, 2.837, 2.960, 3.062, 3.148, 3.235, 3.305, 3.352, 3.405, 3.493, 3.573, 3.643, 3.690, 3.727,
3.770, 3.823, 3.861, 3.890, 3.901],
"x_69": [2.524, 2.693, 2.849, 2.972, 3.074, 3.159, 3.246, 3.317, 3.364, 3.417, 3.505, 3.585, 3.655, 3.702, 3.739,
3.782, 3.835, 3.873, 3.902, 3.913],
"x_70": [2.536, 2.706, 2.861, 2.985, 3.085, 3.170, 3.257, 3.329, 3.376, 3.429, 3.517, 3.597, 3.667, 3.714, 3.750,
3.794, 3.847, 3.885, 3.915, 3.925],
"x_71": [2.549, 2.718, 2.872, 2.997, 3.097, 3.182, 3.268, 3.340, 3.388, 3.440, 3.530, 3.608, 3.678, 3.725, 3.762,
3.806, 3.859, 3.897, 3.927, 3.938],
"x_72": [2.561, 2.731, 2.884, 3.010, 3.108, 3.193, 3.279, 3.352, 3.400, 3.452, 3.542, 3.620, 3.690, 3.737, 3.773,
3.818, 3.871, 3.909, 3.940, 3.950],
"x_73": [2.573, 2.743, 2.896, 3.022, 3.120, 3.204, 3.290, 3.364, 3.412, 3.464, 3.554, 3.632, 3.702, 3.749, 3.785,
3.830, 3.883, 3.921, 3.952, 3.962],
"x_74": [2.585, 2.755, 2.908, 3.034, 3.131, 3.216, 3.302, 3.376, 3.424, 3.475, 3.565, 3.644, 3.713, 3.760, 3.797,
3.842, 3.895, 3.933, 3.964, 3.974],
"x_75": [2.597, 2.767, 2.919, 3.045, 3.142, 3.228, 3.314, 3.388, 3.435, 3.487, 3.576, 3.656, 3.724, 3.772, 3.809,
3.854, 3.907, 3.945, 3.976, 3.986],
"x_76": [2.610, 2.780, 2.931, 3.057, 3.153, 3.239, 3.326, 3.399, 3.447, 3.498, 3.588, 3.667, 3.736, 3.783, 3.820,
3.865, 3.918, 3.957, 3.987, 3.999],
"x_77": [2.622, 2.792, 2.942, 3.068, 3.164, 3.251, 3.338, 3.411, 3.458, 3.510, 3.599, 3.679, 3.747, 3.795, 3.832,
3.877, 3.930, 3.969, 3.999, 4.011],
"x_78": [2.634, 2.804, 2.954, 3.080, 3.175, 3.263, 3.350, 3.423, 3.470, 3.521, 3.610, 3.691, 3.758, 3.806, 3.844,
3.889, 3.942, 3.981, 4.011, 4.023],
"x_79": [2.646, 2.816, 2.966, 3.092, 3.187, 3.274, 3.361, 3.434, 3.482, 3.532, 3.622, 3.702, 3.769, 3.817, 3.855,
3.900, 3.953, 3.992, 4.023, 4.035],
"x_80": [2.658, 2.828, 2.978, 3.103, 3.199, 3.286, 3.372, 3.446, 3.494, 3.544, 3.633, 3.714, 3.780, 3.828, 3.867,
3.911, 3.965, 4.004, 4.035, 4.046],
"x_81": [2.671, 2.840, 2.989, 3.115, 3.211, 3.297, 3.382, 3.457, 3.506, 3.555, 3.645, 3.725, 3.790, 3.840, 3.878,
3.923, 3.976, 4.015, 4.046, 4.058],
"x_82": [2.683, 2.852, 3.001, 3.126, 3.223, 3.309, 3.393, 3.469, 3.518, 3.567, 3.656, 3.737, 3.801, 3.851, 3.890,
3.934, 3.988, 4.027, 4.058, 4.069],
"x_83": [2.695, 2.864, 3.013, 3.138, 3.235, 3.320, 3.404, 3.480, 3.530, 3.578, 3.668, 3.748, 3.812, 3.862, 3.901,
3.945, 3.999, 4.038, 4.070, 4.081],
"x_84": [2.707, 2.876, 3.025, 3.150, 3.246, 3.332, 3.416, 3.492, 3.541, 3.588, 3.679, 3.758, 3.823, 3.873, 3.912,
3.956, 4.010, 4.049, 4.081, 4.092],
"x_85": [2.719, 2.888, 3.037, 3.161, 3.258, 3.343, 3.428, 3.504, 3.552, 3.598, 3.689, 3.769, 3.834, 3.883, 3.923,
3.967, 4.021, 4.060, 4.093, 4.104],
"x_86": [2.732, 2.900, 3.048, 3.173, 3.269, 3.355, 3.440, 3.515, 3.562, 3.609, 3.700, 3.779, 3.844, 3.894, 3.933,
3.978, 4.031, 4.071, 4.104, 4.115],
"x_87": [2.744, 2.912, 3.060, 3.184, 3.281, 3.366, 3.452, 3.527, 3.573, 3.619, 3.710, 3.790, 3.855, 3.904, 3.944,
3.989, 4.042, 4.082, 4.116, 4.127],
"x_88": [2.756, 2.924, 3.072, 3.196, 3.292, 3.378, 3.464, 3.539, 3.584, 3.629, 3.721, 3.800, 3.866, 3.915, 3.955,
4.000, 4.053, 4.093, 4.127, 4.138],
"x_89": [2.768, 2.936, 3.084, 3.208, 3.304, 3.389, 3.475, 3.551, 3.595, 3.639, 3.732, 3.811, 3.877, 3.926, 3.966,
4.011, 4.064, 4.104, 4.138, 4.149],
"x_90": [2.780, 2.948, 3.096, 3.220, 3.315, 3.401, 3.486, 3.562, 3.605, 3.650, 3.743, 3.822, 3.888, 3.937, 3.977,
4.022, 4.075, 4.116, 4.150, 4.160],
"x_91": [2.793, 2.960, 3.108, 3.231, 3.327, 3.412, 3.497, 3.574, 3.616, 3.660, 3.753, 3.832, 3.899, 3.947, 3.988,
4.033, 4.087, 4.127, 4.161, 4.172],
"x_92": [2.805, 2.972, 3.120, 3.243, 3.338, 3.424, 3.508, 3.585, 3.626, 3.671, 3.764, 3.843, 3.910, 3.958, 3.999,
4.044, 4.098, 4.139, 4.173, 4.183],
"x_93": [2.817, 2.984, 3.132, 3.255, 3.350, 3.435, 3.519, 3.597, 3.637, 3.681, 3.775, 3.854, 3.921, 3.969, 4.010,
4.055, 4.109, 4.150, 4.184, 4.194],
"x_94": [2.829, 2.996, 3.144, 3.267, 3.361, 3.446, 3.529, 3.607, 3.648, 3.692, 3.786, 3.865, 3.932, 3.980, 4.022,
4.067, 4.121, 4.161, 4.195, 4.205],
"x_95": [2.841, 3.008, 3.156, 3.279, 3.373, 3.456, 3.540, 3.617, 3.658, 3.702, 3.797, 3.876, 3.943, 3.991, 4.034,
4.079, 4.132, 4.172, 4.206, 4.216],
"x_96": [2.854, 3.020, 3.168, 3.290, 3.384, 3.467, 3.550, 3.628, 3.669, 3.713, 3.808, 3.886, 3.955, 4.003, 4.045,
4.090, 4.144, 4.183, 4.217, 4.227],
"x_97": [2.866, 3.032, 3.180, 3.302, 3.396, 3.477, 3.561, 3.638, 3.679, 3.723, 3.819, 3.897, 3.966, 4.014, 4.057,
4.102, 4.155, 4.194, 4.228, 4.238],
"x_98": [2.878, 3.044, 3.192, 3.314, 3.407, 3.488, 3.571, 3.648, 3.690, 3.734, 3.830, 3.908, 3.977, 4.025, 4.069,
4.114, 4.167, 4.205, 4.239, 4.249]
})
df_1 = df.T
numt = df_1.to_numpy()
y = numt
X1_mesh, X2_mesh = np.meshgrid(x1, x2)
X = np.column_stack((X1_mesh.ravel(), X2_mesh.ravel()))
Y = y.ravel()
# Model
model = tf.keras.Sequential()
model.add(tf.keras.layers.Dense(64, input_dim=2, activation='relu'))
model.add(tf.keras.layers.Dense(64, activation='relu'))
model.add(tf.keras.layers.Dense(1, activation='linear'))
model.compile(optimizer='adam', loss='mean_squared_error')
model.fit(X, Y, epochs=1000, verbose=0)
# Predict
new_point = np.array([[1.1789, 0.489]])
predicted_value = model.predict(new_point)
print(predicted_value)
我想知道为什么当数组的形状比开始时大时结果会改变,为什么它会改变?
作者:黑洞官方问答小能手
链接:https://www.pythonheidong.com/blog/article/2039335/389c10ed77356571115e/
来源:python黑洞网
任何形式的转载都请注明出处,如有侵权 一经发现 必将追究其法律责任
昵称:
评论内容:(最多支持255个字符)
---无人问津也好,技不如人也罢,你都要试着安静下来,去做自己该做的事,而不是让内心的烦躁、焦虑,坏掉你本来就不多的热情和定力
Copyright © 2018-2021 python黑洞网 All Rights Reserved 版权所有,并保留所有权利。 京ICP备18063182号-1
投诉与举报,广告合作请联系vgs_info@163.com或QQ3083709327
免责声明:网站文章均由用户上传,仅供读者学习交流使用,禁止用做商业用途。若文章涉及色情,反动,侵权等违法信息,请向我们举报,一经核实我们会立即删除!