# -*- coding: utf-8 -*-
# Author: allen
# Copyright (c) 2021 RuiCare All rights reserved.
from keras.models import Model
from keras.layers import Input
from keras.utils import plot_model
from keras.layers import Activation
from keras.backend import concatenate
from tensorflow.python.keras.layers import Dropout, Add
from keras.layers import BatchNormalization, GaussianNoise
from keras.layers import Conv3D, MaxPooling3D, UpSampling3D
from tensorflow.keras.mixed_precision import experimental as mixed_precision
def BNResConvPool(tensor, filters, pool=True, kernel_size=(3, 3, 3), strides=(1, 1, 1), pool_size=(2, 2, 2),
dropout=0.1):
"""BatchNormalization residual conv pool unit"""
tensor_init = Conv3D(filters, kernel_size, strides=strides, padding='same')(tensor)
tensor = RCU(tensor_init, filters)
# tensor = Dropout(rate=dropout)(tensor)
tensor = BatchNormalization(axis=-1)(tensor)
if pool: tensor = MaxPooling3D(pool_size=pool_size)(tensor)
return tensor
def BNConvUpsampleConcat(tensor, concat_tensor, filters, kernel_size=(3, 3, 3), strides=(1, 1, 1),
upsample_size=(2, 2, 2), dropout=0.1):
"""BatchNormalization conv UpSampling concatenate unit"""
tensor = concatenate([UpSampling3D(size=upsample_size)(tensor), concat_tensor], axis=-1)
tensor = Conv3D(filters, kernel_size=kernel_size, strides=strides, padding='same')(tensor)
# tensor = Dropout(rate=dropout)(tensor)
tensor = BatchNormalization(axis=-1)(tensor)
tensor = Activation('relu')(tensor)
tensor = Conv3D(filters, kernel_size=kernel_size, strides=strides, padding='same')(tensor)
# tensor = Dropout(rate=dropout)(tensor)
tensor = BatchNormalization(axis=-1)(tensor)
return Activation('relu')(tensor)
def RCU(tensor, filters, kernel_size=(3, 3, 3), strides=(1, 1, 1), dropout=0.1):
"""residual conv unit"""
tensor_new = Activation('relu')(tensor)
tensor_new = Conv3D(filters, kernel_size, strides=strides, padding='same')(tensor_new)
# tensor_new = Dropout(rate=dropout)(tensor_new)
tensor_new = BatchNormalization(axis=-1)(tensor_new)
tensor_new = Activation('relu')(tensor_new)
tensor_new = Conv3D(filters, kernel_size, strides=strides, padding='same')(tensor_new)
# tensor_new = Dropout(rate=dropout)(tensor_new)
tensor_new = BatchNormalization(axis=-1)(tensor_new)
tensor_out = Add()([tensor, tensor_new])
return tensor_out
def CRP(tensor, filters=None, pool_size=(2, 2, 2), kernel_size=(3, 3, 3), strides=1, dropout=0.1):
"""chained residual pooling unit"""
tensor = Activation('relu')(tensor)
tensor1 = tensor
tensor = Conv3D(filters=filters, kernel_size=kernel_size, strides=strides, padding='same')(tensor)
# tensor = Dropout(rate=dropout)(tensor)
tensor = BatchNormalization(axis=-1)(tensor)
tensor = Activation('relu')(tensor)
tensor = MaxPooling3D(pool_size=pool_size, strides=strides, padding='same')(tensor)
tensor2 = tensor
tensor = Conv3D(filters=filters, kernel_size=kernel_size, strides=strides, padding='same')(tensor)
# tensor = Dropout(rate=dropout)(tensor)
tensor = BatchNormalization(axis=-1)(tensor)
tensor = Activation('relu')(tensor)
tensor = MaxPooling3D(pool_size=pool_size, strides=strides, padding='same')(tensor)
tensor3 = tensor
tensor = Conv3D(filters=filters, kernel_size=kernel_size, strides=strides, padding='same')(tensor)
# tensor = Dropout(rate=dropout)(tensor)
tensor = BatchNormalization(axis=-1)(tensor)
tensor = Activation('relu')(tensor)
tensor = MaxPooling3D(pool_size=pool_size, strides=strides, padding='same')(tensor)
tensor4 = tensor
tensor = Add()([tensor1, tensor2, tensor3, tensor4])
return tensor
def MRF(upper=None, lower=None, filters=None, kernel_size=(3, 3, 3), strides=(1, 1, 1), dropout=0.1):
"""multi resolution fusion unit"""
if lower is None:
tensor = Conv3D(filters=filters, kernel_size=kernel_size, strides=strides, padding='same')(upper)
# tensor = Dropout(rate=dropout)(tensor)
tensor = BatchNormalization(axis=-1)(tensor)
tensor = Activation('relu')(tensor)
return tensor
else:
upper = Conv3D(filters=filters, kernel_size=kernel_size, strides=strides, padding='same')(upper)
# upper = Dropout(rate=dropout)(upper)
upper = BatchNormalization(axis=-1)(upper)
upper = Activation('relu')(upper)
lower = Conv3D(filters=filters, kernel_size=kernel_size, strides=strides, padding='same')(lower)
# lower = Dropout(rate=dropout)(lower)
lower = BatchNormalization(axis=-1)(lower)
lower = Activation('relu')(lower)
lower = UpSampling3D(size=(2, 2, 2))(lower)
return Add()([upper, lower])
def RefineBlock(upper=None, lower=None, filters=256):
"""Refine Block unit"""
if lower is None: # block 4
upper = RCU(upper, filters=filters)
upper = RCU(upper, filters=filters)
mrf = MRF(upper=upper, lower=None, filters=filters)
crp = CRP(mrf, filters=filters)
out = RCU(crp, filters=filters)
return out
else:
upper = RCU(upper, filters=filters)
upper = RCU(upper, filters=filters)
lower = RCU(lower, filters=filters * 2)
lower = RCU(lower, filters=filters * 2)
mrf = MRF(upper=upper, lower=lower, filters=filters)
crp = CRP(mrf, filters=filters)
out = RCU(crp, filters=filters)
return out
def load_refine_net_plus_3d(input_shape, num_classes=1):
policy = mixed_precision.Policy('mixed_float16')
mixed_precision.set_policy(policy)
inputs = Input(shape=input_shape)
left_init = BNResConvPool(inputs, 8, pool=False)
left0 = BNResConvPool(left_init, 16)
left1 = BNResConvPool(left0, 32)
left2 = BNResConvPool(left1, 64)
left3 = BNResConvPool(left2, 128)
right3 = RefineBlock(upper=left3, lower=None, filters=128)
right2 = RefineBlock(left2, right3, filters=64)
right1 = RefineBlock(left1, right2, filters=32)
right0 = RefineBlock(left0, right1, filters=16)
right_init = RefineBlock(left_init, right0, filters=8)
outputs = Conv3D(num_classes, (1, 1, 1), activation='sigmoid')(right_init)
return Model(inputs=inputs, outputs=outputs, name='RefineNetPlus3D')
if __name__ == '__main__':
m = load_refine_net_plus_3d(input_shape=(32, 256, 256, 1))
plot_model(m, to_file='./RefineNetPlus3D_v2.png', show_shapes=True)