问题 如何只替换numpy数组中大于某个值的前n个元素?
我有一个阵列 myA
喜欢这个:
array([ 7, 4, 5, 8, 3, 10])
如果我想替换大于值的所有值 val
到0,我可以简单地做:
myA[myA > val] = 0
这给了我想要的输出(为 val = 5
):
array([0, 4, 5, 0, 3, 0])
但是,我的目标是不是全部而是仅替换第一个 n
此数组的元素大于值 val
。
因此,如果 n = 2
我期望的结果看起来像这样(10
是第三个元素,因此不应被替换):
array([ 0, 4, 5, 0, 3, 10])
一个简单的实现将是:
import numpy as np
myA = np.array([7, 4, 5, 8, 3, 10])
n = 2
val = 5
# track the number of replacements
repl = 0
for ind, vali in enumerate(myA):
if vali > val:
myA[ind] = 0
repl += 1
if repl == n:
break
这可行,但也许有人可以用聪明的方式掩盖!?
9582
2018-01-26 14:52
起源
答案:
以下应该有效:
myA[(myA > val).nonzero()[0][:2]] = 0
以来 非零 将返回布尔数组所在的索引 myA > val
是非零的,例如 True
。
例如:
In [1]: myA = array([ 7, 4, 5, 8, 3, 10])
In [2]: myA[(myA > 5).nonzero()[0][:2]] = 0
In [3]: myA
Out[3]: array([ 0, 4, 5, 0, 3, 10])
5
2018-01-26 15:12
答案:
以下应该有效:
myA[(myA > val).nonzero()[0][:2]] = 0
以来 非零 将返回布尔数组所在的索引 myA > val
是非零的,例如 True
。
例如:
In [1]: myA = array([ 7, 4, 5, 8, 3, 10])
In [2]: myA[(myA > 5).nonzero()[0][:2]] = 0
In [3]: myA
Out[3]: array([ 0, 4, 5, 0, 3, 10])
5
2018-01-26 15:12
最终解决方案非常简单:
import numpy as np
myA = np.array([7, 4, 5, 8, 3, 10])
n = 2
val = 5
myA[np.where(myA > val)[0][:n]] = 0
print(myA)
输出:
[ 0 4 5 0 3 10]
2
2018-01-26 15:10
这是另一种可能性(未经测试),可能并不比 nonzero
:
def truncate_mask(m, stop):
m = m.astype(bool, copy=False) # if we allow non-bool m, the next line becomes nonsense
return m & (np.cumsum(m) <= stop)
myA[truncate_mask(myA > val, n)] = 0
通过避免构建和使用显式索引,您最终可能会获得更好的性能......但是您必须对其进行测试才能找到答案。
编辑1: 当我们谈论可能性时,您也可以尝试:
def truncate_mask(m, stop):
m = m.astype(bool, copy=True) # note we need to copy m here to safely modify it
m[np.searchsorted(np.cumsum(m), stop):] = 0
return m
编辑2(第二天): 我刚刚测试了这个,似乎就是这样 cumsum
实际上比 nonzero
,至少与 各种价值观 我正在使用(所以上述两种方法都不值得使用)。出于好奇,我也尝试了numba:
import numba
@numba.jit
def set_first_n_gt_thresh(a, val, thresh, n):
ii = 0
while n>0 and ii < len(a):
if a[ii] > thresh:
a[ii] = val
n -= 1
ii += 1
这只迭代数组一次,或者说它只迭代数组的必要部分一次,甚至不接触后一部分。这为您提供了非常出色的小型性能 n
,但即使是最糟糕的情况 n>=len(a)
这种方法更快。
2
2018-01-26 18:28
您可以使用相同的解决方案 这里 转换你 np.array
至 pd.Series
:
s = pd.Series([ 7, 4, 5, 8, 3, 10])
n = 2
m = 5
s[s[s>m].iloc[:n].index] = 0
In [416]: s
Out[416]:
0 0
1 4
2 5
3 0
4 3
5 10
dtype: int64
分步说明:
In [426]: s > m
Out[426]:
0 True
1 False
2 False
3 True
4 False
5 True
dtype: bool
In [428]: s[s>m].iloc[:n]
Out[428]:
0 7
3 8
dtype: int64
In [429]: s[s>m].iloc[:n].index
Out[429]: Int64Index([0, 3], dtype='int64')
In [430]: s[s[s>m].iloc[:n].index]
Out[430]:
0 7
3 8
dtype: int64
输出 In[430]
看起来一样 In[428]
但是在428它是一个副本和430原始系列。
如果你需要的话 np.array
你可以用 values
方法:
In [418]: s.values
Out[418]: array([ 0, 4, 5, 0, 3, 10], dtype=int64)
1
2018-01-26 15:00