问题 如何只替换numpy数组中大于某个值的前n个元素?


我有一个阵列 myA 喜欢这个:

array([ 7,  4,  5,  8,  3, 10])

如果我想替换大于值的所有值 val 到0,我可以简单地做:

myA[myA > val] = 0

这给了我想要的输出(为 val = 5):

 array([0, 4, 5, 0, 3, 0])

但是,我的目标是不是全部而是仅替换第一个 n 此数组的元素大于值 val

因此,如果 n = 2 我期望的结果看起来像这样(10 是第三个元素,因此不应被替换):

array([ 0,  4,  5,  0,  3, 10])

一个简单的实现将是:

import numpy as np

myA = np.array([7, 4, 5, 8, 3, 10])
n = 2
val = 5

# track the number of replacements
repl = 0

for ind, vali in enumerate(myA):

    if vali > val:

        myA[ind] = 0
        repl += 1

        if repl == n:
            break

这可行,但也许有人可以用聪明的方式掩盖!?


9582
2018-01-26 14:52


起源



答案:


以下应该有效:

myA[(myA > val).nonzero()[0][:2]] = 0

以来 非零 将返回布尔数组所在的索引 myA > val 是非零的,例如 True

例如:

In [1]: myA = array([ 7,  4,  5,  8,  3, 10])

In [2]: myA[(myA > 5).nonzero()[0][:2]] = 0

In [3]: myA
Out[3]: array([ 0,  4,  5,  0,  3, 10])

5
2018-01-26 15:12



非常优雅,谢谢。我现在支持它,可能会在以后根据其他答案的质量接受它。 - Cleb


答案:


以下应该有效:

myA[(myA > val).nonzero()[0][:2]] = 0

以来 非零 将返回布尔数组所在的索引 myA > val 是非零的,例如 True

例如:

In [1]: myA = array([ 7,  4,  5,  8,  3, 10])

In [2]: myA[(myA > 5).nonzero()[0][:2]] = 0

In [3]: myA
Out[3]: array([ 0,  4,  5,  0,  3, 10])

5
2018-01-26 15:12



非常优雅,谢谢。我现在支持它,可能会在以后根据其他答案的质量接受它。 - Cleb


最终解决方案非常简单:

import numpy as np
myA = np.array([7, 4, 5, 8, 3, 10])
n = 2
val = 5

myA[np.where(myA > val)[0][:n]] = 0

print(myA)

输出:

[ 0  4  5  0  3 10]

2
2018-01-26 15:10



当n = 3时,这似乎失败了。 - Cleb
是的,那应该是 np.where(mask)[0][n:] - George Petrov
是的,现在它也可以正常工作,所以我也支持它,并且可能会在以后根据其他答案的质量接受它。 - Cleb
非常好。你可以像JuniorCompressor的答案那样做一个单行: myA[np.where(myA > val)[0][:n]] = 0 - mtrw
@mtrw是的,我可以,但是 VisibleDeprecationWarning: boolean index did not match indexed array along dimension 0; dimension is 6 but corresponding boolean dimension is 2 myA[mask[np.where(mask)[0][n:]]] = 0 提出警告。 - George Petrov


这是另一种可能性(未经测试),可能并不比 nonzero

def truncate_mask(m, stop):
  m = m.astype(bool, copy=False) #  if we allow non-bool m, the next line becomes nonsense
  return m & (np.cumsum(m) <= stop)

myA[truncate_mask(myA > val, n)] = 0

通过避免构建和使用显式索引,您最终可能会获得更好的性能......但是您必须对其进行测试才能找到答案。

编辑1: 当我们谈论可能性时,您也可以尝试:

def truncate_mask(m, stop):
   m = m.astype(bool, copy=True) #  note we need to copy m here to safely modify it
   m[np.searchsorted(np.cumsum(m), stop):] = 0
   return m

编辑2(第二天): 我刚刚测试了这个,似乎就是这样 cumsum 实际上比 nonzero,至少与 各种价值观 我正在使用(所以上述两种方法都不值得使用)。出于好奇,我也尝试了numba:

import numba

@numba.jit
def set_first_n_gt_thresh(a, val, thresh, n):
    ii = 0
    while n>0 and ii < len(a):
        if a[ii] > thresh:
            a[ii] = val
            n -= 1
        ii += 1

这只迭代数组一次,或者说它只迭代数组的必要部分一次,甚至不接触后一部分。这为您提供了非常出色的小型性能 n,但即使是最糟糕的情况 n>=len(a) 这种方法更快。


2
2018-01-26 18:28



代码中有一个小错误:应该是 <=stop 代替 stop,它似乎工作正常。感谢您的建议,我也赞成了它。 - Cleb
啊,是的(我正在考虑切片表示法)。我在此添加了第二个变体,其中也可能包含错误。 - dan-man


您可以使用相同的解决方案 这里 转换你 np.array 至 pd.Series

s = pd.Series([ 7,  4,  5,  8,  3, 10])
n = 2
m = 5
s[s[s>m].iloc[:n].index] = 0

In [416]: s
Out[416]:
0     0
1     4
2     5
3     0
4     3
5    10
dtype: int64

分步说明:

In [426]: s > m
Out[426]:
0     True
1    False
2    False
3     True
4    False
5     True
dtype: bool

In [428]: s[s>m].iloc[:n]
Out[428]:
0    7
3    8
dtype: int64

In [429]: s[s>m].iloc[:n].index
Out[429]: Int64Index([0, 3], dtype='int64')

In [430]: s[s[s>m].iloc[:n].index]
Out[430]:
0    7
3    8
dtype: int64

输出 In[430] 看起来一样 In[428] 但是在428它是一个副本和430原始系列。

如果你需要的话 np.array 你可以用 values 方法:

In [418]: s.values
Out[418]: array([ 0,  4,  5,  0,  3, 10], dtype=int64)

1
2018-01-26 15:00



好的,工作正常!还要感谢详细解释。我现在支持它,可能会在以后接受它,具体取决于其他答案的质量。 - Cleb