import threading import queue import time import logging import cv2 import numpy as np import tflite_runtime.interpreter as tflite from config import Config logger = logging.getLogger(__name__) class InferenceWorker: def __init__(self): self.input_queue = queue.Queue(maxsize=10) self.result_queue = queue.Queue() self.running = False self.interpreter = None self.input_details = None self.output_details = None self.lock = threading.Lock() # Load Model self.load_model() def load_model(self): try: logger.info(f"Loading TFLite model from: {Config.MODEL_PATH}") self.interpreter = tflite.Interpreter(model_path=Config.MODEL_PATH) self.interpreter.allocate_tensors() self.input_details = self.interpreter.get_input_details() self.output_details = self.interpreter.get_output_details() # Store original input shape for resizing logic self.original_input_shape = self.input_details[0]['shape'] logger.info(f"Model loaded. Default input shape: {self.original_input_shape}") except Exception as e: logger.critical(f"Failed to load TFLite model: {e}") self.interpreter = None def start(self): if self.running: return self.running = True threading.Thread(target=self._worker_loop, daemon=True).start() logger.info("Inference worker started.") def add_task(self, camera_id, rois, frame): """Add task (non-blocking).""" if not self.interpreter: return try: task = { 'camera_id': camera_id, 'rois': rois, 'frame': frame, 'timestamp': time.time() } self.input_queue.put(task, block=False) except queue.Full: pass def get_result(self): try: return self.result_queue.get(block=False) except queue.Empty: return None def _worker_loop(self): while self.running: try: task = self.input_queue.get(timeout=1) except queue.Empty: continue cam_id = task['camera_id'] rois = task['rois'] frame = task['frame'] try: # 1. Crop all ROIs crops = self._crop_rois(frame, rois) if not crops: continue # 2. Batch Predict (Optimized Step) digits = self.predict_batch(crops) # 3. Combine valid_digits = [d for d in digits if d.isdigit()] if len(valid_digits) == len(digits) and len(valid_digits) > 0: final_number = int("".join(valid_digits)) self.result_queue.put({ 'camera_id': cam_id, 'value': final_number, 'digits': valid_digits }) except Exception as e: logger.error(f"Inference error for {cam_id}: {e}") def _crop_rois(self, image, roi_list): cropped_images = [] for roi in roi_list: try: x, y, w, h = roi['x'], roi['y'], roi['width'], roi['height'] cropped = image[y:y+h, x:x+w] if cropped.size > 0: cropped_images.append(cropped) except Exception: pass return cropped_images def predict_batch(self, images): """Run inference on a batch of images at once.""" with self.lock: # <--- Add this wrapper if not self.interpreter: return [] num_images = len(images) if num_images == 0: return [] input_index = self.input_details[0]['index'] output_index = self.output_details[0]['index'] # Preprocess all images into a single batch array # Shape: [N, 32, 20, 3] (assuming model expects 32x20 rgb) batch_input = [] target_h, target_w = 32, 20 # Based on your previous code logic for img in images: # Resize roi_resized = cv2.resize(img, (target_w, target_h)) # Color roi_rgb = cv2.cvtColor(roi_resized, cv2.COLOR_BGR2RGB) # Normalize roi_norm = roi_rgb.astype(np.float32) batch_input.append(roi_norm) # Create batch tensor input_tensor = np.array(batch_input) # --- DYNAMIC RESIZING --- # TFLite models have a fixed input size (usually batch=1). # We must resize the input tensor to match our current batch size (N). # 1. Resize input tensor self.interpreter.resize_tensor_input(input_index, [num_images, target_h, target_w, 3]) # 2. Re-allocate tensors (This is expensive! See note below) self.interpreter.allocate_tensors() # 3. Run Inference self.interpreter.set_tensor(input_index, input_tensor) self.interpreter.invoke() # 4. Get Results output_data = self.interpreter.get_tensor(output_index) # Result shape is [N, 10] (probabilities for 10 digits) predictions = [] for i in range(num_images): digit_class = np.argmax(output_data[i]) predictions.append(str(digit_class)) return predictions