<template>
  <div>
    <vue-headful title="KneeNet Biomarkers"/>

    <div class="header-image py-4">
      <b-container class="text-light my-5">
        <b-row align-h="center" class="form-inline">
          <b-col cols="12" md="8" class="mb-2 text-center">
            <h2>Our KneeNet tool automatically evaluates knee osteoarthritis severity in X-rays using the Kellgren-Lawrence scoring system.</h2>
            <br/>
            <div style="margin: 0.5em;"><a style="text-decoration: underline" href="/license/">Use of our software implies your acceptance of our license agreement. Please read it here.</a></div>
            <h4>Try it out! Pick an image to analyze</h4>
          </b-col>
        </b-row>

        <b-row align-h="center" >
          <b-col class="mx-sm-3 mb-2" cols="12" sm="10" md="8" lg="6">
            <b-form-file
              v-model="newImages"
              :state="Boolean(newImages)"
              :disabled="!mobilenet"
              placeholder="Choose or drag-and-drop a file..."              
              multiple/>                       
          </b-col>
        </b-row>

        <b-row align-h="center">            
          <b-col class="text-center" style="font-weight: bold;">
            The analysis is run on your computer - no images are sent out!
          </b-col>
        </b-row>  
      </b-container>
    </div>

    <Status :networkStatus="networkStatus" /> 

    <!-- Result block -->
    <b-container>
      <Results
        v-if="(images.length > 0)"
        v-model="selected"
        :images="images"
        :width="imageSize"
        :height="imageSize"/>     
    </b-container>

    <!-- Other blocks -->
  </div>  
</template>

<script>
// neuro stuff
import * as tf from '@tensorflow/tfjs'
import { IMAGENET_CLASSES } from '@/ai/imagenet_classes.js'
const MOBILENET_MODEL_PATH =
      'https://s3.amazonaws.com/health-ai/models/kneenet/model.json';
const IMAGE_SIZE = 299

import tst from '@/components/Toasts/Toasts.js'
import Results from '@/components/Results'
import Status from '@/components/Status'
import store from '@/store.js'
import router from '@/router.js'
import { loadImage, imageToCanvas, splitCanvasInTwoImages } from '@/util/ImageCanvasConvertor'
import { countKneesInCanvas } from '@/util/ImageCalculations'

function toPixels(tensor) {
  const pixels = tensor.dataSync();
  const imageData = new ImageData(tensor.shape[0], tensor.shape[1]);
  if (tensor.shape.length == 2 || tensor.shape[2] == 1) {
    // Grayscale
    for (let i = 0; i < pixels.length; i++) {
      imageData.data[i*4+0] = pixels[i];
      imageData.data[i*4+1] = pixels[i];
      imageData.data[i*4+2] = pixels[i];
      imageData.data[i*4+3] = 255;
    }
  } else if (tensor.shape[2] == 3) {
    // RGB
    for (let i = 0; i < pixels.length / 3; i++) {
      imageData.data[i*4+0] = pixels[i*3+0];
      imageData.data[i*4+1] = pixels[i*3+1];
      imageData.data[i*4+2] = pixels[i*3+2];
      imageData.data[i*4+3] = 255;
    }
  } else if (tensor.shape[2] == 4) {
    // RGBA
    imageData.data = pixels;
  }
  return imageData;
};

function preprocess(imgElement) {
  const min_hw_ratio = 1;
  const output_width = IMAGE_SIZE;
  const output_height = IMAGE_SIZE; // Trim equal rows from top and bottom to get a square image

  const r = imgElement.naturalHeight;
  const c = imgElement.naturalWidth;
  const image_hw_ratio = r / c;
  const r_to_keep = c * min_hw_ratio;
  const r_to_delete = r - r_to_keep;
  const remove_from_top = Math.ceil(r_to_delete / 2);
  const remove_from_bottom = Math.floor(r_to_delete / 2);

  // resample to get the desired image size
  let canvas = document.createElement('canvas'); //document.getElementById('debug');
  let ctx = canvas.getContext('2d');
  canvas.width = IMAGE_SIZE;
  canvas.height = IMAGE_SIZE;

  ctx.drawImage(imgElement, 0, remove_from_top, c, r - remove_from_bottom - remove_from_top, 0, 0, canvas.width, canvas.height);
  ctx.save();

  // Normalize pixel values to take the range [0,1]
  let imgData = ctx.getImageData(0, 0, output_width, output_height);
  let image_resampled_org = tf.browser.fromPixels(imgData).toFloat();
  let image_resampled = tf.unstack(image_resampled_org, 2)[0];

  const img_mean = image_resampled.mean();
  let image_clean_org = tf.sub(image_resampled, img_mean);
  const img_var = image_clean_org.square().mean().sqrt();

  let image_clean = tf.div(image_clean_org, img_var);
  let scalar = tf.sub(image_clean.max(), image_clean.min()).arraySync();//.get();
    
  image_clean = tf.div(tf.sub(image_clean, image_clean.min()), scalar); 

  // Stack into three channels
  let image_clean_stacked = tf.transpose(tf.stack([image_clean, image_clean, image_clean]),[1,2,0]);
    
  // canvas = document.getElementById('debug');
  // ctx = canvas.getContext('2d');
  // console.log(image_clean_stacked.shape)
  // ctx.putImageData(toPixels( tf.mul(image_clean_stacked.slice([0,0,2], [IMAGE_SIZE,IMAGE_SIZE,1]), 255)),0,0)
  // ctx.save()
  return image_clean_stacked
};

/**
 * Given an image element, makes a prediction through mobilenet returning the
 * probabilities of the top K classes.
 */
async function predict(imgElement) {
  status('Predicting...');

  const startTime = performance.now();
  const logits = tf.tidy(() => {
    // tf.fromPixels() returns a Tensor from an image element.
    let img = preprocess(imgElement);
      const batched = img.reshape([1, IMAGE_SIZE, IMAGE_SIZE, 3]);
      console.log(batched.shape);
      const logits = mobilenet.predict(batched);
      logits.print();
    // Make a prediction through mobilenet.
      return logits;
  });

  // Convert logits to probabilities and class names.
  const classes = await getTopKClasses(logits, TOPK_PREDICTIONS);
  const totalTime = performance.now() - startTime;
  status(`Processed in ${Math.floor(totalTime)}ms`);

  // Show the classes in the DOM.
  showResults(imgElement, classes);
}

export default {
  name: 'Demo',
  components: {
    Results,
    Status
  },
  data () {
    return {
      mobilenet: false,
      newImages: false,
      images: [],
      selected: 0,
      networkStatus: 'Loading... (wait up to 30 seconds)'
    }
  },
  computed: {
    // hasKey () {
    //   return this.$store.state.apiKey.length > 0
    // },
    imageSize () {
      return IMAGE_SIZE
    }
  },
  async created () {
    // Load model.
    // and then warm it up. This isn't necessary, but makes the first prediction
    // faster. Call `dispose` to release the WebGL memory allocated for the return
    // value of `predict`.
    // if (store.state.apiKey.length > 0) {
    // } else {
    //   router.push({ name: 'General' })
    // }

    tf.loadLayersModel(MOBILENET_MODEL_PATH).then( (value) => {
        this.networkStatus = "Ready! You can upload images now"
        this.mobilenet = value
//	this.processImage(require('@/assets/img/test/true_label_0.jpg'));
//        this.mobilenet.predict(tf.zeros([1, IMAGE_SIZE, IMAGE_SIZE, 3])).dispose()
    })

    this.images.push({
       filename: "true_label_2.png",
       img: {src: require('@/assets/img/test/true_label_2.jpg')},
       res: [
       {className: "2",
       probability: 0.5556}
       ],
       time: 0
    })

    this.images.push({
       filename: "true_label_4.png",
       img: {src: require('@/assets/img/test/true_label_4.jpg')},
       res: [
       {className: "4",
       probability: 0.7899}
       ],
       time: 0
    })
  },
  watch: {
    newImages: function (value) {
      for (let i = 0; i < value.length; i++) {
        this.processImage(value[i])
      }
    }
  },
  methods: {
    getTopKClasses (logits, topK) {
      //const values = await logits.data()
      return logits.data().then(values => {
        const valuesAndIndices = []
        for (let i = 0; i < values.length; i++) {
          valuesAndIndices.push({value: values[i], index: i})
        }
        valuesAndIndices.sort((a, b) => {
          return b.value - a.value
        })
        const topkValues = new Float32Array(topK)
        const topkIndices = new Int32Array(topK)
        for (let i = 0; i < topK; i++) {
          topkValues[i] = valuesAndIndices[i].value
          topkIndices[i] = valuesAndIndices[i].index
        }

        const topClassesAndProbs = [];
        for (let i = 0; i < topkIndices.length; i++) {
          topClassesAndProbs.push({
            className: topkIndices[i],
            probability: topkValues[i]
          })
        }
        return topClassesAndProbs
      })
    },    
    async processImage (file) {
      // convert to image
      this.networkStatus = "Processing a new image"

      const rootImage = await loadImage(URL.createObjectURL(file))

      let startTime = performance.now()

      try {
        // check for number of knees
        const canvas = imageToCanvas(rootImage)

        const imagesToProcess = countKneesInCanvas(canvas) === 1
          ? [{ name: file.name, image: rootImage }]
          : (await splitCanvasInTwoImages(canvas)).map((image, index) => ({
            image,
            name: (index === 0 ? 'left_' : 'right_') + file.name
          }))

        for (let i = 0; i < imagesToProcess.length; i++) {
          const image = imagesToProcess[i].image

          const logits = tf.tidy(() => {
            let img = preprocess(image)
            const batched = img.reshape([1, IMAGE_SIZE, IMAGE_SIZE, 3])
            return this.mobilenet.predict(batched)
          })

          const classes = await this.getTopKClasses(logits, 5)

          // remove excess items if needed
          if (this.images.length >= 5) {
            // this.images.length.shift()
          }

          this.images.push({
            filename: imagesToProcess[i].name,
            img: image,
            res: classes,
            time: performance.now() - startTime
          })
        }

      } finally {
        const totalTime = performance.now() - startTime
        this.networkStatus = `Done in ${Math.floor(totalTime)}ms. Ready for more images`
        // select new image
        this.$nextTick(() => {
          this.selected = this.images.length - 1
        })                
      }
    }
  }
}
</script>
