tfsnippet dataflow使用记录

人生之路坎坎坷坷,跌跌撞撞在所难免。但是,不论跌了多少次,你都必须坚强勇敢地站起来。任何时候,无论你面临着生命的何等困惑抑或经受着多少挫折,无论道路多艰难,希望变得如何渺茫,请你不要绝望,再试一次,坚持到底,成功终将属于勇不言败的你。

导读:本篇文章讲解 tfsnippet dataflow使用记录,希望对大家有帮助,欢迎收藏,转发!站点地址:www.bmabk.com,来源:原文

1 需求背景

在修改代码时InterFusion,使用训练数据训练时,之前只能用一个实体划分滑动窗口,现在需要对多个实体划分滑动窗口,原本这部分实现用到了tfsnippet.dataflows相关的一些API。

2 tfsnippet.dataflows

简介
主要用到了这三个类,这些类主要用于训练前生成batch形式的数据处理,类似pytorch的dataloader,在不同数据形式之间提供了方便的转换方式。

3 之前代码理解

修改代码主要在InterFusion/algorithm/utils.py这个文件下的

def get_sliding_window_data_flow(window_size, batch_size, x, u=None, y=None, shuffle=False, skip_incomplete=False) -> spt.DataFlow:
       n = len(x)
       # 得到在原始时间序列中从window_size - 1到n的索引列表,便于之后生成滑动窗口
       seq = np.arange(window_size - 1, n, dtype=np.int32).reshape([-1, 1])
       # 创建一个DataFlow的子类ArrayFlow对象,传入的arrays对象就是seq
       seq_df: spt.DataFlow = spt.DataFlow.arrays(
           [seq], shuffle=shuffle, skip_incomplete=skip_incomplete, batch_size=batch_size)
       # 偏移量列表,和之后的每个滑动窗口对应,用于根据一个点访问原时间序列数据中前一个窗口范围的点
       offset = np.arange(-window_size + 1, 1, dtype=np.int32)

       if y is not None:
           if u is not None:
               df = seq_df.map(lambda idx: (x[idx + offset], u[idx + offset], y[idx + offset]))
           else:
               df = seq_df.map(lambda idx: (x[idx + offset], y[idx + offset]))
       else:
           if u is not None:
	           # 和具体模型相关,执行到这个分支,将ArrayFlow通过map API转化为MapperFlow对象,具体就是将每个点转化为在原时间序列中该点及之前一个时间窗口的数据列表
               df = seq_df.map(lambda idx: (x[idx + offset], u[idx + offset]))
           else:
               df = seq_df.map(lambda idx: (x[idx + offset],))

       return df

4 代码修改

之前是根据一个array划分为一个和这个array对应的一系列滑动窗口的生成器。
现在需要根据多个array,及一个array list,划分出一系列滑动窗口的生成器,划分的滑动窗口要在每一个array内连续,思路是将每个array划分的窗口生成器gather起来。

def get_sliding_window_data_flow(window_size, batch_size, x, u=None, y=None, shuffle=False, skip_incomplete=False) -> spt.DataFlow:
    # 对多个数据同时划分时间窗口并合并
    if type(x) == list and type(u) == list:
        x_list, u_list, df_list = x, u, []
        for x, u in zip(x_list, u_list):
            n = len(x)
            seq = np.arange(window_size - 1, n, dtype=np.int32).reshape([-1, 1])
            seq_df: spt.DataFlow = spt.DataFlow.arrays(
                [seq], shuffle=shuffle, skip_incomplete=skip_incomplete, batch_size=batch_size)
            offset = np.arange(-window_size + 1, 1, dtype=np.int32)
            if y is not None:
                if u is not None:
                    df = seq_df.map(lambda idx: (x[idx + offset], u[idx + offset], y[idx + offset]))
                else:
                    df = seq_df.map(lambda idx: (x[idx + offset], y[idx + offset]))
            else:
                if u is not None:
                    df = seq_df.map(lambda idx: (x[idx + offset], u[idx + offset]))
                else:
                    df = seq_df.map(lambda idx: (x[idx + offset],))
            df_list.append(df)
        # 将spt.dataflows.MapperFlow的一个listgather起来
        res_df = spt.dataflows.MapperFlow.gather(df_list)
        return res_df
    else:
        n = len(x)
        seq = np.arange(window_size - 1, n, dtype=np.int32).reshape([-1, 1])
        seq_df: spt.DataFlow = spt.DataFlow.arrays(
            [seq], shuffle=shuffle, skip_incomplete=skip_incomplete, batch_size=batch_size)
        offset = np.arange(-window_size + 1, 1, dtype=np.int32)

        if y is not None:
            if u is not None:
                df = seq_df.map(lambda idx: (x[idx + offset], u[idx + offset], y[idx + offset]))
            else:
                df = seq_df.map(lambda idx: (x[idx + offset], y[idx + offset]))
        else:
            if u is not None:
                df = seq_df.map(lambda idx: (x[idx + offset], u[idx + offset]))
            else:
                df = seq_df.map(lambda idx: (x[idx + offset],))

        return df

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。

文章由极客之音整理,本文链接:https://www.bmabk.com/index.php/post/133480.html

(0)
飞熊的头像飞熊bm

相关推荐

发表回复

登录后才能评论
极客之音——专业性很强的中文编程技术网站,欢迎收藏到浏览器,订阅我们!