import invert from 'lodash/invert';

import { ModelFamilyStatus as ModelFamilyStatusApi } from '../codegen';

export enum ModelFamily {
  Attributer = 'attributer',
  Classifier = 'classifier',
  Detector = 'detector',
  ImageTagger = 'image-tagger',
  Segmentor = 'segmentor',
  SemanticSegmentor = 'semantic-segmentor',
  SegmentPredictor = 'activity-recognizer',
  KeypointDetector = 'keypoint-detector',
}

export enum ModelFamilyStatus {
  Activated = 'Activated',
  NotActivated = 'Not activated',
  Failed = 'Failed',
  Training = 'Training',
  Done = 'Done',
}

export enum ModelFamilyTrainUnit {
  Image = 'image',
  Label = 'label',
}

export enum ModelFamilyTargetMetric {
  Loss = 'Loss',
  Accuracy = 'Accuracy',
}

export const ImageModelFamilies = [
  ModelFamilyStatusApi.modelFamily.ATTRIBUTER,
  ModelFamilyStatusApi.modelFamily.CLASSIFIER,
  ModelFamilyStatusApi.modelFamily.DETECTOR,
  ModelFamilyStatusApi.modelFamily.IMAGE_TAGGER,
  ModelFamilyStatusApi.modelFamily.SEGMENTOR,
  ModelFamilyStatusApi.modelFamily.SEMANTIC_SEGMENTOR,
  ModelFamilyStatusApi.modelFamily.KEYPOINT_DETECTOR,
];

export const VideoModelFamilies = [
  ModelFamilyStatusApi.modelFamily.ACTIVITY_RECOGNIZER,
];

// TODO: remove this when switched to React Query
export const modelFamilyApiToModelFamily: Record<
  ModelFamilyStatusApi.modelFamily,
  ModelFamily
> = {
  [ModelFamilyStatusApi.modelFamily.CLASSIFIER]: ModelFamily.Classifier,
  [ModelFamilyStatusApi.modelFamily.SEMANTIC_SEGMENTOR]:
    ModelFamily.SemanticSegmentor,
  [ModelFamilyStatusApi.modelFamily.SEGMENTOR]: ModelFamily.Segmentor,
  [ModelFamilyStatusApi.modelFamily.DETECTOR]: ModelFamily.Detector,
  [ModelFamilyStatusApi.modelFamily.IMAGE_TAGGER]: ModelFamily.ImageTagger,
  [ModelFamilyStatusApi.modelFamily.ATTRIBUTER]: ModelFamily.Attributer,
  [ModelFamilyStatusApi.modelFamily.ACTIVITY_RECOGNIZER]:
    ModelFamily.SegmentPredictor,
  [ModelFamilyStatusApi.modelFamily.KEYPOINT_DETECTOR]:
    ModelFamily.KeypointDetector,
};

export const modelFamilyToApiModelFamily = invert(
  modelFamilyApiToModelFamily,
) as Record<ModelFamily, ModelFamilyStatusApi.modelFamily>;
