import { ModelFamilyStatus as ModelFamilyApi } from '../codegen';
import {
  ModelFamily,
  ModelFamilyStatus,
  ModelFamilyTrainUnit,
} from '../constants/modelFamily';

export const modelFamilies = Object.values(ModelFamily);
export const modelFamilyLabels: Record<ModelFamily, string> = {
  [ModelFamily.Attributer]: 'Attribute prediction',
  [ModelFamily.Classifier]: 'Label class prediction',
  [ModelFamily.Detector]: 'Object detection',
  [ModelFamily.ImageTagger]: 'Image tag prediction',
  [ModelFamily.Segmentor]: 'Instance segmentation',
  [ModelFamily.SemanticSegmentor]: 'Semantic segmentation',
  [ModelFamily.SegmentPredictor]: 'Segment prediction',
  [ModelFamily.KeypointDetector]: 'Keypoints detection',
};

export const InferenceFamilies = [
  'CLASSIFIER',
  'ATTRIBUTER',
  'DETECTOR',
  'SEGMENTOR',
  'SEMANTIC_SEGMENTOR',
  'TAGGER',
] as const;

export const UpcomingInferenceFamilies = [
  'CLASSIFIER',
  'ATTRIBUTER',
  'TAGGER',
] as const;

export type InferenceFamily = typeof InferenceFamilies[number];

export const inferenceModelFamilyLabels: Record<InferenceFamily, string> = {
  ATTRIBUTER: 'Attribute prediction',
  CLASSIFIER: 'Classification',
  DETECTOR: 'Object detection',
  SEGMENTOR: 'Instance segmentation',
  SEMANTIC_SEGMENTOR: 'Semantic segmentation',
  TAGGER: 'Tagger',
};

export const modelFamilyToTrainUnit = (
  family: ModelFamily,
): ModelFamilyTrainUnit => {
  switch (family) {
    case ModelFamily.Classifier:
    case ModelFamily.Attributer:
      return ModelFamilyTrainUnit.Label;
    default:
      return ModelFamilyTrainUnit.Image;
  }
};

export const modelFamilyStatusColorsMui: Record<ModelFamilyStatus, string> = {
  [ModelFamilyStatus.Activated]: 'primary.light',
  [ModelFamilyStatus.NotActivated]: 'secondary.light',
  [ModelFamilyStatus.Failed]: 'error.main',
  [ModelFamilyStatus.Done]: 'error.main',
  [ModelFamilyStatus.Training]: 'warning.main',
};

// Duplicated until we move all to ReactQuery
export const modelFamilyApiStatusColorsMui: Record<
  ModelFamilyApi.status,
  string
> = {
  [ModelFamilyApi.status.ACTIVATED]: 'primary.light',
  [ModelFamilyApi.status.NOT_ACTIVATED]: 'secondary.light',
  [ModelFamilyApi.status.FAILED]: 'error.main',
  [ModelFamilyApi.status.DONE]: 'error.main',
  [ModelFamilyApi.status.TRAINING]: 'warning.main',
};

export const modelFamilyStatusProgressColors: Record<
  ModelFamilyStatus,
  string
> = {
  [ModelFamilyStatus.Activated]: 'primary.light',
  [ModelFamilyStatus.NotActivated]: 'success.main',
  [ModelFamilyStatus.Failed]: 'error.main',
  [ModelFamilyStatus.Done]: 'error.main',
  [ModelFamilyStatus.Training]: 'warning.main',
};
