Files
Denys Seredenko c148b03ab5 Ready Model
2024-12-18 12:15:41 +01:00

58 lines
2.2 KiB
Python

from classes.DataLoader import DataSetLoader
import tensorflow as tf
import matplotlib.pyplot as plt
def create_model(image_size):
model = tf.keras.Sequential()
if image_size <= 10:
model.add(tf.keras.layers.Conv2D(32, (2, 2), activation='relu', input_shape=(image_size, image_size, 3)))
model.add(tf.keras.layers.MaxPooling2D((2, 2), padding='same'))
else:
model.add(tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(image_size, image_size, 3)))
model.add(tf.keras.layers.MaxPooling2D((2, 2)))
if image_size > 10:
model.add(tf.keras.layers.Conv2D(64, (3, 3), activation='relu'))
model.add(tf.keras.layers.MaxPooling2D((2, 2)))
if image_size > 20:
model.add(tf.keras.layers.Conv2D(128, (3, 3), activation='relu'))
model.add(tf.keras.layers.MaxPooling2D((2, 2)))
model.add(tf.keras.layers.GlobalAveragePooling2D())
model.add(tf.keras.layers.Dense(128, activation='relu'))
model.add(tf.keras.layers.Dense(7, activation='softmax'))
return model
if __name__ == '__main__':
#Create dataset from image folder
# loader = DataSetLoader(p='/Users/denysseredenko/Desktop/ML-cropped/processed', x=50, y=50)
# ds = loader.create_dataset_from_csv('../all_data_cropped.csv')
# ds.save('/Users/denysseredenko/Desktop/privat/StreetSignRecognitionTensor/dataset_cropped_50x50')
ds = tf.data.Dataset.load('../dataset/13802249122632658895/00000000.shard/00000000.snapshot')
ds = ds.shuffle(buffer_size=43349)
ds.save('/Users/denysseredenko/Desktop/privat/StreetSignRecognitionTensor/dataset_cropped_50x50')
train_ds, test_ds = tf.keras.utils.split_dataset(ds, left_size=0.9)
train_data = train_ds.batch(5).prefetch(tf.data.AUTOTUNE)
test_data = test_ds.batch(5).prefetch(tf.data.AUTOTUNE)
model = create_model(50)
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
loss=tf.keras.losses.SparseCategoricalCrossentropy(),
metrics=['accuracy'])
model.fit(train_data, epochs=7)
test_loss, test_acc = model.evaluate(test_data, verbose=2)
print('\nTest accuracy:', test_acc)