from keras import backend as Kfrom keras.initializations import zerofrom keras.engine import InputSpecfrom keras.models import Sequentialfrom keras.layers import LSTM, activations, Wrapper, Recurrent, Layerflatten= lambda l:item for sublist in l for item in sublistclass Attention(Layer): def __init__(self, lstm_gen, nlayers=1, **kwargs): self.supports_masking = True self.lstm_gen = lstm_gen self.nlayers = nlayers super(Attention, self).__init__(**kwargs) def build(self, input_shape): def add_w(dims, init, name): return self.add_weight(dims, init, name=name.format(self.name), trainable=True) self.encoder_sh, self.decoder_sh = input_shape assert len(self.encoder_sh) >= 3 self.layers = self.lstm_gen() for i in range(self.nlayers) nb_samples, nb_time, nb_dims = self.decoder_sh first_l = self.layers0 out_shape = self.get_output_shape_for(input_shape) for layer in self.layers: if not layer.built: layer.build(out_shape) init, out_dim = first_l.init, first_l.output_dim self.W1 = add_w((self.encoder_sh-1, nb_dims), init, ‘{}_W1’) self.W2 = add_w((out_dim, nb_dims), init, ‘{}_W2’) self.b2 = add_w((nb_dims,), zero, ‘{}_b2’) self.W3 = add_w((nb_dims + out_dim, out_dim), init, ‘{}_W3’) self.b3 = add_w((out_dim,), zero, ‘{}_b3’) self.V = add_w((nb_dims,), init, ‘{}_V’) self.trainable_weights += flatten(l.trainable_weights for l in self.layers) self.non_trainable_weights += flatten(l.non_trainable_weights for l in self.layers) super(Attention, self).build(input_shape) def step(self, x, states): print(len(states)) h = states0 encoder_o = states-1 xW1 = states-2 hW2 = K.expand_dims(K.dot(h,self.W2) + self.b2, 1) u = K.tanh(xW1 + hW2) a = K.expand_dims(K.softmax(K.sum(self.V * u,2)), -1) Xa = K.sum(a * encoder_o,1) h = K.dot(K.concatenate(x,Xa,1),self.W3) + self.b3 for layer in self.layers: h, new_states = layer.step(h, states) return h, new_states def get_output_shape_for(self, input_shape): return self.layers0.get_output_shape_for(input_shape1) def compute_mask(self, input, mask): return self.layers0.compute_mask(input, mask1) def get_constants(self, encoder_o, constants): constants.append(K.dot(encoder_o,self.W1)) constants.append(encoder_o) return constants def call(self, x, mask=None): first_l = self.layers0 encoder_o, decoder_i = x if first_l.stateful: initial_states = first_l.states else: initial_states = first_l.get_initial_states(decoder_i) constants = self.get_constants(encoder_o, first_l.get_constants(decoder_i)) preprocessed_input = first_l.preprocess_input(decoder_i) last_output, outputs, states = K.rnn(self.step, preprocessed_input,initial_states, go_backwards=first_l.go_backwards, mask=mask1, constants=constants, unroll=first_l.unroll, input_length=self.decoder_sh1) if first_l.stateful: self.updates = for i in range(len(states)): self.updates.append((first_l.statesi, statesi)) return outputs if first_l.return_sequences else last_output