问题 保留索引字符串对应的spark字符串索引器


Spark的StringIndexer非常有用,但是通常需要检索生成的索引值和原始字符串之间的对应关系,似乎应该有一种内置的方法来实现这一点。我将用这个简单的例子来说明 Spark文档

from pyspark.ml.feature import StringIndexer

df = sqlContext.createDataFrame(
    [(0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")],
    ["id", "category"])
indexer = StringIndexer(inputCol="category", outputCol="categoryIndex")
indexed_df = indexer.fit(df).transform(df)

这个简化的案例给了我们:

+---+--------+-------------+
| id|category|categoryIndex|
+---+--------+-------------+
|  0|       a|          0.0|
|  1|       b|          2.0|
|  2|       c|          1.0|
|  3|       a|          0.0|
|  4|       a|          0.0|
|  5|       c|          1.0|
+---+--------+-------------+

所有精细和花花公子,但对于许多用例,我想知道我的原始字符串和索引标签之间的映射。我能想到的最简单的方法就是这样:

   In [8]: indexed.select('category','categoryIndex').distinct().show()
+--------+-------------+
|category|categoryIndex|
+--------+-------------+
|       b|          2.0|
|       c|          1.0|
|       a|          0.0|
+--------+-------------+

如果我想的话,我可以将其存储为字典或类似的结果:

In [12]: mapping = {row.categoryIndex:row.category for row in
           indexed.select('category','categoryIndex').distinct().collect()}

In [13]: mapping
Out[13]: {0.0: u'a', 1.0: u'c', 2.0: u'b'}

我的问题是:由于这是一个如此常见的任务,我猜测(但当然可能是错误的)字符串索引器无论如何都以某种方式存储这个映射,有没有办法更简单地完成上述任务?

我的解决方案或多或少是直截了当的,但对于大型数据结构,这涉及一些额外的计算(也许)我可以避免。想法?


1015
2017-11-10 18:24


起源



答案:


标签映射可以从列元数据中提取:

meta = [
    f.metadata for f in indexed_df.schema.fields if f.name == "categoryIndex"
]
meta[0]
## {'ml_attr': {'name': 'category', 'type': 'nominal', 'vals': ['a', 'c', 'b']}}

哪里 ml_attr.vals 提供位置和标签之间的映射:

dict(enumerate(meta[0]["ml_attr"]["vals"]))
## {0: 'a', 1: 'c', 2: 'b'}

Spark 1.6+

您可以使用将数值转换为标签 IndexToString。这将使用如上所示的列元数据。

from pyspark.ml.feature import IndexToString

idx_to_string = IndexToString(
    inputCol="categoryIndex", outputCol="categoryValue")

idx_to_string.transform(indexed_df).drop("id").distinct().show()
## +--------+-------------+-------------+
## |category|categoryIndex|categoryValue|
## +--------+-------------+-------------+
## |       b|          2.0|            b|
## |       a|          0.0|            a|
## |       c|          1.0|            c|
## +--------+-------------+-------------+

Spark <= 1.5

这是一个肮脏的黑客,但你可以简单地从Java索引器中提取标签,如下所示:

from pyspark.ml.feature import StringIndexerModel

# A simple monkey patch so we don't have to _call_java later 
def labels(self):
    return self._call_java("labels")

StringIndexerModel.labels = labels

# Fit indexer model
indexer = StringIndexer(inputCol="category", outputCol="categoryIndex").fit(df)

# Extract mapping
mapping = dict(enumerate(indexer.labels()))
mapping
## {0: 'a', 1: 'c', 2: 'b'}

10
2017-11-24 21:10



只是 enumerate(indexer.labels()) 由于stringIndexer默认使用频率指数类别,因此不保证相同的顺序 - Ayush K Singh
Pyspark 1.6+解决方案如何呈现与简单不同的东西 indexed_df.drop('id').distinct().show()... category 和 categoryValue 是相同的 - blacksite


答案:


标签映射可以从列元数据中提取:

meta = [
    f.metadata for f in indexed_df.schema.fields if f.name == "categoryIndex"
]
meta[0]
## {'ml_attr': {'name': 'category', 'type': 'nominal', 'vals': ['a', 'c', 'b']}}

哪里 ml_attr.vals 提供位置和标签之间的映射:

dict(enumerate(meta[0]["ml_attr"]["vals"]))
## {0: 'a', 1: 'c', 2: 'b'}

Spark 1.6+

您可以使用将数值转换为标签 IndexToString。这将使用如上所示的列元数据。

from pyspark.ml.feature import IndexToString

idx_to_string = IndexToString(
    inputCol="categoryIndex", outputCol="categoryValue")

idx_to_string.transform(indexed_df).drop("id").distinct().show()
## +--------+-------------+-------------+
## |category|categoryIndex|categoryValue|
## +--------+-------------+-------------+
## |       b|          2.0|            b|
## |       a|          0.0|            a|
## |       c|          1.0|            c|
## +--------+-------------+-------------+

Spark <= 1.5

这是一个肮脏的黑客,但你可以简单地从Java索引器中提取标签,如下所示:

from pyspark.ml.feature import StringIndexerModel

# A simple monkey patch so we don't have to _call_java later 
def labels(self):
    return self._call_java("labels")

StringIndexerModel.labels = labels

# Fit indexer model
indexer = StringIndexer(inputCol="category", outputCol="categoryIndex").fit(df)

# Extract mapping
mapping = dict(enumerate(indexer.labels()))
mapping
## {0: 'a', 1: 'c', 2: 'b'}

10
2017-11-24 21:10



只是 enumerate(indexer.labels()) 由于stringIndexer默认使用频率指数类别,因此不保证相同的顺序 - Ayush K Singh
Pyspark 1.6+解决方案如何呈现与简单不同的东西 indexed_df.drop('id').distinct().show()... category 和 categoryValue 是相同的 - blacksite