React state is always one step behind while making predictions on uploaded image using Tensorflowjs

Update: It seems to work fine on mobile browsers but not desktop browsers.

I’m developing a React app to classify Pokemon based on Images using Tensorflow.js.

  • What I want – Upload an image of a Pokemon, generate predictions for the same Pokemon.

  • What is actually happening – When I upload an image to make a prediction, the output is always for the previous image. So, 1st prediction is always garbage (random Pokemon). The predictions I get when I upload the 2nd Pokemon is always for the 1st Pokemon. 3rd Pokemon when uploaded gives prediction for the 2nd Pokemon and so on.

Here are the relevant pieces of code –

  1. I first check if the model is present in indexeddb, if yes, I load it into state model. If not, I fetch it from a server and store it in the state. This is what the first useEffect does upon the first render of the page.

  2. I use another useEffect that runs whenever the findState.uploadedImage changes. This state is present in Redux-toolkit.

Here’s a short demo of the problem => https://youtu.be/MX70zbupNWQ

Here’s the app URL => https://poke-zoo.herokuapp.com/

Here’s the Github repo => https://github.com/theairbend3r/poke-zoo/tree/master/frontend/src/features/find

Here’s the file SearchOutput.js. This fetches the model and makes predictions.

const SearchOutput = () => {
  const findState = useSelector(selectorFind)
  const dispatch = useDispatch()
  const imageRef = useRef(null)

  const [model, setModel] = useState(null)
  const [predictions, setPredictions] = useState([])

  const MODEL_HTTP_URL = "http://localhost:3001/api/pokeml/classify"
  const MODEL_INDEXEDDB_URL = "indexeddb://poke-model"

  useEffect(() => {
    async function fetchModel() {
      try {
        const localClassifierModel = await tf.loadLayersModel(
          MODEL_INDEXEDDB_URL
        )

        setModel(localClassifierModel)
        console.log("Model loaded from IndexedDB")
      } catch (e) {
        const classifierModel = await tf.loadLayersModel(MODEL_HTTP_URL)
        setModel(classifierModel)

        await classifierModel.save(MODEL_INDEXEDDB_URL)

        console.error(e)
      }
    }
    fetchModel()
  }, [])

  const getTopKPred = (pred, k) => {
    const predIdx = []
    const predNames = []

    const topkPred = [...pred].sort((a, b) => b - a).slice(0, k)

    topkPred.map(i => predIdx.push(pred.indexOf(i)))
    predIdx.map(i => predNames.push(idx2class[i]))

    return predNames
  }

  useEffect(() => {
    async function makePredictions() {
      if (imageRef && model) {
        try {
          const imgTensor = tf.browser
            .fromPixels(imageRef.current)
            .resizeNearestNeighbor([160, 160])
            .toFloat()
            .sub(127.5)
            .div(127.5)
            .expandDims()

          const y_pred = await model.predict(imgTensor).dataSync()
          const topkPredNames = getTopKPred(y_pred, 5)

          console.log(topkPredNames)
          return topkPredNames
        } catch (e) {
          console.log("Unable to run predictions.")
        }
      }
    }
    makePredictions()
  }, [findState.uploadedImage])

  return (.............)
}

Here’s the file findSlice.js that stores the input image into the redux state.

import { createSlice } from "@reduxjs/toolkit"
import axios from "axios"

const initialState = {
  uploadedImage: "",
  model: null,
  matchesFound: [],
}

export const findSlice = createSlice({
  name: "find",
  initialState: initialState,
  reducers: {
    storeInputImage: (state, action) => {
      state.uploadedImage = action.payload.uploadedImage
    },
    setModel: (state, action) => {
      state.model = action.payload.model
    },
  },
})

export const selectorFind = state => state.find
export const { storeInputImage, setModel } = findSlice.actions
export default findSlice.reducer

Edit: Based on the suggestions below. This has not solved the problem tho. Putting it here for reference.

/** @jsx jsx */
import { jsx, css } from "@emotion/core"
import tw from "twin.macro"
import axios from "axios"
import * as tf from "@tensorflow/tfjs"

import { useSelector, useDispatch } from "react-redux"
import { selectorFind, setModel } from "./findSlice"
import { useEffect, useState, useRef, useCallback } from "react"
import idx2class from "./components/classIdxDict"

const SearchOutput = () => {
  const findState = useSelector(selectorFind)
  const [imageRef, setImageRef] = useState(null)

  const onChangeRef = useCallback(node => {
    setImageRef(node)
  }, [])

  const [model, setModel] = useState(null)
  const [predictions, setPredictions] = useState([])

  const MODEL_HTTP_URL = "api/pokeml/classify"
  const MODEL_INDEXEDDB_URL = "indexeddb://poke-model"

  useEffect(() => {
    async function fetchModel() {
      try {
        const localClassifierModel = await tf.loadLayersModel(
          MODEL_INDEXEDDB_URL
        )

        setModel(localClassifierModel)
        console.log("Model loaded from IndexedDB")
      } catch (e) {
        const classifierModel = await tf.loadLayersModel(MODEL_HTTP_URL)
        setModel(classifierModel)

        await classifierModel.save(MODEL_INDEXEDDB_URL)

        console.error(e)
      }
    }
    fetchModel()
  }, [])

  const getTopKPred = (pred, k) => {
    const predIdx = []
    const predNames = []

    const topkPred = [...pred].sort((a, b) => b - a).slice(0, k)

    topkPred.map(i => predIdx.push(pred.indexOf(i)))
    predIdx.map(i => predNames.push(idx2class[i]))

    return predNames
  }

  useEffect(() => {
    async function makePredictions() {
      if (imageRef && model) {
        try {
          const imgTensor = tf.browser
            .fromPixels(imageRef)
            .resizeNearestNeighbor([160, 160])
            .toFloat()
            .sub(127.5)
            .div(127.5)
            .expandDims()

          const y_pred = await model.predict(imgTensor).data()
          const topkPredNames = getTopKPred(y_pred, 5)

          console.log(topkPredNames)
          return topkPredNames
        } catch (e) {
          console.log("Unable to run predictions.", e)
        }
      }
    }
    makePredictions()
  }, [findState.uploadedImage, imageRef])

  return (
    <div>
          {findState.uploadedImage && (
            <img
              ref={onChangeRef}
              src={findState.uploadedImage}
              width="600"
              height="600"
            />
          )}
    </div>
  )
}

export default SearchOutput

Source: ReactJs