327 lines
11 KiB
Python
327 lines
11 KiB
Python
import logging
|
|
import queue
|
|
import threading
|
|
import time
|
|
|
|
import cv2
|
|
import numpy as np
|
|
import tflite_runtime.interpreter as tflite
|
|
|
|
from config import Config
|
|
|
|
# ------------------------------------------------------------------------------
|
|
# 1. USER CONFIGURATION (Edit these values here)
|
|
# ------------------------------------------------------------------------------
|
|
|
|
# Minimum confidence (0-1) to accept a digit.
|
|
# - Higher (0.85-0.90) reduces false positives like "1010" from noise.
|
|
# - Lower (0.70-0.75) helps with weak/dark digits.
|
|
CONFIDENCE_THRESHOLD = 0.1
|
|
|
|
# Minimum and Maximum expected values for the number.
|
|
MIN_VALUE = 5
|
|
MAX_VALUE = 100
|
|
|
|
# ------------------------------------------------------------------------------
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
def _cfg(*names, default=None):
|
|
for n in names:
|
|
if hasattr(Config, n):
|
|
return getattr(Config, n)
|
|
return default
|
|
|
|
|
|
class InferenceWorker:
|
|
def __init__(self, debug_log: bool = False):
|
|
self.debug_log = bool(debug_log)
|
|
|
|
self.input_queue = queue.Queue(maxsize=10)
|
|
self.result_queue = queue.Queue()
|
|
self.running = False
|
|
|
|
self.interpreter = None
|
|
self.input_details = None
|
|
self.output_details = None
|
|
self.lock = threading.Lock()
|
|
|
|
# Debug counters / telemetry
|
|
self.task_seq = 0
|
|
self.dropped_tasks = 0
|
|
self.processed_tasks = 0
|
|
self.last_invoke_secs = None
|
|
|
|
# Set thresholds from top-level variables
|
|
self.CONFIDENCE_THRESHOLD = CONFIDENCE_THRESHOLD
|
|
self.MIN_VALUE = MIN_VALUE
|
|
self.MAX_VALUE = MAX_VALUE
|
|
|
|
self.load_model()
|
|
|
|
def load_model(self):
|
|
try:
|
|
model_path = _cfg("MODEL_PATH", "MODELPATH", default=None)
|
|
logger.info("Loading TFLite model from: %s", model_path)
|
|
|
|
self.interpreter = tflite.Interpreter(model_path=model_path)
|
|
self.interpreter.allocate_tensors()
|
|
|
|
self.input_details = self.interpreter.get_input_details()
|
|
self.output_details = self.interpreter.get_output_details()
|
|
|
|
self.original_input_shape = self.input_details[0]['shape']
|
|
if self.debug_log:
|
|
logger.info("Model loaded. Default input shape: %s", self.original_input_shape)
|
|
|
|
except Exception as e:
|
|
logger.critical("Failed to load TFLite model: %s", e)
|
|
self.interpreter = None
|
|
|
|
def start(self):
|
|
if self.running:
|
|
return
|
|
self.running = True
|
|
threading.Thread(target=self._worker_loop, daemon=True).start()
|
|
logger.info("Inference worker started.")
|
|
|
|
def add_task(self, camera_id, rois, frame, frame_std=None):
|
|
"""Add task (non-blocking)."""
|
|
if not self.interpreter:
|
|
return
|
|
|
|
self.task_seq += 1
|
|
task = {
|
|
'camera_id': camera_id,
|
|
'rois': rois,
|
|
'frame': frame,
|
|
'timestamp': time.time(),
|
|
'task_id': self.task_seq,
|
|
'frame_std': frame_std,
|
|
}
|
|
|
|
try:
|
|
self.input_queue.put(task, block=False)
|
|
except queue.Full:
|
|
self.dropped_tasks += 1
|
|
logger.warning(
|
|
"add_task drop cam=%s qsize=%d dropped=%d",
|
|
camera_id,
|
|
self.input_queue.qsize(),
|
|
self.dropped_tasks,
|
|
)
|
|
|
|
def get_result(self):
|
|
try:
|
|
return self.result_queue.get(block=False)
|
|
except queue.Empty:
|
|
return None
|
|
|
|
def _put_result(self, d):
|
|
try:
|
|
self.result_queue.put(d, block=False)
|
|
except Exception:
|
|
logger.exception("Failed to enqueue result")
|
|
|
|
def _worker_loop(self):
|
|
while self.running:
|
|
try:
|
|
task = self.input_queue.get(timeout=1)
|
|
except queue.Empty:
|
|
continue
|
|
|
|
cam_id = task['camera_id']
|
|
rois = task['rois']
|
|
frame = task['frame']
|
|
task_id = task.get('task_id')
|
|
task_ts = task.get('timestamp')
|
|
|
|
if self.debug_log:
|
|
try:
|
|
age_s = (time.time() - task_ts) if task_ts else None
|
|
logger.info(
|
|
"Worker got task cam=%s task_id=%s age_s=%s frame_std=%s rois=%d in_q=%d",
|
|
cam_id,
|
|
task_id,
|
|
(f"{age_s:.3f}" if age_s is not None else "n/a"),
|
|
task.get('frame_std'),
|
|
len(rois) if rois else 0,
|
|
self.input_queue.qsize(),
|
|
)
|
|
except Exception:
|
|
pass
|
|
|
|
try:
|
|
t0 = time.time()
|
|
crops = self._crop_rois(frame, rois)
|
|
t_crop = time.time()
|
|
|
|
if not crops:
|
|
self._put_result({
|
|
'type': 'error',
|
|
'camera_id': cam_id,
|
|
'message': 'No ROIs cropped',
|
|
'task_id': task_id,
|
|
'task_ts': task_ts,
|
|
'timing_s': {'crop': t_crop - t0, 'total': t_crop - t0},
|
|
})
|
|
continue
|
|
|
|
predictions = self.predict_batch(crops)
|
|
t_pred = time.time()
|
|
|
|
valid_digits_str = []
|
|
confidences = []
|
|
low_conf_details = []
|
|
|
|
for i, p in enumerate(predictions):
|
|
if p['confidence'] < self.CONFIDENCE_THRESHOLD:
|
|
low_conf_details.append(
|
|
f"Digit {i} conf {p['confidence']:.2f} < {self.CONFIDENCE_THRESHOLD}"
|
|
)
|
|
valid_digits_str.append(p['digit'])
|
|
confidences.append(p['confidence'])
|
|
|
|
if low_conf_details:
|
|
self._put_result({
|
|
'type': 'error',
|
|
'camera_id': cam_id,
|
|
'message': f"Low confidence: {', '.join(low_conf_details)}",
|
|
'digits': valid_digits_str,
|
|
'task_id': task_id,
|
|
'task_ts': task_ts,
|
|
'timing_s': {'crop': t_crop - t0, 'predict': t_pred - t_crop, 'total': t_pred - t0},
|
|
})
|
|
continue
|
|
|
|
if not valid_digits_str:
|
|
self._put_result({
|
|
'type': 'error',
|
|
'camera_id': cam_id,
|
|
'message': 'No digits produced',
|
|
'task_id': task_id,
|
|
'task_ts': task_ts,
|
|
'timing_s': {'crop': t_crop - t0, 'predict': t_pred - t_crop, 'total': t_pred - t0},
|
|
})
|
|
continue
|
|
|
|
final_number_str = "".join(valid_digits_str)
|
|
|
|
try:
|
|
final_number = int(final_number_str)
|
|
except ValueError:
|
|
self._put_result({
|
|
'type': 'error',
|
|
'camera_id': cam_id,
|
|
'message': f"Parse error: {valid_digits_str}",
|
|
'task_id': task_id,
|
|
'task_ts': task_ts,
|
|
'timing_s': {'crop': t_crop - t0, 'predict': t_pred - t_crop, 'total': t_pred - t0},
|
|
})
|
|
continue
|
|
|
|
if self.MIN_VALUE <= final_number <= self.MAX_VALUE:
|
|
avg_conf = float(np.mean(confidences)) if confidences else None
|
|
self._put_result({
|
|
'type': 'success',
|
|
'camera_id': cam_id,
|
|
'value': final_number,
|
|
'digits': valid_digits_str,
|
|
'confidence': avg_conf,
|
|
'task_id': task_id,
|
|
'task_ts': task_ts,
|
|
'timing_s': {'crop': t_crop - t0, 'predict': t_pred - t_crop, 'total': t_pred - t0},
|
|
})
|
|
self.processed_tasks += 1
|
|
else:
|
|
self._put_result({
|
|
'type': 'error',
|
|
'camera_id': cam_id,
|
|
'message': f"Value {final_number} out of range ({self.MIN_VALUE}-{self.MAX_VALUE})",
|
|
'value': final_number,
|
|
'task_id': task_id,
|
|
'task_ts': task_ts,
|
|
'timing_s': {'crop': t_crop - t0, 'predict': t_pred - t_crop, 'total': t_pred - t0},
|
|
})
|
|
|
|
except Exception:
|
|
logger.exception("Inference error cam=%s task_id=%s", cam_id, task_id)
|
|
self._put_result({
|
|
'type': 'error',
|
|
'camera_id': cam_id,
|
|
'message': 'Exception during inference; see logs',
|
|
'task_id': task_id,
|
|
'task_ts': task_ts,
|
|
})
|
|
|
|
def _crop_rois(self, image, roi_list):
|
|
cropped_images = []
|
|
for roi in roi_list:
|
|
try:
|
|
x, y, w, h = roi['x'], roi['y'], roi['width'], roi['height']
|
|
cropped = image[y:y + h, x:x + w]
|
|
if cropped.size > 0:
|
|
cropped_images.append(cropped)
|
|
except Exception:
|
|
pass
|
|
return cropped_images
|
|
|
|
def predict_batch(self, images):
|
|
"""Run inference on a batch of images.
|
|
|
|
Returns list of dicts: {'digit': str, 'confidence': float}
|
|
"""
|
|
with self.lock:
|
|
if not self.interpreter:
|
|
return []
|
|
|
|
num_images = len(images)
|
|
if num_images == 0:
|
|
return []
|
|
|
|
input_index = self.input_details[0]['index']
|
|
output_index = self.output_details[0]['index']
|
|
|
|
batch_input = []
|
|
target_h, target_w = 32, 20
|
|
|
|
for img in images:
|
|
roi_resized = cv2.resize(img, (target_w, target_h))
|
|
roi_rgb = cv2.cvtColor(roi_resized, cv2.COLOR_BGR2RGB)
|
|
roi_norm = roi_rgb.astype(np.float32)
|
|
batch_input.append(roi_norm)
|
|
|
|
input_tensor = np.array(batch_input)
|
|
|
|
# Keep current behavior (resize+allocate per batch). Debug timing is optional.
|
|
self.interpreter.resize_tensor_input(input_index, [num_images, target_h, target_w, 3])
|
|
self.interpreter.allocate_tensors()
|
|
|
|
self.interpreter.set_tensor(input_index, input_tensor)
|
|
|
|
t0 = time.time()
|
|
self.interpreter.invoke()
|
|
self.last_invoke_secs = time.time() - t0
|
|
|
|
if self.debug_log and self.last_invoke_secs and self.last_invoke_secs > 1.0:
|
|
logger.warning("Slow invoke: %.3fs (batch=%d)", self.last_invoke_secs, num_images)
|
|
|
|
output_data = self.interpreter.get_tensor(output_index)
|
|
|
|
results = []
|
|
for i in range(num_images):
|
|
logits = output_data[i]
|
|
|
|
# Numerically stable softmax
|
|
logits = logits - np.max(logits)
|
|
ex = np.exp(logits)
|
|
denom = np.sum(ex)
|
|
probs = (ex / denom) if denom != 0 else np.zeros_like(ex)
|
|
|
|
digit_class = int(np.argmax(probs))
|
|
confidence = float(probs[digit_class]) if probs.size else 0.0
|
|
|
|
results.append({'digit': str(digit_class), 'confidence': confidence})
|
|
|
|
return results
|