import numpy as np
import tensorflow as tf
from bionic_apps.ai.interface import TFBaseNet
from tensorflow.keras.utils import plot_model
from tensorflow.keras.regularizers import l2


# Pozicionális kódolás
def positional_encoding(length, depth):
    depth = depth // 2  # A mélység felezése (szétosztva sin és cos számára)
    positions = tf.cast(tf.range(length)[:, tf.newaxis], dtype=tf.float32)  # Pozíció vektor létrehozása oszlopként
    depths = tf.cast(tf.range(depth)[np.newaxis, :], dtype=tf.float32) / depth  # Mélység vektor létrehozása sorban
    angle_rates = 1 / (10000 ** depths)  # Szögsebességek kiszámítása
    angle_rads = positions * angle_rates  # Szögek kiszámítása
    pos_encoding = tf.concat([tf.sin(angle_rads), tf.cos(angle_rads)], axis=-1)  # Sin és cos összefűzése
    return pos_encoding  # Visszatér a pozicionális kódolás tenzora



# BaseAttention réteg
class BaseAttention(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super().__init__()  # Szülő osztály inicializálása
        self.mha = tf.keras.layers.MultiHeadAttention(**kwargs)  # Többszörös fejű figyelmi réteg inicializálása
        self.layernorm = tf.keras.layers.LayerNormalization()  # Rétegnormalizálás inicializálása
        self.add_layer = tf.keras.layers.Add()  # Add réteg inicializálása a reziduális kapcsolat számára

# GlobalSelfAttention réteg
class GlobalSelfAttention(BaseAttention):
    def call(self, x):
        attn_output = self.mha(query=x, value=x, key=x)  # Többszörös fejű önfigyelés alkalmazása
        x = self.add_layer([x, attn_output])  # Reziduális kapcsolat hozzáadása
        x = self.layernorm(x)  # Rétegnormalizálás alkalmazása
        return x  # Visszatér a feldolgozott tenzorral

# FeedForward réteg L2 regulárizációval
class FeedForward(tf.keras.layers.Layer):
    def __init__(self, d_model, dff, dropout_rate=0.5, l2_reg=0.01):
        super().__init__()  # Szülő osztály inicializálása
        self.seq = tf.keras.Sequential([  # Szekvenciális réteg inicializálása
            tf.keras.layers.Dense(dff, activation='relu', kernel_regularizer=l2(l2_reg)),  # Dense réteg hozzáadása ReLU aktivációval
            tf.keras.layers.Dense(d_model, kernel_regularizer=l2(l2_reg)),
            tf.keras.layers.Dropout(dropout_rate)  # Dropout réteg hozzáadása
        ])
        self.add_layer = tf.keras.layers.Add()  # Add réteg inicializálása a reziduális kapcsolat számára
        self.layer_norm = tf.keras.layers.LayerNormalization()  # Rétegnormalizálás inicializálása

    def call(self, x):
        x = self.add_layer([x, self.seq(x)])
        x = self.layer_norm(x)  # Rétegnormalizálás alkalmazása
        return x

# PositionalEmbedding réteg L2 regulárizációval
class PositionalEmbedding(tf.keras.layers.Layer):
    def __init__(self, d_model, max_len=2048, l2_reg=0.01):
        super().__init__()  # Szülő osztály inicializálása
        self.d_model = d_model  # Modell méretének beállítása
        self.embedding = tf.keras.layers.Dense(d_model, use_bias=False, kernel_regularizer=l2(l2_reg))  # Sűrű réteg inicializálása az embedding számára L2 regulárizációval
        self.pos_encoding = positional_encoding(max_len, d_model)  # Pozicionális kódolás generálása

    def call(self, x):
        seq_len = tf.shape(x)[1]  # A szekvencia hosszának lekérése
        x = self.embedding(x)  # Embedding alkalmazása
        x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))  # Embeddingek skálázása
        pos_encoding = self.pos_encoding[:seq_len, :]
        pos_encoding = tf.expand_dims(pos_encoding, 0)
        pos_encoding = tf.tile(pos_encoding, [tf.shape(x)[0], 1, 1])
        x = x + pos_encoding  # Pozicionális kódolás hozzáadása az embeddingekhez
        return x

# Encoder réteg
class EncoderLayer(tf.keras.layers.Layer):
    def __init__(self, d_model, num_heads, dff, dropout_rate=0.5, l2_reg=0.01):
        super().__init__()  # Szülő osztály inicializálása
        self.self_attention = GlobalSelfAttention(num_heads=num_heads, key_dim=d_model, dropout=dropout_rate)
        self.ffn = FeedForward(d_model, dff, dropout_rate, l2_reg)

    def call(self, x):
        x = self.self_attention(x)  # Önfigyelés alkalmazása
        x = self.ffn(x)  # Feedforward hálózat alkalmazása
        return x

# Encoder
class Encoder(tf.keras.layers.Layer):
    def __init__(self, num_layers, d_model, num_heads, dff, dropout_rate=0.5, l2_reg=0.01):
        super().__init__()  # Szülő osztály inicializálása
        self.d_model = d_model  # Modell méretének beállítása
        self.num_layers = num_layers  # Rétegek számának beállítása
        self.enc_layers = [EncoderLayer(d_model, num_heads, dff, dropout_rate, l2_reg) for _ in range(num_layers)]
        self.dropout = tf.keras.layers.Dropout(dropout_rate)  # Dropout réteg inicializálása

    def call(self, x):
        for i in range(self.num_layers):  # Iteráció a rétegeken
            x = self.enc_layers[i](x)  # Minden encoder réteg alkalmazása
        return x

# Transformer modell
class Transformer(TFBaseNet):
    def __init__(self, input_shape, output_shape, num_layers=4, d_model=512, num_heads=4, dff=2048, dropout_rate=0.2, l2_reg=0.01, **kwargs):
        self.input_shape = input_shape  # Input alakjának beállítása
        self.output_shape = output_shape  # Output alakjának beállítása
        self.num_layers = num_layers  # Rétegek számának beállítása
        self.d_model = d_model  # Modell méretének beállítása
        self.num_heads = num_heads  # Figyelmi fejek számának beállítása
        self.dff = dff  # Feedforward hálózat méretének beállítása
        self.dropout_rate = dropout_rate  # Dropout arány beállítása
        self.l2_reg = l2_reg  # L2 regulárizáció faktor beállítása
        super().__init__(input_shape=input_shape, output_shape=output_shape, **kwargs)  # Szülő osztály inicializálása

    def _build_graph(self):
        input_tensor = tf.keras.layers.Input(shape=self.input_shape)  # Input layer definiálása

        pos_embedding_layer = PositionalEmbedding(d_model=self.d_model, max_len=self.input_shape[0], l2_reg=self.l2_reg)
        x = pos_embedding_layer(input_tensor)

        # Encoder rétegek alkalmazása
        encoder = Encoder(num_layers=self.num_layers, d_model=self.d_model, num_heads=self.num_heads, dff=self.dff, dropout_rate=self.dropout_rate, l2_reg=self.l2_reg)
        encoder_output = encoder(x)

        encoder_output = tf.keras.layers.GlobalAveragePooling1D()(encoder_output)  # Globális átlagos pooling alkalmazása
        output = tf.keras.layers.Dense(self.output_shape, activation='softmax', kernel_regularizer=l2(self.l2_reg))(encoder_output)

        return input_tensor, output

    def plot_model(self, filename="model.png"):
        # Biztosítja, hogy a modell létre van hozva
        if not hasattr(self, '_model'):
            input_tensor, outputs = self._build_graph()
            self._create_model(input_tensor, outputs)
        try:
            plot_model(self._model, to_file=filename, show_shapes=True, show_layer_names=True)  # Modell ábrázolása és mentése
        except Exception as e:
            print(f"Hiba a modell ábrázolása során: {e}")

input_shape = (64, 160)  # Példa input alak
output_shape = 4  # Példa output alak
transformer_model = Transformer(input_shape, output_shape)
transformer_model.plot_model("transformer_model.png")