import {
  serialization,
  loadLayersModel,
  LayersModel,
  Tensor,
  Rank,
} from '@tensorflow/tfjs'

import MovingAvgProcessor, {
  MovingAvgProcessorInteface,
} from './moveAvgProcessor'
import TSM from '../tensorflow/TSM'
import AttentionMask from '../tensorflow/AttentionMask'
import { TensorStoreInterface } from './tensorStore'
import tensorStore from './tensorStore'

export interface PosprocessorInteface {
  compute(normalizedBatch: Tensor<Rank>, rawBatch: Tensor<Rank>): void
}

class Posprocessor implements PosprocessorInteface {
  tensorStore: TensorStoreInterface
  rppgAvgProcessor: MovingAvgProcessorInteface
  respAvgProcessor: MovingAvgProcessor
  vppgModel: LayersModel | null

  constructor(tensorStore: TensorStoreInterface) {
    this.tensorStore = tensorStore
    this.rppgAvgProcessor = new MovingAvgProcessor()
    this.respAvgProcessor = new MovingAvgProcessor()
    this.vppgModel = null
  }

  reset = () => {
    this.rppgAvgProcessor.reset()
    this.respAvgProcessor.reset()
  }

  loadModel = async () => {
    if (this.vppgModel === null) {
      serialization.registerClass(TSM)
      serialization.registerClass(AttentionMask)
      this.vppgModel = await loadLayersModel(
        `${process.env.PUBLIC_URL}/models/vppg/model.json`
      )
      console.log('vppgModel loaded succesfully')
    }
    return true
  }

  compute = (normalizedBatch: Tensor<Rank>, rawBatch: Tensor<Rank>) => {
    if (this.vppgModel) {
      const rppg = this.vppgModel.predict([
        normalizedBatch,
        rawBatch,
      ]) as Tensor<Rank>

      this.tensorStore.addRppgPltData(rppg.dataSync())
      rppg.dispose()
    }
  }
}

export default new Posprocessor(tensorStore)
