import '@tensorflow/tfjs-backend-wasm';
import * as bodySegmentation from '@tensorflow-models/body-segmentation';
import {initializeTFBackend, fileToHTMLImage, urlToFile, formatError} from './utils';

const defaultConfig = {
  flipHorizontal: false,
  multiSegmentation: false,
  segmentBodyParts: false,
  segmentationThreshold: 0.7,
  internalResolution: 'full',
};

/**
 * Cache the loaded segmentation model globally so we only load once.
 */
let segmentationModel: bodySegmentation.BodySegmenter | null = null;

async function loadModel(): Promise<bodySegmentation.BodySegmenter> {
  if (segmentationModel) {
    return segmentationModel;
  }

  try {
    await initializeTFBackend();

    // Create segmenter with correct config
    segmentationModel = await bodySegmentation.createSegmenter(
      bodySegmentation.SupportedModels.MediaPipeSelfieSegmentation,
      {
        runtime: 'tfjs',
        modelType: 'general',
      }
    );

    return segmentationModel;
  } catch (error) {
    console.error('Detailed error loading segmentation model:', error);
    throw new Error(formatError(error, 'load model'));
  }
}

/**
 * Performs segmentation on the input image.
 */
async function predict(image: HTMLImageElement, config = defaultConfig) {
  const model = await loadModel();
  return model.segmentPeople(image, config);
}

/**
 * Creates a new image with transparent background based on segmentation.
 */
async function createTransparentBackground(
  image: HTMLImageElement,
  segmentation: Awaited<ReturnType<typeof predict>>
): Promise<ImageData> {
  const canvas = document.createElement('canvas');
  canvas.width = image.width;
  canvas.height = image.height;
  const ctx = canvas.getContext('2d');

  if (!ctx) {
    throw new Error('Failed to get canvas context');
  }

  // Draw original image
  ctx.drawImage(image, 0, 0);
  const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height);
  const pixels = imageData.data;

  // Get mask data
  const mask = await segmentation[0].mask.toImageData();
  const maskData = mask.data;

  // Make background transparent
  for (let i = 0; i < maskData.length / 4; i++) {
    const isBackground = maskData[i * 4] === 0;
    if (isBackground) {
      pixels[i * 4 + 3] = 0; // Set alpha to 0
    }
  }

  return imageData;
}

/**
 * Main function to remove background from an image.
 */
export async function removeBackground(file: File): Promise<string> {
  try {
    const image = await fileToHTMLImage(file);
    const segmentation = await predict(image);

    if (!segmentation.length) {
      throw new Error('No person detected in the image');
    }

    const canvas = document.createElement('canvas');
    canvas.width = image.width;
    canvas.height = image.height;
    const ctx = canvas.getContext('2d');

    if (!ctx) {
      throw new Error('Failed to get canvas context');
    }

    const imageData = await createTransparentBackground(image, segmentation);
    ctx.putImageData(imageData, 0, 0);

    return canvas.toDataURL('image/png');
  } catch (error) {
    console.error('Error in removeBackground:', error);
    throw new Error(formatError(error, 'remove background'));
  }
}

export async function getBackgroundRemoved(input: string | File): Promise<string[]> {
  try {
    const file = input instanceof File ? input : await urlToFile(input);
    return Promise.all([removeBackground(file)]);
  } catch (error) {
    console.error('Error in getBackgroundRemoved:', error);
    throw new Error(formatError(error, 'remove background'));
  }
}
