tfrecords的使用

Get start tfrecords

Posted by xhhszc on March 14, 2021

tfrecords的使用


1. 生成tfrecords的数据

利用python写生成数据的脚本文件gen_data.py:

# file of gen_data.py
from pyspark import SparkConf, SparkContext
from pyspark.sql import SparkSession
conf = SparkConf()
spark = SparkSession.builder.config(conf=conf).getOrCreate()

# create data
data = [('James','','Smith','1991-04-01','M',3000),
  ('Michael','Rose','','2000-05-19','M',4000),
  ('Robert','','Williams','1978-09-05','M',4000),
  ('Maria','Anne','Jones','1967-12-01','F',4000),
  ('Jen','Mary','Brown','1980-02-17','F',-1)
]
columns = ["firstname","middlename","lastname","dob","gender","salary"]

# make a dataframe from data
# 注意:tfrecords不能接收二维数组,或string列表(如["abc", "sss"]),若dataframe中含有这些数据,则必须将dataframe的每列都转换为float或string或float数组。
df = spark.createDataFrame(data=data, schema=columns)

"""
+---------+----------+--------+----------+------+------+
|firstname|middlename|lastname|dob       |gender|salary|
+---------+----------+--------+----------+------+------+
|James    |          |Smith   |1991-04-01|M     |3000  |
|Michael  |Rose      |        |2000-05-19|M     |4000  |
|Robert   |          |Williams|1978-09-05|M     |4000  |
|Maria    |Anne      |Jones   |1967-12-01|F     |4000  |
|Jen      |Mary      |Brown   |1980-02-17|F     |-1    |
+---------+----------+--------+----------+------+------+
"""

# 将dataframe存储到hdfs路径中
your_path = "/user/test/"
df.repartition(10).write.mode('overwrite').format("tfrecords").option("recordType", "Example").save(your_path)

使用shell命令执行上述脚本文件:

spark-submit \
--driver-memory 20g \
--executor-cores 4 \
--executor-memory 11g \
--conf spark.dynamicAllocation.minExecutors=100 \
--conf spark.dynamicAllocation.maxExecutors=150 \
--conf spark.defualt.parallelism=1200 \
--conf spark.executor.memoryOverhead=3096 \
--queue "your_queue_name" \
--jars spark-connector_2.11-1.10.0.jar \
gen_data.py

其中jar包spark-connector_2.11-1.10.0.jar用于将dataframe存为tfrecords格式。

2. 读取tfrecords的数据

import tensorflow as tf

def _parse_function(example_proto):
    features = {"firstname": tf.io.VarLenFeature(tf.string),
                "middlename": tf.io.VarLenFeature(tf.string),
                "lastname": tf.io.VarLenFeature(tf.string),
                "dob": tf.io.VarLenFeature(tf.string),
                "gender":tf.io.VarLenFeature(tf.string),
                "salary":tf.io.FixedLenFeature((1), tf.float32)}
                # tf.io.FixedLenFeature(shape, type), shape可以为二维,例如(3,2)
    parsed_feature = tf.io.parse_single_example(example_proto, features)
    return parsed_feature['firstname'], parsed_feature['middlename'], parsed_feature['lastname'], parsed_feature['dob'], parsed_feature['gender'], parsed_feature['salary']

def parse_dataset(data_file_path):
    num_threads = tf.data.experimental.AUTOTUNE
    if num_threads > 4:
        num_threads = 4 # 限制线程数量
    files_name = data_file_path + "/*"
    data_files = tf.data.Dataset.list_files(files_name)
    dataset = data_files.interleave(tf.data.TFRecordDataset, cycle_length=num_threads)
    dataset = dataset.shuffle(batch_size*10) #随机打乱数据
    dataset = data.repeat(epoches) #将数据重复epoches次,用于训练模型epoches次
    data_parsed = dataset.map(_parse_function, num_parallel_calls=num_threads)
    data_parsed = data_parsed.batch(batch_size)
    data_parsed = data_parsed.prefetch(1)
    data_parsed = data_parsed.make_one_shot_iterator()
    return data_parsed


def get_dataset():
    iterator = parse_dataset(data_file_path="/user/test/")
    try:
        while True:
            sample = sess.run(iterator)
            # sample[0] is firstname, ..., sample[5] is salary.
    except tf.errors.OutOfRangeError:
        print("end of the dataset")


get_dataset()