diff --git a/physiopro/model/eventprediction.py b/physiopro/model/eventprediction.py index 6973ece..aaf0167 100644 --- a/physiopro/model/eventprediction.py +++ b/physiopro/model/eventprediction.py @@ -409,8 +409,11 @@ def forward(self, inputs): if self.temporal_encoding: tem_enc = self.temporal_enc(event_time, non_pad_mask) enc_output += tem_enc - - enc_out = self.network(enc_output) + + try: # for Contiformer, it requires event_time as input + enc_out = self.network(enc_output, t=event_time) + except: + enc_out = self.network(enc_output) if self.hyper_paras["use_rnn"]: enc_out = self.rnn(enc_output, non_pad_mask)