EdgeAI_Digit_Recognition/inference.py

308 lines
11 KiB
Python

import logging
import queue
import threading
import time
import cv2
import numpy as np
import tflite_runtime.interpreter as tflite
from config import Config
logger = logging.getLogger(__name__)
def _cfg(*names, default=None):
for n in names:
if hasattr(Config, n):
return getattr(Config, n)
return default
class InferenceWorker:
def __init__(self):
self.input_queue = queue.Queue(maxsize=10)
self.result_queue = queue.Queue()
self.running = False
self.interpreter = None
self.input_details = None
self.output_details = None
self.lock = threading.Lock()
# Debug counters / telemetry
self.task_seq = 0
self.dropped_tasks = 0
self.processed_tasks = 0
self.last_invoke_secs = None
# Validation thresholds
self.CONFIDENCE_THRESHOLD = 0.10
self.MIN_VALUE = 5
self.MAX_VALUE = 100
self.load_model()
def load_model(self):
try:
model_path = _cfg("MODEL_PATH", "MODELPATH", default=None)
logger.info("Loading TFLite model from: %s", model_path)
self.interpreter = tflite.Interpreter(model_path=model_path)
self.interpreter.allocate_tensors()
self.input_details = self.interpreter.get_input_details()
self.output_details = self.interpreter.get_output_details()
self.original_input_shape = self.input_details[0]['shape']
logger.info("Model loaded. Default input shape: %s", self.original_input_shape)
except Exception as e:
logger.critical("Failed to load TFLite model: %s", e)
self.interpreter = None
def start(self):
if self.running:
return
self.running = True
threading.Thread(target=self._worker_loop, daemon=True).start()
logger.info("Inference worker started.")
def add_task(self, camera_id, rois, frame, frame_std=None):
"""Add task (non-blocking)."""
if not self.interpreter:
return
self.task_seq += 1
task = {
'camera_id': camera_id,
'rois': rois,
'frame': frame,
'timestamp': time.time(),
'task_id': self.task_seq,
'frame_std': frame_std,
}
try:
self.input_queue.put(task, block=False)
except queue.Full:
self.dropped_tasks += 1
logger.warning(
"add_task drop cam=%s qsize=%d dropped=%d",
camera_id,
self.input_queue.qsize(),
self.dropped_tasks,
)
def get_result(self):
try:
return self.result_queue.get(block=False)
except queue.Empty:
return None
def _put_result(self, d):
"""Best-effort put so failures never go silent."""
try:
self.result_queue.put(d, block=False)
except Exception:
# Should be extremely rare; log + drop
logger.exception("Failed to enqueue result")
def _worker_loop(self):
while self.running:
try:
task = self.input_queue.get(timeout=1)
except queue.Empty:
continue
cam_id = task['camera_id']
rois = task['rois']
frame = task['frame']
task_id = task.get('task_id')
task_ts = task.get('timestamp')
try:
age_s = (time.time() - task_ts) if task_ts else None
logger.info(
"Worker got task cam=%s task_id=%s age_s=%s frame_std=%s rois=%d in_q=%d",
cam_id,
task_id,
(f"{age_s:.3f}" if age_s is not None else "n/a"),
task.get('frame_std'),
len(rois) if rois else 0,
self.input_queue.qsize(),
)
t0 = time.time()
crops = self._crop_rois(frame, rois)
t_crop = time.time()
if not crops:
self._put_result({
'type': 'error',
'camera_id': cam_id,
'message': 'No ROIs cropped',
'task_id': task_id,
'task_ts': task_ts,
'timing_s': {'crop': t_crop - t0, 'total': t_crop - t0},
})
continue
predictions = self.predict_batch(crops)
t_pred = time.time()
valid_digits_str = []
confidences = []
low_conf_details = []
for i, p in enumerate(predictions):
if p['confidence'] < self.CONFIDENCE_THRESHOLD:
low_conf_details.append(
f"Digit {i} conf {p['confidence']:.2f} < {self.CONFIDENCE_THRESHOLD}"
)
valid_digits_str.append(p['digit'])
confidences.append(p['confidence'])
if low_conf_details:
self._put_result({
'type': 'error',
'camera_id': cam_id,
'message': f"Low confidence: {', '.join(low_conf_details)}",
'digits': valid_digits_str,
'task_id': task_id,
'task_ts': task_ts,
'timing_s': {'crop': t_crop - t0, 'predict': t_pred - t_crop, 'total': t_pred - t0},
})
continue
if not valid_digits_str:
self._put_result({
'type': 'error',
'camera_id': cam_id,
'message': 'No digits produced',
'task_id': task_id,
'task_ts': task_ts,
'timing_s': {'crop': t_crop - t0, 'predict': t_pred - t_crop, 'total': t_pred - t0},
})
continue
final_number_str = "".join(valid_digits_str)
try:
final_number = int(final_number_str)
except ValueError:
self._put_result({
'type': 'error',
'camera_id': cam_id,
'message': f"Parse error: {valid_digits_str}",
'task_id': task_id,
'task_ts': task_ts,
'timing_s': {'crop': t_crop - t0, 'predict': t_pred - t_crop, 'total': t_pred - t0},
})
continue
if self.MIN_VALUE <= final_number <= self.MAX_VALUE:
avg_conf = float(np.mean(confidences)) if confidences else None
self._put_result({
'type': 'success',
'camera_id': cam_id,
'value': final_number,
'digits': valid_digits_str,
'confidence': avg_conf,
'task_id': task_id,
'task_ts': task_ts,
'timing_s': {'crop': t_crop - t0, 'predict': t_pred - t_crop, 'total': t_pred - t0},
})
else:
self._put_result({
'type': 'error',
'camera_id': cam_id,
'message': f"Value {final_number} out of range ({self.MIN_VALUE}-{self.MAX_VALUE})",
'value': final_number,
'task_id': task_id,
'task_ts': task_ts,
'timing_s': {'crop': t_crop - t0, 'predict': t_pred - t_crop, 'total': t_pred - t0},
})
self.processed_tasks += 1
except Exception:
logger.exception("Inference error cam=%s task_id=%s", cam_id, task_id)
self._put_result({
'type': 'error',
'camera_id': cam_id,
'message': 'Exception during inference; see logs',
'task_id': task_id,
'task_ts': task_ts,
})
def _crop_rois(self, image, roi_list):
cropped_images = []
for roi in roi_list:
try:
x, y, w, h = roi['x'], roi['y'], roi['width'], roi['height']
cropped = image[y:y + h, x:x + w]
if cropped.size > 0:
cropped_images.append(cropped)
except Exception:
pass
return cropped_images
def predict_batch(self, images):
"""Run inference on a batch of images.
Returns list of dicts: {'digit': str, 'confidence': float}
"""
with self.lock:
if not self.interpreter:
return []
num_images = len(images)
if num_images == 0:
return []
input_index = self.input_details[0]['index']
output_index = self.output_details[0]['index']
batch_input = []
target_h, target_w = 32, 20
for img in images:
roi_resized = cv2.resize(img, (target_w, target_h))
roi_rgb = cv2.cvtColor(roi_resized, cv2.COLOR_BGR2RGB)
roi_norm = roi_rgb.astype(np.float32)
batch_input.append(roi_norm)
input_tensor = np.array(batch_input)
# NOTE: Keeping original behavior (resize+allocate) but timing it.
self.interpreter.resize_tensor_input(input_index, [num_images, target_h, target_w, 3])
self.interpreter.allocate_tensors()
self.interpreter.set_tensor(input_index, input_tensor)
t0 = time.time()
self.interpreter.invoke()
self.last_invoke_secs = time.time() - t0
if self.last_invoke_secs > 1.0:
logger.warning("Slow invoke: %.3fs (batch=%d)", self.last_invoke_secs, num_images)
output_data = self.interpreter.get_tensor(output_index)
results = []
for i in range(num_images):
logits = output_data[i]
# More stable softmax
logits = logits - np.max(logits)
ex = np.exp(logits)
denom = np.sum(ex)
probs = (ex / denom) if denom != 0 else np.zeros_like(ex)
digit_class = int(np.argmax(probs))
confidence = float(probs[digit_class]) if probs.size else 0.0
results.append({'digit': str(digit_class), 'confidence': confidence})
return results