8000 adding resnet18 support by cl3m3nt · Pull Request #694 · autorope/donkeycar · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

adding resnet18 support #694

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 110 additions & 0 deletions donkeycar/parts/keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@
from tensorflow import keras
from tensorflow.keras.layers import Input, Dense
from tensorflow.keras.layers import Convolution2D, MaxPooling2D, BatchNormalization
from tensorflow.keras.layers import GlobalAveragePooling2D
from tensorflow.keras.layers import Activation, Dropout, Flatten
from tensorflow.keras.layers import LSTM
from tensorflow.keras.layers import TimeDistributed as TD
from tensorflow.keras.layers import Conv3D, MaxPooling3D, Conv2DTranspose
from tensorflow.keras.backend import concatenate
from tensorflow.keras.models import Model, Sequential
from donkeycar.parts.keras_resnet18 import identity_block,conv_block,ResNet18
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line falls over in the CI, please also use pep-8 formatting

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any guess why it falls over in CI ?
Might be related to keras_resnet18 import
'from keras_applications.imagenet_utils' ?


import donkeycar as dk
from donkeycar.utils import normalize_image
Expand Down Expand Up @@ -698,3 +700,111 @@ def default_latent(num_outputs, input_shape):

model = Model(inputs=[img_in], outputs=outputs)
return model

# ResNet18 pre-requesite parameters
backend = tf.compat.v1.keras.backend
layers = tf.keras.layers
models = tf.keras.models
utils = tf.keras.utils


def resnet18_default_n_linear(num_outputs, input_shape=(120,60,3)):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As you already made a new module for the new model, pls put those functions there.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean here ? Do you recommend to move resnet18 function and classes to keras_resnet18 module ?


# Instantiate a ResNet18 model
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of a comment, please rather provide a doctoring.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is a doctoring :) ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a typo (autocomplete :-) meant to be 'docstring'.

resnet18 = ResNet18(include_top=False,weights='cifar100_coarse',input_shape=input_shape,backend=backend,layers=layers,models=models,utils=utils)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use pep-8 line length and formatting.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

for layer in resnet18.layers:
layer.trainable=False
resnet18_preprocess = tf.keras.applications.resnet.preprocess_input

# Transfer learning with Resnet18
drop = 0.2
img_in = Input(shape=input_shape,name='img_in')
x = resnet18_preprocess(img_in)
x = resnet18(img_in,training=True)
x = GlobalAveragePooling2D()(x)
# Classifier
x = Dense(128,activation='relu',name='dense_1')(x)
x = Dropout(drop)(x)
x = Dense(64,activation='relu',name='dense_2')(x)
x = Dropout(drop)(x)

outputs = []
for i in range(num_outputs):
outputs.append(
Dense(1, activation='linear', name='n_outputs' + str(i))(x))

model = Model(inputs=[img_in],outputs=outputs)

return model


def resnet18_default_categorical(input_shape=(120, 60, 3)):
# Instantiate a ResNet18 model
resnet18 = ResNet18(include_top=False,weights='cifar100_coarse',input_shape=input_shape,backend=backend,layers=layers,models=models,utils=utils)
for layer in resnet18.layers:
layer.trainable=False
resnet18_preprocess = tf.keras.applications.resnet.preprocess_input

# Transfer learning with Resnet18
drop = 0.2
img_in = Input(shape=input_shape,name='img_in')
x = resnet18_preprocess(img_in)
x = resnet18(img_in,training=True)
x = GlobalAveragePooling2D()(x)
# Classifier
x = Dense(128,activation='relu',name='dense_1')(x)
x = Dropout(drop)(x)
x = Dense(64,activation='relu',name='dense_2')(x)
x = Dropout(drop)(x)

# Categorical output of the angle into 15 bins
angle_out = Dense(15, activation='softmax', name='angle_out')(x)
# categorical output of throttle into 20 bins
throttle_out = Dense(20, activation='softmax', name='throttle_out')(x)

model = Model(inputs=[img_in], outputs=[angle_out, throttle_out])
return model


class Resnet18LinearKeras(KerasPilot):
def __init__(self, num_outputs=2, input_shape=(120, 160, 3)):
super().__init__()
self.model = resnet18_default_n_linear(num_outputs, input_shape)
self.optimizer = 'adam'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is set in the base class already.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean for the optimizer definition right ? I can remove it.


def compile(self):
self.model.compile(optimizer=self.optimizer, loss='mse',metrics='mse')

def inference(self, img_arr, other_arr):
img_arr = img_arr.reshape((1,) + img_arr.shape)
outputs = self.model.predict(img_arr)
steering = outputs[0]
return steering[0] , dk.utils.throttle(steering[0])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the behaviour of KerasInferred, KerasLinear returns steering and throttle independently.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes you are right, I'll fix this following KerasLinear.



class Resnet18CategoricalKeras(KerasPilot):
def __init__(self, input_shape=(120, 160, 3), throttle_range=0.5):
super().__init__()
self.model = resnet18_default_categorical(input_shape)
self.optimizer = 'adam'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is set in the base class.

self.compile()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pls remove, we only compile before training.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes makes sense.

self.throttle_range = throttle_range

def compile(self):
self.model.compile(optimizer=self.optimizer, metrics=['accuracy'],
loss={'angle_out': 'categorical_crossentropy',
'throttle_out': 'categorical_crossentropy'},
loss_weights={'angle_out': 0.5, 'throttle_out': 0.5})

def inference(self, img_arr, other_arr):
if img_arr is None:
print('no image')
return 0.0, 0.0

img_arr = img_arr.reshape((1,) + img_arr.shape)
angle_binned, throttle_binned = self.model.predict(img_arr)
N = len(throttle_binned[0])
throttle = dk.utils.linear_unbin(throttle_binned, N=N,
offset=0.0, R=self.throttle_range)
angle = dk.utils.linear_unbin(angle_binned)
return angle, throttle
191 changes: 191 additions & 0 deletions donkeycar/parts/keras_resnet18.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
import tensorflow as tf
from tensorflow.keras.layers import ZeroPadding2D, Input, GlobalAveragePooling2D,GlobalMaxPooling2D,Dense
from tensorflow.keras.layers import Convolution2D,MaxPooling2D,BatchNormalization
from tensorflow.keras.layers import Activation,Dropout,Flatten
from tensorflow.keras.models import Model,Sequential
from keras_applications.imagenet_utils import _obtain_input_shape, get_submodules_from_kwargs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a question: _obtain_... looks like a function not for public consumption - do you know why?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No idea...

import os
import warnings

# Those are mandatory for ResNet function to work
backend = tf.compat.v1.keras.backend
layers = tf.keras.layers
models = tf.keras.models
utils = tf.keras.utils

WEIGHTS_PATH = 'https://raw.githubusercontent.com/cl3m3nt/resnet/master/resnet18_cifar100_top.h5'
WEIGHTS_PATH_NO_TOP = 'https://raw.githubusercontent.com/cl3m3nt/resnet/master/resnet18_cifar100_no_top.h5'


def identity_block(input_tensor, kernel_size, filters, stage, block):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a doc string (one-liner might be enough, but full one is better of course).

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok sure.

filters1, filters2 = filters
if backend.image_data_format() == 'channels_last':
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This following code seems to re-appear a couple of times, maybe factor this into a small helper function.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, same here I followed what keras-team did for resnet50.
There a 2 occurences of the same code, which could be acceptable don't you think ?

bn_axis = 3

else:
bn_axis = 1
conv_name_base = 'res' + str(stage) + block + '_branch'
bn_name_base = 'bn' + str(stage) + block +'_branch'

x = Convolution2D(filters1,(1,1),
kernel_initializer='he_normal',
name = conv_name_base + '2a')(input_tensor)
x = BatchNormalization(axis=bn_axis,name=bn_name_base + '2a')(x)
x = Activation('relu')(x)

x = Convolution2D(filters2, kernel_size,
padding='same',
kernel_initializer='he_normal',
name = conv_name_base + '2b')(x)
x = BatchNormalization(axis=bn_axis,name=bn_name_base+'2b')(x)
x = Activation('relu')(x)

x = tf.keras.layers.add([x,input_tensor])
x = Activation('relu')(x)
return x


def conv_block(input_tensor,kernel_size,filters,stage,block,strides=(2,2)):

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok.

filters1, filters2 = filters
if backend.image_data_format() == 'channels_last':
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can compact this into bn_axis = 3 if ... else 1.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok.

bn_axis = 3
else:
bn_axis = 1
conv_name_base = 'res' + str(stage) + block + '_branch'
bn_name_base = 'bn' + str(stage) + block + '_branch'

x = Convolution2D(filters1,(1,1),strides=strides,
kernel_initializer='he_normal',
name = conv_name_base + '2a')(input_tensor)
x = BatchNormalization(bn_axis,name=bn_name_base + '2a')(x)
x = Activation('relu')(x)


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No double white space - pep-8

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok

x = Convolution2D(filters2, kernel_size,
padding='same',
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indention: pep-8.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok

kernel_initializer='he_normal',
name = conv_name_base + '2b')(x)
x = BatchNormalization(axis=bn_axis,name=bn_name_base+'2b')(x)
x = Activation('relu')(x)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here too.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok


shortcut = Convolution2D(filters2,(1,1),strides=strides,
kernel_initializer='he_normal',
name=conv_name_base+'1')(input_tensor)

shortcut = BatchNormalization(
axis=bn_axis,name=bn_name_base+'1')(shortcut)

x = tf.keras.layers.add([x,shortcut])
x = Activation('relu')(x)
return x


# ResnNet18
def ResNet18(include_top=True,
weights='cifar100_coarse',
input_tensor=None,
input_shape=None,
pooling=None,
classes=20,
**kwargs):
global backend, layers, models, keras_utils
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please try to avoid global variables. If yo need to share them between functions it's fine to put them into a class.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

backend, layers, models, keras_utils = get_submodules_from_kwargs(kwargs)

# Check Weights
if not (weights in {'cifar100_coarse', None} or os.path.exists(weights)):
raise ValueError('The `weights` argument should be either '
'`None` (random initialization), `cifar100_coarse` '
'(pre-training on cifar100 coarse (super) classes), '
'or the path to the weights file to be loaded.')

if weights == 'cifar100_coarse' and include_top and classes != 20:
raise ValueError('If using `weights` as `"cifar100_coarse"` with `include_top`'
' as true, `classes` should be 20')

# Determine proper input shape
input_shape = _obtain_input_shape(input_shape,
default_size=224,
min_size=32,
data_format=backend.image_data_format(),
require_flatten=include_top,
weights=weights)

if input_tensor is None:
img_input = layers.Input(shape=input_shape)
else:
if not tf.keras.backend.is_keras_tensor(input_tensor):
img_input = layers.Input(tensor=input_tensor, shape=input_shape)
else:
img_input = input_tensor
if backend.image_data_format() == 'channels_last':
bn_axis = 3
else:
bn_axis = 1

# Build ResNet18 architecture
x = ZeroPadding2D(padding=(3,3),name='conv1_pad')(img_input)
x = Convolution2D(64,(7,7),
strides=(2,2),
padding='valid',
kernel_initializer='he_normal',
name='conv1')(x)
x = BatchNormalization(axis=bn_axis,name='bn_conv1')(x)
x = Activation('relu')(x)
x = ZeroPadding2D(padding=(1,1),name='pool1_pad')(x)
x = MaxPooling2D((3,3),strides=(2,2))(x)

x = identity_block(x,3,[64,64],stage=2,block='a')
x = identity_block(x,3,[64,64],stage=2,block='b')

x = conv_block(x,3,[128,128],stage=3,block='a')
x = identity_block(x,3,[128,128],stage=3,block='b')

x = conv_block(x,3,[256,256],stage=4,block='a')
x = identity_block(x,3,[256,256],stage=4,block='b')

x = conv_block(x,3,[512,512],stage=5,block='a')
x = identity_block(x,3,[512,512],stage=5,block='b')

# Managing Top
if include_top:
x = layers.GlobalAveragePooling2D(name='avg_pool')(x)
x = layers.Dense(classes, activation='softmax', name='fc20')(x)
else:
if pooling == 'avg':
x = layers.GlobalAveragePooling2D()(x)
elif pooling == 'max':
x = layers.GlobalMaxPooling2D()(x)
else:
warnings.warn('No flattenting layer operation like AveragePooling2D or MaxPooling2D has been added'
'whereas there are not top. You will need to apply AveragePooling2D or MaxPooling2D in case of'
'doing transfer learning')

# Ensure that the model takes into account
# any potential predecessors of `input_tensor`.
if input_tensor is not None:
inputs = keras_utils.get_source_inputs(input_tensor)
else:
inputs = img_input
# Create model
model = Model(inputs, x, name='resnet18')

# Load weights
if weights == 'cifar100_coarse':
if include_top:
weights_path = keras_utils.get_file(
'resnet18_cifar100_top.h5',
WEIGHTS_PATH,
cache_subdir='models',
md5_hash='e0798dd90ac7e0498cbdea853bd3ed7f')
else:
weights_path = keras_utils.get_file(
'resnet18_cifar100_no_top.h5',
WEIGHTS_PATH_NO_TOP,
cache_subdir='models',
md5_hash='bfeace78cec55f2b0401c1f41c81e1dd')
model.load_weights(weights_path)


return model
5 changes: 5 additions & 0 deletions donkeycar/parts/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ def load(self, model_path):
'TFlitePilot should load only .tflite files'
# Load TFLite model and allocate tensors.
self.interpreter = tf.lite.Interpreter(model_path=model_path)
'''
#Uncomment below self.interpreter and comment above in case you have TPU edge on your donkeycar to accelerate inference
#You need tpu edge runtime installed as pre-requesite: https://coral.ai/docs/accelerator/get-started
self.interpreter = tf.lite.Interpreter(model_path=model_path,experimental_delegates=[tflite.load_delegate('libedgetpu.so.1')])
'''
self.interpreter.allocate_tensors()

# Get input and output tensors.
Expand Down
Loading
0