35), 6) test_eq(get_embed_size(
Model utilities
Utility functions used to build PyTorch timeseries models.
apply_idxs
apply_idxs (o, idxs)
Function to apply indices to zarr, dask and numpy arrays
SeqTokenizer
SeqTokenizer (c_in, embed_dim, token_size=60, norm=False)
Generates non-overlapping tokens from sub-sequences within a sequence by applying a sliding window
get_embed_size
get_embed_size (n_cat, rule='log2')
has_weight_or_bias
has_weight_or_bias (l)
has_weight
has_weight (l)
has_bias
has_bias (l)
is_conv
is_conv (l)
is_affine_layer
is_affine_layer (l)
is_conv_linear
is_conv_linear (l)
is_bn
is_bn (l)
is_linear
is_linear (l)
is_layer
is_layer (*args)
get_layers
get_layers (model, cond=<function noop>, full=True)
check_weight
check_weight (m, cond=<function noop>, verbose=False)
check_bias
check_bias (m, cond=<function noop>, verbose=False)
get_nf
get_nf (m)
Get nf from model’s first linear layer in head
ts_splitter
ts_splitter (m)
Split of a model between body and head
transfer_weights
transfer_weights (model, weights_path:pathlib.Path, device:torch.device=None, exclude_head:bool=True)
Utility function that allows to easily transfer weights between models. Taken from the great self-supervised repository created by Kerem Turgutlu. https://github.com/KeremTurgutlu/self_supervised/blob/d87ebd9b4961c7da0efd6073c42782bbc61aaa2e/self_supervised/utils.py
build_ts_model
build_ts_model (arch, c_in=None, c_out=None, seq_len=None, d=None, dls=None, device=None, verbose=False, s_cat_idxs=None, s_cat_embeddings=None, s_cat_embedding_dims=None, s_cont_idxs=None, o_cat_idxs=None, o_cat_embeddings=None, o_cat_embedding_dims=None, o_cont_idxs=None, patch_len=None, patch_stride=None, fusion_layers=128, fusion_act='relu', fusion_dropout=0.0, fusion_use_bn=True, pretrained=False, weights_path=None, exclude_head=True, cut=-1, init=None, arch_config={}, **kwargs)
count_parameters
count_parameters (model, trainable=True)
build_tsimage_model
build_tsimage_model (arch, c_in=None, c_out=None, dls=None, pretrained=False, device=None, verbose=False, init=None, arch_config={}, **kwargs)
build_tabular_model
build_tabular_model (arch, dls, layers=None, emb_szs=None, n_out=None, y_range=None, device=None, arch_config={}, **kwargs)
from tsai.data.external import get_UCR_data
from tsai.data.core import TSCategorize, get_ts_dls
from tsai.data.preprocessing import TSStandardize
from tsai.models.InceptionTime import *
= get_UCR_data('NATOPS', split_data=False)
X, y, splits = [None, TSCategorize()]
tfms = TSStandardize()
batch_tfms = get_ts_dls(X, y, splits, tfms=tfms, batch_tfms=batch_tfms)
dls = build_ts_model(InceptionTime, dls=dls)
model 460038) test_eq(count_parameters(model),
get_clones
get_clones (module, N)
= nn.Conv1d(3,4,3)
m 3) get_clones(m,
ModuleList(
(0-2): 3 x Conv1d(3, 4, kernel_size=(3,), stride=(1,))
)
split_model
split_model (m)
output_size_calculator
output_size_calculator (mod, c_in, seq_len=None)
= 3
c_in = 30
seq_len = nn.Conv1d(3, 12, kernel_size=3, stride=2)
m = output_size_calculator(m, c_in, seq_len)
new_c_in, new_seq_len 12, 14)) test_eq((new_c_in, new_seq_len), (
[W NNPACK.cpp:64] Could not initialize NNPACK! Reason: Unsupported hardware.
change_model_head
change_model_head (model, custom_head, **kwargs)
Replaces a model’s head by a custom head as long as the model has a head, head_nf, c_out and seq_len attributes
true_forecaster
true_forecaster (o, split, horizon=1)
naive_forecaster
naive_forecaster (o, split, horizon=1)
= np.random.rand(20).cumsum()
a = np.arange(10, 20)
split 1), true_forecaster(a, split, 1) a, naive_forecaster(a, split,
(array([ 0.74775537, 1.41245663, 2.12445924, 2.8943163 , 3.56384351,
4.23789602, 4.83134182, 5.18560431, 5.30551186, 6.29076506,
6.58873471, 7.03661275, 7.0884361 , 7.57927022, 8.21911791,
8.59726773, 9.37382718, 10.17298849, 10.40118308, 10.82265631]),
array([ 6.29076506, 6.58873471, 7.03661275, 7.0884361 , 7.57927022,
8.21911791, 8.59726773, 9.37382718, 10.17298849, 10.40118308]),
array([ 6.58873471, 7.03661275, 7.0884361 , 7.57927022, 8.21911791,
8.59726773, 9.37382718, 10.17298849, 10.40118308, 10.82265631]))