# AUTOGENERATED! DO NOT EDIT! File to edit: 01_camera_projection.ipynb (unless otherwise specified).

__all__ = ['camera_projection']

# Cell

import numpy as np
import scipy as sp
import cv2
from cv2 import aruco
import apriltag
import time
import yaml

import pytransform3d.rotations as pr
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
import matplotlib as mpl

import os
import sys
import gdown
from zipfile import ZipFile

from scipy.spatial.transform import Rotation as R
from numpy.linalg import inv

# Cell

class camera_projection:
    def __init__(self):
        self.camera_info_path = 'ViperX_apriltags/camera_info.yaml'
        self.img_path = 'ViperX_apriltags/rgb/'
        self.depth_path = 'ViperX_apriltags/depth/'
        self.tag_size = 0.0415
        self.s = 0.5 * self.tag_size

    def read_camera_info(self):
        with open(self.camera_info_path, "r") as stream:
            try:
                camera_data = yaml.safe_load(stream)
            except yaml.YAMLError as exc:
                print(exc)
        self.camera_matrix = np.array(camera_data['camera_matrix']['data'])
        self.camera_matrix = self.camera_matrix.reshape(3, 3)
        self.dist_coeffs = np.array(camera_data['distortion_coefficients']['data'])
        self.dist_coeffs = self.dist_coeffs.reshape(1, 5)
        self.cameraParams_Intrinsic = [self.camera_matrix[0,0], self.camera_matrix[1,1],
                                       self.camera_matrix[0,2], self.camera_matrix[1,2]]
        # cameraParams_Intrinsic = [camera_fx, camera_fy,
        #                           camera_cx, camera_cy]

    def read_images(self, idx):
        self.img_path = self.img_path + str(idx) + '.png'
        self.depth_path = self.depth_path + str(idx) + '.png'
        self.img = cv2.imread(self.img_path)
        self.gray = cv2.cvtColor(self.img, cv2.COLOR_BGR2GRAY)
        self.depth = cv2.imread(self.depth_path, -cv2.IMREAD_ANYDEPTH)
        self.img_dst = cv2.cvtColor(self.img, cv2.COLOR_BGR2RGB)

    def apriltag_detection(self):
        print("[INFO] detecting AprilTags...")
        options = apriltag.DetectorOptions(families="tag36h11")
        detector = apriltag.Detector(options)
        #results = detector.detect(gray)
        self.detection_results, dimg = detector.detect(self.gray, return_image=True)
        print("[INFO] {} total AprilTags detected".format(len(self.detection_results)))

    def solvePnP(self):
        img_pts = self.detection_results[0].corners.reshape(1,4,2)
        obj_pt1 = [-self.s, -self.s, 0.0]
        obj_pt2 = [ self.s, -self.s, 0.0]
        obj_pt3 = [ self.s,  self.s, 0.0]
        obj_pt4 = [-self.s,  self.s, 0.0]
        obj_pts = obj_pt1 + obj_pt2 + obj_pt3 + obj_pt4
        obj_pts = np.array(obj_pts).reshape(4,3)
        # print(img_pts)
        # print(obj_pts)
        _, self.r_vec, self.t_vec = cv2.solvePnP(obj_pts, img_pts, self.camera_matrix,
                                       self.dist_coeffs, flags=cv2.SOLVEPNP_ITERATIVE)
        R_mat, _ = cv2.Rodrigues(self.r_vec)
        T = np.hstack((R_mat, self.t_vec)).reshape(3,4)
        tag_pose = np.vstack((T, [0,0,0,1])).reshape(4,4)
        dist = np.linalg.norm(self.t_vec)
    def draw_point(self, tag_2_inv, base2joint):
        # --------------- project a point ---------------
        tag2joint = np.matmul(tag_2_inv, base2joint)
        obj_pts = np.array([tag2joint[0,3], tag2joint[1,3], tag2joint[2,3]]).reshape(1,3)
        proj_img_pts, jac = cv2.projectPoints(obj_pts, self.r_vec, self.t_vec,
                                              self.camera_matrix, self.dist_coeffs)
        proj_img_pts = np.array(proj_img_pts).reshape(2,1)
        # --------------- draw a point ---------------
        draw_image = cv2.circle(self.img_dst, (int(proj_img_pts[0]), int(proj_img_pts[1])),
                                radius=5, color=(255, 0, 0), thickness=-1)
        return draw_image