こんな人にオススメ
pythonのmax関数とかで出てくるkey
って引数の役割って何?わざわざ使う必要ってあるの?
ということで、今回はpythonのmax()
やsorted()
で使用される引数key
について解説する。あまり使い所がないかもしれないけど、いざというときに役立つのかもしれない。
python環境は以下。
- Python 3.9.7
- numpy 1.21.2
作成したコード全文
下準備
import numpy as np def check(*args, func, key, **kwargs): ans = func(*args, key=key, **kwargs) return ans
まずは下準備としてのimport
と今回使用する関数。numpy
はNaN
を作成するときと平均する時のみ使用する。check
関数は作成しなくてもよかったけど、各関数を一括で見やすくなると思い定義した。
この関数の可変長引数args
に入れた要素に対して、func
で実行したい関数(max
とかsorted
とか)をkey
の条件で適用する。また、追加で引数が欲しい場合はkwargs
で追加できるようにした。
文字列の長さを勘定
まずはkey=len
として文字列の長さを勘定する。for
文の内包表記を使用してもいいし他の方法を使用しても可能だけど、key
を使うことでスッキリ書くことができる。
シンプルに文字列の長さを勘定
# 数値文字列の中で検索 lst = ['123', '123456', '1234'] print(lst) # ['123', '123456', '1234'] # 一番長い文字列 ans = check(lst, func=max, key=len) print(ans) # 123456 # 一番短い文字列 ans = check(lst, func=min, key=len) print(ans) # 123
ここでは文字列に数値を選択した。数値といっても文字にしてあるので長さを数えることは可能。max
を使用することで一番長い文字列を、min
では一番短い文字列を出力することができる。
max
の場合だと123456
、min
だと123
が出力される。シンプル。
複数要素を入れて勘定してもいい
# 単純な文字列でも可能 a = 'hoge' b = 'foo' ans = check(a, b, func=max, key=len) print(ans) # hoge ans = check(a, b, func=min, key=len) print(ans) # foo # 要素は増えてもいい c = 'hogefoohoge' ans = check(a, b, c, func=max, key=len) print(ans) # hogefoohoge
さっきは変数lst
に全ての要素を入れて長さを計ったけど、もちろんバラバラに指定してもいい。*args
を使ったのもそのため。max(a, b, c)
と同じ意味。
ここでは文字列3種類を使って文字列の長さを判定している。
数値の文字列と文字の組み合わせも可
print(lst + [c]) # ['123', '123456', '1234', 'hogefoohoge'] ans = check(lst + [c], func=max, key=len) print(ans) # hogefoohoge ans = check(lst + [c], func=min, key=len) print(ans) # 123
もちろん文字列なので、文字列の数値と単なる文字列を組み合わせてもいい。
数値が入っている場合の挙動
今度は文字列ではない、単なる数値が入っているときの挙動について解説。文字列の数字と文字列ではない数字(ここでは数値とする)では意味合いが全く異なるので注意。
数値だと長さは測れない
lst = ['1', 20, '100'] print(lst) # ['1', 20, '100'] # # intが入っているとエラー # ans = check(lst, func=max, key=len) # # TypeError: object of type 'int' has no len()
そもそも数値の場合は長さという概念がない。ということで、key=lenで長さを計ろうとするとエラーとなる。これは小数の場合でも同様。
# 小数も同様 lst = ['1', '20', 100.1] print(lst) # ['1', '20', 100.1] ans = check(lst, func=max, key=len) # TypeError: object of type 'float' has no len(
key
をint
, float
にすると大小の比較が可能
# 文字列の数字を変換すると、100.1が一番大きい lst = ['10', 20, 100.1, '001'] ans = check(lst, func=max, key=int) print(ans) # 100.1 # 001は数値に変換すると1になるので一倍小さい ans = check(lst, func=min, key=int) print(ans) # 001
一方で、数値の場合だとkey=int
が使える。int
にすると文字列の数字を全て数値に変換して、func
で指定した処理を実行してくれる。上の例だと最大値と最小値を計算してくれる。
'001'
に関してはint
にすると1
になるので、最小値で出力されるようになる。したがって、0.1
を配列に追加すると最小値は'001'
から0.1
になる。これは最大値の場合でもそう。
# 0.1を追加すると0.1が一番小さくなる ans = check(lst + [0.1], func=min, key=int) print(ans) # 0.1 # '1200'を追加すると1200が一番大きい ans = check(lst + ['1200'], func=max, key=int) print(ans) # 1200
なお、文字列をそのまま入れると、これは数値に変換することができないのでエラー。
# 文字列が入るとintは使えない ans = check(lst + ['a'], func=min, key=int) # ValueError: invalid literal for int() with base 10: 'a'
また、int
で比較するので100.1
も100.5
も同じ100
となる。しかし、最初に当てはまった要素が出力されるのでint
の場合は100.1
が出力される。
一方でkey=float
にすると小数も反映されるので、この場合は100.5
が最大となる。
# intで比較するので、100.1も100.5も100となる # しかし、出力されるのは初めに見つけたもの ans = check(lst + [100.5], func=max, key=int) print(ans) # 100.1 # key=floatにすると小数も含めての判定になるので100.5が一番大きい ans = check(lst + [100.5], func=max, key=float) print(ans) # 100.5
NaN
が入っていても大丈夫
lst = ['10', 20, 100.1, np.nan] print(lst) # ['10', 20, 100.1, nan]
要素の中にNaN
が入っている場合はどうだろうか。NaN
はfloat
の仲間なのでint
にできない。したがって、key=int
はエラー。
# NaNはfloatなのでintにできない ans = check(lst, func=max, key=int) # ValueError: cannot convert float NaN to integer
float
にすると解決する。変にNaN
が出力されることはない。なので、NaN
が出て欲しいなら予めNaN
を検出するようにnp.nan
などを実行する必要がある。
# key=floatにすると解決 ans = check(lst, func=max, key=float) print(ans) # 100.1 ans = check(lst, func=min, key=float) print(ans) # 10
sorted
に適用
lst = ['hoge', 'foo', 'hogefoohoge'] ans = check(lst, func=sorted, key=len) print(ans) # ['foo', 'hoge', 'hogefoohoge']
func
をsorted
にすることも可能。sorted
にすると昇順で並び替えられる。並び替えの方法をkey
で指定する。上の例では文字列の長さを基準に並び替えを行なっている。
もちろん降順にすることも可能。その場合は逆順を示すreverse=True
にする必要がある。
# 逆順も可能 ans = check(lst, func=sorted, key=len, reverse=True) print(ans) # ['hogefoohoge', 'hoge', 'foo']
2次元配列で適用
今までは1次元配列での適用だったが、2次元配列の場合はどうだろうか。2次元配列となると、合計値を出せたりどの要素を基準にするかといった要素が絡んでくるのでややこしくなる。
合計値で検索
lst = [[3, 7], [8, 5], [2, 9], [0, 10, 3]] print(lst) # [[3, 7], [8, 5], [2, 9], [0, 10, 3]] # 各listの合計値で判定 # 該当する要素が複数個ある場合は初めの要素が出力 ans = check(lst, func=max, key=sum) print(ans) # [8, 5]
合計値で判断するならkey=sum
とすればいい。上の場合だと[8, 5]
と[0, 10, 3]
がともに合計値13
で最大だが、初めにヒットするのは[8, 5]
なので、出力は[8, 5]
となる。
先頭の値でソート
# 0番目=先頭の数値を基準にソート ans = check(lst, func=sorted, key=lambda x: x[0]) print(ans) # [[0, 10, 3], [2, 9], [3, 7], [8, 5]]
先頭の数値を基準としてソートする場合はlambda
関数を使用して0番目の値を基準にするように指定する。もちろん自作関数でもいいけど、これについては後述。
x[1]
とすることで、1番目の要素を基準にソートすることができる。ただし、2番目の要素となると[0, 10, 3]以
外は範囲外となるためエラー。
# 1番目の数値を基準にソート ans = check(lst, func=sorted, key=lambda x: x[1]) print(ans) # [[8, 5], [3, 7], [2, 9], [0, 10, 3]] # 2番目の数値を基準にソート ans = check(lst, func=sorted, key=lambda x: x[2]) # IndexError: list index out of range
自作関数を使用
最後は自作関数を使用する方法。これまではmax
やlen
といった標準関数を使用しての処理だったが、平均値や余りの値で判断することも可能。
余りで判断
# 余りを出力する関数 def remainder(x): return x % 3
remainder関数は入力された数値の余りを出力する関数。この関数をkeyに使用して値を判断する。
# 3で割った余りが最小 lst = [6, 7, 11] # それぞれ余りが0, 1, 2 # lambda式で書いてもいいし関数にしてもいい ans = check(lst, func=min, key=lambda x: x % 3) ans = check(lst, func=min, key=remainder) print(ans) # 6
remainder
関数を使用してもいいけど、lambda
式でもいい。簡単な関数かつその時にしか使用しないならlambda
式でもいいだろう。
平均値で判断
def ave(x): return np.average(x)
平均値に関してはnumpy
の平均する関数を使用。別に合計値を個数で割ってもいい。こちらも先ほど同様にkey
に適用できる。
# 平均値は2, 3, 5.5, 4.3 lst = [[3, 1], [1, 5], [2, 9], [0, 10, 3]] # lambda式で書いてもいいし関数にしてもいい ans = check(lst, func=max, key=lambda x: sum(x) / len(x)) ans = check(lst, func=max, key=ave) print(ans) # [2, 9]
サクッと使用したい時に使えそう
今回はmax
関数やsorted
関数の引数key
について解説した。もちろんnumpy
などを使用すると簡単にできるのかもしれないけど、わざわざimport
する必要がなかったりサクッとしたいときに使えそう。
あとはいちいちnumpy
の関数の引数の使い方を調べなくてもkey=len
とすれば長さが計れるなどのお手軽さもあると思う。
関連記事
-
-
【python3.7以降&dictのkeys】ネスト(入れ子)されたdictのkeys一覧をカッコで出力
続きを見る
-
-
【astropy&単位】astropyで変数に単位を付与
続きを見る
-
-
【inspect&引数名取得】defのパラメータ名を取り出す
続きを見る
-
-
【python&関数化】defとかargsとかを使って関数を作成する
続きを見る