import { call, put, select, takeEvery } from 'typed-redux-saga';
import { PayloadAction } from '@reduxjs/toolkit';
import intersection from 'lodash/intersection';

import {
  loadModelHandler,
  modelLoadedHandler,
  updateModelHandler,
} from '../models/models.saga';
import {
  KEYPOINTS_DETECTION_FAMILY_NAME,
  MODEL_MESSAGES,
} from './keypointsDetection.constants';
import { ImageTool, ImageToolTitleMap } from '../tools.constants';
import {
  addLabels,
  addLabel,
  loadDetections,
  loadKeypointsDetectionModel,
  loadKeypointsDetectionModelError,
  loadKeypointsDetectionModelStart,
  loadKeypointsDetectionModelSuccess,
  loadKeypointsFailure,
  loadKeypointsSuccess,
  setInstanceThreshold,
  setKeypointThreshold,
  setMaxDetections,
  updateKeypointsDetectionModel,
  updateKeypointsDetectionModelSuccess,
  confirmAddLabel,
  confirmInstanceThreshold,
  confirmKeypointsThreshold,
  confirmMaxDetections,
  resetData,
} from './keypointsDetection.slice';
import {
  createModelChangePattern,
  ModelChangePayload,
} from '../models/models.constants';
import {
  keypointsDetectionInstanceThresholdSelector,
  keypointsDetectionKeypointThresholdSelector,
  keypointsDetectionMaxDetectionsSelector,
  keypointsDetectionModelIdSelector,
  keypointsDetectionDetectedDataSelector,
  keypointsModelLoadedSelector,
  keypointsDetectionDetectedDataByIdSelector,
} from './keypointsDetection.selectors';
import {
  apiLoadKeypointsDetection,
  apiLoadKeypointsDetectionModel,
} from '../../../../../../api/requests/projectTools';
import { imageViewImageIdSelector } from '../../currentImage/currentImage.selectors';
import { getErrorMessage } from '../../../../../../api/utils';
import { addLabels as commonAddLabels } from '../../labels/labels.slice';
import { uuidv4 } from '../../../../../../util/uuidv4';
import { LabelType } from '../../../../../../api/constants/label';
import { resetActiveTool, setActiveTool } from '../tools.slice';
import {
  KEYPOINTS_DETECTION_INSTANCE_THRESHOLD_MAX_VALUE,
  KEYPOINTS_DETECTION_INSTANCE_THRESHOLD_MIN_VALUE,
  KEYPOINTS_DETECTION_KEYPOINT_THRESHOLD_MAX_VALUE,
  KEYPOINTS_DETECTION_KEYPOINT_THRESHOLD_MIN_VALUE,
  KEYPOINTS_DETECTION_MAX_DETECTIONS,
  KEYPOINTS_DETECTION_MIN_DETECTIONS,
} from '../../../../../../constants/keypointsDetection';
import { activeProjectIdSelector } from '../../../../project/project.selectors';
import { keypointsSchemasSelector } from '../../../../project/annotationTaxonomy/keypointsSchemas/keypointsSchemas.selectors';
import { KCRelationsSelector } from '../../../../project/annotationTaxonomy/KCRelations/KCRelations.selectors';
import { handleError } from '../../../../commonFeatures/errorHandler/errorHandler.actions';
import { MODEL_LOADED, MODEL_UPDATED } from '../../../../ws/ws.constants';

function* loadKeypointsDetectionModelHandler() {
  const projectId = yield* select(activeProjectIdSelector);

  yield* call(loadModelHandler, {
    projectId,
    loadApiCall: apiLoadKeypointsDetectionModel,
    messages: MODEL_MESSAGES,
    toolId: ImageTool.KeypointsDetection,
    modelLoadStartAction: loadKeypointsDetectionModelStart,
    modelLoadSuccessAction: loadKeypointsDetectionModelSuccess,
    modelLoadErrorAction: loadKeypointsDetectionModelError,
    toolName: ImageToolTitleMap[ImageTool.KeypointsDetection],
  });
}

function* keypointsDetectionModelLoadedHandler(
  action: PayloadAction<ModelChangePayload>,
) {
  const { status, progress, id, modelUseIE } = action.payload;

  yield* call(modelLoadedHandler, {
    id,
    modelIdSelector: keypointsDetectionModelIdSelector,
    modelLoadedSelector: keypointsModelLoadedSelector,
    progress,
    status,
    modelUseIE,
    updateAction: updateKeypointsDetectionModel,
  });
}

function* keypointsDetectionModelUpdatedHandler(
  action: PayloadAction<ModelChangePayload>,
) {
  const { status, progress, id, modelUseIE } = action.payload;

  yield* put(
    updateKeypointsDetectionModel({
      status,
      progress,
      id,
      modelUseIE,
    }),
  );
}

function* updateKeypointsDetectionModelHandler(
  action: ActionType<typeof updateKeypointsDetectionModel>,
) {
  const { status, progress, id, modelUseIE } = action.payload;

  yield* call(updateModelHandler, {
    id,
    loadAction: loadKeypointsDetectionModel,
    progress,
    status,
    modelUseIE,
    successAction: updateKeypointsDetectionModelSuccess,
  });
}

function* loadDetectionsHandler() {
  const projectId = yield* select(activeProjectIdSelector);
  const imageId = yield* select(imageViewImageIdSelector);
  const instanceConfidenceThreshold = yield* select(
    keypointsDetectionInstanceThresholdSelector,
  );
  const keypointConfidenceThreshold = yield* select(
    keypointsDetectionKeypointThresholdSelector,
  );
  const maxDetectionsPerImage = yield* select(
    keypointsDetectionMaxDetectionsSelector,
  );
  const modelId = yield* select(keypointsDetectionModelIdSelector);
  const schemas = yield* select(keypointsSchemasSelector);
  const KCRelations = yield* select(KCRelationsSelector);

  if (!imageId || modelId === null) return;

  try {
    const { data } = yield* call(
      apiLoadKeypointsDetection,
      projectId,
      imageId,
      {
        instanceConfidenceThreshold,
        keypointConfidenceThreshold,
        maxDetectionsPerImage,
        modelId,
      },
    );

    const enrichedData = data.map((datum) => {
      const schema = schemas.find(
        (schema) =>
          intersection(
            schema.keypointClasses.map(({ id }) => id),
            datum.keypointClassIds,
          ).length > 0,
      );
      // temp solution
      const classId = KCRelations.find(
        (relation) =>
          relation.keypointSchemaId === schema?.id &&
          relation.labelClassOrder === 0,
      )?.labelClassId;

      return {
        ...datum,
        schemaId: schema?.id,
        classId,
        id: uuidv4(),
      };
    });
    yield* put(loadKeypointsSuccess(enrichedData));
  } catch (error) {
    const errorMessage = getErrorMessage(
      error,
      'Keypoints detection failed to detect objects',
    );

    yield* put(loadKeypointsFailure({ message: errorMessage }));
    yield* put(handleError({ message: errorMessage, error }));
  }
}

function* addLabelsHandler() {
  const data = yield* select(keypointsDetectionDetectedDataSelector);

  if (data.length > 0) {
    yield* put(
      commonAddLabels(
        data.map((datum) => ({
          id: uuidv4(),
          classId: datum.classId,
          keypoints: datum.keypoints.map(([x, y], index) => ({
            id: uuidv4(),
            x,
            y,
            visible: true,
            keypointClassId: datum.keypointClassIds[index],
          })),
          toolUsed: ImageTool.Keypoints,
          type: LabelType.Keypoints,
        })),
      ),
    );

    yield* put(setActiveTool(ImageTool.Default));
  }
}

function* addLabelHandler(action: ActionType<typeof addLabel>) {
  const id = action.payload.detectionId;
  const data = yield* select((state: RootState) =>
    keypointsDetectionDetectedDataByIdSelector(state, id),
  );

  if (!data) return;

  yield* put(
    commonAddLabels([
      {
        id: uuidv4(),
        classId: data.classId,
        keypoints: data.keypoints.map(([x, y], index) => ({
          id: uuidv4(),
          x,
          y,
          visible: true,
          keypointClassId: data.keypointClassIds[index],
        })),
        toolUsed: ImageTool.Keypoints,
        type: LabelType.Keypoints,
      },
    ]),
  );

  yield* put(confirmAddLabel({ detectionId: id }));
}

function* setInstanceThresholdHandler(
  action: ActionType<typeof setInstanceThreshold>,
) {
  let { threshold } = action.payload;

  if (threshold < KEYPOINTS_DETECTION_INSTANCE_THRESHOLD_MIN_VALUE) {
    threshold = KEYPOINTS_DETECTION_INSTANCE_THRESHOLD_MIN_VALUE;
  }
  if (threshold > KEYPOINTS_DETECTION_INSTANCE_THRESHOLD_MAX_VALUE) {
    threshold = KEYPOINTS_DETECTION_INSTANCE_THRESHOLD_MAX_VALUE;
  }

  yield* put(confirmInstanceThreshold({ threshold }));
  yield* put(loadDetections());
}

function* setKeypointThresholdHandler(
  action: ActionType<typeof setKeypointThreshold>,
) {
  let { threshold } = action.payload;

  if (threshold < KEYPOINTS_DETECTION_KEYPOINT_THRESHOLD_MIN_VALUE) {
    threshold = KEYPOINTS_DETECTION_KEYPOINT_THRESHOLD_MIN_VALUE;
  }
  if (threshold > KEYPOINTS_DETECTION_KEYPOINT_THRESHOLD_MAX_VALUE) {
    threshold = KEYPOINTS_DETECTION_KEYPOINT_THRESHOLD_MAX_VALUE;
  }

  yield* put(confirmKeypointsThreshold({ threshold }));
  yield* put(loadDetections());
}

function* setMaxDetectionsHandler(action: ActionType<typeof setMaxDetections>) {
  let { detections } = action.payload;

  if (detections < KEYPOINTS_DETECTION_MIN_DETECTIONS) {
    detections = KEYPOINTS_DETECTION_MIN_DETECTIONS;
  }
  if (detections > KEYPOINTS_DETECTION_MAX_DETECTIONS) {
    detections = KEYPOINTS_DETECTION_MAX_DETECTIONS;
  }

  yield* put(confirmMaxDetections({ detections }));
  yield* put(loadDetections());
}

function* setActiveToolHandler() {
  yield* put(resetData());
}

export function* keypointsDetectionSaga() {
  yield* takeEvery(setInstanceThreshold, setInstanceThresholdHandler);
  yield* takeEvery(setKeypointThreshold, setKeypointThresholdHandler);
  yield* takeEvery(setMaxDetections, setMaxDetectionsHandler);
  yield* takeEvery(
    loadKeypointsDetectionModel,
    loadKeypointsDetectionModelHandler,
  );
  yield* takeEvery(
    createModelChangePattern(KEYPOINTS_DETECTION_FAMILY_NAME, MODEL_LOADED),
    keypointsDetectionModelLoadedHandler,
  );
  yield* takeEvery(
    createModelChangePattern(KEYPOINTS_DETECTION_FAMILY_NAME, MODEL_UPDATED),
    keypointsDetectionModelUpdatedHandler,
  );
  yield* takeEvery(
    updateKeypointsDetectionModel,
    updateKeypointsDetectionModelHandler,
  );
  yield* takeEvery(loadDetections, loadDetectionsHandler);
  yield* takeEvery(addLabels, addLabelsHandler);
  yield* takeEvery(addLabel, addLabelHandler);
  yield* takeEvery([setActiveTool, resetActiveTool], setActiveToolHandler);
}
