import {useMemo, useRef} from "react";
import * as THREE from "three";
import {useFBO} from "@react-three/drei";
import {getRandomData} from "./simulation/simulationMaterial";
import {createPortal, useFrame} from "@react-three/fiber";
import fragmentShader from "./fragment";
import vertexShader from "./vertex";

const positions = new Float32Array([
    -1,
    -1,
    0,
    1,
    -1,
    0,
    1,
    1,
    0,
    -1,
    -1,
    0,
    1,
    1,
    0,
    -1,
    1,
    0
]);
const uvs = new Float32Array([0, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0]);


export const FBOParticles = ({mouseRef}) => {
    const size = 128;

    const points = useRef();
    const simulationMaterialRef = useRef();
    const timeRef = useRef(0); // new
    const flagRef = useRef(false);

    const sceneRef = useRef(new THREE.Scene());
    const scene = sceneRef.current;
    const cameraRef = useRef(
        new THREE.OrthographicCamera(-1, 1, 1, -1, 1 / Math.pow(2, 53), 1)
    );
    const camera = cameraRef.current;

    const renderTargetA = useFBO(size, size, {
        minFilter: THREE.NearestFilter,
        magFilter: THREE.NearestFilter,
        format: THREE.RGBAFormat,
        stencilBuffer: false,
        type: THREE.FloatType
    });
    const renderTargetB = useFBO(size, size, {
        minFilter: THREE.NearestFilter,
        magFilter: THREE.NearestFilter,
        format: THREE.RGBAFormat,
        stencilBuffer: false,
        type: THREE.FloatType
    });

    const particlesPosition = useMemo(() => {
        const length = size * size;
        const particles = new Float32Array(length * 3);
        for (let i = 0; i < length; i++) {
            let i3 = i * 3;
            particles[i3 + 0] = (i % size) / size;
            particles[i3 + 1] = i / size / size;
        }
        return particles;
    }, [size]);

    const uniforms = useMemo(
        () => ({
            uPositions: {
                value: null
            },
            uMouse: {value: {x: 0, y: 0}}
        }),
        []
    );

    // new ------
    // const elapsedTime = timeRef.current;
    const simulationTexture = useMemo(() => {
        const positionsTexture = new THREE.DataTexture(
            getRandomData(size, size),
            size,
            size,
            THREE.RGBAFormat,
            THREE.FloatType
        );
        positionsTexture.needsUpdate = true;
        return positionsTexture;
    }, [size]);


    useFrame((state, delta) => {
        const {gl} = state;
        const flag = flagRef.current;
        timeRef.current += delta; // new
        simulationMaterialRef.current.uniforms.uTime.value = timeRef.current;
        points.current.material.uniforms.uMouse.value = mouseRef.current;
        simulationMaterialRef.current.uniforms.uMouse.value = mouseRef.current;

        // Render the scene to the current render target and get the texture
        gl.setRenderTarget(flag ? renderTargetA : renderTargetB);
        gl.clear();
        gl.render(scene, camera);
        gl.setRenderTarget(null);

        points.current.material.uniforms.uPositions.value = flag
            ? renderTargetA.texture
            : renderTargetB.texture;

        // Pass the previous texture to the simulation material
        simulationMaterialRef.current.uniforms.positions.value = flag
            ? renderTargetA.texture
            : renderTargetB.texture;

        flagRef.current = !flag;
    });

    const portal = useMemo(() => {
        const elapsedTime = timeRef.current;
        const scene = sceneRef.current;
        return createPortal(
            <mesh>
                <simulationMaterial
                    ref={simulationMaterialRef}
                    args={[elapsedTime, simulationTexture]}
                />
                <bufferGeometry>
                    <bufferAttribute
                        attach="attributes-position"
                        count={positions.length / 3}
                        array={positions}
                        itemSize={3}
                    />
                    <bufferAttribute
                        attach="attributes-uv"
                        count={uvs.length / 2}
                        array={uvs}
                        itemSize={2}
                    />
                </bufferGeometry>
            </mesh>,
            scene
        );
    }, [simulationTexture]);

    return (
        <>
            {portal}
            <points ref={points}>
                <bufferGeometry>
                    <bufferAttribute
                        attach="attributes-position"
                        count={particlesPosition.length / 3}
                        array={particlesPosition}
                        itemSize={3}
                    />
                </bufferGeometry>
                <shaderMaterial
                    blending={THREE.AdditiveBlending}
                    depthWrite={false}
                    fragmentShader={fragmentShader}
                    vertexShader={vertexShader}
                    uniforms={uniforms}
                />
            </points>
        </>
    );
};