This commit is contained in:
Bora 2026-01-01 11:20:05 +01:00
parent 78487918e4
commit 8d1b45ce73
3 changed files with 323 additions and 148 deletions

218
app.py
View File

@ -1,29 +1,37 @@
import base64
import json
import logging
import sys
import os
import threading
import json
import time
import traceback
import base64
import cv2
import numpy as np
import paho.mqtt.client as mqtt
from flask import Flask, render_template, jsonify, request, Response
from flask import Flask, Response, jsonify, render_template, request
# Import Config, Manager, and NEW Inference Worker
from config import Config
from manager import CameraManager
from inference import InferenceWorker
from manager import CameraManager
# --- Logging Setup ---
def _cfg(*names, default=None):
"""Return first matching attribute from Config, else default."""
for n in names:
if hasattr(Config, n):
return getattr(Config, n)
return default
LOG_LEVEL = _cfg("LOG_LEVEL", "LOGLEVEL", default=logging.INFO)
logging.basicConfig(
level=Config.LOG_LEVEL,
level=LOG_LEVEL,
format='%(asctime)s [%(levelname)s] %(message)s',
handlers=[logging.StreamHandler(sys.stdout)]
handlers=[logging.StreamHandler(sys.stdout)],
)
logger = logging.getLogger(__name__)
app = Flask(__name__)
# --- Initialize Components ---
@ -34,15 +42,22 @@ inference_worker.start()
# --- MQTT Setup ---
mqtt_client = mqtt.Client()
if Config.MQTT_USERNAME and Config.MQTT_PASSWORD:
mqtt_client.username_pw_set(Config.MQTT_USERNAME, Config.MQTT_PASSWORD)
MQTT_USERNAME = _cfg("MQTT_USERNAME", "MQTTUSERNAME", default=None)
MQTT_PASSWORD = _cfg("MQTT_PASSWORD", "MQTTPASSWORD", default=None)
MQTT_BROKER = _cfg("MQTT_BROKER", "MQTTBROKER", default="127.0.0.1")
MQTT_PORT = int(_cfg("MQTT_PORT", "MQTTPORT", default=1883))
MQTT_TOPIC = _cfg("MQTT_TOPIC", "MQTTTOPIC", default="homeassistant/sensor/RTSPCamDigitDetection/state")
if MQTT_USERNAME and MQTT_PASSWORD:
mqtt_client.username_pw_set(MQTT_USERNAME, MQTT_PASSWORD)
try:
mqtt_client.connect(Config.MQTT_BROKER, Config.MQTT_PORT, 60)
mqtt_client.connect(MQTT_BROKER, MQTT_PORT, 60)
mqtt_client.loop_start()
logger.info(f"Connected to MQTT Broker at {Config.MQTT_BROKER}:{Config.MQTT_PORT}")
logger.info("Connected to MQTT Broker at %s:%s", MQTT_BROKER, MQTT_PORT)
except Exception as e:
logger.error(f"Failed to connect to MQTT Broker: {e}")
logger.error("Failed to connect to MQTT Broker: %s", e)
# --- Helper Functions (UI Only) ---
def crop_image_for_ui(image, roi_list, scaleX, scaleY):
@ -61,14 +76,17 @@ def crop_image_for_ui(image, roi_list, scaleX, scaleY):
pass
return cropped_images
def publish_detected_number(camera_id, detected_number, confidence=None):
"""Publish result to MQTT with optional confidence score."""
topic = f"{Config.MQTT_TOPIC}/{camera_id}"
topic = f"{MQTT_TOPIC}/{camera_id}"
payload_dict = {"value": detected_number}
if confidence is not None:
payload_dict["confidence"] = round(confidence, 2)
payload_dict["confidence"] = round(float(confidence), 2)
payload = json.dumps(payload_dict)
try:
mqtt_client.publish(topic, payload)
log_msg = f"Published to {topic}: {detected_number}"
@ -76,40 +94,77 @@ def publish_detected_number(camera_id, detected_number, confidence=None):
log_msg += f" (Conf: {confidence:.2f})"
logger.info(log_msg)
except Exception as e:
logger.error(f"MQTT Publish failed: {e}")
logger.error("MQTT Publish failed: %s", e)
# --- Debug helpers ---
_last_log = {}
def log_rl(level, key, msg, every_s=10):
now = time.time()
last = _last_log.get(key, 0.0)
if now - last >= every_s:
_last_log[key] = now
logger.log(level, msg)
# --- Main Processing Loop (Refactored) ---
last_processed_time = {}
def process_all_cameras():
"""
Revised Loop with Rate Limiting
"""
DETECTION_INTERVAL = 10 # Configurable interval (seconds)
"""Revised loop with rate limiting + debug instrumentation."""
DETECTION_INTERVAL = int(_cfg("DETECTION_INTERVAL", default=10))
hb_last = 0.0
while True:
try:
# Heartbeat (proves loop is alive even when no publishes happen)
now = time.time()
if now - hb_last >= 5.0:
hb_last = now
in_q = getattr(inference_worker, "input_queue", None)
out_q = getattr(inference_worker, "result_queue", None)
logger.info(
"HB mainloop alive; in_q=%s out_q=%s dropped=%s processed=%s last_invoke_s=%s",
(in_q.qsize() if in_q else "n/a"),
(out_q.qsize() if out_q else "n/a"),
getattr(inference_worker, "dropped_tasks", "n/a"),
getattr(inference_worker, "processed_tasks", "n/a"),
getattr(inference_worker, "last_invoke_secs", "n/a"),
)
# --- Part 1: Process Results ---
while True:
result = inference_worker.get_result()
if not result:
break
cam_id = result['camera_id']
cam_id = result.get('camera_id')
# End-to-end latency tracing
task_ts = result.get("task_ts")
if task_ts is not None:
try:
age = time.time() - float(task_ts)
logger.info(
"Result cam=%s type=%s task_id=%s age_s=%.3f timing=%s",
cam_id,
result.get("type"),
result.get("task_id"),
age,
result.get("timing_s"),
)
except Exception:
pass
# Check Result Type
if result.get('type') == 'success':
val = result['value']
conf = result.get('confidence')
# Update State & Publish
camera_manager.results[cam_id] = val
publish_detected_number(cam_id, val, conf)
elif result.get('type') == 'error':
# Log the error (Range or Confidence or Parse)
# This ensures the log appears exactly when the result is processed
msg = result.get('message', 'Unknown error')
logger.warning(f"[{cam_id}] Detection skipped: {msg}")
logger.warning("[%s] Detection skipped: %s", cam_id, msg)
# --- Part 2: Feed Frames ---
camera_manager.load_roi_config()
@ -118,54 +173,73 @@ def process_all_cameras():
if not camera_data.get("active", True):
continue
# RATE LIMIT CHECK
current_time = time.time()
last_time = last_processed_time.get(camera_id, 0)
last_time = last_processed_time.get(camera_id, 0.0)
if current_time - last_time < DETECTION_INTERVAL:
continue # Skip this camera, it's too soon
log_rl(
logging.DEBUG,
f"{camera_id}:rate",
f"[{camera_id}] skip: rate limit ({current_time - last_time:.2f}s<{DETECTION_INTERVAL}s)",
every_s=30,
)
continue
stream = camera_data.get("stream")
if not stream: continue
if not stream:
log_rl(logging.WARNING, f"{camera_id}:nostream", f"[{camera_id}] skip: no stream", every_s=10)
continue
# Warmup Check
if (current_time - stream.start_time) < 5:
# Warmup check
start_time = getattr(stream, "start_time", getattr(stream, "starttime", None))
if start_time is not None and (current_time - start_time) < 5:
log_rl(logging.DEBUG, f"{camera_id}:warmup", f"[{camera_id}] skip: warmup", every_s=10)
continue
frame = stream.read()
if frame is None:
log_rl(logging.WARNING, f"{camera_id}:noframe", f"[{camera_id}] skip: frame is None", every_s=5)
continue
if np.std(frame) < 10:
frame_std = float(np.std(frame))
if frame_std < 5:
log_rl(
logging.INFO,
f"{camera_id}:lowstd",
f"[{camera_id}] skip: low frame std={frame_std:.2f} (<10) (disturbed/blank/frozen?)",
every_s=5,
)
mqtt_client.publish(f"{Config.MQTT_TOPIC}/{camera_id}/status", "disturbed")
continue
roi_list = camera_manager.rois.get(camera_id, [])
if not roi_list:
log_rl(logging.WARNING, f"{camera_id}:norois", f"[{camera_id}] skip: no ROIs", every_s=30)
continue
# SEND TO WORKER
inference_worker.add_task(camera_id, roi_list, frame)
# Update last processed time
inference_worker.add_task(camera_id, roi_list, frame, frame_std=frame_std)
last_processed_time[camera_id] = current_time
# Sleep briefly to prevent CPU spinning, but keep it responsive for results
time.sleep(0.1)
except Exception as e:
logger.error(f"Global process loop error: {e}")
logger.error("Global process loop error: %s", e)
traceback.print_exc()
time.sleep(5)
# --- Flask Routes ---
@app.route('/')
def index():
return render_template('index.html')
@app.route('/cameras', methods=['GET'])
def get_cameras():
return jsonify(camera_manager.get_camera_list())
@app.route('/video/<camera_id>')
def video_feed(camera_id):
def generate():
@ -174,11 +248,16 @@ def video_feed(camera_id):
if frame is not None:
ret, jpeg = cv2.imencode('.jpg', frame)
if ret:
yield (b'--frame\r\n' b'Content-Type: image/jpeg\r\n\r\n' + jpeg.tobytes() + b'\r\n\r\n')
yield (
b'--frame\r\n'
b'Content-Type: image/jpeg\r\n\r\n' + jpeg.tobytes() + b'\r\n\r\n'
)
else:
time.sleep(0.1)
return Response(generate(), mimetype='multipart/x-mixed-replace; boundary=frame')
@app.route('/snapshot/<camera_id>')
def snapshot(camera_id):
frame = camera_manager.get_frame(camera_id)
@ -188,6 +267,7 @@ def snapshot(camera_id):
return Response(jpeg.tobytes(), mimetype='image/jpeg')
return 'No frame available', 404
@app.route('/rois/<camera_id>', methods=['GET'])
def get_rois(camera_id):
try:
@ -218,28 +298,34 @@ def get_rois(camera_id):
"y": int(round(roi["y"] * scaleY)),
"width": int(round(roi["width"] * scaleX)),
"height": int(round(roi["height"] * scaleY)),
"angle": roi["angle"]
"angle": roi.get("angle", 0),
})
return jsonify(scaled_rois)
except Exception as e:
return jsonify({"error": str(e)}), 500
@app.route("/save_rois", methods=["POST"])
def save_rois_api():
data = request.json
camera_id = data.get("camera_id")
new_rois = data.get("rois")
img_width = data.get("img_width")
img_height = data.get("img_height")
if not camera_id or new_rois is None: return jsonify({"success": False})
if not camera_id or new_rois is None:
return jsonify({"success": False})
cam = camera_manager.cameras.get(camera_id)
if not cam: return jsonify({"success": False})
if not cam:
return jsonify({"success": False})
stream = cam.get("stream")
real_w = stream.width if stream and stream.width else cam["width"]
real_h = stream.height if stream and stream.height else cam["height"]
real_w = stream.width if stream and getattr(stream, "width", None) else cam["width"]
real_h = stream.height if stream and getattr(stream, "height", None) else cam["height"]
scaleX = real_w / img_width if img_width else 1
scaleY = real_h / img_height if img_height else 1
@ -252,21 +338,24 @@ def save_rois_api():
"y": int(round(roi["y"] * scaleY)),
"width": int(round(roi["width"] * scaleX)),
"height": int(round(roi["height"] * scaleY)),
"angle": roi["angle"]
"angle": roi.get("angle", 0),
})
camera_manager.rois[camera_id] = scaled_rois
return jsonify(camera_manager.save_roi_config())
@app.route('/crop', methods=['POST'])
def crop():
data = request.json
camera_id = data.get('camera_id')
scaleX = data.get('scaleX', 1)
scaleY = data.get('scaleY', 1)
frame = camera_manager.get_frame(camera_id)
if frame is None: return jsonify({'error': 'No frame'}), 500
if frame is None:
return jsonify({'error': 'No frame'}), 500
roi_list = camera_manager.rois.get(camera_id, [])
cropped_images = crop_image_for_ui(frame, roi_list, scaleX, scaleY)
@ -279,31 +368,29 @@ def crop():
return jsonify({'cropped_images': cropped_base64_list})
@app.route('/detect_digits', methods=['POST'])
def detect_digits():
"""Manual trigger: Runs inference immediately and returns result with validation."""
data = request.json
camera_id = data.get('camera_id')
if not camera_id:
return jsonify({'error': 'Invalid camera ID'}), 400
# 1. Get Frame
frame = camera_manager.get_frame(camera_id)
if frame is None:
return jsonify({'error': 'Failed to capture image'}), 500
# 2. Get ROIs
roi_list = camera_manager.rois.get(camera_id, [])
if not roi_list:
return jsonify({'error': 'No ROIs defined'}), 400
# 3. Crop
cropped_images = crop_image_for_ui(frame, roi_list, scaleX=1, scaleY=1)
if not cropped_images:
return jsonify({'error': 'Failed to crop ROIs'}), 500
try:
# 4. Run Inference Synchronously
predictions = inference_worker.predict_batch(cropped_images)
valid_digits_str = []
@ -318,30 +405,24 @@ def detect_digits():
if p['confidence'] < CONFIDENCE_THRESHOLD:
msg = f"Digit {i} ('{p['digit']}') rejected: conf {p['confidence']:.2f} < {CONFIDENCE_THRESHOLD}"
rejected_reasons.append(msg)
logger.warning(f"[Manual] {msg}")
logger.warning("[Manual] %s", msg)
else:
valid_digits_str.append(p['digit'])
confidences.append(p['confidence'])
if len(valid_digits_str) != len(predictions):
return jsonify({
'error': 'Low confidence detection',
'details': rejected_reasons,
'raw': predictions
}), 400
return jsonify({'error': 'Low confidence detection', 'details': rejected_reasons, 'raw': predictions}), 400
final_number_str = "".join(valid_digits_str)
try:
final_number = int(final_number_str)
# Range Check
if not (MIN_VALUE <= final_number <= MAX_VALUE):
msg = f"Value {final_number} out of range ({MIN_VALUE}-{MAX_VALUE})"
logger.warning(f"[Manual] {msg}")
logger.warning("[Manual] %s", msg)
return jsonify({'error': 'Value out of range', 'value': final_number}), 400
# Valid result
avg_conf = float(np.mean(confidences))
avg_conf = float(np.mean(confidences)) if confidences else None
publish_detected_number(camera_id, final_number, avg_conf)
camera_manager.results[camera_id] = final_number
@ -349,25 +430,28 @@ def detect_digits():
'detected_digits': valid_digits_str,
'final_number': final_number,
'confidences': confidences,
'avg_confidence': avg_conf
'avg_confidence': avg_conf,
})
except ValueError:
return jsonify({'error': 'Could not parse digits', 'raw': valid_digits_str}), 500
return jsonify({'error': 'Could not parse digits', 'raw': valid_digits_str}), 500
except Exception as e:
logger.error(f"Error during manual detection: {e}")
logger.error("Error during manual detection: %s", e)
return jsonify({'error': str(e)}), 500
@app.route('/update_camera_config', methods=['POST'])
def update_camera_config():
data = request.json
success = camera_manager.update_camera_flip(data.get("camera_id"), data.get("flip_type"))
return jsonify({"success": success})
# --- Main ---
if __name__ == '__main__':
t = threading.Thread(target=process_all_cameras, daemon=True)
t.start()
logger.info("Starting Flask Server...")
app.run(host='0.0.0.0', port=5000, threaded=True)

View File

@ -1,66 +1,98 @@
import threading
import queue
import time
import logging
import queue
import threading
import time
import cv2
import numpy as np
import tflite_runtime.interpreter as tflite
from config import Config
logger = logging.getLogger(__name__)
def _cfg(*names, default=None):
for n in names:
if hasattr(Config, n):
return getattr(Config, n)
return default
class InferenceWorker:
def __init__(self):
self.input_queue = queue.Queue(maxsize=10)
self.result_queue = queue.Queue()
self.running = False
self.interpreter = None
self.input_details = None
self.output_details = None
self.lock = threading.Lock()
# Validation thresholds
self.CONFIDENCE_THRESHOLD = 0.80 # Minimum confidence (0-1) to accept a digit
self.MIN_VALUE = 5 # Minimum allowed temperature value
self.MAX_VALUE = 100 # Maximum allowed temperature value
# Debug counters / telemetry
self.task_seq = 0
self.dropped_tasks = 0
self.processed_tasks = 0
self.last_invoke_secs = None
# Validation thresholds
self.CONFIDENCE_THRESHOLD = 0.10
self.MIN_VALUE = 5
self.MAX_VALUE = 100
# Load Model
self.load_model()
def load_model(self):
try:
logger.info(f"Loading TFLite model from: {Config.MODEL_PATH}")
self.interpreter = tflite.Interpreter(model_path=Config.MODEL_PATH)
model_path = _cfg("MODEL_PATH", "MODELPATH", default=None)
logger.info("Loading TFLite model from: %s", model_path)
self.interpreter = tflite.Interpreter(model_path=model_path)
self.interpreter.allocate_tensors()
self.input_details = self.interpreter.get_input_details()
self.output_details = self.interpreter.get_output_details()
# Store original input shape for resizing logic
self.original_input_shape = self.input_details[0]['shape']
logger.info(f"Model loaded. Default input shape: {self.original_input_shape}")
logger.info("Model loaded. Default input shape: %s", self.original_input_shape)
except Exception as e:
logger.critical(f"Failed to load TFLite model: {e}")
logger.critical("Failed to load TFLite model: %s", e)
self.interpreter = None
def start(self):
if self.running: return
if self.running:
return
self.running = True
threading.Thread(target=self._worker_loop, daemon=True).start()
logger.info("Inference worker started.")
def add_task(self, camera_id, rois, frame):
def add_task(self, camera_id, rois, frame, frame_std=None):
"""Add task (non-blocking)."""
if not self.interpreter: return
if not self.interpreter:
return
self.task_seq += 1
task = {
'camera_id': camera_id,
'rois': rois,
'frame': frame,
'timestamp': time.time(),
'task_id': self.task_seq,
'frame_std': frame_std,
}
try:
task = {
'camera_id': camera_id,
'rois': rois,
'frame': frame,
'timestamp': time.time()
}
self.input_queue.put(task, block=False)
except queue.Full:
pass
self.dropped_tasks += 1
logger.warning(
"add_task drop cam=%s qsize=%d dropped=%d",
camera_id,
self.input_queue.qsize(),
self.dropped_tasks,
)
def get_result(self):
try:
@ -68,6 +100,14 @@ class InferenceWorker:
except queue.Empty:
return None
def _put_result(self, d):
"""Best-effort put so failures never go silent."""
try:
self.result_queue.put(d, block=False)
except Exception:
# Should be extremely rare; log + drop
logger.exception("Failed to enqueue result")
def _worker_loop(self):
while self.running:
try:
@ -78,86 +118,122 @@ class InferenceWorker:
cam_id = task['camera_id']
rois = task['rois']
frame = task['frame']
task_id = task.get('task_id')
task_ts = task.get('timestamp')
try:
# 1. Crop all ROIs
age_s = (time.time() - task_ts) if task_ts else None
logger.info(
"Worker got task cam=%s task_id=%s age_s=%s frame_std=%s rois=%d in_q=%d",
cam_id,
task_id,
(f"{age_s:.3f}" if age_s is not None else "n/a"),
task.get('frame_std'),
len(rois) if rois else 0,
self.input_queue.qsize(),
)
t0 = time.time()
crops = self._crop_rois(frame, rois)
t_crop = time.time()
if not crops:
# Report failure to queue so main loop knows we tried
self.result_queue.put({
self._put_result({
'type': 'error',
'camera_id': cam_id,
'message': 'No ROIs cropped'
'message': 'No ROIs cropped',
'task_id': task_id,
'task_ts': task_ts,
'timing_s': {'crop': t_crop - t0, 'total': t_crop - t0},
})
continue
# 2. Batch Predict
predictions = self.predict_batch(crops)
t_pred = time.time()
# 3. Validation Logic
valid_digits_str = []
confidences = []
all_confident = True
low_conf_details = []
for i, p in enumerate(predictions):
if p['confidence'] < self.CONFIDENCE_THRESHOLD:
low_conf_details.append(f"Digit {i} conf {p['confidence']:.2f} < {self.CONFIDENCE_THRESHOLD}")
all_confident = False
low_conf_details.append(
f"Digit {i} conf {p['confidence']:.2f} < {self.CONFIDENCE_THRESHOLD}"
)
valid_digits_str.append(p['digit'])
confidences.append(p['confidence'])
if not all_confident:
# Send failure result
self.result_queue.put({
if low_conf_details:
self._put_result({
'type': 'error',
'camera_id': cam_id,
'message': f"Low confidence: {', '.join(low_conf_details)}",
'digits': valid_digits_str
'digits': valid_digits_str,
'task_id': task_id,
'task_ts': task_ts,
'timing_s': {'crop': t_crop - t0, 'predict': t_pred - t_crop, 'total': t_pred - t0},
})
continue
if not valid_digits_str:
continue
# Parse number
try:
final_number_str = "".join(valid_digits_str)
final_number = int(final_number_str)
# Check Range
if self.MIN_VALUE <= final_number <= self.MAX_VALUE:
avg_conf = float(np.mean(confidences))
self.result_queue.put({
'type': 'success',
'camera_id': cam_id,
'value': final_number,
'digits': valid_digits_str,
'confidence': avg_conf
})
else:
# Send range error result
self.result_queue.put({
'type': 'error',
'camera_id': cam_id,
'message': f"Value {final_number} out of range ({self.MIN_VALUE}-{self.MAX_VALUE})",
'value': final_number
})
except ValueError:
self.result_queue.put({
self._put_result({
'type': 'error',
'camera_id': cam_id,
'message': f"Parse error: {valid_digits_str}"
'message': 'No digits produced',
'task_id': task_id,
'task_ts': task_ts,
'timing_s': {'crop': t_crop - t0, 'predict': t_pred - t_crop, 'total': t_pred - t0},
})
continue
final_number_str = "".join(valid_digits_str)
try:
final_number = int(final_number_str)
except ValueError:
self._put_result({
'type': 'error',
'camera_id': cam_id,
'message': f"Parse error: {valid_digits_str}",
'task_id': task_id,
'task_ts': task_ts,
'timing_s': {'crop': t_crop - t0, 'predict': t_pred - t_crop, 'total': t_pred - t0},
})
continue
if self.MIN_VALUE <= final_number <= self.MAX_VALUE:
avg_conf = float(np.mean(confidences)) if confidences else None
self._put_result({
'type': 'success',
'camera_id': cam_id,
'value': final_number,
'digits': valid_digits_str,
'confidence': avg_conf,
'task_id': task_id,
'task_ts': task_ts,
'timing_s': {'crop': t_crop - t0, 'predict': t_pred - t_crop, 'total': t_pred - t0},
})
else:
self._put_result({
'type': 'error',
'camera_id': cam_id,
'message': f"Value {final_number} out of range ({self.MIN_VALUE}-{self.MAX_VALUE})",
'value': final_number,
'task_id': task_id,
'task_ts': task_ts,
'timing_s': {'crop': t_crop - t0, 'predict': t_pred - t_crop, 'total': t_pred - t0},
})
except Exception as e:
logger.error(f"Inference error for {cam_id}: {e}")
self.result_queue.put({
self.processed_tasks += 1
except Exception:
logger.exception("Inference error cam=%s task_id=%s", cam_id, task_id)
self._put_result({
'type': 'error',
'camera_id': cam_id,
'message': str(e)
'message': 'Exception during inference; see logs',
'task_id': task_id,
'task_ts': task_ts,
})
def _crop_rois(self, image, roi_list):
@ -165,7 +241,7 @@ class InferenceWorker:
for roi in roi_list:
try:
x, y, w, h = roi['x'], roi['y'], roi['width'], roi['height']
cropped = image[y:y+h, x:x+w]
cropped = image[y:y + h, x:x + w]
if cropped.size > 0:
cropped_images.append(cropped)
except Exception:
@ -173,12 +249,17 @@ class InferenceWorker:
return cropped_images
def predict_batch(self, images):
"""Run inference on a batch of images at once. Returns list of dicts: {'digit': str, 'confidence': float}"""
"""Run inference on a batch of images.
Returns list of dicts: {'digit': str, 'confidence': float}
"""
with self.lock:
if not self.interpreter: return []
if not self.interpreter:
return []
num_images = len(images)
if num_images == 0: return []
if num_images == 0:
return []
input_index = self.input_details[0]['index']
output_index = self.output_details[0]['index']
@ -194,23 +275,33 @@ class InferenceWorker:
input_tensor = np.array(batch_input)
# NOTE: Keeping original behavior (resize+allocate) but timing it.
self.interpreter.resize_tensor_input(input_index, [num_images, target_h, target_w, 3])
self.interpreter.allocate_tensors()
self.interpreter.set_tensor(input_index, input_tensor)
t0 = time.time()
self.interpreter.invoke()
self.last_invoke_secs = time.time() - t0
if self.last_invoke_secs > 1.0:
logger.warning("Slow invoke: %.3fs (batch=%d)", self.last_invoke_secs, num_images)
output_data = self.interpreter.get_tensor(output_index)
results = []
for i in range(num_images):
logits = output_data[i]
probs = np.exp(logits) / np.sum(np.exp(logits))
digit_class = np.argmax(probs)
confidence = probs[digit_class]
results.append({
'digit': str(digit_class),
'confidence': float(confidence)
})
# More stable softmax
logits = logits - np.max(logits)
ex = np.exp(logits)
denom = np.sum(ex)
probs = (ex / denom) if denom != 0 else np.zeros_like(ex)
digit_class = int(np.argmax(probs))
confidence = float(probs[digit_class]) if probs.size else 0.0
results.append({'digit': str(digit_class), 'confidence': confidence})
return results