<template>
  <div>
    <vue-headful title="KneeNet Biomarkers"/>

    <div class="header-image py-4">
      <b-container class="text-light my-5">
        <b-row align-h="center" class="form-inline">
          <b-col cols="12" md="8" class="mb-2 text-center">
            <h2>Our T2 Mapper tool automatically segments the femoral cartilage in multi-echo spin echo T2-weighted knee MRIs.</h2>
            <br/>
            <div style="margin: 0.5em;"><a style="text-decoration: underline" href="/license/">Use of our software implies your acceptance of our license agreement. Please read it here.</a></div>
            <h4>Try it out! Pick an image* to analyze.</h4>
          </b-col>
        </b-row>

        <div class="d-flex justify-content-center align-items-center mb-2">
          <b-col cols="12" sm="10" md="8" lg="6">
            <b-form-file
              v-model="newImages"
              :state="Boolean(newImages)"
              placeholder="Choose or drag-and-drop a file..."
              :disabled="loadingFile || !ready"
              multiple/>
          </b-col>

          <div v-if="loadingFile">
            <i class="fas fa-spinner fa-spin fa-2x"/>
          </div>
        </div>

        <b-row align-h="center">            
          <b-col class="text-center" style="font-weight: bold;">
            <div>
            The analysis is run on your computer - no images are sent out!
            </div>
            <div>
            Takes ~3 min for an MRI with 30 slices on a typical computer
            </div>
          </b-col>
        </b-row>  
      </b-container>
    </div>

    <Status
      :networkStatus="networkStatus"/> 

    <!-- Result block -->
    <b-container>
      <ResultsMRI
        v-if="(images.length > 0)"
        v-model="selected"        
        :images="images"
        :width="imageSize"
        :height="imageSize"
        @download="onDownloadMask"/>

      <!--
      <canvas ref="canvas1"/>
      <canvas ref="canvas2"/>
      <canvas ref="canvas3"/>
      <canvas ref="canvas4"/>
      -->
    </b-container>

    <!-- Other blocks -->
    <!-- <div class="header-image py-4"> -->
    <div>
      <b-container>
        <b-row align-h="center">            
          <b-col class="text-center">
            <div>
            *This web app is designed to take a .zip file as input. Within that .zip file, there should be a set of dicom files, with each file representing one echo time of one slice of one 3D MESE MRI volume. There should be no other files in that .zip besides these MESE dicoms. This web app is only designed for sagittal-view MESE MRIs. 
            </div>
          </b-col>
        </b-row>  
      </b-container>
    </div>
  </div>  
</template>

<script>
import axios from 'axios'
import Zip from 'jszip'
import { saveAs } from 'file-saver'
import * as tf from '@tensorflow/tfjs'
import { mapState } from 'vuex'

import DICOMConverter from '@/util/DICOMConverter'
import ResultsMRI from '@/components/ResultsMRI'
import Status from '@/components/Status'
import longToByteArray from '@/util/longToByteArray'

const MOBILENET_MODEL_PATH = 'https://storage.googleapis.com/automatic_knee_mri_segmentation/tfjs_model/model.json'
const IMAGE_SIZE = 384

function preprocess (imgData, refs) {
  // assuming that our images are all greyscale
  const pointColors = [...new Uint32Array(imgData.data.buffer)]

  pointColors.sort((a, b) => {
    const aA = longToByteArray(a)
    const aB = longToByteArray(b)
    return aA[0] - aB[0]
  })

  // Trim outliers
  const index3 = Math.round(pointColors.length * .03)
  const index97 = Math.round(pointColors.length * .97)

  const percentile3 = longToByteArray(pointColors[index3])[0]
  const percentile97 = longToByteArray(pointColors[index97])[0]

  for (let i = 0; i < imgData.width; i++) {
    for (let j = 0; j < imgData.height; j++) {
      const pt = j*(imgData.width*4) + i*4

      // Trim lower bound outliers
      if (imgData.data[pt] < percentile3) {
        imgData.data[pt] = percentile3
        imgData.data[pt+1] = percentile3
        imgData.data[pt+2] = percentile3
      }
      // Trim upper bound outliers
      else if (imgData.data[pt] > percentile97) {
        imgData.data[pt] = percentile97
        imgData.data[pt+1] = percentile97
        imgData.data[pt+2] = percentile97
      }
    }
  }

  /*
  const ctx = refs.canvas2.getContext('2d')
  ctx.putImageData(imgData, 0, 0)
  */

  // Convert to tensor
  let image_resampled_org = tf.browser.fromPixels(imgData).toFloat()

  // Grab the first channel
  let image_resampled = tf.unstack(image_resampled_org, 2)[0]

  // Find the median
  let t = image_resampled.reshape([-1]); //flatten to 1D
 
  //let median = tf.topk(t, t.size).values.slice((t.size / 2).toInt(), 1); //sort pixel values and then take the middle one
  let median = tf.topk(t, t.size).values.slice(Math.round(t.size / 2), 1); //sort pixel values and then take the middle one

  // Center the image pixels 
  let image_centered = tf.sub(image_resampled, median);

  // Find the 75th percentile pixel and use that to scale all pixels
  t = image_centered.reshape([-1]);
  //let p75 = tf.topk(t, t.size).values.slice((t.size *.75).toInt(), 1);
  let p75 = tf.topk(t, t.size).values.slice(Math.round(t.size *.75), 1);
  let image_normalized = tf.div(image_centered,p75);

  // Find the 25th and 75th percentile pixel values and use them to scale the image so most pixels are in the range [-1,1]
  t = image_normalized.reshape([-1]);
  //p75 = tf.topk(t, t.size).values.slice((t.size *.75).toInt(), 1);
  p75 = tf.topk(t, t.size).values.slice(Math.round(t.size *.75), 1);
  let p25 = tf.topk(t, t.size).values.slice(t.size *.25, 1);
  let image_clean = tf.sub(tf.mul(tf.div(tf.sub(image_normalized, p25),tf.sub(p75,p25)),2),1);

  return image_clean
}

export default {
  name: 'DemoMRI',
  components: {
    ResultsMRI,
    Status
  },
  data () {
    return {
      ready: false,
      loadingFile: false,
      mobilenet: false,
      newImages: false,
      images: [],
      selected: 0,
      networkStatus: 'Loading... (wait up to 30 seconds)'
    }
  },
  computed: {
    // ...mapState({
    //   apiKey: state => state.apiKey
    // }),
    // hasKey () {
    //   return this.apiKey.length > 0
    // },
    imageSize () {
      return 384
    }
  },
  async created () {
    // Load model.
    // and then warm it up. This isn't necessary, but makes the first prediction
    // faster. Call `dispose` to release the WebGL memory allocated for the return
    // value of `predict`.
    // if (this.hasKey) {
    // } else {
    //   this.$router.push({ name: 'General' })
    // }

    // Load the new MRI model (MOBILENET_MODEL_PATH specified above)
    tf.loadLayersModel(MOBILENET_MODEL_PATH).then(value => {
      this.networkStatus = "Ready! You can upload images now"
      this.mobilenet = value
      this.ready = true
    })

  },
  watch: {
    newImages: async function (value) {    
      this.loadingFile = true

      try {
        await this.processImages(value)
      } finally {
        this.loadingFile = false
      }
    }
  },
  methods: {    
    predict (imgData) {
      const mask = tf.tidy(() => {
        let img = preprocess(imgData, this.$refs)
        const pixels = img.dataSync()

        const batched = img.reshape([1, IMAGE_SIZE, IMAGE_SIZE, 1])

        return this.mobilenet.predict(batched)
      })

      const pixels3 = mask.dataSync()
      return mask
    },
    async processImages (value) {
      let dicoms = []

      // Go through all archives
      for (let i = 0; i < value.length; i++) {
        let arc = await Zip.loadAsync(value[i])

        let files = Object.values(arc.files).filter((file) => (!file.dir && !file.name.includes('DS_Store')))

        for (let j = 0; j < files.length; j++) {
          let blob = await arc.file(files[j].name).async("blob")

          // Decode DICOM
          blob.name = files[j].name
          const dicom = await DICOMConverter(blob, 0, IMAGE_SIZE)

          // Push resulting array
          dicoms = dicoms.concat(dicom)
        }
      }

      // filter without echo
      dicoms = dicoms.filter(d => d.tags.echoTime)

      // set distance
      dicoms.forEach(d => {
        d.tags.echoDistance = Math.abs(20 - d.tags.echoTime)
      })

      // group by slice location
      let dicomBySliceLocation = []

      dicoms.forEach(d => {
        const existing = dicomBySliceLocation.find(dbsl => dbsl.sliceLocation === d.tags.sliceLocation)

        if (existing) {
          existing.elements.push(d)
        } else {
          dicomBySliceLocation.push({
            sliceLocation: d.tags.sliceLocation,
            elements: [ d ]
          })
        }
      })

      // Sort by echo distance
      dicomBySliceLocation.forEach(dbsl => {
        dbsl.elements.sort((a, b) => a.tags.echoDistance - b.tags.echoDistance)
      })

      // Sort by slice location
      dicomBySliceLocation.sort((a, b) => a.sliceLocation - b.sliceLocation)

      // Take slice closest to echo 20
      dicoms = dicomBySliceLocation.map(dbsl => dbsl.elements[0])

      // Make predictions staff only on remaining images (after filtering)
      const sampleCanvas = document.createElement('canvas')
      sampleCanvas.width = IMAGE_SIZE
      sampleCanvas.height = IMAGE_SIZE
      const canvasCtx = sampleCanvas.getContext('2d')

      const maskCanvas = document.createElement('canvas')
      maskCanvas.width = IMAGE_SIZE
      maskCanvas.height = IMAGE_SIZE
      const maskCanvasCtx = maskCanvas.getContext('2d')

      for (let i = 0; i < dicoms.length; i++) {
        const dicom = dicoms[i]

        const img = new Image()

        await new Promise(resolve => {
          img.onload = () => resolve()
          img.src = dicom.src
        })

        canvasCtx.drawImage(img, 0, 0, img.naturalWidth, img.naturalHeight, 0, 0, IMAGE_SIZE, IMAGE_SIZE)
        
        let imgData = canvasCtx.getImageData(0, 0, IMAGE_SIZE, IMAGE_SIZE)

        const prediction = this.predict(imgData)

        const mask = prediction.dataSync()

        // Convert mask to Image data
        imgData = new ImageData(IMAGE_SIZE, IMAGE_SIZE)

        for (let j = 0; j < mask.length; j++) {
          const pt = j * 4

          if (mask[j] < 0.5) {//if (mask[j] === 0) {
            // Transparent for zero value
            imgData.data[pt] = 0
            imgData.data[pt+1] = 0 
            imgData.data[pt+2] = 0
            imgData.data[pt+3] = 0
          } else {
            // maroon color for 1 value
            imgData.data[pt] = 127
            imgData.data[pt+1] = 0 
            imgData.data[pt+2] = 0
            imgData.data[pt+3] = 127
          }
        }

        maskCanvasCtx.putImageData(imgData, 0, 0)
        dicom.maskSrc = maskCanvas.toDataURL('image/png')
      }

      let exportName = value[0].name.split('.')
      exportName.pop()
      exportName = exportName.join('.') + '_mask.zip'

      this.images.push({
        filename: value.map(v => v.name).join(','),
        exportName,
        dicomIndex: 0,
        exporting: false,
        dicoms
      })
    },
    async onDownloadMask (index) {
      const data = this.images[index]

      data.exporting = true

      try {
        const zip = new Zip()

        for (let i = 0; i < data.dicoms.length; i++) {
          const dicom = data.dicoms[i]
          const blob = await axios.get(dicom.maskSrc, {
            responseType: 'arraybuffer'
          })

          zip.file(dicom.blob.name + '.png', blob.data, {
            binary: true
          })
        }

        const zipFile = await zip.generateAsync({ type: 'blob' })
        saveAs(zipFile, data.exportName)
      } finally {
        data.exporting = false
      }
    }
  }
}
</script>

