Files
StreetSignRecognitionTensor/classes/DataLoader.py
Denys Seredenko c148b03ab5 Ready Model
2024-12-18 12:15:41 +01:00

96 lines
2.9 KiB
Python

import os
import csv
import cv2
import tensorflow as tf
import numpy as np
class DataSetLoader:
def __init__(self, p: str, x: int, y: int) -> None:
if not os.path.exists(p) or p is None:
raise Exception("Path is not correct")
else:
self.path_to_dataset = p
self.x_size = x if x is not None else 25
self.y_size = y if y is not None else 25
def get_concept(self, path) -> int:
if "fahrtrichtung_links" in path:
return 1
elif "fahrtrichtung_rechts" in path:
return 2
elif "rechts_vor_links" in path:
return 3
elif "stop" in path:
return 4
elif "vorfahrt_gewaehren" in path:
return 5
elif "vorfahrtsstrasse" in path:
return 6
else:
return 7
def get_classified_csv(self):
images = []
for dirpath, dnames, fnames in os.walk(self.path_to_dataset):
for fname in fnames:
image_path = os.path.join(dirpath, fname)
concept = self.get_concept(image_path)
if any(ext in fname for ext in ['.bmp', '.jpeg', '.jpg', '.png']):
images.append((image_path, concept))
with open(os.path.abspath(os.path.join(__file__, "..", "..", "all_data_cropped.csv")), "w") as csv_output:
csv_writer = csv.writer(csv_output, delimiter=";")
csv_writer.writerow(["name", "concept"])
for image in images:
csv_writer.writerow(image)
def create_dataset_from_csv(self, path: str):
images = []
concepts = []
with open(path, "r") as csv_data:
csv_reader = csv.reader(csv_data, delimiter=";")
next(csv_reader)
for row in csv_reader:
images.append(row[0])
concepts.append(int(row[1]))
all_images = []
all_labels = []
for idx, (img_path, label) in enumerate(zip(images, concepts)):
try:
image = self.load_image_with_opencv(img_path)
all_images.append(image)
all_labels.append(label)
tf.print(idx)
except Exception as e:
print(f"Error loading image {img_path}: {e}")
continue # Skip this image if error occurs
all_images = np.array(all_images)
all_labels = np.array(all_labels)
dataset = tf.data.Dataset.from_tensor_slices((all_images, all_labels))
return dataset
def load_image_with_opencv(self, file_path):
image = cv2.imread(file_path)
if image is None:
raise Exception(f"Could not load image: {file_path}")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = cv2.resize(image, (self.x_size, self.y_size))
image = image.astype(np.float32) / 255.0
return image