from tsai.models.TST import *
Miscellaneous
This contains a set of experiments.
InputWrapper
InputWrapper (arch, c_in, c_out, seq_len, new_c_in=None, new_seq_len=None, **kwargs)
Same as nn.Module
, but no need for subclasses to call super().__init__
= torch.randn(16, 1, 1000)
xb = InputWrapper(TST, 1, 4, 1000, 10, 224)
model 16,4)) test_eq(model.to(xb.device)(xb).shape, (
ResidualWrapper
ResidualWrapper (model)
Same as nn.Module
, but no need for subclasses to call super().__init__
RecursiveWrapper
RecursiveWrapper (model, n_steps, anchored=False)
Same as nn.Module
, but no need for subclasses to call super().__init__
= torch.randn(16, 1, 20)
xb = RecursiveWrapper(TST(1, 1, 20), 5)
model 16, 5)) test_eq(model.to(xb.device)(xb).shape, (