Cifar10 Model 1
# wide_residual_network
def wide_residual_network(input_shape, num_class):
weight_decay = 0.0005
depth = 28
k = 10
n_filters = [16, 16 * k, 32 * k, 64 * k]
n_stack = (depth - 4) // 6
def conv3x3(x, filters):
return Conv2D(filters=filters, kernel_size=(3, 3), strides=(1, 1), padding='same',
kernel_initializer='he_normal',
kernel_regularizer=l2(weight_decay),
use_bias=False)(x)
def bn_relu(x):
x = BatchNormalization(momentum=0.9, epsilon=1e-5)(x)
x = Activation('relu')(x)
return x
def residual_block(x, out_filters, increase=False):
global IN_FILTERS
stride = (1, 1)
if increase:
stride = (2, 2)
o1 = bn_relu(x)
conv_1 = Conv2D(out_filters,
kernel_size=(3, 3), strides=stride, padding='same',
kernel_initializer='he_normal',
kernel_regularizer=l2(weight_decay),
use_bias=False)(o1)
o2 = bn_relu(conv_1)
conv_2 = Conv2D(out_filters,
kernel_size=(3, 3), strides=(1, 1), padding='same',
kernel_initializer='he_normal',
kernel_regularizer=l2(weight_decay),
use_bias=False)(o2)
if increase or IN_FILTERS != out_filters:
proj = Conv2D(out_filters,
kernel_size=(1, 1), strides=stride, padding='same',
kernel_initializer='he_normal',
kernel_regularizer=l2(weight_decay),
use_bias=False)(o1)
block = add([conv_2, proj])
else:
block = add([conv_2, x])
return block
def wide_residual_layer(x, out_filters, increase=False):
global IN_FILTERS
x = residual_block(x, out_filters, increase)
IN_FILTERS = out_filters
for _ in range(1, int(n_stack)):
x = residual_block(x, out_filters)
return x
img_input = Input(shape=input_shape)
x = conv3x3(img_input, n_filters[0])
x = wide_residual_layer(x, n_filters[1])
x = wide_residual_layer(x, n_filters[2], increase=True)
x = wide_residual_layer(x, n_filters[3], increase=True)
x = BatchNormalization(momentum=0.9, epsilon=1e-5)(x)
x = Activation('relu')(x)
x = AveragePooling2D((8, 8))(x)
x = Flatten()(x)
x = Dense(num_class,
activation='softmax',
kernel_initializer='he_normal',
kernel_regularizer=l2(weight_decay),
use_bias=False)(x)
resnet = Model(img_input, x)
# set optimizer
sgd = optimizers.SGD(lr=.1, momentum=0.9, nesterov=True)
resnet.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['accuracy'])
return resnet