import numpy as np

class NeuralNet:
    def __init__(self):
        self.w1 = np.zeros((6, 3))
        self.b1 = np.zeros(6)
        self.w2 = np.zeros((3, 6))
        self.b2 = np.zeros(3)

        self.w1[0] = [2.0, 0.5, -0.3]; self.b1[0] = -0.5
        self.w1[1] = [-2.0, 0.5, 0.3]; self.b1[1] = -0.5
        self.w1[2][0] = -1.0;          self.b1[2] = 0.0
        self.w1[3][1] = 4.0;           self.b1[3] = -1.0
        self.w1[4][2] = 2.0;           self.b1[4] = -0.5
        self.w1[5][2] = -2.0;          self.b1[5] = -0.5

        self.w2[0][[0,3,5]] = [1.2, 0.8, 1.0]
        self.w2[1][[1,3,4]] = [1.2, 0.8, 1.0]
        self.w2[2][[2,3]]   = [1.0, -0.5]

    def normalizeB(self, b):
        v = b / 0.2
        return np.clip(v, 0.0, 1.0)

    def normalizePrice(self, price, ref):
        denom = ref * 0.01
        if denom <= 0:
            denom = 0.0001
        x = (price - ref) / denom
        return np.clip(x, -2.0, 2.0)

    def forward(self, a, b, price, ref_price):
        x = np.array([
            a,
            self.normalizeB(b),
            self.normalizePrice(price, ref_price)
        ])

        hidden = np.tanh(self.w1.dot(x) + self.b1)
        z = self.w2.dot(hidden) + self.b2
        expz = np.exp(z - np.max(z))
        out = expz / expz.sum()

        return out  # LONG, SHORT, NEUTRAL
