import os

import joblib
import pandas as pd
from sklearn.ensemble import RandomForestClassifier

from core.config import settings


class RandomForestRiskService:
    def __init__(self) -> None:
        self.model_path = os.path.join(settings.model_dir, settings.random_forest_model_file)
        self.feature_names = [
            "dispatch_weight",
            "received_weight",
            "weight_delta",
            "expected_duration_minutes",
            "actual_duration_minutes",
            "duration_delta",
            "deviation_distance_km",
            "unexpected_stop_minutes",
            "face_mismatch",
        ]
        self._model: RandomForestClassifier | None = None

    def train(self, df: pd.DataFrame, label_column: str = "is_theft") -> None:
        os.makedirs(settings.model_dir, exist_ok=True)
        train_df = df[self.feature_names + [label_column]].copy()
        train_df[label_column] = train_df[label_column].fillna(0).astype(int)
        x = train_df[self.feature_names]
        y = train_df[label_column]
        model = RandomForestClassifier(
            n_estimators=300,
            max_depth=10,
            random_state=42,
            class_weight="balanced_subsample",
        )
        model.fit(x, y)
        joblib.dump(model, self.model_path)
        self._model = model

    def load(self) -> RandomForestClassifier:
        if self._model is not None:
            return self._model
        self._model = joblib.load(self.model_path)
        return self._model

    def predict(self, feature_row: dict[str, float], threshold: float = 0.6) -> tuple[float, bool, dict[str, float]]:
        model = self.load()
        x = pd.DataFrame([[feature_row[k] for k in self.feature_names]], columns=self.feature_names)
        prob = float(model.predict_proba(x)[0][1])
        blocked = prob >= threshold
        importances = getattr(model, "feature_importances_", [0.0] * len(self.feature_names))
        contributors = {
            k: float(v) for k, v in sorted(zip(self.feature_names, importances), key=lambda i: i[1], reverse=True)[:5]
        }
        return prob, blocked, contributors
