58 lines
2.2 KiB
Python
58 lines
2.2 KiB
Python
from classes.DataLoader import DataSetLoader
|
|
import tensorflow as tf
|
|
import matplotlib.pyplot as plt
|
|
|
|
def create_model(image_size):
|
|
model = tf.keras.Sequential()
|
|
|
|
if image_size <= 10:
|
|
model.add(tf.keras.layers.Conv2D(32, (2, 2), activation='relu', input_shape=(image_size, image_size, 3)))
|
|
model.add(tf.keras.layers.MaxPooling2D((2, 2), padding='same'))
|
|
else:
|
|
model.add(tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(image_size, image_size, 3)))
|
|
model.add(tf.keras.layers.MaxPooling2D((2, 2)))
|
|
|
|
if image_size > 10:
|
|
model.add(tf.keras.layers.Conv2D(64, (3, 3), activation='relu'))
|
|
model.add(tf.keras.layers.MaxPooling2D((2, 2)))
|
|
|
|
if image_size > 20:
|
|
model.add(tf.keras.layers.Conv2D(128, (3, 3), activation='relu'))
|
|
model.add(tf.keras.layers.MaxPooling2D((2, 2)))
|
|
|
|
model.add(tf.keras.layers.GlobalAveragePooling2D())
|
|
|
|
model.add(tf.keras.layers.Dense(128, activation='relu'))
|
|
model.add(tf.keras.layers.Dense(7, activation='softmax'))
|
|
|
|
return model
|
|
|
|
if __name__ == '__main__':
|
|
#Create dataset from image folder
|
|
|
|
# loader = DataSetLoader(p='/Users/denysseredenko/Desktop/ML-cropped/processed', x=50, y=50)
|
|
# ds = loader.create_dataset_from_csv('../all_data_cropped.csv')
|
|
# ds.save('/Users/denysseredenko/Desktop/privat/StreetSignRecognitionTensor/dataset_cropped_50x50')
|
|
|
|
ds = tf.data.Dataset.load('../dataset/13802249122632658895/00000000.shard/00000000.snapshot')
|
|
ds = ds.shuffle(buffer_size=43349)
|
|
|
|
ds.save('/Users/denysseredenko/Desktop/privat/StreetSignRecognitionTensor/dataset_cropped_50x50')
|
|
|
|
train_ds, test_ds = tf.keras.utils.split_dataset(ds, left_size=0.9)
|
|
|
|
train_data = train_ds.batch(5).prefetch(tf.data.AUTOTUNE)
|
|
test_data = test_ds.batch(5).prefetch(tf.data.AUTOTUNE)
|
|
|
|
|
|
model = create_model(50)
|
|
|
|
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
|
|
loss=tf.keras.losses.SparseCategoricalCrossentropy(),
|
|
metrics=['accuracy'])
|
|
|
|
model.fit(train_data, epochs=7)
|
|
|
|
test_loss, test_acc = model.evaluate(test_data, verbose=2)
|
|
print('\nTest accuracy:', test_acc)
|