-
Notifications
You must be signed in to change notification settings - Fork 1.3k
adding resnet18 support #694
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,12 +15,14 @@ | |
from tensorflow import keras | ||
from tensorflow.keras.layers import Input, Dense | ||
from tensorflow.keras.layers import Convolution2D, MaxPooling2D, BatchNormalization | ||
from tensorflow.keras.layers import GlobalAveragePooling2D | ||
from tensorflow.keras.layers import Activation, Dropout, Flatten | ||
from tensorflow.keras.layers import LSTM | ||
from tensorflow.keras.layers import TimeDistributed as TD | ||
from tensorflow.keras.layers import Conv3D, MaxPooling3D, Conv2DTranspose | ||
from tensorflow.keras.backend import concatenate | ||
from tensorflow.keras.models import Model, Sequential | ||
from donkeycar.parts.keras_resnet18 import identity_block,conv_block,ResNet18 | ||
|
||
import donkeycar as dk | ||
from donkeycar.utils import normalize_image | ||
|
@@ -698,3 +700,111 @@ def default_latent(num_outputs, input_shape): | |
|
||
model = Model(inputs=[img_in], outputs=outputs) | ||
return model | ||
|
||
# ResNet18 pre-requesite parameters | ||
backend = tf.compat.v1.keras.backend | ||
layers = tf.keras.layers | ||
models = tf.keras.models | ||
utils = tf.keras.utils | ||
|
||
|
||
def resnet18_default_n_linear(num_outputs, input_shape=(120,60,3)): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As you already made a new module for the new model, pls put those functions there. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What do you mean here ? Do you recommend to move resnet18 function and classes to keras_resnet18 module ? |
||
|
||
# Instantiate a ResNet18 model | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Instead of a comment, please rather provide a doctoring. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is a doctoring :) ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's a typo (autocomplete :-) meant to be 'docstring'. |
||
resnet18 = ResNet18(include_top=False,weights='cifar100_coarse',input_shape=input_shape,backend=backend,layers=layers,models=models,utils=utils) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please use pep-8 line length and formatting. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok |
||
for layer in resnet18.layers: | ||
layer.trainable=False | ||
resnet18_preprocess = tf.keras.applications.resnet.preprocess_input | ||
|
||
# Transfer learning with Resnet18 | ||
drop = 0.2 | ||
img_in = Input(shape=input_shape,name='img_in') | ||
x = resnet18_preprocess(img_in) | ||
x = resnet18(img_in,training=True) | ||
x = GlobalAveragePooling2D()(x) | ||
# Classifier | ||
x = Dense(128,activation='relu',name='dense_1')(x) | ||
x = Dropout(drop)(x) | ||
x = Dense(64,activation='relu',name='dense_2')(x) | ||
x = Dropout(drop)(x) | ||
|
||
outputs = [] | ||
for i in range(num_outputs): | ||
outputs.append( | ||
Dense(1, activation='linear', name='n_outputs' + str(i))(x)) | ||
|
||
model = Model(inputs=[img_in],outputs=outputs) | ||
|
||
return model | ||
|
||
|
||
def resnet18_default_categorical(input_shape=(120, 60, 3)): | ||
# Instantiate a ResNet18 model | ||
resnet18 = ResNet18(include_top=False,weights='cifar100_coarse',input_shape=input_shape,backend=backend,layers=layers,models=models,utils=utils) | ||
for layer in resnet18.layers: | ||
layer.trainable=False | ||
resnet18_preprocess = tf.keras.applications.resnet.preprocess_input | ||
|
||
# Transfer learning with Resnet18 | ||
drop = 0.2 | ||
img_in = Input(shape=input_shape,name='img_in') | ||
x = resnet18_preprocess(img_in) | ||
x = resnet18(img_in,training=True) | ||
x = GlobalAveragePooling2D()(x) | ||
# Classifier | ||
x = Dense(128,activation='relu',name='dense_1')(x) | ||
x = Dropout(drop)(x) | ||
x = Dense(64,activation='relu',name='dense_2')(x) | ||
x = Dropout(drop)(x) | ||
|
||
# Categorical output of the angle into 15 bins | ||
angle_out = Dense(15, activation='softmax', name='angle_out')(x) | ||
# categorical output of throttle into 20 bins | ||
throttle_out = Dense(20, activation='softmax', name='throttle_out')(x) | ||
|
||
model = Model(inputs=[img_in], outputs=[angle_out, throttle_out]) | ||
return model | ||
|
||
|
||
class Resnet18LinearKeras(KerasPilot): | ||
def __init__(self, num_outputs=2, input_shape=(120, 160, 3)): | ||
super().__init__() | ||
self.model = resnet18_default_n_linear(num_outputs, input_shape) | ||
self.optimizer = 'adam' | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is set in the base class already. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you mean for the optimizer definition right ? I can remove it. |
||
|
||
def compile(self): | ||
self.model.compile(optimizer=self.optimizer, loss='mse',metrics='mse') | ||
|
||
def inference(self, img_arr, other_arr): | ||
img_arr = img_arr.reshape((1,) + img_arr.shape) | ||
outputs = self.model.predict(img_arr) | ||
steering = outputs[0] | ||
return steering[0] , dk.utils.throttle(steering[0]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is the behaviour of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes you are right, I'll fix this following KerasLinear. |
||
|
||
|
||
class Resnet18CategoricalKeras(KerasPilot): | ||
def __init__(self, input_shape=(120, 160, 3), throttle_range=0.5): | ||
super().__init__() | ||
self.model = resnet18_default_categorical(input_shape) | ||
self.optimizer = 'adam' | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is set in the base class. |
||
self.compile() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Pls remove, we only compile before training. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes makes sense. |
||
self.throttle_range = throttle_range | ||
|
||
def compile(self): | ||
self.model.compile(optimizer=self.optimizer, metrics=['accuracy'], | ||
loss={'angle_out': 'categorical_crossentropy', | ||
'throttle_out': 'categorical_crossentropy'}, | ||
loss_weights={'angle_out': 0.5, 'throttle_out': 0.5}) | ||
|
||
def inference(self, img_arr, other_arr): | ||
if img_arr is None: | ||
print('no image') | ||
return 0.0, 0.0 | ||
|
||
img_arr = img_arr.reshape((1,) + img_arr.shape) | ||
angle_binned, throttle_binned = self.model.predict(img_arr) | ||
N = len(throttle_binned[0]) | ||
throttle = dk.utils.linear_unbin(throttle_binned, N=N, | ||
offset=0.0, R=self.throttle_range) | ||
angle = dk.utils.linear_unbin(angle_binned) | ||
return angle, throttle |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,191 @@ | ||
import tensorflow as tf | ||
from tensorflow.keras.layers import ZeroPadding2D, Input, GlobalAveragePooling2D,GlobalMaxPooling2D,Dense | ||
from tensorflow.keras.layers import Convolution2D,MaxPooling2D,BatchNormalization | ||
from tensorflow.keras.layers import Activation,Dropout,Flatten | ||
from tensorflow.keras.models import Model,Sequential | ||
from keras_applications.imagenet_utils import _obtain_input_shape, get_submodules_from_kwargs | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just a question: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No idea... |
||
import os | ||
import warnings | ||
|
||
# Those are mandatory for ResNet function to work | ||
backend = tf.compat.v1.keras.backend | ||
layers = tf.keras.layers | ||
models = tf.keras.models | ||
utils = tf.keras.utils | ||
|
||
WEIGHTS_PATH = 'https://raw.githubusercontent.com/cl3m3nt/resnet/master/resnet18_cifar100_top.h5' | ||
WEIGHTS_PATH_NO_TOP = 'https://raw.githubusercontent.com/cl3m3nt/resnet/master/resnet18_cifar100_no_top.h5' | ||
|
||
|
||
def identity_block(input_tensor, kernel_size, filters, stage, block): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add a doc string (one-liner might be enough, but full one is better of course). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok sure. |
||
filters1, filters2 = filters | ||
if backend.image_data_format() == 'channels_last': | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This following code seems to re-appear a couple of times, maybe factor this into a small helper function. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, same here I followed what keras-team did for resnet50. |
||
bn_axis = 3 | ||
|
||
else: | ||
bn_axis = 1 | ||
conv_name_base = 'res' + str(stage) + block + '_branch' | ||
bn_name_base = 'bn' + str(stage) + block +'_branch' | ||
|
||
x = Convolution2D(filters1,(1,1), | ||
kernel_initializer='he_normal', | ||
name = conv_name_base + '2a')(input_tensor) | ||
x = BatchNormalization(axis=bn_axis,name=bn_name_base + '2a')(x) | ||
x = Activation('relu')(x) | ||
|
||
x = Convolution2D(filters2, kernel_size, | ||
padding='same', | ||
kernel_initializer='he_normal', | ||
name = conv_name_base + '2b')(x) | ||
x = BatchNormalization(axis=bn_axis,name=bn_name_base+'2b')(x) | ||
x = Activation('relu')(x) | ||
|
||
x = tf.keras.layers.add([x,input_tensor]) | ||
x = Activation('relu')(x) | ||
return x | ||
|
||
|
||
def conv_block(input_tensor,kernel_size,filters,stage,block,strides=(2,2)): | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok. |
||
filters1, filters2 = filters | ||
if backend.image_data_format() == 'channels_last': | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can compact this into There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok. |
||
bn_axis = 3 | ||
else: | ||
bn_axis = 1 | ||
conv_name_base = 'res' + str(stage) + block + '_branch' | ||
bn_name_base = 'bn' + str(stage) + block + '_branch' | ||
|
||
x = Convolution2D(filters1,(1,1),strides=strides, | ||
kernel_initializer='he_normal', | ||
name = conv_name_base + '2a')(input_tensor) | ||
x = BatchNormalization(bn_axis,name=bn_name_base + '2a')(x) | ||
x = Activation('relu')(x) | ||
|
||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No double white space - pep-8 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok |
||
x = Convolution2D(filters2, kernel_size, | ||
padding='same', | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Indention: pep-8. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok |
||
kernel_initializer='he_normal', | ||
name = conv_name_base + '2b')(x) | ||
x = BatchNormalization(axis=bn_axis,name=bn_name_base+'2b')(x) | ||
x = Activation('relu')(x) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. here too. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok |
||
|
||
shortcut = Convolution2D(filters2,(1,1),strides=strides, | ||
kernel_initializer='he_normal', | ||
name=conv_name_base+'1')(input_tensor) | ||
|
||
shortcut = BatchNormalization( | ||
axis=bn_axis,name=bn_name_base+'1')(shortcut) | ||
|
||
x = tf.keras.layers.add([x,shortcut]) | ||
x = Activation('relu')(x) | ||
return x | ||
|
||
|
||
# ResnNet18 | ||
def ResNet18(include_top=True, | ||
weights='cifar100_coarse', | ||
input_tensor=None, | ||
input_shape=None, | ||
pooling=None, | ||
classes=20, | ||
**kwargs): | ||
global backend, layers, models, keras_utils | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please try to avoid global variables. If yo need to share them between functions it's fine to put them into a class. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I followed here what keras_team did for resnet50 |
||
backend, layers, models, keras_utils = get_submodules_from_kwargs(kwargs) | ||
|
||
# Check Weights | ||
if not (weights in {'cifar100_coarse', None} or os.path.exists(weights)): | ||
raise ValueError('The `weights` argument should be either ' | ||
'`None` (random initialization), `cifar100_coarse` ' | ||
'(pre-training on cifar100 coarse (super) classes), ' | ||
'or the path to the weights file to be loaded.') | ||
|
||
if weights == 'cifar100_coarse' and include_top and classes != 20: | ||
raise ValueError('If using `weights` as `"cifar100_coarse"` with `include_top`' | ||
' as true, `classes` should be 20') | ||
|
||
# Determine proper input shape | ||
input_shape = _obtain_input_shape(input_shape, | ||
default_size=224, | ||
min_size=32, | ||
data_format=backend.image_data_format(), | ||
require_flatten=include_top, | ||
weights=weights) | ||
|
||
if input_tensor is None: | ||
img_input = layers.Input(shape=input_shape) | ||
else: | ||
if not tf.keras.backend.is_keras_tensor(input_tensor): | ||
img_input = layers.Input(tensor=input_tensor, shape=input_shape) | ||
else: | ||
img_input = input_tensor | ||
if backend.image_data_format() == 'channels_last': | ||
bn_axis = 3 | ||
else: | ||
bn_axis = 1 | ||
|
||
# Build ResNet18 architecture | ||
x = ZeroPadding2D(padding=(3,3),name='conv1_pad')(img_input) | ||
x = Convolution2D(64,(7,7), | ||
strides=(2,2), | ||
padding='valid', | ||
kernel_initializer='he_normal', | ||
name='conv1')(x) | ||
x = BatchNormalization(axis=bn_axis,name='bn_conv1')(x) | ||
x = Activation('relu')(x) | ||
x = ZeroPadding2D(padding=(1,1),name='pool1_pad')(x) | ||
x = MaxPooling2D((3,3),strides=(2,2))(x) | ||
|
||
x = identity_block(x,3,[64,64],stage=2,block='a') | ||
x = identity_block(x,3,[64,64],stage=2,block='b') | ||
|
||
x = conv_block(x,3,[128,128],stage=3,block='a') | ||
x = identity_block(x,3,[128,128],stage=3,block='b') | ||
|
||
x = conv_block(x,3,[256,256],stage=4,block='a') | ||
x = identity_block(x,3,[256,256],stage=4,block='b') | ||
|
||
x = conv_block(x,3,[512,512],stage=5,block='a') | ||
x = identity_block(x,3,[512,512],stage=5,block='b') | ||
|
||
# Managing Top | ||
if include_top: | ||
x = layers.GlobalAveragePooling2D(name='avg_pool')(x) | ||
x = layers.Dense(classes, activation='softmax', name='fc20')(x) | ||
else: | ||
if pooling == 'avg': | ||
x = layers.GlobalAveragePooling2D()(x) | ||
elif pooling == 'max': | ||
x = layers.GlobalMaxPooling2D()(x) | ||
else: | ||
warnings.warn('No flattenting layer operation like AveragePooling2D or MaxPooling2D has been added' | ||
'whereas there are not top. You will need to apply AveragePooling2D or MaxPooling2D in case of' | ||
'doing transfer learning') | ||
|
||
# Ensure that the model takes into account | ||
# any potential predecessors of `input_tensor`. | ||
if input_tensor is not None: | ||
inputs = keras_utils.get_source_inputs(input_tensor) | ||
else: | ||
inputs = img_input | ||
# Create model | ||
model = Model(inputs, x, name='resnet18') | ||
|
||
# Load weights | ||
if weights == 'cifar100_coarse': | ||
if include_top: | ||
weights_path = keras_utils.get_file( | ||
'resnet18_cifar100_top.h5', | ||
WEIGHTS_PATH, | ||
cache_subdir='models', | ||
md5_hash='e0798dd90ac7e0498cbdea853bd3ed7f') | ||
else: | ||
weights_path = keras_utils.get_file( | ||
'resnet18_cifar100_no_top.h5', | ||
WEIGHTS_PATH_NO_TOP, | ||
cache_subdir='models', | ||
md5_hash='bfeace78cec55f2b0401c1f41c81e1dd') | ||
model.load_weights(weights_path) | ||
|
||
|
||
return model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line falls over in the CI, please also use pep-8 formatting
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any guess why it falls over in CI ?
Might be related to keras_resnet18 import
'from keras_applications.imagenet_utils' ?