# -*- coding: utf-8 -*-
# Author: allen
# Copyright (c) 2021 RuiCare All rights reserved.
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
os.environ["SM_FRAMEWORK"] = "tf.keras"
import tensorflow as tf
from net import refinenetplus3d, refinenetplus3d_v2
from util.generator import Nii3DGenerator
from keras.callbacks import ModelCheckpoint, TensorBoard
def binary_accuracy_loss(y_true, y_pred):
y_true = tf.cast(y_true, tf.float16)
y_pred = tf.cast(y_pred, tf.float16)
y_true = tf.cast(y_true, tf.float16)
y_pred = tf.cast(y_pred, tf.float16)
return 1. - tf.keras.metrics.binary_accuracy(y_true, y_pred)
if __name__ == '__main__':
train_generator = Nii3DGenerator(
img_dir='/data/Data/NII/WFY/bladder/After256/Train/img',
mask_dir='/data/Data/NII/WFY/bladder/After256/Train/mask',
batch_size=2,
crop=[32, 256, 256],
batch_must_exist_label=True
)
c = train_generator.crop
checkpoint_path = f'./model/RefNetPv2_xxxx_{c[0]}:{c[1]}:{c[2]}.hdf5'
weight_path = checkpoint_path
config = tf.compat.v1.ConfigProto()
config.gpu_options.allow_growth = True
session = tf.compat.v1.Session(config=config)
# model = refinenetplus3d.load_refine_net_plus_3d(input_shape=c + [1])
model = refinenetplus3d_v2.load_refine_net_plus_3d(input_shape=c + [1])
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
model._get_distribution_strategy = lambda: None
if weight_path and os.path.exists(weight_path): model.load_weights(weight_path)
checkpoint = ModelCheckpoint(checkpoint_path, monitor='loss', verbose=1, save_best_only=True)
board = TensorBoard(log_dir='./log', histogram_freq=1, update_freq='batch')
model.fit(train_generator, None, batch_size=train_generator.batch_size,
shuffle=train_generator.shuffle, epochs=1000, steps_per_epoch=100,
use_multiprocessing=False, workers=12, callbacks=[checkpoint])