EdgeAI_Digit_Recognition/inference.py

166 lines
5.7 KiB
Python

import threading
import queue
import time
import logging
import cv2
import numpy as np
import tflite_runtime.interpreter as tflite
from config import Config
logger = logging.getLogger(__name__)
class InferenceWorker:
def __init__(self):
self.input_queue = queue.Queue(maxsize=10)
self.result_queue = queue.Queue()
self.running = False
self.interpreter = None
self.input_details = None
self.output_details = None
self.lock = threading.Lock()
# Load Model
self.load_model()
def load_model(self):
try:
logger.info(f"Loading TFLite model from: {Config.MODEL_PATH}")
self.interpreter = tflite.Interpreter(model_path=Config.MODEL_PATH)
self.interpreter.allocate_tensors()
self.input_details = self.interpreter.get_input_details()
self.output_details = self.interpreter.get_output_details()
# Store original input shape for resizing logic
self.original_input_shape = self.input_details[0]['shape']
logger.info(f"Model loaded. Default input shape: {self.original_input_shape}")
except Exception as e:
logger.critical(f"Failed to load TFLite model: {e}")
self.interpreter = None
def start(self):
if self.running: return
self.running = True
threading.Thread(target=self._worker_loop, daemon=True).start()
logger.info("Inference worker started.")
def add_task(self, camera_id, rois, frame):
"""Add task (non-blocking)."""
if not self.interpreter: return
try:
task = {
'camera_id': camera_id,
'rois': rois,
'frame': frame,
'timestamp': time.time()
}
self.input_queue.put(task, block=False)
except queue.Full:
pass
def get_result(self):
try:
return self.result_queue.get(block=False)
except queue.Empty:
return None
def _worker_loop(self):
while self.running:
try:
task = self.input_queue.get(timeout=1)
except queue.Empty:
continue
cam_id = task['camera_id']
rois = task['rois']
frame = task['frame']
try:
# 1. Crop all ROIs
crops = self._crop_rois(frame, rois)
if not crops: continue
# 2. Batch Predict (Optimized Step)
digits = self.predict_batch(crops)
# 3. Combine
valid_digits = [d for d in digits if d.isdigit()]
if len(valid_digits) == len(digits) and len(valid_digits) > 0:
final_number = int("".join(valid_digits))
self.result_queue.put({
'camera_id': cam_id,
'value': final_number,
'digits': valid_digits
})
except Exception as e:
logger.error(f"Inference error for {cam_id}: {e}")
def _crop_rois(self, image, roi_list):
cropped_images = []
for roi in roi_list:
try:
x, y, w, h = roi['x'], roi['y'], roi['width'], roi['height']
cropped = image[y:y+h, x:x+w]
if cropped.size > 0:
cropped_images.append(cropped)
except Exception:
pass
return cropped_images
def predict_batch(self, images):
"""Run inference on a batch of images at once."""
with self.lock: # <--- Add this wrapper
if not self.interpreter: return []
num_images = len(images)
if num_images == 0: return []
input_index = self.input_details[0]['index']
output_index = self.output_details[0]['index']
# Preprocess all images into a single batch array
# Shape: [N, 32, 20, 3] (assuming model expects 32x20 rgb)
batch_input = []
target_h, target_w = 32, 20 # Based on your previous code logic
for img in images:
# Resize
roi_resized = cv2.resize(img, (target_w, target_h))
# Color
roi_rgb = cv2.cvtColor(roi_resized, cv2.COLOR_BGR2RGB)
# Normalize
roi_norm = roi_rgb.astype(np.float32)
batch_input.append(roi_norm)
# Create batch tensor
input_tensor = np.array(batch_input)
# --- DYNAMIC RESIZING ---
# TFLite models have a fixed input size (usually batch=1).
# We must resize the input tensor to match our current batch size (N).
# 1. Resize input tensor
self.interpreter.resize_tensor_input(input_index, [num_images, target_h, target_w, 3])
# 2. Re-allocate tensors (This is expensive! See note below)
self.interpreter.allocate_tensors()
# 3. Run Inference
self.interpreter.set_tensor(input_index, input_tensor)
self.interpreter.invoke()
# 4. Get Results
output_data = self.interpreter.get_tensor(output_index)
# Result shape is [N, 10] (probabilities for 10 digits)
predictions = []
for i in range(num_images):
digit_class = np.argmax(output_data[i])
predictions.append(str(digit_class))
return predictions