This commit is contained in:
parent
d4ef9c2654
commit
5dae5b86c0
128
app.py
128
app.py
|
|
@ -6,13 +6,11 @@ import json
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
import base64
|
import base64
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import paho.mqtt.client as mqtt
|
import paho.mqtt.client as mqtt
|
||||||
from flask import Flask, render_template, jsonify, request, Response
|
from flask import Flask, render_template, jsonify, request, Response
|
||||||
|
|
||||||
# test
|
|
||||||
# Import Config, Manager, and NEW Inference Worker
|
# Import Config, Manager, and NEW Inference Worker
|
||||||
from config import Config
|
from config import Config
|
||||||
from manager import CameraManager
|
from manager import CameraManager
|
||||||
|
|
@ -24,23 +22,24 @@ logging.basicConfig(
|
||||||
format='%(asctime)s [%(levelname)s] %(message)s',
|
format='%(asctime)s [%(levelname)s] %(message)s',
|
||||||
handlers=[logging.StreamHandler(sys.stdout)]
|
handlers=[logging.StreamHandler(sys.stdout)]
|
||||||
)
|
)
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
app = Flask(__name__)
|
app = Flask(__name__)
|
||||||
|
|
||||||
# --- Initialize Components ---
|
# --- Initialize Components ---
|
||||||
camera_manager = CameraManager()
|
camera_manager = CameraManager()
|
||||||
inference_worker = InferenceWorker() # <--- NEW
|
inference_worker = InferenceWorker()
|
||||||
inference_worker.start() # <--- Start the background thread
|
inference_worker.start()
|
||||||
|
|
||||||
# --- MQTT Setup ---
|
# --- MQTT Setup ---
|
||||||
mqtt_client = mqtt.Client()
|
mqtt_client = mqtt.Client()
|
||||||
|
|
||||||
if Config.MQTT_USERNAME and Config.MQTT_PASSWORD:
|
if Config.MQTT_USERNAME and Config.MQTT_PASSWORD:
|
||||||
mqtt_client.username_pw_set(Config.MQTT_USERNAME, Config.MQTT_PASSWORD)
|
mqtt_client.username_pw_set(Config.MQTT_USERNAME, Config.MQTT_PASSWORD)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
mqtt_client.connect(Config.MQTT_BROKER, Config.MQTT_PORT, 60)
|
mqtt_client.connect(Config.MQTT_BROKER, Config.MQTT_PORT, 60)
|
||||||
mqtt_client.loop_start() # START THE LOOP HERE
|
mqtt_client.loop_start()
|
||||||
logger.info(f"Connected to MQTT Broker at {Config.MQTT_BROKER}:{Config.MQTT_PORT}")
|
logger.info(f"Connected to MQTT Broker at {Config.MQTT_BROKER}:{Config.MQTT_PORT}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to connect to MQTT Broker: {e}")
|
logger.error(f"Failed to connect to MQTT Broker: {e}")
|
||||||
|
|
@ -62,27 +61,31 @@ def crop_image_for_ui(image, roi_list, scaleX, scaleY):
|
||||||
pass
|
pass
|
||||||
return cropped_images
|
return cropped_images
|
||||||
|
|
||||||
def publish_detected_number(camera_id, detected_number):
|
def publish_detected_number(camera_id, detected_number, confidence=None):
|
||||||
"""Publish result to MQTT."""
|
"""Publish result to MQTT with optional confidence score."""
|
||||||
topic = f"{Config.MQTT_TOPIC}/{camera_id}"
|
topic = f"{Config.MQTT_TOPIC}/{camera_id}"
|
||||||
payload = json.dumps({"value": detected_number})
|
payload_dict = {"value": detected_number}
|
||||||
|
if confidence is not None:
|
||||||
|
payload_dict["confidence"] = round(confidence, 2)
|
||||||
|
|
||||||
|
payload = json.dumps(payload_dict)
|
||||||
try:
|
try:
|
||||||
mqtt_client.publish(topic, payload)
|
mqtt_client.publish(topic, payload)
|
||||||
logger.info(f"Published to {topic}: {detected_number}")
|
log_msg = f"Published to {topic}: {detected_number}"
|
||||||
|
if confidence is not None:
|
||||||
|
log_msg += f" (Conf: {confidence:.2f})"
|
||||||
|
logger.info(log_msg)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"MQTT Publish failed: {e}")
|
logger.error(f"MQTT Publish failed: {e}")
|
||||||
|
|
||||||
# --- Main Processing Loop (Refactored) ---
|
# --- Main Processing Loop (Refactored) ---
|
||||||
# Add this global dictionary at the top of app.py (near other globals)
|
|
||||||
last_processed_time = {}
|
last_processed_time = {}
|
||||||
|
|
||||||
# Update process_all_cameras function
|
|
||||||
def process_all_cameras():
|
def process_all_cameras():
|
||||||
"""
|
"""
|
||||||
Revised Loop with Rate Limiting
|
Revised Loop with Rate Limiting
|
||||||
"""
|
"""
|
||||||
# Configurable interval (seconds)
|
DETECTION_INTERVAL = 10 # Configurable interval (seconds)
|
||||||
DETECTION_INTERVAL = 10
|
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
|
|
@ -94,9 +97,11 @@ def process_all_cameras():
|
||||||
|
|
||||||
cam_id = result['camera_id']
|
cam_id = result['camera_id']
|
||||||
val = result['value']
|
val = result['value']
|
||||||
|
conf = result.get('confidence')
|
||||||
|
|
||||||
|
# Result queue now only contains validated (range + confidence checked) values
|
||||||
camera_manager.results[cam_id] = val
|
camera_manager.results[cam_id] = val
|
||||||
publish_detected_number(cam_id, val)
|
publish_detected_number(cam_id, val, conf)
|
||||||
|
|
||||||
# --- Part 2: Feed Frames ---
|
# --- Part 2: Feed Frames ---
|
||||||
camera_manager.load_roi_config()
|
camera_manager.load_roi_config()
|
||||||
|
|
@ -120,7 +125,6 @@ def process_all_cameras():
|
||||||
continue
|
continue
|
||||||
|
|
||||||
frame = stream.read()
|
frame = stream.read()
|
||||||
|
|
||||||
if frame is None:
|
if frame is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
@ -145,9 +149,7 @@ def process_all_cameras():
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
time.sleep(5)
|
time.sleep(5)
|
||||||
|
|
||||||
|
# --- Flask Routes ---
|
||||||
# --- Flask Routes (Unchanged logic, just imports) ---
|
|
||||||
|
|
||||||
@app.route('/')
|
@app.route('/')
|
||||||
def index():
|
def index():
|
||||||
return render_template('index.html')
|
return render_template('index.html')
|
||||||
|
|
@ -180,10 +182,10 @@ def snapshot(camera_id):
|
||||||
|
|
||||||
@app.route('/rois/<camera_id>', methods=['GET'])
|
@app.route('/rois/<camera_id>', methods=['GET'])
|
||||||
def get_rois(camera_id):
|
def get_rois(camera_id):
|
||||||
# ... (Same logic as Step 3, just ensure it uses camera_manager) ...
|
|
||||||
try:
|
try:
|
||||||
camera_manager.load_roi_config()
|
camera_manager.load_roi_config()
|
||||||
all_rois = camera_manager.rois
|
all_rois = camera_manager.rois
|
||||||
|
|
||||||
img_width = request.args.get("img_width", type=float)
|
img_width = request.args.get("img_width", type=float)
|
||||||
img_height = request.args.get("img_height", type=float)
|
img_height = request.args.get("img_height", type=float)
|
||||||
|
|
||||||
|
|
@ -199,6 +201,7 @@ def get_rois(camera_id):
|
||||||
|
|
||||||
scaleX = img_width / real_w
|
scaleX = img_width / real_w
|
||||||
scaleY = img_height / real_h
|
scaleY = img_height / real_h
|
||||||
|
|
||||||
scaled_rois = []
|
scaled_rois = []
|
||||||
for roi in all_rois.get(camera_id, []):
|
for roi in all_rois.get(camera_id, []):
|
||||||
scaled_rois.append({
|
scaled_rois.append({
|
||||||
|
|
@ -215,7 +218,6 @@ def get_rois(camera_id):
|
||||||
|
|
||||||
@app.route("/save_rois", methods=["POST"])
|
@app.route("/save_rois", methods=["POST"])
|
||||||
def save_rois_api():
|
def save_rois_api():
|
||||||
# ... (Same logic as Step 3) ...
|
|
||||||
data = request.json
|
data = request.json
|
||||||
camera_id = data.get("camera_id")
|
camera_id = data.get("camera_id")
|
||||||
new_rois = data.get("rois")
|
new_rois = data.get("rois")
|
||||||
|
|
@ -244,12 +246,12 @@ def save_rois_api():
|
||||||
"height": int(round(roi["height"] * scaleY)),
|
"height": int(round(roi["height"] * scaleY)),
|
||||||
"angle": roi["angle"]
|
"angle": roi["angle"]
|
||||||
})
|
})
|
||||||
|
|
||||||
camera_manager.rois[camera_id] = scaled_rois
|
camera_manager.rois[camera_id] = scaled_rois
|
||||||
return jsonify(camera_manager.save_roi_config())
|
return jsonify(camera_manager.save_roi_config())
|
||||||
|
|
||||||
@app.route('/crop', methods=['POST'])
|
@app.route('/crop', methods=['POST'])
|
||||||
def crop():
|
def crop():
|
||||||
# Helper for UI
|
|
||||||
data = request.json
|
data = request.json
|
||||||
camera_id = data.get('camera_id')
|
camera_id = data.get('camera_id')
|
||||||
scaleX = data.get('scaleX', 1)
|
scaleX = data.get('scaleX', 1)
|
||||||
|
|
@ -259,7 +261,6 @@ def crop():
|
||||||
if frame is None: return jsonify({'error': 'No frame'}), 500
|
if frame is None: return jsonify({'error': 'No frame'}), 500
|
||||||
|
|
||||||
roi_list = camera_manager.rois.get(camera_id, [])
|
roi_list = camera_manager.rois.get(camera_id, [])
|
||||||
# Use the local UI helper function
|
|
||||||
cropped_images = crop_image_for_ui(frame, roi_list, scaleX, scaleY)
|
cropped_images = crop_image_for_ui(frame, roi_list, scaleX, scaleY)
|
||||||
|
|
||||||
cropped_base64_list = []
|
cropped_base64_list = []
|
||||||
|
|
@ -267,14 +268,14 @@ def crop():
|
||||||
ret, buffer = cv2.imencode('.jpg', cropped_img)
|
ret, buffer = cv2.imencode('.jpg', cropped_img)
|
||||||
if ret:
|
if ret:
|
||||||
cropped_base64_list.append(base64.b64encode(buffer).decode('utf-8'))
|
cropped_base64_list.append(base64.b64encode(buffer).decode('utf-8'))
|
||||||
|
|
||||||
return jsonify({'cropped_images': cropped_base64_list})
|
return jsonify({'cropped_images': cropped_base64_list})
|
||||||
|
|
||||||
@app.route('/detect_digits', methods=['POST'])
|
@app.route('/detect_digits', methods=['POST'])
|
||||||
def detect_digits():
|
def detect_digits():
|
||||||
"""Manual trigger: Runs inference immediately and returns result."""
|
"""Manual trigger: Runs inference immediately and returns result with validation."""
|
||||||
data = request.json
|
data = request.json
|
||||||
camera_id = data.get('camera_id')
|
camera_id = data.get('camera_id')
|
||||||
|
|
||||||
if not camera_id:
|
if not camera_id:
|
||||||
return jsonify({'error': 'Invalid camera ID'}), 400
|
return jsonify({'error': 'Invalid camera ID'}), 400
|
||||||
|
|
||||||
|
|
@ -288,49 +289,70 @@ def detect_digits():
|
||||||
if not roi_list:
|
if not roi_list:
|
||||||
return jsonify({'error': 'No ROIs defined'}), 400
|
return jsonify({'error': 'No ROIs defined'}), 400
|
||||||
|
|
||||||
# 3. Crop (Using the UI helper is fine here)
|
# 3. Crop
|
||||||
cropped_images = crop_image_for_ui(frame, roi_list, scaleX=1, scaleY=1)
|
cropped_images = crop_image_for_ui(frame, roi_list, scaleX=1, scaleY=1)
|
||||||
if not cropped_images:
|
if not cropped_images:
|
||||||
return jsonify({'error': 'Failed to crop ROIs'}), 500
|
return jsonify({'error': 'Failed to crop ROIs'}), 500
|
||||||
|
|
||||||
# 4. Run Inference Synchronously
|
|
||||||
# Note: We access the worker directly.
|
|
||||||
# Thread safety: 'predict_batch' uses 'self.interpreter'.
|
|
||||||
# If the background thread is also using it, TFLite might complain or crash.
|
|
||||||
# PROPER FIX: Pause the worker or use a Lock.
|
|
||||||
|
|
||||||
# Since adding a Lock is complex now, a simple hack is to just add it to the queue
|
|
||||||
# and WAIT for the result? No, that's hard to correlate.
|
|
||||||
|
|
||||||
# SAFE APPROACH: Use a Lock in InferenceWorker.
|
|
||||||
# For now, let's assume TFLite is robust enough or race conditions are rare for manual clicks.
|
|
||||||
# CALL THE PUBLIC METHOD:
|
|
||||||
try:
|
try:
|
||||||
detected_digits = inference_worker.predict_batch(cropped_images)
|
# 4. Run Inference Synchronously (using the new method signature)
|
||||||
|
# Returns list of dicts: {'digit': 'X', 'confidence': 0.XX}
|
||||||
|
predictions = inference_worker.predict_batch(cropped_images)
|
||||||
|
|
||||||
valid_digits = [d for d in detected_digits if d.isdigit()]
|
valid_digits_str = []
|
||||||
|
confidences = []
|
||||||
|
rejected_reasons = []
|
||||||
|
|
||||||
if not valid_digits:
|
# 5. Validation Logic (Mirroring _worker_loop logic)
|
||||||
return jsonify({'error': 'No valid digits detected', 'raw': detected_digits}), 500
|
CONFIDENCE_THRESHOLD = inference_worker.CONFIDENCE_THRESHOLD
|
||||||
|
MIN_VALUE = inference_worker.MIN_VALUE
|
||||||
|
MAX_VALUE = inference_worker.MAX_VALUE
|
||||||
|
|
||||||
final_number = int("".join(valid_digits))
|
for i, p in enumerate(predictions):
|
||||||
|
if p['confidence'] < CONFIDENCE_THRESHOLD:
|
||||||
|
msg = f"Digit {i} ('{p['digit']}') rejected: conf {p['confidence']:.2f} < {CONFIDENCE_THRESHOLD}"
|
||||||
|
rejected_reasons.append(msg)
|
||||||
|
logger.warning(f"[Manual] {msg}")
|
||||||
|
else:
|
||||||
|
valid_digits_str.append(p['digit'])
|
||||||
|
confidences.append(p['confidence'])
|
||||||
|
|
||||||
# Publish and Update State
|
if len(valid_digits_str) != len(predictions):
|
||||||
publish_detected_number(camera_id, final_number)
|
return jsonify({
|
||||||
|
'error': 'Low confidence detection',
|
||||||
|
'details': rejected_reasons,
|
||||||
|
'raw': predictions
|
||||||
|
}), 400
|
||||||
|
|
||||||
|
final_number_str = "".join(valid_digits_str)
|
||||||
|
try:
|
||||||
|
final_number = int(final_number_str)
|
||||||
|
|
||||||
|
# Range Check
|
||||||
|
if not (MIN_VALUE <= final_number <= MAX_VALUE):
|
||||||
|
msg = f"Value {final_number} out of range ({MIN_VALUE}-{MAX_VALUE})"
|
||||||
|
logger.warning(f"[Manual] {msg}")
|
||||||
|
return jsonify({'error': 'Value out of range', 'value': final_number}), 400
|
||||||
|
|
||||||
|
# Valid result
|
||||||
|
avg_conf = float(np.mean(confidences))
|
||||||
|
publish_detected_number(camera_id, final_number, avg_conf)
|
||||||
camera_manager.results[camera_id] = final_number
|
camera_manager.results[camera_id] = final_number
|
||||||
|
|
||||||
logger.info(f"Manual detection for {camera_id}: {final_number}")
|
|
||||||
|
|
||||||
return jsonify({
|
return jsonify({
|
||||||
'detected_digits': valid_digits,
|
'detected_digits': valid_digits_str,
|
||||||
'final_number': final_number
|
'final_number': final_number,
|
||||||
|
'confidences': confidences,
|
||||||
|
'avg_confidence': avg_conf
|
||||||
})
|
})
|
||||||
|
|
||||||
|
except ValueError:
|
||||||
|
return jsonify({'error': 'Could not parse digits', 'raw': valid_digits_str}), 500
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error during manual detection: {e}")
|
logger.error(f"Error during manual detection: {e}")
|
||||||
return jsonify({'error': str(e)}), 500
|
return jsonify({'error': str(e)}), 500
|
||||||
|
|
||||||
|
|
||||||
@app.route('/update_camera_config', methods=['POST'])
|
@app.route('/update_camera_config', methods=['POST'])
|
||||||
def update_camera_config():
|
def update_camera_config():
|
||||||
data = request.json
|
data = request.json
|
||||||
|
|
@ -339,13 +361,7 @@ def update_camera_config():
|
||||||
|
|
||||||
# --- Main ---
|
# --- Main ---
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
# Threading:
|
|
||||||
# 1. Video Threads (in Manager)
|
|
||||||
# 2. Inference Thread (in Worker)
|
|
||||||
# 3. Main Loop (process_all_cameras - handles feeding)
|
|
||||||
|
|
||||||
t = threading.Thread(target=process_all_cameras, daemon=True)
|
t = threading.Thread(target=process_all_cameras, daemon=True)
|
||||||
t.start()
|
t.start()
|
||||||
|
|
||||||
logger.info("Starting Flask Server...")
|
logger.info("Starting Flask Server...")
|
||||||
app.run(host='0.0.0.0', port=5000, threaded=True)
|
app.run(host='0.0.0.0', port=5000, threaded=True)
|
||||||
|
|
|
||||||
78
inference.py
78
inference.py
|
|
@ -17,9 +17,13 @@ class InferenceWorker:
|
||||||
self.interpreter = None
|
self.interpreter = None
|
||||||
self.input_details = None
|
self.input_details = None
|
||||||
self.output_details = None
|
self.output_details = None
|
||||||
|
|
||||||
self.lock = threading.Lock()
|
self.lock = threading.Lock()
|
||||||
|
|
||||||
|
# Validation thresholds
|
||||||
|
self.CONFIDENCE_THRESHOLD = 0.80 # Minimum confidence (0-1) to accept a digit
|
||||||
|
self.MIN_VALUE = 5 # Minimum allowed temperature value
|
||||||
|
self.MAX_VALUE = 100 # Maximum allowed temperature value
|
||||||
|
|
||||||
# Load Model
|
# Load Model
|
||||||
self.load_model()
|
self.load_model()
|
||||||
|
|
||||||
|
|
@ -34,7 +38,6 @@ class InferenceWorker:
|
||||||
# Store original input shape for resizing logic
|
# Store original input shape for resizing logic
|
||||||
self.original_input_shape = self.input_details[0]['shape']
|
self.original_input_shape = self.input_details[0]['shape']
|
||||||
logger.info(f"Model loaded. Default input shape: {self.original_input_shape}")
|
logger.info(f"Model loaded. Default input shape: {self.original_input_shape}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.critical(f"Failed to load TFLite model: {e}")
|
logger.critical(f"Failed to load TFLite model: {e}")
|
||||||
self.interpreter = None
|
self.interpreter = None
|
||||||
|
|
@ -81,19 +84,49 @@ class InferenceWorker:
|
||||||
crops = self._crop_rois(frame, rois)
|
crops = self._crop_rois(frame, rois)
|
||||||
if not crops: continue
|
if not crops: continue
|
||||||
|
|
||||||
# 2. Batch Predict (Optimized Step)
|
# 2. Batch Predict (Returns dicts with 'digit' and 'confidence')
|
||||||
digits = self.predict_batch(crops)
|
predictions = self.predict_batch(crops)
|
||||||
|
|
||||||
# 3. Combine
|
# 3. Validation Logic
|
||||||
valid_digits = [d for d in digits if d.isdigit()]
|
valid_digits_str = []
|
||||||
if len(valid_digits) == len(digits) and len(valid_digits) > 0:
|
confidences = []
|
||||||
final_number = int("".join(valid_digits))
|
|
||||||
|
|
||||||
|
# Check individual digit confidence
|
||||||
|
all_confident = True
|
||||||
|
for p in predictions:
|
||||||
|
if p['confidence'] < self.CONFIDENCE_THRESHOLD:
|
||||||
|
logger.warning(f"[{cam_id}] Rejected digit '{p['digit']}' due to low confidence: {p['confidence']:.2f}")
|
||||||
|
all_confident = False
|
||||||
|
break
|
||||||
|
valid_digits_str.append(p['digit'])
|
||||||
|
confidences.append(p['confidence'])
|
||||||
|
|
||||||
|
if not all_confident:
|
||||||
|
continue # Skip this frame entirely if any digit is uncertain
|
||||||
|
|
||||||
|
if not valid_digits_str:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Parse number
|
||||||
|
try:
|
||||||
|
final_number_str = "".join(valid_digits_str)
|
||||||
|
final_number = int(final_number_str)
|
||||||
|
|
||||||
|
# Check Range
|
||||||
|
if self.MIN_VALUE <= final_number <= self.MAX_VALUE:
|
||||||
|
avg_conf = float(np.mean(confidences))
|
||||||
self.result_queue.put({
|
self.result_queue.put({
|
||||||
'camera_id': cam_id,
|
'camera_id': cam_id,
|
||||||
'value': final_number,
|
'value': final_number,
|
||||||
'digits': valid_digits
|
'digits': valid_digits_str,
|
||||||
|
'confidence': avg_conf
|
||||||
})
|
})
|
||||||
|
logger.info(f"[{cam_id}] Valid reading: {final_number} (Avg Conf: {avg_conf:.2f})")
|
||||||
|
else:
|
||||||
|
logger.warning(f"[{cam_id}] Value {final_number} out of range ({self.MIN_VALUE}-{self.MAX_VALUE}). Ignored.")
|
||||||
|
|
||||||
|
except ValueError:
|
||||||
|
logger.warning(f"[{cam_id}] Could not parse digits into integer: {valid_digits_str}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Inference error for {cam_id}: {e}")
|
logger.error(f"Inference error for {cam_id}: {e}")
|
||||||
|
|
@ -111,8 +144,8 @@ class InferenceWorker:
|
||||||
return cropped_images
|
return cropped_images
|
||||||
|
|
||||||
def predict_batch(self, images):
|
def predict_batch(self, images):
|
||||||
"""Run inference on a batch of images at once."""
|
"""Run inference on a batch of images at once. Returns list of dicts: {'digit': str, 'confidence': float}"""
|
||||||
with self.lock: # <--- Add this wrapper
|
with self.lock:
|
||||||
if not self.interpreter: return []
|
if not self.interpreter: return []
|
||||||
|
|
||||||
num_images = len(images)
|
num_images = len(images)
|
||||||
|
|
@ -124,7 +157,6 @@ class InferenceWorker:
|
||||||
# Preprocess all images into a single batch array
|
# Preprocess all images into a single batch array
|
||||||
# Shape: [N, 32, 20, 3] (assuming model expects 32x20 rgb)
|
# Shape: [N, 32, 20, 3] (assuming model expects 32x20 rgb)
|
||||||
batch_input = []
|
batch_input = []
|
||||||
|
|
||||||
target_h, target_w = 32, 20 # Based on your previous code logic
|
target_h, target_w = 32, 20 # Based on your previous code logic
|
||||||
|
|
||||||
for img in images:
|
for img in images:
|
||||||
|
|
@ -146,7 +178,7 @@ class InferenceWorker:
|
||||||
# 1. Resize input tensor
|
# 1. Resize input tensor
|
||||||
self.interpreter.resize_tensor_input(input_index, [num_images, target_h, target_w, 3])
|
self.interpreter.resize_tensor_input(input_index, [num_images, target_h, target_w, 3])
|
||||||
|
|
||||||
# 2. Re-allocate tensors (This is expensive! See note below)
|
# 2. Re-allocate tensors
|
||||||
self.interpreter.allocate_tensors()
|
self.interpreter.allocate_tensors()
|
||||||
|
|
||||||
# 3. Run Inference
|
# 3. Run Inference
|
||||||
|
|
@ -155,11 +187,21 @@ class InferenceWorker:
|
||||||
|
|
||||||
# 4. Get Results
|
# 4. Get Results
|
||||||
output_data = self.interpreter.get_tensor(output_index)
|
output_data = self.interpreter.get_tensor(output_index)
|
||||||
|
# Result shape is [N, 10] (logits or probabilities for 10 digits)
|
||||||
|
|
||||||
# Result shape is [N, 10] (probabilities for 10 digits)
|
results = []
|
||||||
predictions = []
|
|
||||||
for i in range(num_images):
|
for i in range(num_images):
|
||||||
digit_class = np.argmax(output_data[i])
|
# Calculate softmax to get probabilities (if model output is logits)
|
||||||
predictions.append(str(digit_class))
|
# If model output is already softmax, this is redundant but usually harmless if sum is approx 1
|
||||||
|
logits = output_data[i]
|
||||||
|
probs = np.exp(logits) / np.sum(np.exp(logits))
|
||||||
|
|
||||||
return predictions
|
digit_class = np.argmax(probs)
|
||||||
|
confidence = probs[digit_class]
|
||||||
|
|
||||||
|
results.append({
|
||||||
|
'digit': str(digit_class),
|
||||||
|
'confidence': float(confidence)
|
||||||
|
})
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue