"""
Load the pandaset dataset and create a 3LC Table.

Uses the file ./pandaset_scan_summary.json (generated by ./scan_pandaset.py) for global metadata.
"""

from __future__ import annotations

import json
from pathlib import Path

import numpy as np
import pandas as pd
import pandaset
import tlc
import tqdm

with open(Path(__file__).parent / "pandaset_scan_summary.json") as f:
    scan_summary = json.load(f)

# All PandaSet sequences have 80 frames
FRAMES_PER_SEQUENCE = 80

# Global isotropic bounds for PandaSet
PANDASET_BOUNDS = tlc.GeometryHelper.create_isotropic_bounds_3d(
    scan_summary["bounds_world"]["x"]["min"],
    scan_summary["bounds_world"]["x"]["max"],
    scan_summary["bounds_world"]["y"]["min"],
    scan_summary["bounds_world"]["y"]["max"],
    scan_summary["bounds_world"]["z"]["min"],
    scan_summary["bounds_world"]["z"]["max"],
)

# Matrix to ensure the vehicle drives along +X (dataset forward is -Y)
R_ALIGN = np.array(
    [
        [0.0, 1.0, 0.0],  # x' = y
        [1.0, 0.0, 0.0],  # y' = x
        [0.0, 0.0, 1.0],  # z' = z
    ],
    dtype=np.float32,
)

SEMSEG_CLASSES = {int(k): v for k, v in scan_summary["semseg_classes"].items()}

CUBOID_CLASSES = {name: i for i, name in enumerate(scan_summary["unique_cuboid_labels"])}

object_motion_classes = {
    "N/A": -1,
    "Parked": 0,
    "Stopped": 1,
    "Moving": 2,
}

rider_status_classes = {
    "N/A": -1,
    "With Rider": 0,
    "Without Rider": 1,
}

pedestrian_behavior_classes = {
    "N/A": -1,
    "Sitting": 0,
    "Lying": 1,
    "Walking": 2,
    "Standing": 3,
}

pedestrian_age_classes = {
    "N/A": -1,
    "Adult": 0,
    "Child": 1,
}


def get_lidar_schema() -> tlc.Schema:
    schema = tlc.Geometry3DSchema(
        include_3d_vertices=True,
        is_bulk_data=True,
        per_vertex_schemas={
            "intensity": tlc.Float32ListSchema(),
            "distance": tlc.Float32ListSchema(),
            "semseg": tlc.CategoricalLabelListSchema(SEMSEG_CLASSES),
        },
    )

    return schema


def get_bb_schema() -> tlc.Schema:
    schema = tlc.OrientedBoundingBoxes3DSchema(
        classes=CUBOID_CLASSES.keys(),
        # Cuboid attributes are not included for now, can be added if needed
        # per_instance_schemas={
        #     "uuid": tlc.StringListSchema(writable=False),
        #     "stationary": tlc.BoolListSchema(),
        #     "camera_used": tlc.Int32ListSchema(writable=False),
        #     "object_motion": tlc.CategoricalLabelListSchema({v: k for k, v in object_motion_classes.items()}),
        #     "rider_status": tlc.CategoricalLabelListSchema({v: k for k, v in rider_status_classes.items()}),
        #     "pedestrian_behavior": tlc.CategoricalLabelListSchema(
        #         {v: k for k, v in pedestrian_behavior_classes.items()}
        #     ),
        #     "pedestrian_age": tlc.CategoricalLabelListSchema({v: k for k, v in pedestrian_age_classes.items()}),
        #     "sensor_id": tlc.Int32ListSchema(writable=False),
        #     "sibling_id": tlc.StringListSchema(writable=False),
        # },
    )

    return schema


def get_camera_schema(camera_name: str) -> tlc.Schema:
    return tlc.ImageUrlSchema(
        metadata={
            "intrinsics": scan_summary["camera_intrinsics"][camera_name],
            "extrinsics": scan_summary["extrinsics_cam_from_lidar"][camera_name]["lidar_0"]["T_cam_from_lidar"],
        },
        number_role_u=f"{camera_name}_u",
        number_role_v=f"{camera_name}_v",
    )


def load_car(data_path: str) -> tuple[dict, tlc.Schema]:
    car_obj_path = (Path(data_path) / "car/NormalCar2.obj").as_posix()
    scale = 1.25
    transform = np.array(
        [
            [0.0, 0.0, 1.0, 0.0],  # x' = z
            [1.0, 0.0, 0.0, 0.0],  # y' = x
            [0.0, 1.0, 0.0, 0.0],  # z' = y
            [0.0, 0.0, 0.0, 1.0],  # homogeneous row
        ],
        dtype=np.float32,
    )
    car_geometry = tlc.GeometryHelper.load_obj_geometry(car_obj_path, scale, transform, PANDASET_BOUNDS)

    car_schema = tlc.Geometry3DSchema(
        include_3d_vertices=True,
        include_triangles=True,
        per_triangle_schemas={
            "red": tlc.Float32ListSchema(),
            "green": tlc.Float32ListSchema(),
            "blue": tlc.Float32ListSchema(),
        },
        is_bulk_data=True,
    )
    return car_geometry.to_row(), car_schema


def load_pandaset(
    dataset_root: Path,
    max_sequences: int | None = None,
    max_frames: int | None = None,
    table_name: str = "pandaset",
    dataset_name: str = "pandaset",
    project_name: str = "pandaset",
    data_path: str = "./data",
    tlc_project_root: str | None = None,
) -> tlc.Table:
    dataset = pandaset.DataSet(dataset_root)
    car, car_schema = load_car(data_path)

    table_writer = tlc.TableWriter(
        table_name=table_name,
        dataset_name=dataset_name,
        project_name=project_name,
        column_schemas={
            "lidar_0": get_lidar_schema(),
            "lidar_1": get_lidar_schema(),
            "bbs": get_bb_schema(),
            "car": car_schema,
            "front_camera": get_camera_schema("front_camera"),
            "front_right_camera": get_camera_schema("front_right_camera"),
            "right_camera": get_camera_schema("right_camera"),
            "front_left_camera": get_camera_schema("front_left_camera"),
            "left_camera": get_camera_schema("left_camera"),
            "back_camera": get_camera_schema("back_camera"),
        },
        root_url=tlc_project_root,
    )

    sequences = dataset.sequences(with_semseg=True)[:max_sequences]  # 103 sequences total, 76 with semseg
    total_sequences = len(sequences)

    # Progress bar uses sequences as the main unit for stable ETA
    frames_per_seq_effective = FRAMES_PER_SEQUENCE if max_frames is None else min(FRAMES_PER_SEQUENCE, max_frames)
    pbar = tqdm.tqdm(total=total_sequences, desc="Sequences", unit="seq")

    for sequence_idx, sequence_id in enumerate(sequences):
        # Update description per sequence and show loading state
        pbar.set_description(f"Seq {sequence_idx + 1}/{total_sequences} ({sequence_id})")
        pbar.set_postfix_str(f"loading sequence… 0/{frames_per_seq_effective} frames")

        sequence = dataset[sequence_id]
        sequence.load_lidar()
        sequence.load_cuboids()
        sequence.load_semseg()

        # LiDAR 0 (mechanical 360° LiDAR)
        sequence.lidar.set_sensor(0)
        point_cloud_0 = sequence.lidar[:max_frames]
        lidar_poses_0 = sequence.lidar.poses[:max_frames]

        # LiDAR 1 (front-facing long range)
        sequence.lidar.set_sensor(1)
        point_cloud_1 = sequence.lidar[:max_frames]

        # Semantic Segmentation
        semseg_all = sequence.semseg[:max_frames]

        # Cuboids
        cuboids_all = sequence.cuboids[:max_frames]

        # Cameras
        def _numeric_stem(p: Path) -> int:
            return int(p.stem)

        cam = sequence.camera
        back_camera_all = sorted(Path(cam["back_camera"]._directory).glob("*.jpg"), key=_numeric_stem)
        front_camera_all = sorted(Path(cam["front_camera"]._directory).glob("*.jpg"), key=_numeric_stem)
        front_left_camera_all = sorted(Path(cam["front_left_camera"]._directory).glob("*.jpg"), key=_numeric_stem)
        front_right_camera_all = sorted(Path(cam["front_right_camera"]._directory).glob("*.jpg"), key=_numeric_stem)
        left_camera_all = sorted(Path(cam["left_camera"]._directory).glob("*.jpg"), key=_numeric_stem)
        right_camera_all = sorted(Path(cam["right_camera"]._directory).glob("*.jpg"), key=_numeric_stem)

        frame_iter = zip(
            point_cloud_0,
            lidar_poses_0,
            point_cloud_1,
            semseg_all,
            cuboids_all,
            back_camera_all[:max_frames],
            front_camera_all[:max_frames],
            front_left_camera_all[:max_frames],
            front_right_camera_all[:max_frames],
            left_camera_all[:max_frames],
            right_camera_all[:max_frames],
        )
        frames_total = len(point_cloud_0)
        pbar.set_postfix_str(f"frames 0/{frames_total}")

        for frame_id, (
            pc_0,
            lidar_pose,
            pc_1,
            semseg,
            cuboids,
            back_camera_path,
            front_camera_path,
            front_left_camera_path,
            front_right_camera_path,
            left_camera_path,
            right_camera_path,
        ) in enumerate(frame_iter):
            # Create world to ego transform from position and heading
            pose_mat = pandaset.geometry._heading_position_to_mat(lidar_pose["heading"], lidar_pose["position"])
            T_inv = np.linalg.inv(pose_mat)

            R_inv = T_inv[:3, :3]  # 3x3 world to ego rotation matrix
            t_inv = T_inv[:3, 3]  # 3x1 world to ego translation vector

            R_inv = R_ALIGN @ R_inv  # Align ego so the vehicle drives along +X (dataset forward is -Y)
            t_inv = R_ALIGN @ t_inv  # Align ego so the vehicle drives along +X (dataset forward is -Y)

            # Transform LiDAR points from world to ego coordinates
            verts_0 = pc_0.values[:, :3].astype(np.float32, copy=False)
            verts_0 = (R_inv @ verts_0.T + t_inv.reshape(3, 1)).T.astype(np.float32, copy=False)

            verts_1 = pc_1.values[:, :3].astype(np.float32, copy=False)
            verts_1 = (R_inv @ verts_1.T + t_inv.reshape(3, 1)).T.astype(np.float32, copy=False)

            # Extract intensity, distance, and semantic segmentation values (per-vertex)
            intensities_0 = pc_0.values[:, 3].astype(np.float32, copy=False)
            distances_0 = pc_0.values[:, 5].astype(np.float32, copy=False)
            intensities_1 = pc_1.values[:, 3].astype(np.float32, copy=False)
            distances_1 = pc_1.values[:, 5].astype(np.float32, copy=False)
            semseg_values_0 = semseg.values.astype(np.int32, copy=False).reshape(-1)[: len(verts_0)]
            semseg_values_1 = semseg.values.astype(np.int32, copy=False).reshape(-1)[len(verts_0) :]

            # Create a new geometry object to store the transformed LiDAR points
            geometry_0 = tlc.Geometry3DInstances.create_empty(
                *PANDASET_BOUNDS, per_vertex_extras_keys=["intensity", "distance", "semseg"]
            )
            geometry_1 = tlc.Geometry3DInstances.create_empty(
                *PANDASET_BOUNDS, per_vertex_extras_keys=["intensity", "distance", "semseg"]
            )
            geometry_0.add_instance(
                verts_0,
                per_vertex_extras={
                    "intensity": intensities_0,
                    "distance": distances_0,
                    "semseg": semseg_values_0,
                },
            )
            geometry_1.add_instance(
                verts_1,
                per_vertex_extras={
                    "intensity": intensities_1,
                    "distance": distances_1,
                    "semseg": semseg_values_1,
                },
            )

            # Transform cuboids from world to ego coordinates and prepare for table writing
            obbs = transform_cuboids(cuboids, R_inv, t_inv)

            # Add a row to the Table—this redirects bulk data to the correct
            # file on disk and writes references to the Table.
            table_writer.add_row(
                {
                    "sequence_id": sequence_id,
                    "frame_id": frame_id,
                    "lidar_0": geometry_0.to_row(),
                    "lidar_1": geometry_1.to_row(),
                    "bbs": obbs.to_row(),
                    "car": car,
                    "back_camera": back_camera_path.as_posix(),
                    "front_camera": front_camera_path.as_posix(),
                    "front_left_camera": front_left_camera_path.as_posix(),
                    "front_right_camera": front_right_camera_path.as_posix(),
                    "left_camera": left_camera_path.as_posix(),
                    "right_camera": right_camera_path.as_posix(),
                }
            )

            # Update frame progress within the current sequence (does not advance the bar)
            pbar.set_postfix_str(f"frames {frame_id + 1}/{frames_total}")

        # Unload the sequence to free memory
        dataset.unload(sequence_id)

        # Advance the bar by one completed sequence
        pbar.update(1)

    pbar.close()

    table = table_writer.finalize()

    return table


def transform_cuboids(
    cuboids: pd.DataFrame,
    R_inv: np.ndarray,
    t_inv: np.ndarray,
) -> tlc.OBB3DInstances:
    # Vectorized world->ego transform for centers and yaw
    centers_world = np.stack(
        [cuboids["position.x"].values, cuboids["position.y"].values, cuboids["position.z"].values],
        axis=1,
    )
    sizes = np.stack(
        [cuboids["dimensions.x"].values, cuboids["dimensions.y"].values, cuboids["dimensions.z"].values],
        axis=1,
    )
    yaw_world = cuboids["yaw"].values.astype(np.float32, copy=False)

    # Apply world->ego: X_e = R_inv * X_w + t_inv
    centers_ego = centers_world @ R_inv.T + t_inv.reshape(1, 3)

    # Transform yaw angles to ego frame using R_inv
    cos_w = np.cos(yaw_world, dtype=np.float64)
    sin_w = np.sin(yaw_world, dtype=np.float64)
    zeros_w = np.zeros_like(yaw_world, dtype=np.float64)
    dir_world = np.stack([cos_w, sin_w, zeros_w], axis=1)
    dir_ego = dir_world @ R_inv.T
    yaw_ego = np.arctan2(dir_ego[:, 1], dir_ego[:, 0]).astype(np.float32, copy=False)

    obbs = tlc.OBB3DInstances.create_empty(
        x_min=PANDASET_BOUNDS[0],
        x_max=PANDASET_BOUNDS[1],
        y_min=PANDASET_BOUNDS[2],
        y_max=PANDASET_BOUNDS[3],
        z_min=PANDASET_BOUNDS[4],
        z_max=PANDASET_BOUNDS[5],
        # instance_extras_keys=[  # Cuboid attributes are not included for now, can be added if needed
        #     "uuid",
        #     "stationary",
        #     "camera_used",
        #     "object_motion",
        #     "rider_status",
        #     "pedestrian_behavior",
        #     "pedestrian_age",
        #     "sensor_id",
        #     "sibling_id",
        # ],
    )

    # Some of the cuboids do not have these attributes, so we need to handle them gracefully
    rider_statuses = (
        cuboids["attributes.rider_status"].values
        if "attributes.rider_status" in cuboids.columns
        else [-1] * len(cuboids)
    )
    pedestrian_behaviors = (
        cuboids["attributes.pedestrian_behavior"].values
        if "attributes.pedestrian_behavior" in cuboids.columns
        else [-1] * len(cuboids)
    )
    pedestrian_ages = (
        cuboids["attributes.pedestrian_age"].values
        if "attributes.pedestrian_age" in cuboids.columns
        else [-1] * len(cuboids)
    )
    # Pack dictionaries
    for (
        (cx, cy, cz),
        (sx, sy, sz),
        yaw_val,
        label,
        _uuid,
        _stationary,
        _camera_used,
        _object_motion,
        _rider_status,
        _pedestrian_behavior,
        _pedestrian_age,
        _sensor_id,
        _sibling_id,
    ) in zip(
        centers_ego,
        sizes,
        yaw_ego,
        cuboids["label"].values,
        cuboids["uuid"].values,
        cuboids["stationary"].values,
        cuboids["camera_used"].values,
        cuboids["attributes.object_motion"].values,
        rider_statuses,
        pedestrian_behaviors,
        pedestrian_ages,
        cuboids["cuboids.sensor_id"].values,
        cuboids["cuboids.sibling_id"].values,
    ):
        obbs.add_instance(
            obb=np.array([cx, cy, cz, sx, sy, sz, yaw_val, np.nan, np.nan]),
            label=CUBOID_CLASSES.get(label),
            # instance_extras={  # Cuboid attributes are not included for now, can be added if needed
            #     "uuid": str(_uuid),
            #     "stationary": bool(_stationary),
            #     "camera_used": int(_camera_used),
            #     "object_motion": object_motion_classes.get(_object_motion, -1),
            #     "rider_status": rider_status_classes.get(_rider_status, -1),
            #     "pedestrian_behavior": pedestrian_behavior_classes.get(_pedestrian_behavior, -1),
            #     "pedestrian_age": pedestrian_age_classes.get(_pedestrian_age, -1),
            #     "sensor_id": int(_sensor_id),
            #     "sibling_id": str(_sibling_id),
            # },
        )

    return obbs


if __name__ == "__main__":
    TLC_PROJECT_ROOT = "D:/3LC-projects"
    DATASET_ROOT = Path("D:/Data/pandaset")
    DATA_PATH = tlc.Url("<TEST_DATA>/data").to_absolute().to_str()

    table = load_pandaset(
        dataset_root=DATASET_ROOT,
        data_path=DATA_PATH,
        tlc_project_root=TLC_PROJECT_ROOT,
        max_sequences=10,
        max_frames=None,
        table_name="pandaset",
        dataset_name="pandaset",
        project_name="pandaset",
    )
    print(table)
