diff --git a/__pycache__/inference.cpython-311.pyc b/__pycache__/inference.cpython-311.pyc index 54f000b..6d782e6 100644 Binary files a/__pycache__/inference.cpython-311.pyc and b/__pycache__/inference.cpython-311.pyc differ diff --git a/app.py b/app.py index 522308c..baae701 100644 --- a/app.py +++ b/app.py @@ -1,29 +1,37 @@ +import base64 +import json import logging import sys -import os import threading -import json import time import traceback -import base64 + import cv2 import numpy as np import paho.mqtt.client as mqtt -from flask import Flask, render_template, jsonify, request, Response +from flask import Flask, Response, jsonify, render_template, request -# Import Config, Manager, and NEW Inference Worker from config import Config -from manager import CameraManager from inference import InferenceWorker +from manager import CameraManager -# --- Logging Setup --- + +def _cfg(*names, default=None): + """Return first matching attribute from Config, else default.""" + for n in names: + if hasattr(Config, n): + return getattr(Config, n) + return default + + +LOG_LEVEL = _cfg("LOG_LEVEL", "LOGLEVEL", default=logging.INFO) logging.basicConfig( - level=Config.LOG_LEVEL, + level=LOG_LEVEL, format='%(asctime)s [%(levelname)s] %(message)s', - handlers=[logging.StreamHandler(sys.stdout)] + handlers=[logging.StreamHandler(sys.stdout)], ) - logger = logging.getLogger(__name__) + app = Flask(__name__) # --- Initialize Components --- @@ -34,15 +42,22 @@ inference_worker.start() # --- MQTT Setup --- mqtt_client = mqtt.Client() -if Config.MQTT_USERNAME and Config.MQTT_PASSWORD: - mqtt_client.username_pw_set(Config.MQTT_USERNAME, Config.MQTT_PASSWORD) +MQTT_USERNAME = _cfg("MQTT_USERNAME", "MQTTUSERNAME", default=None) +MQTT_PASSWORD = _cfg("MQTT_PASSWORD", "MQTTPASSWORD", default=None) +MQTT_BROKER = _cfg("MQTT_BROKER", "MQTTBROKER", default="127.0.0.1") +MQTT_PORT = int(_cfg("MQTT_PORT", "MQTTPORT", default=1883)) +MQTT_TOPIC = _cfg("MQTT_TOPIC", "MQTTTOPIC", default="homeassistant/sensor/RTSPCamDigitDetection/state") + +if MQTT_USERNAME and MQTT_PASSWORD: + mqtt_client.username_pw_set(MQTT_USERNAME, MQTT_PASSWORD) try: - mqtt_client.connect(Config.MQTT_BROKER, Config.MQTT_PORT, 60) + mqtt_client.connect(MQTT_BROKER, MQTT_PORT, 60) mqtt_client.loop_start() - logger.info(f"Connected to MQTT Broker at {Config.MQTT_BROKER}:{Config.MQTT_PORT}") + logger.info("Connected to MQTT Broker at %s:%s", MQTT_BROKER, MQTT_PORT) except Exception as e: - logger.error(f"Failed to connect to MQTT Broker: {e}") + logger.error("Failed to connect to MQTT Broker: %s", e) + # --- Helper Functions (UI Only) --- def crop_image_for_ui(image, roi_list, scaleX, scaleY): @@ -61,14 +76,17 @@ def crop_image_for_ui(image, roi_list, scaleX, scaleY): pass return cropped_images + def publish_detected_number(camera_id, detected_number, confidence=None): """Publish result to MQTT with optional confidence score.""" - topic = f"{Config.MQTT_TOPIC}/{camera_id}" + topic = f"{MQTT_TOPIC}/{camera_id}" + payload_dict = {"value": detected_number} if confidence is not None: - payload_dict["confidence"] = round(confidence, 2) + payload_dict["confidence"] = round(float(confidence), 2) payload = json.dumps(payload_dict) + try: mqtt_client.publish(topic, payload) log_msg = f"Published to {topic}: {detected_number}" @@ -76,40 +94,77 @@ def publish_detected_number(camera_id, detected_number, confidence=None): log_msg += f" (Conf: {confidence:.2f})" logger.info(log_msg) except Exception as e: - logger.error(f"MQTT Publish failed: {e}") + logger.error("MQTT Publish failed: %s", e) + + +# --- Debug helpers --- +_last_log = {} + +def log_rl(level, key, msg, every_s=10): + now = time.time() + last = _last_log.get(key, 0.0) + if now - last >= every_s: + _last_log[key] = now + logger.log(level, msg) + # --- Main Processing Loop (Refactored) --- last_processed_time = {} def process_all_cameras(): - """ - Revised Loop with Rate Limiting - """ - DETECTION_INTERVAL = 10 # Configurable interval (seconds) + """Revised loop with rate limiting + debug instrumentation.""" + DETECTION_INTERVAL = int(_cfg("DETECTION_INTERVAL", default=10)) + hb_last = 0.0 while True: try: + # Heartbeat (proves loop is alive even when no publishes happen) + now = time.time() + if now - hb_last >= 5.0: + hb_last = now + in_q = getattr(inference_worker, "input_queue", None) + out_q = getattr(inference_worker, "result_queue", None) + logger.info( + "HB mainloop alive; in_q=%s out_q=%s dropped=%s processed=%s last_invoke_s=%s", + (in_q.qsize() if in_q else "n/a"), + (out_q.qsize() if out_q else "n/a"), + getattr(inference_worker, "dropped_tasks", "n/a"), + getattr(inference_worker, "processed_tasks", "n/a"), + getattr(inference_worker, "last_invoke_secs", "n/a"), + ) + # --- Part 1: Process Results --- while True: result = inference_worker.get_result() if not result: break - cam_id = result['camera_id'] + cam_id = result.get('camera_id') + + # End-to-end latency tracing + task_ts = result.get("task_ts") + if task_ts is not None: + try: + age = time.time() - float(task_ts) + logger.info( + "Result cam=%s type=%s task_id=%s age_s=%.3f timing=%s", + cam_id, + result.get("type"), + result.get("task_id"), + age, + result.get("timing_s"), + ) + except Exception: + pass - # Check Result Type if result.get('type') == 'success': val = result['value'] conf = result.get('confidence') - # Update State & Publish camera_manager.results[cam_id] = val publish_detected_number(cam_id, val, conf) - elif result.get('type') == 'error': - # Log the error (Range or Confidence or Parse) - # This ensures the log appears exactly when the result is processed msg = result.get('message', 'Unknown error') - logger.warning(f"[{cam_id}] Detection skipped: {msg}") + logger.warning("[%s] Detection skipped: %s", cam_id, msg) # --- Part 2: Feed Frames --- camera_manager.load_roi_config() @@ -118,54 +173,73 @@ def process_all_cameras(): if not camera_data.get("active", True): continue - # RATE LIMIT CHECK current_time = time.time() - last_time = last_processed_time.get(camera_id, 0) + last_time = last_processed_time.get(camera_id, 0.0) if current_time - last_time < DETECTION_INTERVAL: - continue # Skip this camera, it's too soon + log_rl( + logging.DEBUG, + f"{camera_id}:rate", + f"[{camera_id}] skip: rate limit ({current_time - last_time:.2f}s<{DETECTION_INTERVAL}s)", + every_s=30, + ) + continue stream = camera_data.get("stream") - if not stream: continue + if not stream: + log_rl(logging.WARNING, f"{camera_id}:nostream", f"[{camera_id}] skip: no stream", every_s=10) + continue - # Warmup Check - if (current_time - stream.start_time) < 5: + # Warmup check + start_time = getattr(stream, "start_time", getattr(stream, "starttime", None)) + if start_time is not None and (current_time - start_time) < 5: + log_rl(logging.DEBUG, f"{camera_id}:warmup", f"[{camera_id}] skip: warmup", every_s=10) continue frame = stream.read() if frame is None: + log_rl(logging.WARNING, f"{camera_id}:noframe", f"[{camera_id}] skip: frame is None", every_s=5) continue - if np.std(frame) < 10: + frame_std = float(np.std(frame)) + if frame_std < 5: + log_rl( + logging.INFO, + f"{camera_id}:lowstd", + f"[{camera_id}] skip: low frame std={frame_std:.2f} (<10) (disturbed/blank/frozen?)", + every_s=5, + ) + mqtt_client.publish(f"{Config.MQTT_TOPIC}/{camera_id}/status", "disturbed") + continue roi_list = camera_manager.rois.get(camera_id, []) if not roi_list: + log_rl(logging.WARNING, f"{camera_id}:norois", f"[{camera_id}] skip: no ROIs", every_s=30) continue - # SEND TO WORKER - inference_worker.add_task(camera_id, roi_list, frame) - - # Update last processed time + inference_worker.add_task(camera_id, roi_list, frame, frame_std=frame_std) last_processed_time[camera_id] = current_time - # Sleep briefly to prevent CPU spinning, but keep it responsive for results time.sleep(0.1) except Exception as e: - logger.error(f"Global process loop error: {e}") + logger.error("Global process loop error: %s", e) traceback.print_exc() time.sleep(5) + # --- Flask Routes --- @app.route('/') def index(): return render_template('index.html') + @app.route('/cameras', methods=['GET']) def get_cameras(): return jsonify(camera_manager.get_camera_list()) + @app.route('/video/') def video_feed(camera_id): def generate(): @@ -174,11 +248,16 @@ def video_feed(camera_id): if frame is not None: ret, jpeg = cv2.imencode('.jpg', frame) if ret: - yield (b'--frame\r\n' b'Content-Type: image/jpeg\r\n\r\n' + jpeg.tobytes() + b'\r\n\r\n') + yield ( + b'--frame\r\n' + b'Content-Type: image/jpeg\r\n\r\n' + jpeg.tobytes() + b'\r\n\r\n' + ) else: time.sleep(0.1) + return Response(generate(), mimetype='multipart/x-mixed-replace; boundary=frame') + @app.route('/snapshot/') def snapshot(camera_id): frame = camera_manager.get_frame(camera_id) @@ -188,6 +267,7 @@ def snapshot(camera_id): return Response(jpeg.tobytes(), mimetype='image/jpeg') return 'No frame available', 404 + @app.route('/rois/', methods=['GET']) def get_rois(camera_id): try: @@ -218,28 +298,34 @@ def get_rois(camera_id): "y": int(round(roi["y"] * scaleY)), "width": int(round(roi["width"] * scaleX)), "height": int(round(roi["height"] * scaleY)), - "angle": roi["angle"] + "angle": roi.get("angle", 0), }) + return jsonify(scaled_rois) + except Exception as e: return jsonify({"error": str(e)}), 500 + @app.route("/save_rois", methods=["POST"]) def save_rois_api(): data = request.json + camera_id = data.get("camera_id") new_rois = data.get("rois") img_width = data.get("img_width") img_height = data.get("img_height") - if not camera_id or new_rois is None: return jsonify({"success": False}) + if not camera_id or new_rois is None: + return jsonify({"success": False}) cam = camera_manager.cameras.get(camera_id) - if not cam: return jsonify({"success": False}) + if not cam: + return jsonify({"success": False}) stream = cam.get("stream") - real_w = stream.width if stream and stream.width else cam["width"] - real_h = stream.height if stream and stream.height else cam["height"] + real_w = stream.width if stream and getattr(stream, "width", None) else cam["width"] + real_h = stream.height if stream and getattr(stream, "height", None) else cam["height"] scaleX = real_w / img_width if img_width else 1 scaleY = real_h / img_height if img_height else 1 @@ -252,21 +338,24 @@ def save_rois_api(): "y": int(round(roi["y"] * scaleY)), "width": int(round(roi["width"] * scaleX)), "height": int(round(roi["height"] * scaleY)), - "angle": roi["angle"] + "angle": roi.get("angle", 0), }) camera_manager.rois[camera_id] = scaled_rois return jsonify(camera_manager.save_roi_config()) + @app.route('/crop', methods=['POST']) def crop(): data = request.json + camera_id = data.get('camera_id') scaleX = data.get('scaleX', 1) scaleY = data.get('scaleY', 1) frame = camera_manager.get_frame(camera_id) - if frame is None: return jsonify({'error': 'No frame'}), 500 + if frame is None: + return jsonify({'error': 'No frame'}), 500 roi_list = camera_manager.rois.get(camera_id, []) cropped_images = crop_image_for_ui(frame, roi_list, scaleX, scaleY) @@ -279,31 +368,29 @@ def crop(): return jsonify({'cropped_images': cropped_base64_list}) + @app.route('/detect_digits', methods=['POST']) def detect_digits(): """Manual trigger: Runs inference immediately and returns result with validation.""" data = request.json + camera_id = data.get('camera_id') if not camera_id: return jsonify({'error': 'Invalid camera ID'}), 400 - # 1. Get Frame frame = camera_manager.get_frame(camera_id) if frame is None: return jsonify({'error': 'Failed to capture image'}), 500 - # 2. Get ROIs roi_list = camera_manager.rois.get(camera_id, []) if not roi_list: return jsonify({'error': 'No ROIs defined'}), 400 - # 3. Crop cropped_images = crop_image_for_ui(frame, roi_list, scaleX=1, scaleY=1) if not cropped_images: return jsonify({'error': 'Failed to crop ROIs'}), 500 try: - # 4. Run Inference Synchronously predictions = inference_worker.predict_batch(cropped_images) valid_digits_str = [] @@ -318,30 +405,24 @@ def detect_digits(): if p['confidence'] < CONFIDENCE_THRESHOLD: msg = f"Digit {i} ('{p['digit']}') rejected: conf {p['confidence']:.2f} < {CONFIDENCE_THRESHOLD}" rejected_reasons.append(msg) - logger.warning(f"[Manual] {msg}") + logger.warning("[Manual] %s", msg) else: valid_digits_str.append(p['digit']) confidences.append(p['confidence']) if len(valid_digits_str) != len(predictions): - return jsonify({ - 'error': 'Low confidence detection', - 'details': rejected_reasons, - 'raw': predictions - }), 400 + return jsonify({'error': 'Low confidence detection', 'details': rejected_reasons, 'raw': predictions}), 400 final_number_str = "".join(valid_digits_str) try: final_number = int(final_number_str) - # Range Check if not (MIN_VALUE <= final_number <= MAX_VALUE): msg = f"Value {final_number} out of range ({MIN_VALUE}-{MAX_VALUE})" - logger.warning(f"[Manual] {msg}") + logger.warning("[Manual] %s", msg) return jsonify({'error': 'Value out of range', 'value': final_number}), 400 - # Valid result - avg_conf = float(np.mean(confidences)) + avg_conf = float(np.mean(confidences)) if confidences else None publish_detected_number(camera_id, final_number, avg_conf) camera_manager.results[camera_id] = final_number @@ -349,25 +430,28 @@ def detect_digits(): 'detected_digits': valid_digits_str, 'final_number': final_number, 'confidences': confidences, - 'avg_confidence': avg_conf + 'avg_confidence': avg_conf, }) except ValueError: - return jsonify({'error': 'Could not parse digits', 'raw': valid_digits_str}), 500 + return jsonify({'error': 'Could not parse digits', 'raw': valid_digits_str}), 500 except Exception as e: - logger.error(f"Error during manual detection: {e}") + logger.error("Error during manual detection: %s", e) return jsonify({'error': str(e)}), 500 + @app.route('/update_camera_config', methods=['POST']) def update_camera_config(): data = request.json success = camera_manager.update_camera_flip(data.get("camera_id"), data.get("flip_type")) return jsonify({"success": success}) + # --- Main --- if __name__ == '__main__': t = threading.Thread(target=process_all_cameras, daemon=True) t.start() + logger.info("Starting Flask Server...") app.run(host='0.0.0.0', port=5000, threaded=True) diff --git a/inference.py b/inference.py index 94d33f0..1b31e69 100644 --- a/inference.py +++ b/inference.py @@ -1,66 +1,98 @@ -import threading -import queue -import time import logging +import queue +import threading +import time + import cv2 import numpy as np import tflite_runtime.interpreter as tflite + from config import Config logger = logging.getLogger(__name__) + +def _cfg(*names, default=None): + for n in names: + if hasattr(Config, n): + return getattr(Config, n) + return default + + class InferenceWorker: def __init__(self): self.input_queue = queue.Queue(maxsize=10) self.result_queue = queue.Queue() self.running = False + self.interpreter = None self.input_details = None self.output_details = None self.lock = threading.Lock() - # Validation thresholds - self.CONFIDENCE_THRESHOLD = 0.80 # Minimum confidence (0-1) to accept a digit - self.MIN_VALUE = 5 # Minimum allowed temperature value - self.MAX_VALUE = 100 # Maximum allowed temperature value + # Debug counters / telemetry + self.task_seq = 0 + self.dropped_tasks = 0 + self.processed_tasks = 0 + self.last_invoke_secs = None + + # Validation thresholds + self.CONFIDENCE_THRESHOLD = 0.10 + self.MIN_VALUE = 5 + self.MAX_VALUE = 100 - # Load Model self.load_model() def load_model(self): try: - logger.info(f"Loading TFLite model from: {Config.MODEL_PATH}") - self.interpreter = tflite.Interpreter(model_path=Config.MODEL_PATH) + model_path = _cfg("MODEL_PATH", "MODELPATH", default=None) + logger.info("Loading TFLite model from: %s", model_path) + + self.interpreter = tflite.Interpreter(model_path=model_path) self.interpreter.allocate_tensors() + self.input_details = self.interpreter.get_input_details() self.output_details = self.interpreter.get_output_details() - # Store original input shape for resizing logic self.original_input_shape = self.input_details[0]['shape'] - logger.info(f"Model loaded. Default input shape: {self.original_input_shape}") + logger.info("Model loaded. Default input shape: %s", self.original_input_shape) + except Exception as e: - logger.critical(f"Failed to load TFLite model: {e}") + logger.critical("Failed to load TFLite model: %s", e) self.interpreter = None def start(self): - if self.running: return + if self.running: + return self.running = True threading.Thread(target=self._worker_loop, daemon=True).start() logger.info("Inference worker started.") - def add_task(self, camera_id, rois, frame): + def add_task(self, camera_id, rois, frame, frame_std=None): """Add task (non-blocking).""" - if not self.interpreter: return + if not self.interpreter: + return + + self.task_seq += 1 + task = { + 'camera_id': camera_id, + 'rois': rois, + 'frame': frame, + 'timestamp': time.time(), + 'task_id': self.task_seq, + 'frame_std': frame_std, + } + try: - task = { - 'camera_id': camera_id, - 'rois': rois, - 'frame': frame, - 'timestamp': time.time() - } self.input_queue.put(task, block=False) except queue.Full: - pass + self.dropped_tasks += 1 + logger.warning( + "add_task drop cam=%s qsize=%d dropped=%d", + camera_id, + self.input_queue.qsize(), + self.dropped_tasks, + ) def get_result(self): try: @@ -68,6 +100,14 @@ class InferenceWorker: except queue.Empty: return None + def _put_result(self, d): + """Best-effort put so failures never go silent.""" + try: + self.result_queue.put(d, block=False) + except Exception: + # Should be extremely rare; log + drop + logger.exception("Failed to enqueue result") + def _worker_loop(self): while self.running: try: @@ -78,86 +118,122 @@ class InferenceWorker: cam_id = task['camera_id'] rois = task['rois'] frame = task['frame'] + task_id = task.get('task_id') + task_ts = task.get('timestamp') try: - # 1. Crop all ROIs + age_s = (time.time() - task_ts) if task_ts else None + logger.info( + "Worker got task cam=%s task_id=%s age_s=%s frame_std=%s rois=%d in_q=%d", + cam_id, + task_id, + (f"{age_s:.3f}" if age_s is not None else "n/a"), + task.get('frame_std'), + len(rois) if rois else 0, + self.input_queue.qsize(), + ) + + t0 = time.time() crops = self._crop_rois(frame, rois) + t_crop = time.time() + if not crops: - # Report failure to queue so main loop knows we tried - self.result_queue.put({ + self._put_result({ 'type': 'error', 'camera_id': cam_id, - 'message': 'No ROIs cropped' + 'message': 'No ROIs cropped', + 'task_id': task_id, + 'task_ts': task_ts, + 'timing_s': {'crop': t_crop - t0, 'total': t_crop - t0}, }) continue - # 2. Batch Predict predictions = self.predict_batch(crops) + t_pred = time.time() - # 3. Validation Logic valid_digits_str = [] confidences = [] - - all_confident = True low_conf_details = [] for i, p in enumerate(predictions): if p['confidence'] < self.CONFIDENCE_THRESHOLD: - low_conf_details.append(f"Digit {i} conf {p['confidence']:.2f} < {self.CONFIDENCE_THRESHOLD}") - all_confident = False + low_conf_details.append( + f"Digit {i} conf {p['confidence']:.2f} < {self.CONFIDENCE_THRESHOLD}" + ) valid_digits_str.append(p['digit']) confidences.append(p['confidence']) - if not all_confident: - # Send failure result - self.result_queue.put({ + if low_conf_details: + self._put_result({ 'type': 'error', 'camera_id': cam_id, 'message': f"Low confidence: {', '.join(low_conf_details)}", - 'digits': valid_digits_str + 'digits': valid_digits_str, + 'task_id': task_id, + 'task_ts': task_ts, + 'timing_s': {'crop': t_crop - t0, 'predict': t_pred - t_crop, 'total': t_pred - t0}, }) continue if not valid_digits_str: - continue - - # Parse number - try: - final_number_str = "".join(valid_digits_str) - final_number = int(final_number_str) - - # Check Range - if self.MIN_VALUE <= final_number <= self.MAX_VALUE: - avg_conf = float(np.mean(confidences)) - self.result_queue.put({ - 'type': 'success', - 'camera_id': cam_id, - 'value': final_number, - 'digits': valid_digits_str, - 'confidence': avg_conf - }) - else: - # Send range error result - self.result_queue.put({ - 'type': 'error', - 'camera_id': cam_id, - 'message': f"Value {final_number} out of range ({self.MIN_VALUE}-{self.MAX_VALUE})", - 'value': final_number - }) - - except ValueError: - self.result_queue.put({ + self._put_result({ 'type': 'error', 'camera_id': cam_id, - 'message': f"Parse error: {valid_digits_str}" + 'message': 'No digits produced', + 'task_id': task_id, + 'task_ts': task_ts, + 'timing_s': {'crop': t_crop - t0, 'predict': t_pred - t_crop, 'total': t_pred - t0}, + }) + continue + + final_number_str = "".join(valid_digits_str) + + try: + final_number = int(final_number_str) + except ValueError: + self._put_result({ + 'type': 'error', + 'camera_id': cam_id, + 'message': f"Parse error: {valid_digits_str}", + 'task_id': task_id, + 'task_ts': task_ts, + 'timing_s': {'crop': t_crop - t0, 'predict': t_pred - t_crop, 'total': t_pred - t0}, + }) + continue + + if self.MIN_VALUE <= final_number <= self.MAX_VALUE: + avg_conf = float(np.mean(confidences)) if confidences else None + self._put_result({ + 'type': 'success', + 'camera_id': cam_id, + 'value': final_number, + 'digits': valid_digits_str, + 'confidence': avg_conf, + 'task_id': task_id, + 'task_ts': task_ts, + 'timing_s': {'crop': t_crop - t0, 'predict': t_pred - t_crop, 'total': t_pred - t0}, + }) + else: + self._put_result({ + 'type': 'error', + 'camera_id': cam_id, + 'message': f"Value {final_number} out of range ({self.MIN_VALUE}-{self.MAX_VALUE})", + 'value': final_number, + 'task_id': task_id, + 'task_ts': task_ts, + 'timing_s': {'crop': t_crop - t0, 'predict': t_pred - t_crop, 'total': t_pred - t0}, }) - except Exception as e: - logger.error(f"Inference error for {cam_id}: {e}") - self.result_queue.put({ + self.processed_tasks += 1 + + except Exception: + logger.exception("Inference error cam=%s task_id=%s", cam_id, task_id) + self._put_result({ 'type': 'error', 'camera_id': cam_id, - 'message': str(e) + 'message': 'Exception during inference; see logs', + 'task_id': task_id, + 'task_ts': task_ts, }) def _crop_rois(self, image, roi_list): @@ -165,7 +241,7 @@ class InferenceWorker: for roi in roi_list: try: x, y, w, h = roi['x'], roi['y'], roi['width'], roi['height'] - cropped = image[y:y+h, x:x+w] + cropped = image[y:y + h, x:x + w] if cropped.size > 0: cropped_images.append(cropped) except Exception: @@ -173,12 +249,17 @@ class InferenceWorker: return cropped_images def predict_batch(self, images): - """Run inference on a batch of images at once. Returns list of dicts: {'digit': str, 'confidence': float}""" + """Run inference on a batch of images. + + Returns list of dicts: {'digit': str, 'confidence': float} + """ with self.lock: - if not self.interpreter: return [] + if not self.interpreter: + return [] num_images = len(images) - if num_images == 0: return [] + if num_images == 0: + return [] input_index = self.input_details[0]['index'] output_index = self.output_details[0]['index'] @@ -194,23 +275,33 @@ class InferenceWorker: input_tensor = np.array(batch_input) + # NOTE: Keeping original behavior (resize+allocate) but timing it. self.interpreter.resize_tensor_input(input_index, [num_images, target_h, target_w, 3]) self.interpreter.allocate_tensors() + self.interpreter.set_tensor(input_index, input_tensor) + + t0 = time.time() self.interpreter.invoke() + self.last_invoke_secs = time.time() - t0 + if self.last_invoke_secs > 1.0: + logger.warning("Slow invoke: %.3fs (batch=%d)", self.last_invoke_secs, num_images) output_data = self.interpreter.get_tensor(output_index) results = [] for i in range(num_images): logits = output_data[i] - probs = np.exp(logits) / np.sum(np.exp(logits)) - digit_class = np.argmax(probs) - confidence = probs[digit_class] - results.append({ - 'digit': str(digit_class), - 'confidence': float(confidence) - }) + # More stable softmax + logits = logits - np.max(logits) + ex = np.exp(logits) + denom = np.sum(ex) + probs = (ex / denom) if denom != 0 else np.zeros_like(ex) + + digit_class = int(np.argmax(probs)) + confidence = float(probs[digit_class]) if probs.size else 0.0 + + results.append({'digit': str(digit_class), 'confidence': confidence}) return results